Skip to content

Commit fd77ecc

Browse files
committed
use parametrize
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 438bc13 commit fd77ecc

File tree

2 files changed

+26
-38
lines changed

2 files changed

+26
-38
lines changed

tests/test_transform/factory/test_correctness.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,12 @@
2525
from tests.testing_utils import requires_accelerate, requires_gpu
2626

2727

28-
def scheme_kwargs():
29-
all_types = TransformFactory.registered_names()
30-
base = [{"type": type} for type in all_types]
31-
randomized = [{"type": type, "randomize": True} for type in all_types]
32-
return base + randomized
33-
34-
35-
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
36-
def test_correctness_linear(scheme_kwargs):
28+
@pytest.mark.parametrize("type", TransformFactory.registered_names())
29+
@pytest.mark.parametrize("randomized", (True, False))
30+
def test_correctness_linear(type, randomized):
3731
size = (4, 8)
3832
module = torch.nn.Linear(*size, bias=True)
39-
scheme = TransformScheme(**scheme_kwargs)
33+
scheme = TransformScheme(type=type, randomized=randomized)
4034
factory = TransformFactory.from_scheme(scheme, name="")
4135

4236
input_tfm = factory.create_transform(
@@ -60,8 +54,9 @@ def test_correctness_linear(scheme_kwargs):
6054
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
6155

6256

63-
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
64-
def test_correctness_model(scheme_kwargs, model_apply, offload=False):
57+
@pytest.mark.parametrize("type", TransformFactory.registered_names())
58+
@pytest.mark.parametrize("randomized", (True, False))
59+
def test_correctness_model(type, randomized, model_apply, offload=False):
6560
# load model
6661
model = model_apply[0]
6762
if offload:
@@ -76,10 +71,7 @@ def test_correctness_model(scheme_kwargs, model_apply, offload=False):
7671
# apply transforms
7772
config = TransformConfig(
7873
config_groups={
79-
"": TransformScheme(
80-
**scheme_kwargs,
81-
apply=model_apply[1],
82-
)
74+
"": TransformScheme(type=type, randomized=randomized, apply=model_apply[1])
8375
}
8476
)
8577
apply_transform_config(model, config)
@@ -91,6 +83,7 @@ def test_correctness_model(scheme_kwargs, model_apply, offload=False):
9183

9284
@requires_gpu
9385
@requires_accelerate()
94-
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
95-
def test_correctness_model_offload(scheme_kwargs, model_apply):
96-
test_correctness_model(scheme_kwargs, model_apply, offload=True)
86+
@pytest.mark.parametrize("type", TransformFactory.registered_names())
87+
@pytest.mark.parametrize("randomized", (True, False))
88+
def test_correctness_model_offload(type, randomized, model_apply):
89+
test_correctness_model(type, randomized, model_apply, offload=True)

tests/test_transform/factory/test_memory.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,10 @@
2929
from tests.testing_utils import requires_accelerate, requires_gpu
3030

3131

32-
def scheme_kwargs():
33-
all_types = TransformFactory.registered_names()
34-
base = [{"type": type} for type in all_types]
35-
randomized = [{"type": type, "randomize": True} for type in all_types]
36-
return base + randomized
37-
38-
39-
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
40-
def test_memory_sharing(scheme_kwargs, offload=False):
32+
@pytest.mark.parametrize("type", TransformFactory.registered_names())
33+
@pytest.mark.parametrize("randomized", (True, False))
34+
@pytest.mark.parametrize("requires_grad", (True, False))
35+
def test_memory_sharing(type, randomized, requires_grad, offload=False):
4136
# load model (maybe with offloading)
4237
model = TransformableModel(2, 2, 4, 4, 8, 8)
4338
if offload:
@@ -47,7 +42,9 @@ def test_memory_sharing(scheme_kwargs, offload=False):
4742
config = TransformConfig(
4843
config_groups={
4944
"": TransformScheme(
50-
**scheme_kwargs,
45+
type=type,
46+
randomzied=randomized,
47+
requires_grad=requires_grad,
5148
apply=[
5249
TransformArgs(targets="Linear", location="input"),
5350
TransformArgs(targets="Linear", location="output"),
@@ -87,12 +84,10 @@ def test_memory_sharing(scheme_kwargs, offload=False):
8784

8885
@requires_gpu
8986
@requires_accelerate()
90-
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
91-
def test_memory_sharing_offload(scheme_kwargs):
92-
test_memory_sharing(scheme_kwargs, offload=True)
93-
94-
95-
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
96-
def test_memory_sharing_training(scheme_kwargs):
97-
scheme_kwargs["requires_grad"] = True
98-
test_memory_sharing(scheme_kwargs, offload=False)
87+
@pytest.mark.parametrize("type", TransformFactory.registered_names())
88+
@pytest.mark.parametrize("randomized", (True, False))
89+
def test_memory_sharing_offload(
90+
type,
91+
randomized,
92+
):
93+
test_memory_sharing(type, randomized, requires_grad=False, offload=True)

0 commit comments

Comments
 (0)