1
1
import numpy as np
2
2
import pytest
3
- import scipy .sparse as sp
4
3
5
4
from sklearn .datasets import load_iris
6
5
from sklearn .linear_model import Perceptron
7
6
from sklearn .utils import check_random_state
8
7
from sklearn .utils ._testing import assert_allclose , assert_array_almost_equal
8
+ from sklearn .utils .fixes import CSR_CONTAINERS
9
9
10
10
iris = load_iris ()
11
11
random_state = check_random_state (12 )
12
12
indices = np .arange (iris .data .shape [0 ])
13
13
random_state .shuffle (indices )
14
14
X = iris .data [indices ]
15
15
y = iris .target [indices ]
16
- X_csr = sp .csr_matrix (X )
17
- X_csr .sort_indices ()
18
16
19
17
20
18
class MyPerceptron :
@@ -40,12 +38,13 @@ def predict(self, X):
40
38
return np .sign (self .project (X ))
41
39
42
40
43
- def test_perceptron_accuracy ():
44
- for data in (X , X_csr ):
45
- clf = Perceptron (max_iter = 100 , tol = None , shuffle = False )
46
- clf .fit (data , y )
47
- score = clf .score (data , y )
48
- assert score > 0.7
41
+ @pytest .mark .parametrize ("container" , CSR_CONTAINERS + [np .array ])
42
+ def test_perceptron_accuracy (container ):
43
+ data = container (X )
44
+ clf = Perceptron (max_iter = 100 , tol = None , shuffle = False )
45
+ clf .fit (data , y )
46
+ score = clf .score (data , y )
47
+ assert score > 0.7
49
48
50
49
51
50
def test_perceptron_correctness ():
0 commit comments