Skip to content

Commit 4daff41

Browse files
OmarManzoorogrisel
andauthored
FIX GaussianMixture sample method to correctly handle mps (scikit-learn#31639)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent c92330f commit 4daff41

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

sklearn/mixture/_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
_convert_to_numpy,
2121
_is_numpy_namespace,
2222
_logsumexp,
23+
_max_precision_float_dtype,
2324
get_namespace,
2425
get_namespace_and_device,
2526
)
@@ -504,7 +505,8 @@ def sample(self, n_samples=1):
504505
]
505506
)
506507

507-
return xp.asarray(X, device=device_), y
508+
max_float_dtype = _max_precision_float_dtype(xp=xp, device=device_)
509+
return xp.asarray(X, dtype=max_float_dtype, device=device_), y
508510

509511
def _estimate_weighted_log_prob(self, X, xp=None):
510512
"""Estimate the weighted log-probabilities, log P(X | Z) + log weights.

0 commit comments

Comments
 (0)