Skip to content

Commit 952f480

Browse files
authored
TST Extend tests for scipy.sparse.*array in sklearn/linear_model/tests/test_perceptron.py (scikit-learn#27160)
1 parent a551884 commit 952f480

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

sklearn/linear_model/tests/test_perceptron.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
11
import numpy as np
22
import pytest
3-
import scipy.sparse as sp
43

54
from sklearn.datasets import load_iris
65
from sklearn.linear_model import Perceptron
76
from sklearn.utils import check_random_state
87
from sklearn.utils._testing import assert_allclose, assert_array_almost_equal
8+
from sklearn.utils.fixes import CSR_CONTAINERS
99

1010
iris = load_iris()
1111
random_state = check_random_state(12)
1212
indices = np.arange(iris.data.shape[0])
1313
random_state.shuffle(indices)
1414
X = iris.data[indices]
1515
y = iris.target[indices]
16-
X_csr = sp.csr_matrix(X)
17-
X_csr.sort_indices()
1816

1917

2018
class MyPerceptron:
@@ -40,12 +38,13 @@ def predict(self, X):
4038
return np.sign(self.project(X))
4139

4240

43-
def test_perceptron_accuracy():
44-
for data in (X, X_csr):
45-
clf = Perceptron(max_iter=100, tol=None, shuffle=False)
46-
clf.fit(data, y)
47-
score = clf.score(data, y)
48-
assert score > 0.7
41+
@pytest.mark.parametrize("container", CSR_CONTAINERS + [np.array])
42+
def test_perceptron_accuracy(container):
43+
data = container(X)
44+
clf = Perceptron(max_iter=100, tol=None, shuffle=False)
45+
clf.fit(data, y)
46+
score = clf.score(data, y)
47+
assert score > 0.7
4948

5049

5150
def test_perceptron_correctness():

0 commit comments

Comments
 (0)