Skip to content

Commit 9f3ca07

Browse files
clane9adrinjalalijeremiedbb
authored
FIX Add input array check to randomized_svd and randomized_range_finder (scikit-learn#30819)
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com> Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
1 parent 5059058 commit 9f3ca07

File tree

10 files changed

+128
-20
lines changed

10 files changed

+128
-20
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- :func:`sklearn.utils.extmath.randomized_svd` now support Array API compatible inputs.
2+
By :user:`Connor Lane <clane9>` and :user:`Jérémie du Boisberranger <jeremiedbb>`.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- :func:`utils.extmath.randomized_svd` and :func:`utils.extmath.randomized_range_finder`
2+
now validate their input array to fail early with an informative error message on
3+
invalid input.
4+
By :user:`Connor Lane <clane9>`.

sklearn/cluster/_bicluster.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ..base import BaseEstimator, BiclusterMixin, _fit_context
1515
from ..utils import check_random_state, check_scalar
1616
from ..utils._param_validation import Interval, StrOptions
17-
from ..utils.extmath import make_nonnegative, randomized_svd, safe_sparse_dot
17+
from ..utils.extmath import _randomized_svd, make_nonnegative, safe_sparse_dot
1818
from ..utils.validation import assert_all_finite, validate_data
1919
from ._kmeans import KMeans, MiniBatchKMeans
2020

@@ -144,7 +144,7 @@ def _svd(self, array, n_components, n_discard):
144144
kwargs = {}
145145
if self.n_svd_vecs is not None:
146146
kwargs["n_oversamples"] = self.n_svd_vecs
147-
u, _, vt = randomized_svd(
147+
u, _, vt = _randomized_svd(
148148
array, n_components, random_state=self.random_state, **kwargs
149149
)
150150

sklearn/decomposition/_dict_learning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..linear_model import Lars, Lasso, LassoLars, orthogonal_mp_gram
2222
from ..utils import check_array, check_random_state, gen_batches, gen_even_slices
2323
from ..utils._param_validation import Interval, StrOptions, validate_params
24-
from ..utils.extmath import randomized_svd, row_norms, svd_flip
24+
from ..utils.extmath import _randomized_svd, row_norms, svd_flip
2525
from ..utils.parallel import Parallel, delayed
2626
from ..utils.validation import check_is_fitted, validate_data
2727

@@ -2049,7 +2049,7 @@ def _initialize_dict(self, X, random_state):
20492049
dictionary = self.dict_init
20502050
else:
20512051
# Init V with SVD of X
2052-
_, S, dictionary = randomized_svd(
2052+
_, S, dictionary = _randomized_svd(
20532053
X, self._n_components, random_state=random_state
20542054
)
20552055
dictionary = S[:, np.newaxis] * dictionary

sklearn/decomposition/_factor_analysis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from ..exceptions import ConvergenceWarning
3333
from ..utils import check_random_state
3434
from ..utils._param_validation import Interval, StrOptions
35-
from ..utils.extmath import fast_logdet, randomized_svd, squared_norm
35+
from ..utils.extmath import _randomized_svd, fast_logdet, squared_norm
3636
from ..utils.validation import check_is_fitted, validate_data
3737

3838

@@ -264,7 +264,7 @@ def my_svd(X):
264264
random_state = check_random_state(self.random_state)
265265

266266
def my_svd(X):
267-
_, s, Vt = randomized_svd(
267+
_, s, Vt = _randomized_svd(
268268
X,
269269
n_components,
270270
random_state=random_state,

sklearn/decomposition/_nmf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
StrOptions,
2929
validate_params,
3030
)
31-
from ..utils.extmath import randomized_svd, safe_sparse_dot, squared_norm
31+
from ..utils.extmath import _randomized_svd, safe_sparse_dot, squared_norm
3232
from ..utils.validation import (
3333
check_is_fitted,
3434
check_non_negative,
@@ -314,7 +314,7 @@ def _initialize_nmf(X, n_components, init=None, eps=1e-6, random_state=None):
314314
return W, H
315315

316316
# NNDSVD initialization
317-
U, S, V = randomized_svd(X, n_components, random_state=random_state)
317+
U, S, V = _randomized_svd(X, n_components, random_state=random_state)
318318
W = np.zeros_like(U)
319319
H = np.zeros_like(V)
320320

sklearn/decomposition/_pca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ..utils._arpack import _init_arpack_v0
1717
from ..utils._array_api import _convert_to_numpy, get_namespace
1818
from ..utils._param_validation import Interval, RealNotInt, StrOptions
19-
from ..utils.extmath import fast_logdet, randomized_svd, stable_cumsum, svd_flip
19+
from ..utils.extmath import _randomized_svd, fast_logdet, stable_cumsum, svd_flip
2020
from ..utils.sparsefuncs import _implicit_column_offset, mean_variance_axis
2121
from ..utils.validation import check_is_fitted, validate_data
2222
from ._base import _BasePCA
@@ -754,7 +754,7 @@ def _fit_truncated(self, X, n_components, xp):
754754

755755
elif svd_solver == "randomized":
756756
# sign flipping is done inside
757-
U, S, Vt = randomized_svd(
757+
U, S, Vt = _randomized_svd(
758758
X_centered,
759759
n_components=n_components,
760760
n_oversamples=self.n_oversamples,

sklearn/decomposition/_truncated_svd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ..utils import check_array, check_random_state
1919
from ..utils._arpack import _init_arpack_v0
2020
from ..utils._param_validation import Interval, StrOptions
21-
from ..utils.extmath import randomized_svd, safe_sparse_dot, svd_flip
21+
from ..utils.extmath import _randomized_svd, safe_sparse_dot, svd_flip
2222
from ..utils.sparsefuncs import mean_variance_axis
2323
from ..utils.validation import check_is_fitted, validate_data
2424

@@ -241,7 +241,7 @@ def fit_transform(self, X, y=None):
241241
f"n_components({self.n_components}) must be <="
242242
f" n_features({X.shape[1]})."
243243
)
244-
U, Sigma, VT = randomized_svd(
244+
U, Sigma, VT = _randomized_svd(
245245
X,
246246
self.n_components,
247247
n_iter=self.n_iter,

sklearn/utils/extmath.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def randomized_range_finder(
219219
220220
Parameters
221221
----------
222-
A : 2D array
222+
A : {array-like, sparse matrix} of shape (n_samples, n_features)
223223
The input data matrix.
224224
225225
size : int
@@ -246,9 +246,9 @@ def randomized_range_finder(
246246
247247
Returns
248248
-------
249-
Q : ndarray
250-
A (size x size) projection matrix, the range of which
251-
approximates well the range of the input matrix A.
249+
Q : ndarray of shape (size, size)
250+
A projection matrix, the range of which approximates well the range of the
251+
input matrix A.
252252
253253
Notes
254254
-----
@@ -273,6 +273,21 @@ def randomized_range_finder(
273273
[-0.52..., 0.24...],
274274
[-0.82..., -0.38...]])
275275
"""
276+
A = check_array(A, accept_sparse=True)
277+
278+
return _randomized_range_finder(
279+
A,
280+
size=size,
281+
n_iter=n_iter,
282+
power_iteration_normalizer=power_iteration_normalizer,
283+
random_state=random_state,
284+
)
285+
286+
287+
def _randomized_range_finder(
288+
A, *, size, n_iter, power_iteration_normalizer="auto", random_state=None
289+
):
290+
"""Body of randomized_range_finder without input validation."""
276291
xp, is_array_api_compliant = get_namespace(A)
277292
random_state = check_random_state(random_state)
278293

@@ -344,7 +359,7 @@ def randomized_range_finder(
344359

345360
@validate_params(
346361
{
347-
"M": [np.ndarray, "sparse matrix"],
362+
"M": ["array-like", "sparse matrix"],
348363
"n_components": [Interval(Integral, 1, None, closed="left")],
349364
"n_oversamples": [Interval(Integral, 0, None, closed="left")],
350365
"n_iter": [Interval(Integral, 0, None, closed="left"), StrOptions({"auto"})],
@@ -381,7 +396,7 @@ def randomized_svd(
381396
382397
Parameters
383398
----------
384-
M : {ndarray, sparse matrix}
399+
M : {array-like, sparse matrix} of shape (n_samples, n_features)
385400
Matrix to decompose.
386401
387402
n_components : int
@@ -499,6 +514,35 @@ def randomized_svd(
499514
>>> U.shape, s.shape, Vh.shape
500515
((3, 2), (2,), (2, 4))
501516
"""
517+
M = check_array(M, accept_sparse=True)
518+
return _randomized_svd(
519+
M,
520+
n_components=n_components,
521+
n_oversamples=n_oversamples,
522+
n_iter=n_iter,
523+
power_iteration_normalizer=power_iteration_normalizer,
524+
transpose=transpose,
525+
flip_sign=flip_sign,
526+
random_state=random_state,
527+
svd_lapack_driver=svd_lapack_driver,
528+
)
529+
530+
531+
def _randomized_svd(
532+
M,
533+
n_components,
534+
*,
535+
n_oversamples=10,
536+
n_iter="auto",
537+
power_iteration_normalizer="auto",
538+
transpose="auto",
539+
flip_sign=True,
540+
random_state=None,
541+
svd_lapack_driver="gesdd",
542+
):
543+
"""Body of randomized_svd without input validation."""
544+
xp, is_array_api_compliant = get_namespace(M)
545+
502546
if sparse.issparse(M) and M.format in ("lil", "dok"):
503547
warnings.warn(
504548
"Calculating SVD of a {} is expensive. "
@@ -521,7 +565,7 @@ def randomized_svd(
521565
# this implementation is a bit faster with smaller shape[1]
522566
M = M.T
523567

524-
Q = randomized_range_finder(
568+
Q = _randomized_range_finder(
525569
M,
526570
size=n_random,
527571
n_iter=n_iter,
@@ -533,7 +577,6 @@ def randomized_svd(
533577
B = Q.T @ M
534578

535579
# compute the SVD on the thin matrix: (k + p) wide
536-
xp, is_array_api_compliant = get_namespace(B)
537580
if is_array_api_compliant:
538581
Uhat, s, Vt = xp.linalg.svd(B, full_matrices=False)
539582
else:

sklearn/utils/tests/test_extmath.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,18 @@
99
from scipy.linalg import eigh
1010
from scipy.sparse.linalg import eigsh
1111

12+
from sklearn import config_context
1213
from sklearn.datasets import make_low_rank_matrix, make_sparse_spd_matrix
1314
from sklearn.utils import gen_batches
1415
from sklearn.utils._arpack import _init_arpack_v0
16+
from sklearn.utils._array_api import (
17+
_convert_to_numpy,
18+
_get_namespace_device_dtype_ids,
19+
get_namespace,
20+
yield_namespace_device_dtype_combinations,
21+
)
1522
from sklearn.utils._testing import (
23+
_array_api_for_tests,
1624
assert_allclose,
1725
assert_allclose_dense_sparse,
1826
assert_almost_equal,
@@ -28,6 +36,7 @@
2836
_safe_accumulator_op,
2937
cartesian,
3038
density,
39+
randomized_range_finder,
3140
randomized_svd,
3241
row_norms,
3342
safe_sparse_dot,
@@ -1060,3 +1069,53 @@ def test_approximate_mode():
10601069
# 25% * 99.000 = 24.750
10611070
# 25% * 1.000 = 250
10621071
assert_array_equal(ret, [24750, 250])
1072+
1073+
1074+
@pytest.mark.parametrize(
1075+
"array_namespace, device, dtype",
1076+
yield_namespace_device_dtype_combinations(),
1077+
ids=_get_namespace_device_dtype_ids,
1078+
)
1079+
def test_randomized_svd_array_api_compliance(array_namespace, device, dtype):
1080+
xp = _array_api_for_tests(array_namespace, device)
1081+
1082+
rng = np.random.RandomState(0)
1083+
X = rng.normal(size=(30, 10)).astype(dtype)
1084+
X_xp = xp.asarray(X, device=device)
1085+
n_components = 5
1086+
atol = 1e-5 if dtype == "float32" else 0
1087+
1088+
with config_context(array_api_dispatch=True):
1089+
u_np, s_np, vt_np = randomized_svd(X, n_components, random_state=0)
1090+
u_xp, s_xp, vt_xp = randomized_svd(X_xp, n_components, random_state=0)
1091+
1092+
assert get_namespace(u_xp)[0].__name__ == xp.__name__
1093+
assert get_namespace(s_xp)[0].__name__ == xp.__name__
1094+
assert get_namespace(vt_xp)[0].__name__ == xp.__name__
1095+
1096+
assert_allclose(_convert_to_numpy(u_xp, xp), u_np, atol=atol)
1097+
assert_allclose(_convert_to_numpy(s_xp, xp), s_np, atol=atol)
1098+
assert_allclose(_convert_to_numpy(vt_xp, xp), vt_np, atol=atol)
1099+
1100+
1101+
@pytest.mark.parametrize(
1102+
"array_namespace, device, dtype",
1103+
yield_namespace_device_dtype_combinations(),
1104+
ids=_get_namespace_device_dtype_ids,
1105+
)
1106+
def test_randomized_range_finder_array_api_compliance(array_namespace, device, dtype):
1107+
xp = _array_api_for_tests(array_namespace, device)
1108+
1109+
rng = np.random.RandomState(0)
1110+
X = rng.normal(size=(30, 10)).astype(dtype)
1111+
X_xp = xp.asarray(X, device=device)
1112+
size = 5
1113+
n_iter = 10
1114+
atol = 1e-5 if dtype == "float32" else 0
1115+
1116+
with config_context(array_api_dispatch=True):
1117+
Q_np = randomized_range_finder(X, size=size, n_iter=n_iter, random_state=0)
1118+
Q_xp = randomized_range_finder(X_xp, size=size, n_iter=n_iter, random_state=0)
1119+
1120+
assert get_namespace(Q_xp)[0].__name__ == xp.__name__
1121+
assert_allclose(_convert_to_numpy(Q_xp, xp), Q_np, atol=atol)

0 commit comments

Comments
 (0)