Skip to content

Commit cc526ee

Browse files
lesteveStefanieSengerogrisel
authored
FEA Add array API support for GaussianMixture (scikit-learn#30777)
Co-authored-by: Stefanie Senger <stefanie.senger@posteo.de> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent b39ab89 commit cc526ee

File tree

8 files changed

+543
-158
lines changed

8 files changed

+543
-158
lines changed

doc/modules/array_api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ Estimators
117117
- :class:`preprocessing.MaxAbsScaler`
118118
- :class:`preprocessing.MinMaxScaler`
119119
- :class:`preprocessing.Normalizer`
120+
- :class:`mixture.GaussianMixture` (with `init_params="random"` or
121+
`init_params="random_from_data"` and `warm_start=False`)
120122

121123
Meta-estimators
122124
---------------
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- :class:`sklearn.gaussian_mixture.GaussianMixture` with
2+
`init_params="random"` or `init_params="random_from_data"` and
3+
`warm_start=False` now supports Array API compatible inputs.
4+
By :user:`Stefanie Senger <StefanieSenger>` and :user:`Loïc Estève <lesteve>`

sklearn/mixture/_base.py

Lines changed: 86 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,24 @@
55

66
import warnings
77
from abc import ABCMeta, abstractmethod
8+
from contextlib import nullcontext
89
from numbers import Integral, Real
910
from time import time
1011

1112
import numpy as np
12-
from scipy.special import logsumexp
1313

1414
from .. import cluster
1515
from ..base import BaseEstimator, DensityMixin, _fit_context
1616
from ..cluster import kmeans_plusplus
1717
from ..exceptions import ConvergenceWarning
1818
from ..utils import check_random_state
19+
from ..utils._array_api import (
20+
_convert_to_numpy,
21+
_is_numpy_namespace,
22+
_logsumexp,
23+
get_namespace,
24+
get_namespace_and_device,
25+
)
1926
from ..utils._param_validation import Interval, StrOptions
2027
from ..utils.validation import check_is_fitted, validate_data
2128

@@ -31,7 +38,6 @@ def _check_shape(param, param_shape, name):
3138
3239
name : str
3340
"""
34-
param = np.array(param)
3541
if param.shape != param_shape:
3642
raise ValueError(
3743
"The parameter '%s' should have the shape of %s, but got %s"
@@ -86,7 +92,7 @@ def __init__(
8692
self.verbose_interval = verbose_interval
8793

8894
@abstractmethod
89-
def _check_parameters(self, X):
95+
def _check_parameters(self, X, xp=None):
9096
"""Check initial parameters of the derived class.
9197
9298
Parameters
@@ -95,7 +101,7 @@ def _check_parameters(self, X):
95101
"""
96102
pass
97103

98-
def _initialize_parameters(self, X, random_state):
104+
def _initialize_parameters(self, X, random_state, xp=None):
99105
"""Initialize the model parameters.
100106
101107
Parameters
@@ -106,6 +112,7 @@ def _initialize_parameters(self, X, random_state):
106112
A random number generator instance that controls the random seed
107113
used for the method chosen to initialize the parameters.
108114
"""
115+
xp, _, device = get_namespace_and_device(X, xp=xp)
109116
n_samples, _ = X.shape
110117

111118
if self.init_params == "kmeans":
@@ -119,16 +126,25 @@ def _initialize_parameters(self, X, random_state):
119126
)
120127
resp[np.arange(n_samples), label] = 1
121128
elif self.init_params == "random":
122-
resp = np.asarray(
123-
random_state.uniform(size=(n_samples, self.n_components)), dtype=X.dtype
129+
resp = xp.asarray(
130+
random_state.uniform(size=(n_samples, self.n_components)),
131+
dtype=X.dtype,
132+
device=device,
124133
)
125-
resp /= resp.sum(axis=1)[:, np.newaxis]
134+
resp /= xp.sum(resp, axis=1)[:, xp.newaxis]
126135
elif self.init_params == "random_from_data":
127-
resp = np.zeros((n_samples, self.n_components), dtype=X.dtype)
136+
resp = xp.zeros(
137+
(n_samples, self.n_components), dtype=X.dtype, device=device
138+
)
128139
indices = random_state.choice(
129140
n_samples, size=self.n_components, replace=False
130141
)
131-
resp[indices, np.arange(self.n_components)] = 1
142+
# TODO: when array API supports __setitem__ with fancy indexing we
143+
# can use the previous code:
144+
# resp[indices, xp.arange(self.n_components)] = 1
145+
# Until then we use a for loop on one dimension.
146+
for col, index in enumerate(indices):
147+
resp[index, col] = 1
132148
elif self.init_params == "k-means++":
133149
resp = np.zeros((n_samples, self.n_components), dtype=X.dtype)
134150
_, indices = kmeans_plusplus(
@@ -210,20 +226,21 @@ def fit_predict(self, X, y=None):
210226
labels : array, shape (n_samples,)
211227
Component labels.
212228
"""
213-
X = validate_data(self, X, dtype=[np.float64, np.float32], ensure_min_samples=2)
229+
xp, _ = get_namespace(X)
230+
X = validate_data(self, X, dtype=[xp.float64, xp.float32], ensure_min_samples=2)
214231
if X.shape[0] < self.n_components:
215232
raise ValueError(
216233
"Expected n_samples >= n_components "
217234
f"but got n_components = {self.n_components}, "
218235
f"n_samples = {X.shape[0]}"
219236
)
220-
self._check_parameters(X)
237+
self._check_parameters(X, xp=xp)
221238

222239
# if we enable warm_start, we will have a unique initialisation
223240
do_init = not (self.warm_start and hasattr(self, "converged_"))
224241
n_init = self.n_init if do_init else 1
225242

226-
max_lower_bound = -np.inf
243+
max_lower_bound = -xp.inf
227244
best_lower_bounds = []
228245
self.converged_ = False
229246

@@ -234,9 +251,9 @@ def fit_predict(self, X, y=None):
234251
self._print_verbose_msg_init_beg(init)
235252

236253
if do_init:
237-
self._initialize_parameters(X, random_state)
254+
self._initialize_parameters(X, random_state, xp=xp)
238255

239-
lower_bound = -np.inf if do_init else self.lower_bound_
256+
lower_bound = -xp.inf if do_init else self.lower_bound_
240257
current_lower_bounds = []
241258

242259
if self.max_iter == 0:
@@ -247,8 +264,8 @@ def fit_predict(self, X, y=None):
247264
for n_iter in range(1, self.max_iter + 1):
248265
prev_lower_bound = lower_bound
249266

250-
log_prob_norm, log_resp = self._e_step(X)
251-
self._m_step(X, log_resp)
267+
log_prob_norm, log_resp = self._e_step(X, xp=xp)
268+
self._m_step(X, log_resp, xp=xp)
252269
lower_bound = self._compute_lower_bound(log_resp, log_prob_norm)
253270
current_lower_bounds.append(lower_bound)
254271

@@ -261,7 +278,7 @@ def fit_predict(self, X, y=None):
261278

262279
self._print_verbose_msg_init_end(lower_bound, converged)
263280

264-
if lower_bound > max_lower_bound or max_lower_bound == -np.inf:
281+
if lower_bound > max_lower_bound or max_lower_bound == -xp.inf:
265282
max_lower_bound = lower_bound
266283
best_params = self._get_parameters()
267284
best_n_iter = n_iter
@@ -281,19 +298,19 @@ def fit_predict(self, X, y=None):
281298
ConvergenceWarning,
282299
)
283300

