Skip to content

Commit 93c2246

Browse files
authored
make tests pass mps (#528)
1 parent 6673d88 commit 93c2246

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

tests/acceptance/test_hooked_transformer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,12 @@ def check_dtype(dtype, margin, no_processing=False):
311311
gc.collect()
312312

313313

314+
@pytest.mark.skipif(
315+
torch.backends.mps.is_available() or not torch.cuda.is_available(),
316+
reason="some operations unsupported by MPS: https://github.com/pytorch/pytorch/issues/77754 or no GPU",
317+
)
314318
@pytest.mark.parametrize("dtype", [torch.float64, torch.float32])
315-
def test_dtypes(dtype):
319+
def test_dtype_float(dtype):
316320
check_dtype(dtype, margin=5e-4)
317321

318322

0 commit comments

Comments
 (0)