Skip to content

Commit ee5d94e

Browse files
Alexlandeauavignyjeremiedbbogrisel
authored
ENH Add n_components="auto" to NMF to be inferred from custom init (scikit-learn#26634)
Co-authored-by: avigny <alexandre.vigny@dataiku.com> Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 2579841 commit ee5d94e

File tree

4 files changed

+218
-14
lines changed

4 files changed

+218
-14
lines changed

doc/whats_new/v1.4.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,13 @@ TODO: update at the time of the release.
5858
:meth:`base.OutlierMixin.fit_predict` now accept ``**kwargs`` which are
5959
passed to the ``fit`` method of the the estimator. :pr:`26506` by `Adrin
6060
Jalali`_.
61+
62+
:mod:`sklearn.decomposition`
63+
............................
64+
65+
- |Enhancement| An "auto" option was added to the `n_components` parameter of
66+
:func:`decomposition.non_negative_factorization`, :class:`decomposition.NMF` and
67+
:class:`decomposition.MiniBatchNMF` to automatically infer the number of components from W or H shapes
68+
when using a custom initialization. The default value of this parameter will change
69+
from `None` to `auto` in version 1.6.
70+
:pr:`26634` by :user:`Alexandre Landeau <AlexL>` and :user:`Alexandre Vigny <avigny>`.

sklearn/decomposition/_nmf.py

Lines changed: 77 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ..exceptions import ConvergenceWarning
2828
from ..utils import check_array, check_random_state, gen_batches, metadata_routing
2929
from ..utils._param_validation import (
30+
Hidden,
3031
Interval,
3132
StrOptions,
3233
validate_params,
@@ -69,14 +70,19 @@ def trace_dot(X, Y):
6970

7071
def _check_init(A, shape, whom):
7172
A = check_array(A)
72-
if np.shape(A) != shape:
73+
if shape[0] != "auto" and A.shape[0] != shape[0]:
7374
raise ValueError(
74-
"Array with wrong shape passed to %s. Expected %s, but got %s "
75-
% (whom, shape, np.shape(A))
75+
f"Array with wrong first dimension passed to {whom}. Expected {shape[0]}, "
76+
f"but got {A.shape[0]}."
77+
)
78+
if shape[1] != "auto" and A.shape[1] != shape[1]:
79+
raise ValueError(
80+
f"Array with wrong second dimension passed to {whom}. Expected {shape[1]}, "
81+
f"but got {A.shape[1]}."
7682
)
7783
check_non_negative(A, whom)
7884
if np.max(A) == 0:
79-
raise ValueError("Array passed to %s is full of zeros." % whom)
85+
raise ValueError(f"Array passed to {whom} is full of zeros.")
8086

8187

8288
def _beta_divergence(X, W, H, beta, square_root=False):
@@ -903,7 +909,7 @@ def non_negative_factorization(
903909
X,
904910
W=None,
905911
H=None,
906-
n_components=None,
912+
n_components="warn",
907913
*,
908914
init=None,
909915
update_H=True,
@@ -976,9 +982,14 @@ def non_negative_factorization(
976982
If `update_H=False`, it is used as a constant, to solve for W only.
977983
If `None`, uses the initialisation method specified in `init`.
978984
979-
n_components : int, default=None
985+
n_components : int or {'auto'} or None, default=None
980986
Number of components, if n_components is not set all features
981987
are kept.
988+
If `n_components='auto'`, the number of components is automatically inferred
989+
from `W` or `H` shapes.
990+
991+
.. versionchanged:: 1.4
992+
Added `'auto'` value.
982993
983994
init : {'random', 'nndsvd', 'nndsvda', 'nndsvdar', 'custom'}, default=None
984995
Method used to initialize the procedure.
@@ -1133,7 +1144,12 @@ class _BaseNMF(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator,
11331144
__metadata_request__inverse_transform = {"W": metadata_routing.UNUSED}
11341145

11351146
_parameter_constraints: dict = {
1136-
"n_components": [Interval(Integral, 1, None, closed="left"), None],
1147+
"n_components": [
1148+
Interval(Integral, 1, None, closed="left"),
1149+
None,
1150+
StrOptions({"auto"}),
1151+
Hidden(StrOptions({"warn"})),
1152+
],
11371153
"init": [
11381154
StrOptions({"random", "nndsvd", "nndsvda", "nndsvdar", "custom"}),
11391155
None,
@@ -1153,7 +1169,7 @@ class _BaseNMF(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator,
11531169

11541170
def __init__(
11551171
self,
1156-
n_components=None,
1172+
n_components="warn",
11571173
*,
11581174
init=None,
11591175
beta_loss="frobenius",
@@ -1179,6 +1195,16 @@ def __init__(
11791195
def _check_params(self, X):
11801196
# n_components
11811197
self._n_components = self.n_components
1198+
if self.n_components == "warn":
1199+
warnings.warn(
1200+
(
1201+
"The default value of `n_components` will change from `None` to"
1202+
" `'auto'` in 1.6. Set the value of `n_components` to `None`"
1203+
" explicitly to supress the warning."
1204+
),
1205+
FutureWarning,
1206+
)
1207+
self._n_components = None # Keeping the old default value
11821208
if self._n_components is None:
11831209
self._n_components = X.shape[1]
11841210

@@ -1188,32 +1214,61 @@ def _check_params(self, X):
11881214
def _check_w_h(self, X, W, H, update_H):
11891215
"""Check W and H, or initialize them."""
11901216
n_samples, n_features = X.shape
1217+
11911218
if self.init == "custom" and update_H:
11921219
_check_init(H, (self._n_components, n_features), "NMF (input H)")
11931220
_check_init(W, (n_samples, self._n_components), "NMF (input W)")
1221+
if self._n_components == "auto":
1222+
self._n_components = H.shape[0]
1223+
11941224
if H.dtype != X.dtype or W.dtype != X.dtype:
11951225
raise TypeError(
11961226
"H and W should have the same dtype as X. Got "
11971227
"H.dtype = {} and W.dtype = {}.".format(H.dtype, W.dtype)
11981228
)
1229+
11991230
elif not update_H:
1231+
if W is not None:
1232+
warnings.warn(
1233+
"When update_H=False, the provided initial W is not used.",
1234+
RuntimeWarning,
1235+
)
1236+
12001237
_check_init(H, (self._n_components, n_features), "NMF (input H)")
1238+
if self._n_components == "auto":
1239+
self._n_components = H.shape[0]
1240+
12011241
if H.dtype != X.dtype:
12021242
raise TypeError(
12031243
"H should have the same dtype as X. Got H.dtype = {}.".format(
12041244
H.dtype
12051245
)
12061246
)
1247+
12071248
# 'mu' solver should not be initialized by zeros
12081249
if self.solver == "mu":
12091250
avg = np.sqrt(X.mean() / self._n_components)
12101251
W = np.full((n_samples, self._n_components), avg, dtype=X.dtype)
12111252
else:
12121253
W = np.zeros((n_samples, self._n_components), dtype=X.dtype)
1254+
12131255
else:
1256+
if W is not None or H is not None:
1257+
warnings.warn(
1258+
(
1259+
"When init!='custom', provided W or H are ignored. Set "
1260+
" init='custom' to use them as initialization."
1261+
),
1262+
RuntimeWarning,
1263+
)
1264+
1265+
if self._n_components == "auto":
1266+
self._n_components = X.shape[1]
1267+
12141268
W, H = _initialize_nmf(
12151269
X, self._n_components, init=self.init, random_state=self.random_state
12161270
)
1271+
12171272
return W, H
12181273

12191274
def _compute_regularization(self, X):
@@ -1352,9 +1407,14 @@ class NMF(_BaseNMF):
13521407
13531408
Parameters
13541409
----------
1355-
n_components : int, default=None
1410+
n_components : int or {'auto'} or None, default=None
13561411
Number of components, if n_components is not set all features
13571412
are kept.
1413+
If `n_components='auto'`, the number of components is automatically inferred
1414+
from W or H shapes.
1415+
1416+
.. versionchanged:: 1.4
1417+
Added `'auto'` value.
13581418
13591419
init : {'random', 'nndsvd', 'nndsvda', 'nndsvdar', 'custom'}, default=None
13601420
Method used to initialize the procedure.
@@ -1517,7 +1577,7 @@ class NMF(_BaseNMF):
15171577

15181578
def __init__(
15191579
self,
1520-
n_components=None,
1580+
n_components="warn",
15211581
*,
15221582
init=None,
15231583
solver="cd",
@@ -1786,9 +1846,14 @@ class MiniBatchNMF(_BaseNMF):
17861846
17871847
Parameters
17881848
----------
1789-
n_components : int, default=None
1849+
n_components : int or {'auto'} or None, default=None
17901850
Number of components, if `n_components` is not set all features
17911851
are kept.
1852+
If `n_components='auto'`, the number of components is automatically inferred
1853+
from W or H shapes.
1854+
1855+
.. versionchanged:: 1.4
1856+
Added `'auto'` value.
17921857
17931858
init : {'random', 'nndsvd', 'nndsvda', 'nndsvdar', 'custom'}, default=None
17941859
Method used to initialize the procedure.
@@ -1953,7 +2018,7 @@ class MiniBatchNMF(_BaseNMF):
19532018

19542019
def __init__(
19552020
self,
1956-
n_components=None,
2021+
n_components="warn",
19572022
*,
19582023
init=None,
19592024
batch_size=1024,

0 commit comments

Comments
 (0)