Skip to content

Commit bca6004

Browse files
committed
fix skipping
1 parent 0f5bc73 commit bca6004

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/test_nnj.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def test_jacobians(self, model, input_shape, device, dtype):
137137
"""Test that the analytical jacobian of the model is consistent with finite
138138
order approximation
139139
"""
140-
if device == "cuda" and not torch.cuda.is_available():
140+
if "cuda" in device and not torch.cuda.is_available():
141141
pytest.skip("Test requires cuda support")
142142

143143
model = deepcopy(model).to(device=device, dtype=dtype).eval()
@@ -149,7 +149,7 @@ def test_jacobians(self, model, input_shape, device, dtype):
149149
@pytest.mark.parametrize("return_jac", [True, False])
150150
def test_jac_return(self, model, input_shape, device, return_jac):
151151
""" Test that all models returns the jacobian output if asked for it """
152-
if device == "cuda" and not torch.cuda.is_available():
152+
if "cuda" in device and not torch.cuda.is_available():
153153
pytest.skip("Test requires cuda support")
154154

155155
input = torch.randn(*input_shape, device=device)

0 commit comments

Comments
 (0)