284-
self._set_parameters(best_params)
301+
self._set_parameters(best_params, xp=xp)
285302
self.n_iter_ = best_n_iter
286303
self.lower_bound_ = max_lower_bound
287304
self.lower_bounds_ = best_lower_bounds
288305

289306
# Always do a final e-step to guarantee that the labels returned by
290307
# fit_predict(X) are always consistent with fit(X).predict(X)
291308
# for any value of max_iter and tol (and any random_state).
292-
_, log_resp = self._e_step(X)
309+
_, log_resp = self._e_step(X, xp=xp)
293310

294-
return log_resp.argmax(axis=1)
311+
return xp.argmax(log_resp, axis=1)
295312

296-
def _e_step(self, X):
313+
def _e_step(self, X, xp=None):
297314
"""E step.
298315
299316
Parameters
@@ -309,8 +326,9 @@ def _e_step(self, X):
309326
Logarithm of the posterior probabilities (or responsibilities) of
310327
the point of each sample in X.
311328
"""
312-
log_prob_norm, log_resp = self._estimate_log_prob_resp(X)
313-
return np.mean(log_prob_norm), log_resp
329+
xp, _ = get_namespace(X, xp=xp)
330+
log_prob_norm, log_resp = self._estimate_log_prob_resp(X, xp=xp)
331+
return xp.mean(log_prob_norm), log_resp
314332

