Skip to content

Commit 367bc40

Browse files
mdhabertylerjereddy
authored andcommitted
MAINT: stats.Mixture: make return type consistent
1 parent 94a07f9 commit 367bc40

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

scipy/stats/_distribution_infrastructure.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4506,13 +4506,13 @@ def _sum(self, fun, *args):
45064506
out = self._full(0, *args)
45074507
for var, weight in zip(self._components, self._weights):
45084508
out += getattr(var, fun)(*args) * weight
4509-
return out
4509+
return out[()]
45104510

45114511
def _logsum(self, fun, *args):
45124512
out = self._full(-np.inf, *args)
45134513
for var, log_weight in zip(self._components, np.log(self._weights)):
45144514
np.logaddexp(out, getattr(var, fun)(*args) + log_weight, out=out)
4515-
return out
4515+
return out[()]
45164516

45174517
def support(self):
45184518
a = self._full(np.inf)
@@ -4588,7 +4588,7 @@ def _moment_raw(self, order):
45884588
out = self._full(0)
45894589
for var, weight in zip(self._components, self._weights):
45904590
out += var.moment(order, kind='raw') * weight
4591-
return out
4591+
return out[()]
45924592

45934593
def _moment_central(self, order):
45944594
order = int(order)
@@ -4599,7 +4599,7 @@ def _moment_central(self, order):
45994599
a, b = var.mean(), self.mean()
46004600
moment = var._moment_transform_center(order, moment_as, a, b)
46014601
out += moment * weight
4602-
return out
4602+
return out[()]
46034603

46044604
def _moment_standardized(self, order):
46054605
return self._moment_central(order) / self.standard_deviation()**order

scipy/stats/tests/test_continuous.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1794,12 +1794,18 @@ def test_input_validation(self):
17941794
with pytest.raises(ValueError, match=message):
17951795
Mixture([Normal(), Normal()], weights=[1.5, -0.5])
17961796

1797-
def test_basic(self):
1797+
@pytest.mark.parametrize('shape', [(), (10,)])
1798+
def test_basic(self, shape):
17981799
rng = np.random.default_rng(582348972387243524)
17991800
X = Mixture((Normal(mu=-0.25, sigma=1.1), Normal(mu=0.5, sigma=0.9)),
18001801
weights=(0.4, 0.6))
18011802
Y = MixedDist()
1802-
x = rng.random(10)
1803+
x = rng.random(shape)
1804+
1805+
def assert_allclose(res, ref, **kwargs):
1806+
if shape == ():
1807+
assert np.isscalar(res)
1808+
np.testing.assert_allclose(res, ref, **kwargs)
18031809

18041810
assert_allclose(X.logentropy(), Y.logentropy())
18051811
assert_allclose(X.entropy(), Y.entropy())

0 commit comments

Comments
 (0)