Skip to content

Commit e7ae63f

Browse files
glemaitreogrisel
andauthored
FIX avoid unecessary if copy branch for sparse array/matrix (scikit-learn#27336)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent b0da1b7 commit e7ae63f

File tree

3 files changed

+39
-20
lines changed

3 files changed

+39
-20
lines changed

doc/whats_new/v1.4.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,12 @@ Changelog
300300
which can be used to check whether a given set of parameters would be consumed.
301301
:pr:`26831` by `Adrin Jalali`_.
302302

303+
- |Fix| :func:`sklearn.utils.check_array` should accept both matrix and array from
304+
the sparse SciPy module. The previous implementation would fail if `copy=True` by
305+
calling specific NumPy `np.may_share_memory` that does not work with SciPy sparse
306+
array and does not return the correct result for SciPy sparse matrix.
307+
:pr:`27336` by :user:`Guillaume Lemaitre <glemaitre>`.
308+
303309
Code and Documentation Contributors
304310
-----------------------------------
305311

sklearn/utils/tests/test_validation.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,13 @@
4949
skip_if_array_api_compat_not_configured,
5050
)
5151
from sklearn.utils.estimator_checks import _NotAnArray
52-
from sklearn.utils.fixes import parse_version
52+
from sklearn.utils.fixes import (
53+
COO_CONTAINERS,
54+
CSC_CONTAINERS,
55+
CSR_CONTAINERS,
56+
DOK_CONTAINERS,
57+
parse_version,
58+
)
5359
from sklearn.utils.validation import (
5460
FLOAT_DTYPES,
5561
_allclose_dense_sparse,
@@ -356,13 +362,20 @@ def test_check_array():
356362
assert X is X_checked
357363

358364
# allowed sparse != None
359-
X_csc = sp.csc_matrix(X_C)
360-
X_coo = X_csc.tocoo()
361-
X_dok = X_csc.todok()
362-
X_int = X_csc.astype(int)
363-
X_float = X_csc.astype(float)
364365

365-
Xs = [X_csc, X_coo, X_dok, X_int, X_float]
366+
# try different type of sparse format
367+
Xs = []
368+
Xs.extend(
369+
[
370+
sparse_container(X_C)
371+
for sparse_container in CSR_CONTAINERS
372+
+ CSC_CONTAINERS
373+
+ COO_CONTAINERS
374+
+ DOK_CONTAINERS
375+
]
376+
)
377+
Xs.extend([Xs[0].astype(np.int64), Xs[0].astype(np.float64)])
378+
366379
accept_sparses = [["csr", "coo"], ["coo", "dok"]]
367380
# scipy sparse matrices do not support the object dtype so
368381
# this dtype is skipped in this loop

sklearn/utils/validation.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,19 @@ def is_sparse(dtype):
962962
allow_nan=force_all_finite == "allow-nan",
963963
)
964964

965+
if copy:
966+
if _is_numpy_namespace(xp):
967+
# only make a copy if `array` and `array_orig` may share memory`
968+
if np.may_share_memory(array, array_orig):
969+
array = _asarray_with_order(
970+
array, dtype=dtype, order=order, copy=True, xp=xp
971+
)
972+
else:
973+
# always make a copy for non-numpy arrays
974+
array = _asarray_with_order(
975+
array, dtype=dtype, order=order, copy=True, xp=xp
976+
)
977+
965978
if ensure_min_samples > 0:
966979
n_samples = _num_samples(array)
967980
if n_samples < ensure_min_samples:
@@ -980,19 +993,6 @@ def is_sparse(dtype):
980993
% (n_features, array.shape, ensure_min_features, context)
981994
)
982995

983-
if copy:
984-
if _is_numpy_namespace(xp):
985-
# only make a copy if `array` and `array_orig` may share memory`
986-
if np.may_share_memory(array, array_orig):
987-
array = _asarray_with_order(
988-
array, dtype=dtype, order=order, copy=True, xp=xp
989-
)
990-
else:
991-
# always make a copy for non-numpy arrays
992-
array = _asarray_with_order(
993-
array, dtype=dtype, order=order, copy=True, xp=xp
994-
)
995-
996996
return array
997997

998998

0 commit comments

Comments
 (0)