24
24
25
25
Inputs = namedtuple ("case" , ["x" , "y" ])
26
26
27
- _cpu_cases = [Inputs (x = torch .randn (100 ,), y = torch .randn (100 ,)),
28
- Inputs (x = torch .randn (100 , requires_grad = True ), y = torch .randn (100 ,requires_grad = True )),
29
- # test that list/numpy arrays still works
30
- Inputs (x = [1 ,2 ,3 ,4 ], y = [1 ,2 ,3 ,4 ]),
31
- Inputs (x = np .random .randn (100 ,), y = np .random .randn (100 ,)),
32
- # test that we can mix
33
- Inputs (x = torch .randn (100 ,), y = torch .randn (100 , requires_grad = True )),
34
- Inputs (x = np .random .randn (100 ,), y = torch .randn (100 , requires_grad = True )),
35
- Inputs (x = torch .randn (5 ,), y = [1 ,2 ,3 ,4 ,5 ]),
36
- ]
27
+ _cpu_cases = [
28
+ Inputs (x = torch .randn (100 ), y = torch .randn (100 )),
29
+ Inputs (x = torch .randn (100 , requires_grad = True ), y = torch .randn (100 , requires_grad = True )),
30
+ # test that list/numpy arrays still works
31
+ Inputs (x = [1 , 2 , 3 , 4 ], y = [1 , 2 , 3 , 4 ]),
32
+ Inputs (x = np .random .randn (100 ), y = np .random .randn (100 )),
33
+ # test that we can mix
34
+ Inputs (x = torch .randn (100 ), y = torch .randn (100 , requires_grad = True )),
35
+ Inputs (x = np .random .randn (100 ), y = torch .randn (100 , requires_grad = True )),
36
+ Inputs (x = torch .randn (5 ), y = [1 , 2 , 3 , 4 , 5 ]),
37
+ ]
37
38
38
- _gpu_cases = [Inputs (x = torch .randn (100 , device = 'cuda' ), y = torch .randn (100 , device = 'cuda' )),
39
- Inputs (x = torch .randn (100 ,requires_grad = True , device = 'cuda' ), y = torch .randn (100 ,requires_grad = True , device = 'cuda' )),
40
- ]
39
+ _gpu_cases = [
40
+ Inputs (x = torch .randn (100 , device = "cuda" ), y = torch .randn (100 , device = "cuda" )),
41
+ Inputs (
42
+ x = torch .randn (100 , requires_grad = True , device = "cuda" ), y = torch .randn (100 , requires_grad = True , device = "cuda" )
43
+ ),
44
+ ]
41
45
42
46
43
-
44
- _members_to_check = [name for name , member in getmembers (plt )
45
- if isfunction (member ) and not name .startswith ('_' )]
47
+ _members_to_check = [name for name , member in getmembers (plt ) if isfunction (member ) and not name .startswith ("_" )]
46
48
47
49
48
50
def string_compare (text1 , text2 ):
49
51
if text1 is None and text2 is None :
50
52
return True
51
53
remove = string .punctuation + string .whitespace
52
- return text1 .translate (str .maketrans (dict .fromkeys (remove ))) == text2 .translate (str .maketrans (dict .fromkeys (remove )))
54
+ return text1 .translate (str .maketrans (dict .fromkeys (remove ))) == text2 .translate (
55
+ str .maketrans (dict .fromkeys (remove ))
56
+ )
53
57
54
58
55
59
@pytest .mark .parametrize ("member" , _members_to_check )
@@ -59,15 +63,14 @@ def test_members(member):
59
63
assert member in dir (tp )
60
64
61
65
62
- @pytest .mark .parametrize (' test_case' , _cpu_cases )
66
+ @pytest .mark .parametrize (" test_case" , _cpu_cases )
63
67
def test_cpu (test_case ):
64
68
""" test that it works on cpu """
65
- assert tp .plot (test_case .x , test_case .y , '.' )
69
+ assert tp .plot (test_case .x , test_case .y , "." )
66
70
67
71
68
- @pytest .mark .skipif (not torch .cuda .is_available (), reason = ' test requires cuda' )
69
- @pytest .mark .parametrize (' test_case' , _gpu_cases )
72
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = " test requires cuda" )
73
+ @pytest .mark .parametrize (" test_case" , _gpu_cases )
70
74
def test_gpu (test_case ):
71
75
""" test that it works on gpu """
72
- assert tp .plot (test_case .x , test_case .y , '.' )
73
-
76
+ assert tp .plot (test_case .x , test_case .y , "." )
0 commit comments