Skip to content

Commit 8e14bd0

Browse files
authored
TST Extend tests for scipy.sparse.*array in sklearn/cluster/tests/test_hierarchical.py (scikit-learn#27101)
1 parent d56dc5d commit 8e14bd0

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

sklearn/cluster/tests/test_hierarchical.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import numpy as np
1414
import pytest
15-
from scipy import sparse
1615
from scipy.cluster import hierarchy
1716
from scipy.sparse.csgraph import connected_components
1817

@@ -48,6 +47,7 @@
4847
create_memmap_backed_data,
4948
ignore_warnings,
5049
)
50+
from sklearn.utils.fixes import LIL_CONTAINERS
5151

5252

5353
def test_linkage_misc():
@@ -176,7 +176,8 @@ def test_agglomerative_clustering_distances(
176176
assert not hasattr(clustering, "distances_")
177177

178178

179-
def test_agglomerative_clustering(global_random_seed):
179+
@pytest.mark.parametrize("lil_container", LIL_CONTAINERS)
180+
def test_agglomerative_clustering(global_random_seed, lil_container):
180181
# Check that we obtain the correct number of clusters with
181182
# agglomerative clustering.
182183
rng = np.random.RandomState(global_random_seed)
@@ -218,7 +219,7 @@ def test_agglomerative_clustering(global_random_seed):
218219
# Check that we raise a TypeError on dense matrices
219220
clustering = AgglomerativeClustering(
220221
n_clusters=10,
221-
connectivity=sparse.lil_matrix(connectivity.toarray()[:10, :10]),
222+
connectivity=lil_container(connectivity.toarray()[:10, :10]),
222223
linkage=linkage,
223224
)
224225
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)