Skip to content

Commit bbf9533

Browse files
committed
remove random from tests
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent fd77ecc commit bbf9533

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

tests/test_transform/factory/test_correctness.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from tests.testing_utils import requires_accelerate, requires_gpu
2626

2727

28-
@pytest.mark.parametrize("type", TransformFactory.registered_names())
28+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
2929
@pytest.mark.parametrize("randomized", (True, False))
3030
def test_correctness_linear(type, randomized):
3131
size = (4, 8)
@@ -54,7 +54,7 @@ def test_correctness_linear(type, randomized):
5454
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
5555

5656

57-
@pytest.mark.parametrize("type", TransformFactory.registered_names())
57+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
5858
@pytest.mark.parametrize("randomized", (True, False))
5959
def test_correctness_model(type, randomized, model_apply, offload=False):
6060
# load model
@@ -83,7 +83,7 @@ def test_correctness_model(type, randomized, model_apply, offload=False):
8383

8484
@requires_gpu
8585
@requires_accelerate()
86-
@pytest.mark.parametrize("type", TransformFactory.registered_names())
86+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
8787
@pytest.mark.parametrize("randomized", (True, False))
8888
def test_correctness_model_offload(type, randomized, model_apply):
8989
test_correctness_model(type, randomized, model_apply, offload=True)

tests/test_transform/factory/test_memory.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
TransformArgs,
2121
TransformBase,
2222
TransformConfig,
23-
TransformFactory,
2423
TransformScheme,
2524
apply_transform_config,
2625
)
@@ -29,7 +28,7 @@
2928
from tests.testing_utils import requires_accelerate, requires_gpu
3029

3130

32-
@pytest.mark.parametrize("type", TransformFactory.registered_names())
31+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
3332
@pytest.mark.parametrize("randomized", (True, False))
3433
@pytest.mark.parametrize("requires_grad", (True, False))
3534
def test_memory_sharing(type, randomized, requires_grad, offload=False):
@@ -84,7 +83,7 @@ def test_memory_sharing(type, randomized, requires_grad, offload=False):
8483

8584
@requires_gpu
8685
@requires_accelerate()
87-
@pytest.mark.parametrize("type", TransformFactory.registered_names())
86+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
8887
@pytest.mark.parametrize("randomized", (True, False))
8988
def test_memory_sharing_offload(
9089
type,

0 commit comments

Comments
 (0)