5
5
6
6
import warnings
7
7
from abc import ABCMeta , abstractmethod
8
+ from contextlib import nullcontext
8
9
from numbers import Integral , Real
9
10
from time import time
10
11
11
12
import numpy as np
12
- from scipy .special import logsumexp
13
13
14
14
from .. import cluster
15
15
from ..base import BaseEstimator , DensityMixin , _fit_context
16
16
from ..cluster import kmeans_plusplus
17
17
from ..exceptions import ConvergenceWarning
18
18
from ..utils import check_random_state
19
+ from ..utils ._array_api import (
20
+ _convert_to_numpy ,
21
+ _is_numpy_namespace ,
22
+ _logsumexp ,
23
+ get_namespace ,
24
+ get_namespace_and_device ,
25
+ )
19
26
from ..utils ._param_validation import Interval , StrOptions
20
27
from ..utils .validation import check_is_fitted , validate_data
21
28
@@ -31,7 +38,6 @@ def _check_shape(param, param_shape, name):
31
38
32
39
name : str
33
40
"""
34
- param = np .array (param )
35
41
if param .shape != param_shape :
36
42
raise ValueError (
37
43
"The parameter '%s' should have the shape of %s, but got %s"
@@ -86,7 +92,7 @@ def __init__(
86
92
self .verbose_interval = verbose_interval
87
93
88
94
@abstractmethod
89
- def _check_parameters (self , X ):
95
+ def _check_parameters (self , X , xp = None ):
90
96
"""Check initial parameters of the derived class.
91
97
92
98
Parameters
@@ -95,7 +101,7 @@ def _check_parameters(self, X):
95
101
"""
96
102
pass
97
103
98
- def _initialize_parameters (self , X , random_state ):
104
+ def _initialize_parameters (self , X , random_state , xp = None ):
99
105
"""Initialize the model parameters.
100
106
101
107
Parameters
@@ -106,6 +112,7 @@ def _initialize_parameters(self, X, random_state):
106
112
A random number generator instance that controls the random seed
107
113
used for the method chosen to initialize the parameters.
108
114
"""
115
+ xp , _ , device = get_namespace_and_device (X , xp = xp )
109
116
n_samples , _ = X .shape
110
117
111
118
if self .init_params == "kmeans" :
@@ -119,16 +126,25 @@ def _initialize_parameters(self, X, random_state):
119
126
)
120
127
resp [np .arange (n_samples ), label ] = 1
121
128
elif self .init_params == "random" :
122
- resp = np .asarray (
123
- random_state .uniform (size = (n_samples , self .n_components )), dtype = X .dtype
129
+ resp = xp .asarray (
130
+ random_state .uniform (size = (n_samples , self .n_components )),
131
+ dtype = X .dtype ,
132
+ device = device ,
124
133
)
125
- resp /= resp .sum (axis = 1 )[:, np .newaxis ]
134
+ resp /= xp .sum (resp , axis = 1 )[:, xp .newaxis ]
126
135
elif self .init_params == "random_from_data" :
127
- resp = np .zeros ((n_samples , self .n_components ), dtype = X .dtype )
136
+ resp = xp .zeros (
137
+ (n_samples , self .n_components ), dtype = X .dtype , device = device
138
+ )
128
139
indices = random_state .choice (
129
140
n_samples , size = self .n_components , replace = False
130
141
)
131
- resp [indices , np .arange (self .n_components )] = 1
142
+ # TODO: when array API supports __setitem__ with fancy indexing we
143
+ # can use the previous code:
144
+ # resp[indices, xp.arange(self.n_components)] = 1
145
+ # Until then we use a for loop on one dimension.
146
+ for col , index in enumerate (indices ):
147
+ resp [index , col ] = 1
132
148
elif self .init_params == "k-means++" :
133
149
resp = np .zeros ((n_samples , self .n_components ), dtype = X .dtype )
134
150
_ , indices = kmeans_plusplus (
@@ -210,20 +226,21 @@ def fit_predict(self, X, y=None):
210
226
labels : array, shape (n_samples,)
211
227
Component labels.
212
228
"""
213
- X = validate_data (self , X , dtype = [np .float64 , np .float32 ], ensure_min_samples = 2 )
229
+ xp , _ = get_namespace (X )
230
+ X = validate_data (self , X , dtype = [xp .float64 , xp .float32 ], ensure_min_samples = 2 )
214
231
if X .shape [0 ] < self .n_components :
215
232
raise ValueError (
216
233
"Expected n_samples >= n_components "
217
234
f"but got n_components = { self .n_components } , "
218
235
f"n_samples = { X .shape [0 ]} "
219
236
)
220
- self ._check_parameters (X )
237
+ self ._check_parameters (X , xp = xp )
221
238
222
239
# if we enable warm_start, we will have a unique initialisation
223
240
do_init = not (self .warm_start and hasattr (self , "converged_" ))
224
241
n_init = self .n_init if do_init else 1
225
242
226
- max_lower_bound = - np .inf
243
+ max_lower_bound = - xp .inf
227
244
best_lower_bounds = []
228
245
self .converged_ = False
229
246
@@ -234,9 +251,9 @@ def fit_predict(self, X, y=None):
234
251
self ._print_verbose_msg_init_beg (init )
235
252
236
253
if do_init :
237
- self ._initialize_parameters (X , random_state )
254
+ self ._initialize_parameters (X , random_state , xp = xp )
238
255
239
- lower_bound = - np .inf if do_init else self .lower_bound_
256
+ lower_bound = - xp .inf if do_init else self .lower_bound_
240
257
current_lower_bounds = []
241
258
242
259
if self .max_iter == 0 :
@@ -247,8 +264,8 @@ def fit_predict(self, X, y=None):
247
264
for n_iter in range (1 , self .max_iter + 1 ):
248
265
prev_lower_bound = lower_bound
249
266
250
- log_prob_norm , log_resp = self ._e_step (X )
251
- self ._m_step (X , log_resp )
267
+ log_prob_norm , log_resp = self ._e_step (X , xp = xp )
268
+ self ._m_step (X , log_resp , xp = xp )
252
269
lower_bound = self ._compute_lower_bound (log_resp , log_prob_norm )
253
270
current_lower_bounds .append (lower_bound )
254
271
@@ -261,7 +278,7 @@ def fit_predict(self, X, y=None):
261
278
262
279
self ._print_verbose_msg_init_end (lower_bound , converged )
263
280
264
- if lower_bound > max_lower_bound or max_lower_bound == - np .inf :
281
+ if lower_bound > max_lower_bound or max_lower_bound == - xp .inf :
265
282
max_lower_bound = lower_bound
266
283
best_params = self ._get_parameters ()
267
284
best_n_iter = n_iter
@@ -281,19 +298,19 @@ def fit_predict(self, X, y=None):
281
298
ConvergenceWarning ,
282
299
)
283
300
284
- self ._set_parameters (best_params )
301
+ self ._set_parameters (best_params , xp = xp )
285
302
self .n_iter_ = best_n_iter
286
303
self .lower_bound_ = max_lower_bound
287
304
self .lower_bounds_ = best_lower_bounds
288
305
289
306
# Always do a final e-step to guarantee that the labels returned by
290
307
# fit_predict(X) are always consistent with fit(X).predict(X)
291
308
# for any value of max_iter and tol (and any random_state).
292
- _ , log_resp = self ._e_step (X )
309
+ _ , log_resp = self ._e_step (X , xp = xp )
293
310
294
- return log_resp .argmax (axis = 1 )
311
+ return xp .argmax (log_resp , axis = 1 )
295
312
296
- def _e_step (self , X ):
313
+ def _e_step (self , X , xp = None ):
297
314
"""E step.
298
315
299
316
Parameters
@@ -309,8 +326,9 @@ def _e_step(self, X):
309
326
Logarithm of the posterior probabilities (or responsibilities) of
310
327
the point of each sample in X.
311
328
"""
312
- log_prob_norm , log_resp = self ._estimate_log_prob_resp (X )
313
- return np .mean (log_prob_norm ), log_resp
329
+ xp , _ = get_namespace (X , xp = xp )
330
+ log_prob_norm , log_resp = self ._estimate_log_prob_resp (X , xp = xp )
331
+ return xp .mean (log_prob_norm ), log_resp
314
332
315
333
@abstractmethod
316
334
def _m_step (self , X , log_resp ):
@@ -351,7 +369,7 @@ def score_samples(self, X):
351
369
check_is_fitted (self )
352
370
X = validate_data (self , X , reset = False )
353
371
354
- return logsumexp (self ._estimate_weighted_log_prob (X ), axis = 1 )
372
+ return _logsumexp (self ._estimate_weighted_log_prob (X ), axis = 1 )
355
373
356
374
def score (self , X , y = None ):
357
375
"""Compute the per-sample average log-likelihood of the given data X.
@@ -370,7 +388,8 @@ def score(self, X, y=None):
370
388
log_likelihood : float
371
389
Log-likelihood of `X` under the Gaussian mixture model.
372
390
"""
373
- return self .score_samples (X ).mean ()
391
+ xp , _ = get_namespace (X )
392
+ return float (xp .mean (self .score_samples (X )))
374
393
375
394
def predict (self , X ):
376
395
"""Predict the labels for the data samples in X using trained model.
@@ -387,8 +406,9 @@ def predict(self, X):
387
406
Component labels.
388
407
"""
389
408
check_is_fitted (self )
409
+ xp , _ = get_namespace (X )
390
410
X = validate_data (self , X , reset = False )
391
- return self ._estimate_weighted_log_prob (X ). argmax ( axis = 1 )
411
+ return xp . argmax ( self ._estimate_weighted_log_prob (X ), axis = 1 )
392
412
393
413
def predict_proba (self , X ):
394
414
"""Evaluate the components' density for each sample.
@@ -406,8 +426,9 @@ def predict_proba(self, X):
406
426
"""
407
427
check_is_fitted (self )
408
428
X = validate_data (self , X , reset = False )
409
- _ , log_resp = self ._estimate_log_prob_resp (X )
410
- return np .exp (log_resp )
429
+ xp , _ = get_namespace (X )
430
+ _ , log_resp = self ._estimate_log_prob_resp (X , xp = xp )
431
+ return xp .exp (log_resp )
411
432
412
433
def sample (self , n_samples = 1 ):
413
434
"""Generate random samples from the fitted Gaussian distribution.
@@ -426,6 +447,7 @@ def sample(self, n_samples=1):
426
447
Component labels.
427
448
"""
428
449
check_is_fitted (self )
450
+ xp , _ , device_ = get_namespace_and_device (self .means_ )
429
451
430
452
if n_samples < 1 :
431
453
raise ValueError (
@@ -435,22 +457,30 @@ def sample(self, n_samples=1):
435
457
436
458
_ , n_features = self .means_ .shape
437
459
rng = check_random_state (self .random_state )
438
- n_samples_comp = rng .multinomial (n_samples , self .weights_ )
460
+ n_samples_comp = rng .multinomial (
461
+ n_samples , _convert_to_numpy (self .weights_ , xp )
462
+ )
439
463
440
464
if self .covariance_type == "full" :
441
465
X = np .vstack (
442
466
[
443
467
rng .multivariate_normal (mean , covariance , int (sample ))
444
468
for (mean , covariance , sample ) in zip (
445
- self .means_ , self .covariances_ , n_samples_comp
469
+ _convert_to_numpy (self .means_ , xp ),
470
+ _convert_to_numpy (self .covariances_ , xp ),
471
+ n_samples_comp ,
446
472
)
447
473
]
448
474
)
449
475
elif self .covariance_type == "tied" :
450
476
X = np .vstack (
451
477
[
452
- rng .multivariate_normal (mean , self .covariances_ , int (sample ))
453
- for (mean , sample ) in zip (self .means_ , n_samples_comp )
478
+ rng .multivariate_normal (
479
+ mean , _convert_to_numpy (self .covariances_ , xp ), int (sample )
480
+ )
481
+ for (mean , sample ) in zip (
482
+ _convert_to_numpy (self .means_ , xp ), n_samples_comp
483
+ )
454
484
]
455
485
)
456
486
else :
@@ -460,18 +490,23 @@ def sample(self, n_samples=1):
460
490
+ rng .standard_normal (size = (sample , n_features ))
461
491
* np .sqrt (covariance )
462
492
for (mean , covariance , sample ) in zip (
463
- self .means_ , self .covariances_ , n_samples_comp
493
+ _convert_to_numpy (self .means_ , xp ),
494
+ _convert_to_numpy (self .covariances_ , xp ),
495
+ n_samples_comp ,
464
496
)
465
497
]
466
498
)
467
499
468
- y = np .concatenate (
469
- [np .full (sample , j , dtype = int ) for j , sample in enumerate (n_samples_comp )]
500
+ y = xp .concat (
501
+ [
502
+ xp .full (int (n_samples_comp [i ]), i , dtype = xp .int64 , device = device_ )
503
+ for i in range (len (n_samples_comp ))
504
+ ]
470
505
)
471
506
472
- return (X , y )
507
+ return xp . asarray (X , device = device_ ), y
473
508
474
- def _estimate_weighted_log_prob (self , X ):
509
+ def _estimate_weighted_log_prob (self , X , xp = None ):
475
510
"""Estimate the weighted log-probabilities, log P(X | Z) + log weights.
476
511
477
512
Parameters
@@ -482,10 +517,10 @@ def _estimate_weighted_log_prob(self, X):
482
517
-------
483
518
weighted_log_prob : array, shape (n_samples, n_component)
484
519
"""
485
- return self ._estimate_log_prob (X ) + self ._estimate_log_weights ()
520
+ return self ._estimate_log_prob (X , xp = xp ) + self ._estimate_log_weights (xp = xp )
486
521
487
522
@abstractmethod
488
- def _estimate_log_weights (self ):
523
+ def _estimate_log_weights (self , xp = None ):
489
524
"""Estimate log-weights in EM algorithm, E[ log pi ] in VB algorithm.
490
525
491
526
Returns
@@ -495,7 +530,7 @@ def _estimate_log_weights(self):
495
530
pass
496
531
497
532
@abstractmethod
498
- def _estimate_log_prob (self , X ):
533
+ def _estimate_log_prob (self , X , xp = None ):
499
534
"""Estimate the log-probabilities log P(X | Z).
500
535
501
536
Compute the log-probabilities per each component for each sample.
@@ -510,7 +545,7 @@ def _estimate_log_prob(self, X):
510
545
"""
511
546
pass
512
547
513
- def _estimate_log_prob_resp (self , X ):
548
+ def _estimate_log_prob_resp (self , X , xp = None ):
514
549
"""Estimate log probabilities and responsibilities for each sample.
515
550
516
551
Compute the log probabilities, weighted log probabilities per
@@ -529,11 +564,17 @@ def _estimate_log_prob_resp(self, X):
529
564
log_responsibilities : array, shape (n_samples, n_components)
530
565
logarithm of the responsibilities
531
566
"""
532
- weighted_log_prob = self ._estimate_weighted_log_prob (X )
533
- log_prob_norm = logsumexp (weighted_log_prob , axis = 1 )
534
- with np .errstate (under = "ignore" ):
567
+ xp , _ = get_namespace (X , xp = xp )
568
+ weighted_log_prob = self ._estimate_weighted_log_prob (X , xp = xp )
569
+ log_prob_norm = _logsumexp (weighted_log_prob , axis = 1 , xp = xp )
570
+
571
+ # There is no errstate equivalent for warning/error management in array API
572
+ context_manager = (
573
+ np .errstate (under = "ignore" ) if _is_numpy_namespace (xp ) else nullcontext ()
574
+ )
575
+ with context_manager :
535
576
# ignore underflow
536
- log_resp = weighted_log_prob - log_prob_norm [:, np .newaxis ]
577
+ log_resp = weighted_log_prob - log_prob_norm [:, xp .newaxis ]
537
578
return log_prob_norm , log_resp
538
579
539
580
def _print_verbose_msg_init_beg (self , n_init ):
0 commit comments