315333
@abstractmethod
316334
def _m_step(self, X, log_resp):
@@ -351,7 +369,7 @@ def score_samples(self, X):
351369
check_is_fitted(self)
352370
X = validate_data(self, X, reset=False)
353371

354-
return logsumexp(self._estimate_weighted_log_prob(X), axis=1)
372+
return _logsumexp(self._estimate_weighted_log_prob(X), axis=1)
355373

356374
def score(self, X, y=None):
357375
"""Compute the per-sample average log-likelihood of the given data X.
@@ -370,7 +388,8 @@ def score(self, X, y=None):
370388
log_likelihood : float
371389
Log-likelihood of `X` under the Gaussian mixture model.
372390
"""
373-
return self.score_samples(X).mean()
391+
xp, _ = get_namespace(X)
392+
return float(xp.mean(self.score_samples(X)))
374393

375394
def predict(self, X):
376395
"""Predict the labels for the data samples in X using trained model.
@@ -387,8 +406,9 @@ def predict(self, X):
387406
Component labels.
388407
"""
389408
check_is_fitted(self)
409+
xp, _ = get_namespace(X)
390410
X = validate_data(self, X, reset=False)
391-
return self._estimate_weighted_log_prob(X).argmax(axis=1)
411+
return xp.argmax(self._estimate_weighted_log_prob(X), axis=1)
392412

393413
def predict_proba(self, X):
394414
"""Evaluate the components' density for each sample.
@@ -406,8 +426,9 @@ def predict_proba(self, X):
406426
"""
407427
check_is_fitted(self)
408428
X = validate_data(self, X, reset=False)
409-
_, log_resp = self._estimate_log_prob_resp(X)
410-
return np.exp(log_resp)
429+
xp, _ = get_namespace(X)
430+
_, log_resp = self._estimate_log_prob_resp(X, xp=xp)
431+
return xp.exp(log_resp)
411432

412433
def sample(self, n_samples=1):
413434
"""Generate random samples from the fitted Gaussian distribution.
@@ -426,6 +447,7 @@ def sample(self, n_samples=1):
426447
Component labels.
427448
"""
428449
check_is_fitted(self)
450+
xp, _, device_ = get_namespace_and_device(self.means_)
429451

430452
if n_samples < 1:
431453
raise ValueError(
@@ -435,22 +457,30 @@ def sample(self, n_samples=1):
435457

436458
_, n_features = self.means_.shape
437459
rng = check_random_state(self.random_state)
438-
n_samples_comp = rng.multinomial(n_samples, self.weights_)
460+
n_samples_comp = rng.multinomial(
461+
n_samples, _convert_to_numpy(self.weights_, xp)
462+
)
439463

440464
if self.covariance_type == "full":
441465
X = np.vstack(
442466
[
443467
rng.multivariate_normal(mean, covariance, int(sample))
444468
for (mean, covariance, sample) in zip(
445-
self.means_, self.covariances_, n_samples_comp
469+
_convert_to_numpy(self.means_, xp),
470+
_convert_to_numpy(self.covariances_, xp),
471+
n_samples_comp,
446472
)
447473
]
448474
)
449475
elif self.covariance_type == "tied":
450476
X = np.vstack(
451477
[
452-
rng.multivariate_normal(mean, self.covariances_, int(sample))
453-
for (mean, sample) in zip(self.means_, n_samples_comp)
478+
rng.multivariate_normal(
479+
mean, _convert_to_numpy(self.covariances_, xp), int(sample)
480+
)
481+
for (mean, sample) in zip(
482+
_convert_to_numpy(self.means_, xp), n_samples_comp
483+
)
454484
]
455485
)
456486
else:
@@ -460,18 +490,23 @@ def sample(self, n_samples=1):
460490
+ rng.standard_normal(size=(sample, n_features))
461491
* np.sqrt(covariance)
462492
for (mean, covariance, sample) in zip(
463-
self.means_, self.covariances_, n_samples_comp
493+
_convert_to_numpy(self.means_, xp),
494+
_convert_to_numpy(self.covariances_, xp),
495+
n_samples_comp,
464496
)
465497
]
466498
)
467499

