5
5
import joblib
6
6
import numpy as np
7
7
import pytest
8
- from scipy .sparse import (
9
- bsr_matrix ,
10
- coo_matrix ,
11
- csc_matrix ,
12
- csr_matrix ,
13
- dia_matrix ,
14
- dok_matrix ,
15
- issparse ,
16
- lil_matrix ,
17
- )
8
+ from scipy .sparse import issparse
18
9
19
10
from sklearn import (
20
11
config_context ,
49
40
assert_array_equal ,
50
41
ignore_warnings ,
51
42
)
52
- from sklearn .utils .fixes import parse_version , sp_version
43
+ from sklearn .utils .fixes import (
44
+ BSR_CONTAINERS ,
45
+ COO_CONTAINERS ,
46
+ CSC_CONTAINERS ,
47
+ CSR_CONTAINERS ,
48
+ DIA_CONTAINERS ,
49
+ DOK_CONTAINERS ,
50
+ LIL_CONTAINERS ,
51
+ parse_version ,
52
+ sp_version ,
53
+ )
53
54
from sklearn .utils .validation import check_random_state
54
55
55
56
rng = np .random .RandomState (0 )
65
66
digits .data = digits .data [perm ]
66
67
digits .target = digits .target [perm ]
67
68
68
- SPARSE_TYPES = (bsr_matrix , coo_matrix , csc_matrix , csr_matrix , dok_matrix , lil_matrix )
69
+ SPARSE_TYPES = tuple (
70
+ BSR_CONTAINERS
71
+ + COO_CONTAINERS
72
+ + CSC_CONTAINERS
73
+ + CSR_CONTAINERS
74
+ + DOK_CONTAINERS
75
+ + LIL_CONTAINERS
76
+ )
69
77
SPARSE_OR_DENSE = SPARSE_TYPES + (np .asarray ,)
70
78
71
79
ALGORITHMS = ("ball_tree" , "brute" , "kd_tree" , "auto" )
@@ -460,35 +468,37 @@ def make_train_test(X_train, X_test):
460
468
check_precomputed (make_train_test , estimators )
461
469
462
470
463
- def test_is_sorted_by_data ():
471
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
472
+ def test_is_sorted_by_data (csr_container ):
464
473
# Test that _is_sorted_by_data works as expected. In CSR sparse matrix,
465
474
# entries in each row can be sorted by indices, by data, or unsorted.
466
475
# _is_sorted_by_data should return True when entries are sorted by data,
467
476
# and False in all other cases.
468
477
469
478
# Test with sorted 1D array
470
- X = csr_matrix (np .arange (10 ))
479
+ X = csr_container (np .arange (10 ))
471
480
assert _is_sorted_by_data (X )
472
481
# Test with unsorted 1D array
473
482
X [0 , 2 ] = 5
474
483
assert not _is_sorted_by_data (X )
475
484
476
485
# Test when the data is sorted in each sample, but not necessarily
477
486
# between samples
478
- X = csr_matrix ([[0 , 1 , 2 ], [3 , 0 , 0 ], [3 , 4 , 0 ], [1 , 0 , 2 ]])
487
+ X = csr_container ([[0 , 1 , 2 ], [3 , 0 , 0 ], [3 , 4 , 0 ], [1 , 0 , 2 ]])
479
488
assert _is_sorted_by_data (X )
480
489
481
490
# Test with duplicates entries in X.indptr
482
491
data , indices , indptr = [0 , 4 , 2 , 2 ], [0 , 1 , 1 , 1 ], [0 , 2 , 2 , 4 ]
483
- X = csr_matrix ((data , indices , indptr ), shape = (3 , 3 ))
492
+ X = csr_container ((data , indices , indptr ), shape = (3 , 3 ))
484
493
assert _is_sorted_by_data (X )
485
494
486
495
487
496
@pytest .mark .filterwarnings ("ignore:EfficiencyWarning" )
488
497
@pytest .mark .parametrize ("function" , [sort_graph_by_row_values , _check_precomputed ])
489
- def test_sort_graph_by_row_values (function ):
498
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
499
+ def test_sort_graph_by_row_values (function , csr_container ):
490
500
# Test that sort_graph_by_row_values returns a graph sorted by row values
491
- X = csr_matrix (np .abs (np .random .RandomState (42 ).randn (10 , 10 )))
501
+ X = csr_container (np .abs (np .random .RandomState (42 ).randn (10 , 10 )))
492
502
assert not _is_sorted_by_data (X )
493
503
Xt = function (X )
494
504
assert _is_sorted_by_data (Xt )
@@ -497,16 +507,17 @@ def test_sort_graph_by_row_values(function):
497
507
mask = np .random .RandomState (42 ).randint (2 , size = (10 , 10 ))
498
508
X = X .toarray ()
499
509
X [mask == 1 ] = 0
500
- X = csr_matrix (X )
510
+ X = csr_container (X )
501
511
assert not _is_sorted_by_data (X )
502
512
Xt = function (X )
503
513
assert _is_sorted_by_data (Xt )
504
514
505
515
506
516
@pytest .mark .filterwarnings ("ignore:EfficiencyWarning" )
507
- def test_sort_graph_by_row_values_copy ():
517
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
518
+ def test_sort_graph_by_row_values_copy (csr_container ):
508
519
# Test if the sorting is done inplace if X is CSR, so that Xt is X.
509
- X_ = csr_matrix (np .abs (np .random .RandomState (42 ).randn (10 , 10 )))
520
+ X_ = csr_container (np .abs (np .random .RandomState (42 ).randn (10 , 10 )))
510
521
assert not _is_sorted_by_data (X_ )
511
522
512
523
# sort_graph_by_row_values is done inplace if copy=False
@@ -531,9 +542,10 @@ def test_sort_graph_by_row_values_copy():
531
542
sort_graph_by_row_values (X .tocsc (), copy = False )
532
543
533
544
534
- def test_sort_graph_by_row_values_warning ():
545
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
546
+ def test_sort_graph_by_row_values_warning (csr_container ):
535
547
# Test that the parameter warn_when_not_sorted works as expected.
536
- X = csr_matrix (np .abs (np .random .RandomState (42 ).randn (10 , 10 )))
548
+ X = csr_container (np .abs (np .random .RandomState (42 ).randn (10 , 10 )))
537
549
assert not _is_sorted_by_data (X )
538
550
539
551
# warning
@@ -550,36 +562,39 @@ def test_sort_graph_by_row_values_warning():
550
562
sort_graph_by_row_values (X , copy = True , warn_when_not_sorted = False )
551
563
552
564
553
- @pytest .mark .parametrize ("format" , [dok_matrix , bsr_matrix , dia_matrix ])
554
- def test_sort_graph_by_row_values_bad_sparse_format (format ):
565
+ @pytest .mark .parametrize (
566
+ "sparse_container" , DOK_CONTAINERS + BSR_CONTAINERS + DIA_CONTAINERS
567
+ )
568
+ def test_sort_graph_by_row_values_bad_sparse_format (sparse_container ):
555
569
# Test that sort_graph_by_row_values and _check_precomputed error on bad formats
556
- X = format (np .abs (np .random .RandomState (42 ).randn (10 , 10 )))
570
+ X = sparse_container (np .abs (np .random .RandomState (42 ).randn (10 , 10 )))
557
571
with pytest .raises (TypeError , match = "format is not supported" ):
558
572
sort_graph_by_row_values (X )
559
573
with pytest .raises (TypeError , match = "format is not supported" ):
560
574
_check_precomputed (X )
561
575
562
576
563
577
@pytest .mark .filterwarnings ("ignore:EfficiencyWarning" )
564
- def test_precomputed_sparse_invalid ():
578
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
579
+ def test_precomputed_sparse_invalid (csr_container ):
565
580
dist = np .array ([[0.0 , 2.0 , 1.0 ], [2.0 , 0.0 , 3.0 ], [1.0 , 3.0 , 0.0 ]])
566
- dist_csr = csr_matrix (dist )
581
+ dist_csr = csr_container (dist )
567
582
neigh = neighbors .NearestNeighbors (n_neighbors = 1 , metric = "precomputed" )
568
583
neigh .fit (dist_csr )
569
584
neigh .kneighbors (None , n_neighbors = 1 )
570
585
neigh .kneighbors (np .array ([[0.0 , 0.0 , 0.0 ]]), n_neighbors = 2 )
571
586
572
587
# Ensures enough number of nearest neighbors
573
588
dist = np .array ([[0.0 , 2.0 , 0.0 ], [2.0 , 0.0 , 3.0 ], [0.0 , 3.0 , 0.0 ]])
574
- dist_csr = csr_matrix (dist )
589
+ dist_csr = csr_container (dist )
575
590
neigh .fit (dist_csr )
576
591
msg = "2 neighbors per samples are required, but some samples have only 1"
577
592
with pytest .raises (ValueError , match = msg ):
578
593
neigh .kneighbors (None , n_neighbors = 1 )
579
594
580
595
# Checks error with inconsistent distance matrix
581
596
dist = np .array ([[5.0 , 2.0 , 1.0 ], [- 2.0 , 0.0 , 3.0 ], [1.0 , 3.0 , 0.0 ]])
582
- dist_csr = csr_matrix (dist )
597
+ dist_csr = csr_container (dist )
583
598
msg = "Negative values in data passed to precomputed distance matrix."
584
599
with pytest .raises (ValueError , match = msg ):
585
600
neigh .kneighbors (dist_csr , n_neighbors = 1 )
@@ -995,12 +1010,13 @@ def test_radius_neighbors_boundary_handling():
995
1010
assert_array_equal (results [0 ], [0 , 1 ])
996
1011
997
1012
998
- def test_radius_neighbors_returns_array_of_objects ():
1013
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
1014
+ def test_radius_neighbors_returns_array_of_objects (csr_container ):
999
1015
# check that we can pass precomputed distances to
1000
1016
# NearestNeighbors.radius_neighbors()
1001
1017
# non-regression test for
1002
1018
# https://github.com/scikit-learn/scikit-learn/issues/16036
1003
- X = csr_matrix (np .ones ((4 , 4 )))
1019
+ X = csr_container (np .ones ((4 , 4 )))
1004
1020
X .setdiag ([0 , 0 , 0 , 0 ])
1005
1021
1006
1022
nbrs = neighbors .NearestNeighbors (
@@ -1371,7 +1387,7 @@ def test_kneighbors_regressor_sparse(
1371
1387
assert np .mean (knn .predict (X2 ).round () == y ) > 0.95
1372
1388
1373
1389
X2_pre = sparsev (pairwise_distances (X , metric = "euclidean" ))
1374
- if sparsev in { dok_matrix , bsr_matrix } :
1390
+ if sparsev in DOK_CONTAINERS + BSR_CONTAINERS :
1375
1391
msg = "not supported due to its handling of explicit zeros"
1376
1392
with pytest .raises (TypeError , match = msg ):
1377
1393
knn_pre .predict (X2_pre )
@@ -1453,12 +1469,13 @@ def test_kneighbors_graph():
1453
1469
1454
1470
@pytest .mark .parametrize ("n_neighbors" , [1 , 2 , 3 ])
1455
1471
@pytest .mark .parametrize ("mode" , ["connectivity" , "distance" ])
1456
- def test_kneighbors_graph_sparse (n_neighbors , mode , seed = 36 ):
1472
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
1473
+ def test_kneighbors_graph_sparse (n_neighbors , mode , csr_container , seed = 36 ):
1457
1474
# Test kneighbors_graph to build the k-Nearest Neighbor graph
1458
1475
# for sparse input.
1459
1476
rng = np .random .RandomState (seed )
1460
1477
X = rng .randn (10 , 10 )
1461
- Xcsr = csr_matrix (X )
1478
+ Xcsr = csr_container (X )
1462
1479
1463
1480
assert_allclose (
1464
1481
neighbors .kneighbors_graph (X , n_neighbors , mode = mode ).toarray (),
@@ -1481,12 +1498,13 @@ def test_radius_neighbors_graph():
1481
1498
1482
1499
@pytest .mark .parametrize ("n_neighbors" , [1 , 2 , 3 ])
1483
1500
@pytest .mark .parametrize ("mode" , ["connectivity" , "distance" ])
1484
- def test_radius_neighbors_graph_sparse (n_neighbors , mode , seed = 36 ):
1501
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
1502
+ def test_radius_neighbors_graph_sparse (n_neighbors , mode , csr_container , seed = 36 ):
1485
1503
# Test radius_neighbors_graph to build the Nearest Neighbor graph
1486
1504
# for sparse input.
1487
1505
rng = np .random .RandomState (seed )
1488
1506
X = rng .randn (10 , 10 )
1489
- Xcsr = csr_matrix (X )
1507
+ Xcsr = csr_container (X )
1490
1508
1491
1509
assert_allclose (
1492
1510
neighbors .radius_neighbors_graph (X , n_neighbors , mode = mode ).toarray (),
@@ -1503,11 +1521,12 @@ def test_radius_neighbors_graph_sparse(n_neighbors, mode, seed=36):
1503
1521
neighbors .RadiusNeighborsRegressor ,
1504
1522
],
1505
1523
)
1506
- def test_neighbors_validate_parameters (Estimator ):
1524
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
1525
+ def test_neighbors_validate_parameters (Estimator , csr_container ):
1507
1526
"""Additional parameter validation for *Neighbors* estimators not covered by common
1508
1527
validation."""
1509
1528
X = rng .random_sample ((10 , 2 ))
1510
- Xsparse = csr_matrix (X )
1529
+ Xsparse = csr_container (X )
1511
1530
X3 = rng .random_sample ((10 , 3 ))
1512
1531
y = np .ones (10 )
1513
1532
@@ -1759,13 +1778,14 @@ def custom_metric(x1, x2):
1759
1778
@pytest .mark .parametrize (
1760
1779
"metric" , neighbors .VALID_METRICS ["brute" ] + DISTANCE_METRIC_OBJS
1761
1780
)
1781
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
1762
1782
def test_valid_brute_metric_for_auto_algorithm (
1763
- global_dtype , metric , n_samples = 20 , n_features = 12
1783
+ global_dtype , metric , csr_container , n_samples = 20 , n_features = 12
1764
1784
):
1765
1785
metric = _parse_metric (metric , global_dtype )
1766
1786
1767
1787
X = rng .rand (n_samples , n_features ).astype (global_dtype , copy = False )
1768
- Xcsr = csr_matrix (X )
1788
+ Xcsr = csr_container (X )
1769
1789
1770
1790
metric_params_list = _generate_test_params_for (metric , n_features )
1771
1791
@@ -1811,7 +1831,8 @@ def test_metric_params_interface():
1811
1831
est .fit (X , y )
1812
1832
1813
1833
1814
- def test_predict_sparse_ball_kd_tree ():
1834
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
1835
+ def test_predict_sparse_ball_kd_tree (csr_container ):
1815
1836
rng = np .random .RandomState (0 )
1816
1837
X = rng .rand (5 , 5 )
1817
1838
y = rng .randint (0 , 2 , 5 )
@@ -1820,7 +1841,7 @@ def test_predict_sparse_ball_kd_tree():
1820
1841
for model in [nbrs1 , nbrs2 ]:
1821
1842
model .fit (X , y )
1822
1843
with pytest .raises (ValueError ):
1823
- model .predict (csr_matrix (X ))
1844
+ model .predict (csr_container (X ))
1824
1845
1825
1846
1826
1847
def test_non_euclidean_kneighbors ():
@@ -2073,16 +2094,17 @@ def test_dtype_convert():
2073
2094
assert_array_equal (result , y )
2074
2095
2075
2096
2076
- def test_sparse_metric_callable ():
2097
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
2098
+ def test_sparse_metric_callable (csr_container ):
2077
2099
def sparse_metric (x , y ): # Metric accepting sparse matrix input (only)
2078
2100
assert issparse (x ) and issparse (y )
2079
2101
return x .dot (y .T ).toarray ().item ()
2080
2102
2081
- X = csr_matrix (
2103
+ X = csr_container (
2082
2104
[[1 , 1 , 1 , 1 , 1 ], [1 , 0 , 1 , 0 , 1 ], [0 , 0 , 1 , 0 , 0 ]] # Population matrix
2083
2105
)
2084
2106
2085
- Y = csr_matrix ([[1 , 1 , 0 , 1 , 1 ], [1 , 0 , 0 , 0 , 1 ]]) # Query matrix
2107
+ Y = csr_container ([[1 , 1 , 0 , 1 , 1 ], [1 , 0 , 0 , 0 , 1 ]]) # Query matrix
2086
2108
2087
2109
nn = neighbors .NearestNeighbors (
2088
2110
algorithm = "brute" , n_neighbors = 2 , metric = sparse_metric
0 commit comments