27
27
from ..exceptions import ConvergenceWarning
28
28
from ..utils import check_array , check_random_state , gen_batches , metadata_routing
29
29
from ..utils ._param_validation import (
30
+ Hidden ,
30
31
Interval ,
31
32
StrOptions ,
32
33
validate_params ,
@@ -69,14 +70,19 @@ def trace_dot(X, Y):
69
70
70
71
def _check_init (A , shape , whom ):
71
72
A = check_array (A )
72
- if np . shape ( A ) != shape :
73
+ if shape [ 0 ] != "auto" and A . shape [ 0 ] != shape [ 0 ] :
73
74
raise ValueError (
74
- "Array with wrong shape passed to %s. Expected %s, but got %s "
75
- % (whom , shape , np .shape (A ))
75
+ f"Array with wrong first dimension passed to { whom } . Expected { shape [0 ]} , "
76
+ f"but got { A .shape [0 ]} ."
77
+ )
78
+ if shape [1 ] != "auto" and A .shape [1 ] != shape [1 ]:
79
+ raise ValueError (
80
+ f"Array with wrong second dimension passed to { whom } . Expected { shape [1 ]} , "
81
+ f"but got { A .shape [1 ]} ."
76
82
)
77
83
check_non_negative (A , whom )
78
84
if np .max (A ) == 0 :
79
- raise ValueError ("Array passed to %s is full of zeros." % whom )
85
+ raise ValueError (f "Array passed to { whom } is full of zeros." )
80
86
81
87
82
88
def _beta_divergence (X , W , H , beta , square_root = False ):
@@ -903,7 +909,7 @@ def non_negative_factorization(
903
909
X ,
904
910
W = None ,
905
911
H = None ,
906
- n_components = None ,
912
+ n_components = "warn" ,
907
913
* ,
908
914
init = None ,
909
915
update_H = True ,
@@ -976,9 +982,14 @@ def non_negative_factorization(
976
982
If `update_H=False`, it is used as a constant, to solve for W only.
977
983
If `None`, uses the initialisation method specified in `init`.
978
984
979
- n_components : int, default=None
985
+ n_components : int or {'auto'} or None , default=None
980
986
Number of components, if n_components is not set all features
981
987
are kept.
988
+ If `n_components='auto'`, the number of components is automatically inferred
989
+ from `W` or `H` shapes.
990
+
991
+ .. versionchanged:: 1.4
992
+ Added `'auto'` value.
982
993
983
994
init : {'random', 'nndsvd', 'nndsvda', 'nndsvdar', 'custom'}, default=None
984
995
Method used to initialize the procedure.
@@ -1133,7 +1144,12 @@ class _BaseNMF(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator,
1133
1144
__metadata_request__inverse_transform = {"W" : metadata_routing .UNUSED }
1134
1145
1135
1146
_parameter_constraints : dict = {
1136
- "n_components" : [Interval (Integral , 1 , None , closed = "left" ), None ],
1147
+ "n_components" : [
1148
+ Interval (Integral , 1 , None , closed = "left" ),
1149
+ None ,
1150
+ StrOptions ({"auto" }),
1151
+ Hidden (StrOptions ({"warn" })),
1152
+ ],
1137
1153
"init" : [
1138
1154
StrOptions ({"random" , "nndsvd" , "nndsvda" , "nndsvdar" , "custom" }),
1139
1155
None ,
@@ -1153,7 +1169,7 @@ class _BaseNMF(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator,
1153
1169
1154
1170
def __init__ (
1155
1171
self ,
1156
- n_components = None ,
1172
+ n_components = "warn" ,
1157
1173
* ,
1158
1174
init = None ,
1159
1175
beta_loss = "frobenius" ,
@@ -1179,6 +1195,16 @@ def __init__(
1179
1195
def _check_params (self , X ):
1180
1196
# n_components
1181
1197
self ._n_components = self .n_components
1198
+ if self .n_components == "warn" :
1199
+ warnings .warn (
1200
+ (
1201
+ "The default value of `n_components` will change from `None` to"
1202
+ " `'auto'` in 1.6. Set the value of `n_components` to `None`"
1203
+ " explicitly to supress the warning."
1204
+ ),
1205
+ FutureWarning ,
1206
+ )
1207
+ self ._n_components = None # Keeping the old default value
1182
1208
if self ._n_components is None :
1183
1209
self ._n_components = X .shape [1 ]
1184
1210
@@ -1188,32 +1214,61 @@ def _check_params(self, X):
1188
1214
def _check_w_h (self , X , W , H , update_H ):
1189
1215
"""Check W and H, or initialize them."""
1190
1216
n_samples , n_features = X .shape
1217
+
1191
1218
if self .init == "custom" and update_H :
1192
1219
_check_init (H , (self ._n_components , n_features ), "NMF (input H)" )
1193
1220
_check_init (W , (n_samples , self ._n_components ), "NMF (input W)" )
1221
+ if self ._n_components == "auto" :
1222
+ self ._n_components = H .shape [0 ]
1223
+
1194
1224
if H .dtype != X .dtype or W .dtype != X .dtype :
1195
1225
raise TypeError (
1196
1226
"H and W should have the same dtype as X. Got "
1197
1227
"H.dtype = {} and W.dtype = {}." .format (H .dtype , W .dtype )
1198
1228
)
1229
+
1199
1230
elif not update_H :
1231
+ if W is not None :
1232
+ warnings .warn (
1233
+ "When update_H=False, the provided initial W is not used." ,
1234
+ RuntimeWarning ,
1235
+ )
1236
+
1200
1237
_check_init (H , (self ._n_components , n_features ), "NMF (input H)" )
1238
+ if self ._n_components == "auto" :
1239
+ self ._n_components = H .shape [0 ]
1240
+
1201
1241
if H .dtype != X .dtype :
1202
1242
raise TypeError (
1203
1243
"H should have the same dtype as X. Got H.dtype = {}." .format (
1204
1244
H .dtype
1205
1245
)
1206
1246
)
1247
+
1207
1248
# 'mu' solver should not be initialized by zeros
1208
1249
if self .solver == "mu" :
1209
1250
avg = np .sqrt (X .mean () / self ._n_components )
1210
1251
W = np .full ((n_samples , self ._n_components ), avg , dtype = X .dtype )
1211
1252
else :
1212
1253
W = np .zeros ((n_samples , self ._n_components ), dtype = X .dtype )
1254
+
1213
1255
else :
1256
+ if W is not None or H is not None :
1257
+ warnings .warn (
1258
+ (
1259
+ "When init!='custom', provided W or H are ignored. Set "
1260
+ " init='custom' to use them as initialization."
1261
+ ),
1262
+ RuntimeWarning ,
1263
+ )
1264
+
1265
+ if self ._n_components == "auto" :
1266
+ self ._n_components = X .shape [1 ]
1267
+
1214
1268
W , H = _initialize_nmf (
1215
1269
X , self ._n_components , init = self .init , random_state = self .random_state
1216
1270
)
1271
+
1217
1272
return W , H
1218
1273
1219
1274
def _compute_regularization (self , X ):
@@ -1352,9 +1407,14 @@ class NMF(_BaseNMF):
1352
1407
1353
1408
Parameters
1354
1409
----------
1355
- n_components : int, default=None
1410
+ n_components : int or {'auto'} or None , default=None
1356
1411
Number of components, if n_components is not set all features
1357
1412
are kept.
1413
+ If `n_components='auto'`, the number of components is automatically inferred
1414
+ from W or H shapes.
1415
+
1416
+ .. versionchanged:: 1.4
1417
+ Added `'auto'` value.
1358
1418
1359
1419
init : {'random', 'nndsvd', 'nndsvda', 'nndsvdar', 'custom'}, default=None
1360
1420
Method used to initialize the procedure.
@@ -1517,7 +1577,7 @@ class NMF(_BaseNMF):
1517
1577
1518
1578
def __init__ (
1519
1579
self ,
1520
- n_components = None ,
1580
+ n_components = "warn" ,
1521
1581
* ,
1522
1582
init = None ,
1523
1583
solver = "cd" ,
@@ -1786,9 +1846,14 @@ class MiniBatchNMF(_BaseNMF):
1786
1846
1787
1847
Parameters
1788
1848
----------
1789
- n_components : int, default=None
1849
+ n_components : int or {'auto'} or None , default=None
1790
1850
Number of components, if `n_components` is not set all features
1791
1851
are kept.
1852
+ If `n_components='auto'`, the number of components is automatically inferred
1853
+ from W or H shapes.
1854
+
1855
+ .. versionchanged:: 1.4
1856
+ Added `'auto'` value.
1792
1857
1793
1858
init : {'random', 'nndsvd', 'nndsvda', 'nndsvdar', 'custom'}, default=None
1794
1859
Method used to initialize the procedure.
@@ -1953,7 +2018,7 @@ class MiniBatchNMF(_BaseNMF):
1953
2018
1954
2019
def __init__ (
1955
2020
self ,
1956
- n_components = None ,
2021
+ n_components = "warn" ,
1957
2022
* ,
1958
2023
init = None ,
1959
2024
batch_size = 1024 ,
0 commit comments