Skip to content

Commit bde701d

Browse files
authored
MNT Use _add_to_diagonal in GaussianMixture (scikit-learn#31607)
1 parent 8792943 commit bde701d

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

sklearn/mixture/_gaussian_mixture.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ..externals import array_api_extra as xpx
1111
from ..utils import check_array
1212
from ..utils._array_api import (
13+
_add_to_diagonal,
1314
_cholesky,
1415
_linalg_solve,
1516
get_namespace,
@@ -192,8 +193,7 @@ def _estimate_gaussian_covariances_full(resp, X, nk, means, reg_covar, xp=None):
192193
for k in range(n_components):
193194
diff = X - means[k, :]
194195
covariances[k, :, :] = ((resp[:, k] * diff.T) @ diff) / nk[k]
195-
covariances_flat = xp.reshape(covariances[k, :, :], (-1,))
196-
covariances_flat[:: n_features + 1] += reg_covar
196+
_add_to_diagonal(covariances[k, :, :], reg_covar, xp)
197197
return covariances
198198

199199

@@ -222,8 +222,7 @@ def _estimate_gaussian_covariances_tied(resp, X, nk, means, reg_covar, xp=None):
222222
avg_means2 = nk * means.T @ means
223223
covariance = avg_X2 - avg_means2
224224
covariance /= xp.sum(nk)
225-
covariance_flat = xp.reshape(covariance, (-1,))
226-
covariance_flat[:: covariance.shape[0] + 1] += reg_covar
225+
_add_to_diagonal(covariance, reg_covar, xp)
227226
return covariance
228227

229228

0 commit comments

Comments
 (0)