|
25 | 25 | from tests.testing_utils import requires_accelerate, requires_gpu
|
26 | 26 |
|
27 | 27 |
|
28 |
| -@pytest.mark.parametrize("type", TransformFactory.registered_names()) |
| 28 | +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) |
29 | 29 | @pytest.mark.parametrize("randomized", (True, False))
|
30 | 30 | def test_correctness_linear(type, randomized):
|
31 | 31 | size = (4, 8)
|
@@ -54,7 +54,7 @@ def test_correctness_linear(type, randomized):
|
54 | 54 | assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
|
55 | 55 |
|
56 | 56 |
|
57 |
| -@pytest.mark.parametrize("type", TransformFactory.registered_names()) |
| 57 | +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) |
58 | 58 | @pytest.mark.parametrize("randomized", (True, False))
|
59 | 59 | def test_correctness_model(type, randomized, model_apply, offload=False):
|
60 | 60 | # load model
|
@@ -83,7 +83,7 @@ def test_correctness_model(type, randomized, model_apply, offload=False):
|
83 | 83 |
|
84 | 84 | @requires_gpu
|
85 | 85 | @requires_accelerate()
|
86 |
| -@pytest.mark.parametrize("type", TransformFactory.registered_names()) |
| 86 | +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) |
87 | 87 | @pytest.mark.parametrize("randomized", (True, False))
|
88 | 88 | def test_correctness_model_offload(type, randomized, model_apply):
|
89 | 89 | test_correctness_model(type, randomized, model_apply, offload=True)
|
0 commit comments