468-
y = np.concatenate(
469-
[np.full(sample, j, dtype=int) for j, sample in enumerate(n_samples_comp)]
500+
y = xp.concat(
501+
[
502+
xp.full(int(n_samples_comp[i]), i, dtype=xp.int64, device=device_)
503+
for i in range(len(n_samples_comp))
504+
]
470505
)
471506

472-
return (X, y)
507+
return xp.asarray(X, device=device_), y
473508

474-
def _estimate_weighted_log_prob(self, X):
509+
def _estimate_weighted_log_prob(self, X, xp=None):
475510
"""Estimate the weighted log-probabilities, log P(X | Z) + log weights.
476511
477512
Parameters
@@ -482,10 +517,10 @@ def _estimate_weighted_log_prob(self, X):
482517
-------
483518
weighted_log_prob : array, shape (n_samples, n_component)
484519
"""
485-
return self._estimate_log_prob(X) + self._estimate_log_weights()
520+
return self._estimate_log_prob(X, xp=xp) + self._estimate_log_weights(xp=xp)
486521

487522
@abstractmethod
488-
def _estimate_log_weights(self):
523+
def _estimate_log_weights(self, xp=None):
489524
"""Estimate log-weights in EM algorithm, E[ log pi ] in VB algorithm.
490525
491526
Returns
@@ -495,7 +530,7 @@ def _estimate_log_weights(self):
495530
pass
496531

497532
@abstractmethod
498-
def _estimate_log_prob(self, X):
533+
def _estimate_log_prob(self, X, xp=None):
499534
"""Estimate the log-probabilities log P(X | Z).
500535
501536
Compute the log-probabilities per each component for each sample.
@@ -510,7 +545,7 @@ def _estimate_log_prob(self, X):
510545
"""
511546
pass
512547

513-
def _estimate_log_prob_resp(self, X):
548+
def _estimate_log_prob_resp(self, X, xp=None):
514549
"""Estimate log probabilities and responsibilities for each sample.
515550
516551
Compute the log probabilities, weighted log probabilities per
@@ -529,11 +564,17 @@ def _estimate_log_prob_resp(self, X):
529564
log_responsibilities : array, shape (n_samples, n_components)
530565
logarithm of the responsibilities
531566
"""
532-
weighted_log_prob = self._estimate_weighted_log_prob(X)
533-
log_prob_norm = logsumexp(weighted_log_prob, axis=1)
534-
with np.errstate(under="ignore"):
567+
xp, _ = get_namespace(X, xp=xp)
568+
weighted_log_prob = self._estimate_weighted_log_prob(X, xp=xp)
569+
log_prob_norm = _logsumexp(weighted_log_prob, axis=1, xp=xp)
570+
571+
# There is no errstate equivalent for warning/error management in array API
572+
context_manager = (
573+
np.errstate(under="ignore") if _is_numpy_namespace(xp) else nullcontext()
574+
)
575+
with context_manager:
535576
# ignore underflow
536-
log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis]
577+
log_resp = weighted_log_prob - log_prob_norm[:, xp.newaxis]
537578
return log_prob_norm, log_resp
538579

539580
def _print_verbose_msg_init_beg(self, n_init):

0 commit comments

Comments
 (0)