Skip to content

Commit 407070b

Browse files
authored
TST Extend tests for scipy.sparse.*array in sklearn/utils/tests/test_class_weight.py (scikit-learn#27188)
1 parent 8e14bd0 commit 407070b

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

sklearn/utils/class_weight.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,11 @@ def compute_sample_weight(class_weight, y, *, indices=None):
157157

158158
expanded_class_weight = []
159159
for k in range(n_outputs):
160-
y_full = y[:, k]
161-
if sparse.issparse(y_full):
160+
if sparse.issparse(y):
162161
# Ok to densify a single column at a time
163-
y_full = y_full.toarray().flatten()
162+
y_full = y[:, [k]].toarray().flatten()
163+
else:
164+
y_full = y[:, k]
164165
classes_full = np.unique(y_full)
165166
classes_missing = None
166167

sklearn/utils/tests/test_class_weight.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import numpy as np
22
import pytest
33
from numpy.testing import assert_allclose
4-
from scipy import sparse
54

65
from sklearn.datasets import make_blobs
76
from sklearn.linear_model import LogisticRegression
87
from sklearn.tree import DecisionTreeClassifier
98
from sklearn.utils._testing import assert_almost_equal, assert_array_almost_equal
109
from sklearn.utils.class_weight import compute_class_weight, compute_sample_weight
10+
from sklearn.utils.fixes import CSC_CONTAINERS
1111

1212

1313
def test_compute_class_weight():
@@ -308,8 +308,9 @@ def test_class_weight_does_not_contains_more_classes():
308308
tree.fit([[0, 0, 1], [1, 0, 1], [1, 2, 0]], [0, 0, 1])
309309

310310

311-
def test_compute_sample_weight_sparse():
311+
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
312+
def test_compute_sample_weight_sparse(csc_container):
312313
"""Check that we can compute weight for sparse `y`."""
313-
y = sparse.csc_matrix(np.asarray([0, 1, 1])).T
314+
y = csc_container(np.asarray([0, 1, 1])).T
314315
sample_weight = compute_sample_weight("balanced", y)
315316
assert_allclose(sample_weight, [1.5, 0.75, 0.75])

0 commit comments

Comments
 (0)