Skip to content

Commit d38a7e3

Browse files
authored
TST Extend tests for scipy.sparse/*array in sklearn/neighbors/tests/test_neighbors (scikit-learn#27250)
1 parent c1518a9 commit d38a7e3

File tree

4 files changed

+95
-49
lines changed

4 files changed

+95
-49
lines changed

doc/whats_new/v1.4.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,32 @@ and classes are impacted:
7777

7878
**Functions:**
7979

80+
- :func:`cluster.compute_optics_graph` in :pr:`27250` by
81+
:user:`Yao Xiao <Charlie-XIAO>`;
8082
- :func:`decomposition.non_negative_factorization` in :pr:`27100` by
8183
:user:`Isaac Virshup <ivirshup>`;
84+
- :func:`manifold.trustworthiness` in :pr:`27250` by :user:`Yao Xiao <Charlie-XIAO>`;
8285
- :func:`metrics.f_regression` in :pr:`27239` by :user:`Yaroslav Korobko <Tialo>`;
86+
- :func:`metrics.pairwise_distances` in :pr:`27250` by :user:`Yao Xiao <Charlie-XIAO>`;
87+
- :func:`metrics.pairwise_distances_chunked` in :pr:`27250` by
88+
:user:`Yao Xiao <Charlie-XIAO>`;
89+
- :func:`metrics.pairwise.pairwise_kernels` in :pr:`27250` by
90+
:user:`Yao Xiao <Charlie-XIAO>`;
8391
- :func:`metrics.r_regression` in :pr:`27239` by :user:`Yaroslav Korobko <Tialo>`;
8492
- :func:`sklearn.utils.multiclass.type_of_target` in :pr:`27274` by
8593
:user:`Yao Xiao <Charlie-XIAO>`.
8694

8795
**Classes:**
8896

97+
- :class:`cluster.HDBSCAN` in :pr:`27250` by :user:`Yao Xiao <Charlie-XIAO>`;
98+
- :class:`cluster.OPTICS` in :pr:`27250` by :user:`Yao Xiao <Charlie-XIAO>`;
8999
- :class:`decomposition.NMF` in :pr:`27100` by :user:`Isaac Virshup <ivirshup>`;
90100
- :class:`decomposition.MiniBatchNMF` in :pr:`27100` by
91101
:user:`Isaac Virshup <ivirshup>`;
92102
- :class:`feature_extraction.text.TfidfTransformer` in :pr:`27219` by
93103
:user:`Yao Xiao <Charlie-XIAO>`;
104+
- :class:`cluster.Isomap` in :pr:`27250` by :user:`Yao Xiao <Charlie-XIAO>`;
105+
- :class:`manifold.TSNE` in :pr:`27250` by :user:`Yao Xiao <Charlie-XIAO>`;
94106
- :class:`impute.SimpleImputer` in :pr:`27277` by :user:`Yao Xiao <Charlie-XIAO>`;
95107
- :class:`impute.IterativeImputer` in :pr:`27277` by :user:`Yao Xiao <Charlie-XIAO>`;
96108
- :class:`impute.KNNImputer` in :pr:`27277` by :user:`Yao Xiao <Charlie-XIAO>`;

sklearn/metrics/pairwise.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1826,7 +1826,11 @@ def _pairwise_callable(X, Y, metric, force_all_finite=True, **kwds):
18261826
out = np.zeros((X.shape[0], Y.shape[0]), dtype="float")
18271827
iterator = itertools.combinations(range(X.shape[0]), 2)
18281828
for i, j in iterator:
1829-
out[i, j] = metric(X[i], Y[j], **kwds)
1829+
# scipy has not yet implemented 1D sparse slices; once implemented this can
1830+
# be removed and `arr[ind]` can be simply used.
1831+
x = X[[i], :] if issparse(X) else X[i]
1832+
y = Y[[j], :] if issparse(Y) else Y[j]
1833+
out[i, j] = metric(x, y, **kwds)
18301834

18311835
# Make symmetric
18321836
# NB: out += out.T will produce incorrect results
@@ -1835,15 +1839,21 @@ def _pairwise_callable(X, Y, metric, force_all_finite=True, **kwds):
18351839
# Calculate diagonal
18361840
# NB: nonzero diagonals are allowed for both metrics and kernels
18371841
for i in range(X.shape[0]):
1838-
x = X[i]
1842+
# scipy has not yet implemented 1D sparse slices; once implemented this can
1843+
# be removed and `arr[ind]` can be simply used.
1844+
x = X[[i], :] if issparse(X) else X[i]
18391845
out[i, i] = metric(x, x, **kwds)
18401846

18411847
else:
18421848
# Calculate all cells
18431849
out = np.empty((X.shape[0], Y.shape[0]), dtype="float")
18441850
iterator = itertools.product(range(X.shape[0]), range(Y.shape[0]))
18451851
for i, j in iterator:
1846-
out[i, j] = metric(X[i], Y[j], **kwds)
1852+
# scipy has not yet implemented 1D sparse slices; once implemented this can
1853+
# be removed and `arr[ind]` can be simply used.
1854+
x = X[[i], :] if issparse(X) else X[i]
1855+
y = Y[[j], :] if issparse(Y) else Y[j]
1856+
out[i, j] = metric(x, y, **kwds)
18471857

18481858
return out
18491859

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 68 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,7 @@
55
import joblib
66
import numpy as np
77
import pytest
8-
from scipy.sparse import (
9-
bsr_matrix,
10-
coo_matrix,
11-
csc_matrix,
12-
csr_matrix,
13-
dia_matrix,
14-
dok_matrix,
15-
issparse,
16-
lil_matrix,
17-
)
8+
from scipy.sparse import issparse
189

1910
from sklearn import (
2011
config_context,
@@ -49,7 +40,17 @@
4940
assert_array_equal,
5041
ignore_warnings,
5142
)
52-
from sklearn.utils.fixes import parse_version, sp_version
43+
from sklearn.utils.fixes import (
44+
BSR_CONTAINERS,
45+
COO_CONTAINERS,
46+
CSC_CONTAINERS,
47+
CSR_CONTAINERS,
48+
DIA_CONTAINERS,
49+
DOK_CONTAINERS,
50+
LIL_CONTAINERS,
51+
parse_version,
52+
sp_version,
53+
)
5354
from sklearn.utils.validation import check_random_state
5455

5556
rng = np.random.RandomState(0)
@@ -65,7 +66,14 @@
6566
digits.data = digits.data[perm]
6667
digits.target = digits.target[perm]
6768

68-
SPARSE_TYPES = (bsr_matrix, coo_matrix, csc_matrix, csr_matrix, dok_matrix, lil_matrix)
69+
SPARSE_TYPES = tuple(
70+
BSR_CONTAINERS
71+
+ COO_CONTAINERS
72+
+ CSC_CONTAINERS
73+
+ CSR_CONTAINERS
74+
+ DOK_CONTAINERS
75+
+ LIL_CONTAINERS
76+
)
6977
SPARSE_OR_DENSE = SPARSE_TYPES + (np.asarray,)
7078

7179
ALGORITHMS = ("ball_tree", "brute", "kd_tree", "auto")
@@ -460,35 +468,37 @@ def make_train_test(X_train, X_test):
460468
check_precomputed(make_train_test, estimators)
461469

462470

463-
def test_is_sorted_by_data():
471+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
472+
def test_is_sorted_by_data(csr_container):
464473
# Test that _is_sorted_by_data works as expected. In CSR sparse matrix,
465474
# entries in each row can be sorted by indices, by data, or unsorted.
466475
# _is_sorted_by_data should return True when entries are sorted by data,
467476
# and False in all other cases.
468477

469478
# Test with sorted 1D array
470-
X = csr_matrix(np.arange(10))
479+
X = csr_container(np.arange(10))
471480
assert _is_sorted_by_data(X)
472481
# Test with unsorted 1D array
473482
X[0, 2] = 5
474483
assert not _is_sorted_by_data(X)
475484

476485
# Test when the data is sorted in each sample, but not necessarily
477486
# between samples
478-
X = csr_matrix([[0, 1, 2], [3, 0, 0], [3, 4, 0], [1, 0, 2]])
487+
X = csr_container([[0, 1, 2], [3, 0, 0], [3, 4, 0], [1, 0, 2]])
479488
assert _is_sorted_by_data(X)
480489

481490
# Test with duplicates entries in X.indptr
482491
data, indices, indptr = [0, 4, 2, 2], [0, 1, 1, 1], [0, 2, 2, 4]
483-
X = csr_matrix((data, indices, indptr), shape=(3, 3))
492+
X = csr_container((data, indices, indptr), shape=(3, 3))
484493
assert _is_sorted_by_data(X)
485494

486495

487496
@pytest.mark.filterwarnings("ignore:EfficiencyWarning")
488497
@pytest.mark.parametrize("function", [sort_graph_by_row_values, _check_precomputed])
489-
def test_sort_graph_by_row_values(function):
498+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
499+
def test_sort_graph_by_row_values(function, csr_container):
490500
# Test that sort_graph_by_row_values returns a graph sorted by row values
491-
X = csr_matrix(np.abs(np.random.RandomState(42).randn(10, 10)))
501+
X = csr_container(np.abs(np.random.RandomState(42).randn(10, 10)))
492502
assert not _is_sorted_by_data(X)
493503
Xt = function(X)
494504
assert _is_sorted_by_data(Xt)
@@ -497,16 +507,17 @@ def test_sort_graph_by_row_values(function):
497507
mask = np.random.RandomState(42).randint(2, size=(10, 10))
498508
X = X.toarray()
499509
X[mask == 1] = 0
500-
X = csr_matrix(X)
510+
X = csr_container(X)
501511
assert not _is_sorted_by_data(X)
502512
Xt = function(X)
503513
assert _is_sorted_by_data(Xt)
504514

505515

506516
@pytest.mark.filterwarnings("ignore:EfficiencyWarning")
507-
def test_sort_graph_by_row_values_copy():
517+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
518+
def test_sort_graph_by_row_values_copy(csr_container):
508519
# Test if the sorting is done inplace if X is CSR, so that Xt is X.
509-
X_ = csr_matrix(np.abs(np.random.RandomState(42).randn(10, 10)))
520+
X_ = csr_container(np.abs(np.random.RandomState(42).randn(10, 10)))
510521
assert not _is_sorted_by_data(X_)
511522

512523
# sort_graph_by_row_values is done inplace if copy=False
@@ -531,9 +542,10 @@ def test_sort_graph_by_row_values_copy():
531542
sort_graph_by_row_values(X.tocsc(), copy=False)
532543

533544

534-
def test_sort_graph_by_row_values_warning():
545+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
546+
def test_sort_graph_by_row_values_warning(csr_container):
535547
# Test that the parameter warn_when_not_sorted works as expected.
536-
X = csr_matrix(np.abs(np.random.RandomState(42).randn(10, 10)))
548+
X = csr_container(np.abs(np.random.RandomState(42).randn(10, 10)))
537549
assert not _is_sorted_by_data(X)
538550

539551
# warning
@@ -550,36 +562,39 @@ def test_sort_graph_by_row_values_warning():
550562
sort_graph_by_row_values(X, copy=True, warn_when_not_sorted=False)
551563

552564

553-
@pytest.mark.parametrize("format", [dok_matrix, bsr_matrix, dia_matrix])
554-
def test_sort_graph_by_row_values_bad_sparse_format(format):
565+
@pytest.mark.parametrize(
566+
"sparse_container", DOK_CONTAINERS + BSR_CONTAINERS + DIA_CONTAINERS
567+
)
568+
def test_sort_graph_by_row_values_bad_sparse_format(sparse_container):
555569
# Test that sort_graph_by_row_values and _check_precomputed error on bad formats
556-
X = format(np.abs(np.random.RandomState(42).randn(10, 10)))
570+
X = sparse_container(np.abs(np.random.RandomState(42).randn(10, 10)))
557571
with pytest.raises(TypeError, match="format is not supported"):
558572
sort_graph_by_row_values(X)
559573
with pytest.raises(TypeError, match="format is not supported"):
560574
_check_precomputed(X)
561575

562576

563577
@pytest.mark.filterwarnings("ignore:EfficiencyWarning")
564-
def test_precomputed_sparse_invalid():
578+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
579+
def test_precomputed_sparse_invalid(csr_container):
565580
dist = np.array([[0.0, 2.0, 1.0], [2.0, 0.0, 3.0], [1.0, 3.0, 0.0]])
566-
dist_csr = csr_matrix(dist)
581+
dist_csr = csr_container(dist)
567582
neigh = neighbors.NearestNeighbors(n_neighbors=1, metric="precomputed")
568583
neigh.fit(dist_csr)
569584
neigh.kneighbors(None, n_neighbors=1)
570585
neigh.kneighbors(np.array([[0.0, 0.0, 0.0]]), n_neighbors=2)
571586

572587
# Ensures enough number of nearest neighbors
573588
dist = np.array([[0.0, 2.0, 0.0], [2.0, 0.0, 3.0], [0.0, 3.0, 0.0]])
574-
dist_csr = csr_matrix(dist)
589+
dist_csr = csr_container(dist)
575590
neigh.fit(dist_csr)
576591
msg = "2 neighbors per samples are required, but some samples have only 1"
577592
with pytest.raises(ValueError, match=msg):
578593
neigh.kneighbors(None, n_neighbors=1)
579594

580595
# Checks error with inconsistent distance matrix
581596
dist = np.array([[5.0, 2.0, 1.0], [-2.0, 0.0, 3.0], [1.0, 3.0, 0.0]])
582-
dist_csr = csr_matrix(dist)
597+
dist_csr = csr_container(dist)
583598
msg = "Negative values in data passed to precomputed distance matrix."
584599
with pytest.raises(ValueError, match=msg):
585600
neigh.kneighbors(dist_csr, n_neighbors=1)
@@ -995,12 +1010,13 @@ def test_radius_neighbors_boundary_handling():
9951010
assert_array_equal(results[0], [0, 1])
9961011

9971012

998-
def test_radius_neighbors_returns_array_of_objects():
1013+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
1014+
def test_radius_neighbors_returns_array_of_objects(csr_container):
9991015
# check that we can pass precomputed distances to
10001016
# NearestNeighbors.radius_neighbors()
10011017
# non-regression test for
10021018
# https://github.com/scikit-learn/scikit-learn/issues/16036
1003-
X = csr_matrix(np.ones((4, 4)))
1019+
X = csr_container(np.ones((4, 4)))
10041020
X.setdiag([0, 0, 0, 0])
10051021

10061022
nbrs = neighbors.NearestNeighbors(
@@ -1371,7 +1387,7 @@ def test_kneighbors_regressor_sparse(
13711387
assert np.mean(knn.predict(X2).round() == y) > 0.95
13721388

13731389
X2_pre = sparsev(pairwise_distances(X, metric="euclidean"))
1374-
if sparsev in {dok_matrix, bsr_matrix}:
1390+
if sparsev in DOK_CONTAINERS + BSR_CONTAINERS:
13751391
msg = "not supported due to its handling of explicit zeros"
13761392
with pytest.raises(TypeError, match=msg):
13771393
knn_pre.predict(X2_pre)
@@ -1453,12 +1469,13 @@ def test_kneighbors_graph():
14531469

14541470
@pytest.mark.parametrize("n_neighbors", [1, 2, 3])
14551471
@pytest.mark.parametrize("mode", ["connectivity", "distance"])
1456-
def test_kneighbors_graph_sparse(n_neighbors, mode, seed=36):
1472+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
1473+
def test_kneighbors_graph_sparse(n_neighbors, mode, csr_container, seed=36):
14571474
# Test kneighbors_graph to build the k-Nearest Neighbor graph
14581475
# for sparse input.
14591476
rng = np.random.RandomState(seed)
14601477
X = rng.randn(10, 10)
1461-
Xcsr = csr_matrix(X)
1478+
Xcsr = csr_container(X)
14621479

14631480
assert_allclose(
14641481
neighbors.kneighbors_graph(X, n_neighbors, mode=mode).toarray(),
@@ -1481,12 +1498,13 @@ def test_radius_neighbors_graph():
14811498

14821499
@pytest.mark.parametrize("n_neighbors", [1, 2, 3])
14831500
@pytest.mark.parametrize("mode", ["connectivity", "distance"])
1484-
def test_radius_neighbors_graph_sparse(n_neighbors, mode, seed=36):
1501+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
1502+
def test_radius_neighbors_graph_sparse(n_neighbors, mode, csr_container, seed=36):
14851503
# Test radius_neighbors_graph to build the Nearest Neighbor graph
14861504
# for sparse input.
14871505
rng = np.random.RandomState(seed)
14881506
X = rng.randn(10, 10)
1489-
Xcsr = csr_matrix(X)
1507+
Xcsr = csr_container(X)
14901508

14911509
assert_allclose(
14921510
neighbors.radius_neighbors_graph(X, n_neighbors, mode=mode).toarray(),
@@ -1503,11 +1521,12 @@ def test_radius_neighbors_graph_sparse(n_neighbors, mode, seed=36):
15031521
neighbors.RadiusNeighborsRegressor,
15041522
],
15051523
)
1506-
def test_neighbors_validate_parameters(Estimator):
1524+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
1525+
def test_neighbors_validate_parameters(Estimator, csr_container):
15071526
"""Additional parameter validation for *Neighbors* estimators not covered by common
15081527
validation."""
15091528
X = rng.random_sample((10, 2))
1510-
Xsparse = csr_matrix(X)
1529+
Xsparse = csr_container(X)
15111530
X3 = rng.random_sample((10, 3))
15121531
y = np.ones(10)
15131532

@@ -1759,13 +1778,14 @@ def custom_metric(x1, x2):
17591778
@pytest.mark.parametrize(
17601779
"metric", neighbors.VALID_METRICS["brute"] + DISTANCE_METRIC_OBJS
17611780
)
1781+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
17621782
def test_valid_brute_metric_for_auto_algorithm(
1763-
global_dtype, metric, n_samples=20, n_features=12
1783+
global_dtype, metric, csr_container, n_samples=20, n_features=12
17641784
):
17651785
metric = _parse_metric(metric, global_dtype)
17661786

17671787
X = rng.rand(n_samples, n_features).astype(global_dtype, copy=False)
1768-
Xcsr = csr_matrix(X)
1788+
Xcsr = csr_container(X)
17691789

17701790
metric_params_list = _generate_test_params_for(metric, n_features)
17711791

@@ -1811,7 +1831,8 @@ def test_metric_params_interface():
18111831
est.fit(X, y)
18121832

18131833

1814-
def test_predict_sparse_ball_kd_tree():
1834+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
1835+
def test_predict_sparse_ball_kd_tree(csr_container):
18151836
rng = np.random.RandomState(0)
18161837
X = rng.rand(5, 5)
18171838
y = rng.randint(0, 2, 5)
@@ -1820,7 +1841,7 @@ def test_predict_sparse_ball_kd_tree():
18201841
for model in [nbrs1, nbrs2]:
18211842
model.fit(X, y)
18221843
with pytest.raises(ValueError):
1823-
model.predict(csr_matrix(X))
1844+
model.predict(csr_container(X))
18241845

18251846

18261847
def test_non_euclidean_kneighbors():
@@ -2073,16 +2094,17 @@ def test_dtype_convert():
20732094
assert_array_equal(result, y)
20742095

20752096

2076-
def test_sparse_metric_callable():
2097+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
2098+
def test_sparse_metric_callable(csr_container):
20772099
def sparse_metric(x, y): # Metric accepting sparse matrix input (only)
20782100
assert issparse(x) and issparse(y)
20792101
return x.dot(y.T).toarray().item()
20802102

2081-
X = csr_matrix(
2103+
X = csr_container(
20822104
[[1, 1, 1, 1, 1], [1, 0, 1, 0, 1], [0, 0, 1, 0, 0]] # Population matrix
20832105
)
20842106

2085-
Y = csr_matrix([[1, 1, 0, 1, 1], [1, 0, 0, 0, 1]]) # Query matrix
2107+
Y = csr_container([[1, 1, 0, 1, 1], [1, 0, 0, 0, 1]]) # Query matrix
20862108

20872109
nn = neighbors.NearestNeighbors(
20882110
algorithm="brute", n_neighbors=2, metric=sparse_metric

sklearn/utils/fixes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
LIL_CONTAINERS = [scipy.sparse.lil_matrix]
3838
DOK_CONTAINERS = [scipy.sparse.dok_matrix]
3939
BSR_CONTAINERS = [scipy.sparse.bsr_matrix]
40+
DIA_CONTAINERS = [scipy.sparse.dia_matrix]
4041

4142
if parse_version(scipy.__version__) >= parse_version("1.8"):
4243
# Sparse Arrays have been added in SciPy 1.8
@@ -49,6 +50,7 @@
4950
LIL_CONTAINERS.append(scipy.sparse.lil_array)
5051
DOK_CONTAINERS.append(scipy.sparse.dok_array)
5152
BSR_CONTAINERS.append(scipy.sparse.bsr_array)
53+
DIA_CONTAINERS.append(scipy.sparse.dia_array)
5254

5355
try:
5456
from scipy.optimize._linesearch import line_search_wolfe1, line_search_wolfe2

0 commit comments

Comments
 (0)