Skip to content

Commit 98a0cd7

Browse files
authored
[Transform] Update tests to use conftest file (#367)
* update tests to use conf Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * cleanup Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent f5b3e71 commit 98a0cd7

File tree

3 files changed

+114
-92
lines changed

3 files changed

+114
-92
lines changed

tests/test_transform/conftest.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
import torch
17+
from compressed_tensors.transform import TransformArgs
18+
19+
20+
class TransformableModel(torch.nn.Module):
21+
def __init__(self, *sizes):
22+
super().__init__()
23+
self.fcs = torch.nn.ModuleList(
24+
[
25+
torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)
26+
for index in range(0, len(sizes) - 1)
27+
]
28+
)
29+
30+
def forward(self, x):
31+
for layer in self.fcs:
32+
x = layer(x)
33+
return x
34+
35+
36+
@pytest.fixture(scope="function")
37+
def model_apply():
38+
model = TransformableModel(2, 4, 8, 16, 32, 64)
39+
apply = [
40+
# weight output -> input
41+
TransformArgs(targets="fcs.0", location="weight_output"),
42+
TransformArgs(targets="fcs.1", location="input", inverse=True),
43+
# output -> weight input
44+
TransformArgs(targets="fcs.1", location="output"),
45+
TransformArgs(targets="fcs.2", location="weight_input", inverse=True),
46+
# output -> input
47+
TransformArgs(targets="fcs.2", location="output"),
48+
TransformArgs(targets="fcs.3", location="input", inverse=True),
49+
# weight output -> weight input
50+
TransformArgs(targets="fcs.3", location="weight_output"),
51+
TransformArgs(targets="fcs.4", location="weight_input", inverse=True),
52+
]
53+
54+
return model, apply

tests/test_transform/factory/test_correctness.py

Lines changed: 30 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -16,34 +16,26 @@
1616
import torch
1717
from compressed_tensors.transform import (
1818
TransformArgs,
19+
TransformConfig,
1920
TransformFactory,
2021
TransformScheme,
2122
)
2223
from compressed_tensors.utils import offloaded_dispatch
2324
from tests.testing_utils import requires_accelerate, requires_gpu
2425

2526

26-
class TransformableModel(torch.nn.Module):
27-
def __init__(self, *sizes):
28-
super().__init__()
29-
self.fcs = torch.nn.ModuleList([])
30-
self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False))
31-
for index in range(1, len(sizes) - 1):
32-
self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False))
27+
def scheme_kwargs():
28+
all_types = TransformFactory.registered_names()
29+
base = [{"type": type} for type in all_types]
30+
randomized = [{"type": type, "randomize": True} for type in all_types]
31+
return base + randomized
3332

34-
def forward(self, x):
35-
for layer in self.fcs:
36-
x = layer(x)
37-
return x
3833

39-
40-
@pytest.mark.parametrize(
41-
"scheme",
42-
[TransformScheme(type=name) for name in TransformFactory.registered_names()],
43-
)
44-
def test_correctness_linear(scheme):
34+
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
35+
def test_correctness_linear(scheme_kwargs):
4536
size = (4, 8)
4637
module = torch.nn.Linear(*size, bias=True)
38+
scheme = TransformScheme(**scheme_kwargs)
4739
factory = TransformFactory.from_scheme(scheme, name="")
4840

4941
input_tfm = factory.create_transform(
@@ -67,50 +59,39 @@ def test_correctness_linear(scheme):
6759
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
6860

6961

70-
@pytest.mark.parametrize(
71-
"scheme",
72-
[TransformScheme(type=name) for name in TransformFactory.registered_names()],
73-
)
74-
def test_correctness_model(scheme, offload=False):
62+
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
63+
def test_correctness_model(scheme_kwargs, model_apply, offload=False):
7564
# load model
76-
model = TransformableModel(2, 4, 8, 16, 32, 64)
65+
model = model_apply[0]
7766
if offload:
7867
model = offloaded_dispatch(model, torch.device("cuda"))
7968

80-
# create factory
81-
scheme.apply = [
82-
# weight output -> input
83-
TransformArgs(targets="fcs.0", location="weight_output"),
84-
TransformArgs(targets="fcs.1", location="input", inverse=True),
85-
# output -> weight input
86-
TransformArgs(targets="fcs.1", location="output"),
87-
TransformArgs(targets="fcs.2", location="weight_input", inverse=True),
88-
# output -> input
89-
TransformArgs(targets="fcs.2", location="output"),
90-
TransformArgs(targets="fcs.3", location="input", inverse=True),
91-
# weight output -> weight input
92-
TransformArgs(targets="fcs.3", location="weight_output"),
93-
TransformArgs(targets="fcs.4", location="weight_input", inverse=True),
94-
]
95-
factory = TransformFactory.from_scheme(scheme, name="")
96-
97-
# create inputs
69+
# get output
9870
input = torch.rand((17, model.fcs[0].in_features))
9971
if offload:
10072
input = input.to(torch.device("cuda"))
73+
true_output = model(input)
74+
75+
# apply transforms
76+
config = TransformConfig(
77+
config_groups={
78+
"": TransformScheme(
79+
**scheme_kwargs,
80+
apply=model_apply[1],
81+
)
82+
}
83+
)
84+
for name, scheme in config.config_groups.items():
85+
factory = TransformFactory.from_scheme(scheme, name=name)
86+
factory.apply_to_model(model)
10187

10288
# compare outputs
103-
true_output = model(input)
104-
factory.apply_to_model(model)
10589
output = model(input)
10690
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
10791

10892

10993
@requires_gpu
11094
@requires_accelerate()
111-
@pytest.mark.parametrize(
112-
"scheme",
113-
[TransformScheme(type=name) for name in TransformFactory.registered_names()],
114-
)
115-
def test_correctness_model_offload(scheme):
116-
test_correctness_model(scheme, offload=True)
95+
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
96+
def test_correctness_model_offload(scheme_kwargs, model_apply):
97+
test_correctness_model(scheme_kwargs, model_apply, offload=True)

tests/test_transform/factory/test_memory.py

Lines changed: 30 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,49 +19,44 @@
1919
from compressed_tensors.transform import (
2020
TransformArgs,
2121
TransformBase,
22+
TransformConfig,
2223
TransformFactory,
2324
TransformScheme,
2425
)
2526
from compressed_tensors.utils import align_modules, offloaded_dispatch
27+
from tests.test_transform.conftest import TransformableModel
2628
from tests.testing_utils import requires_accelerate, requires_gpu
2729

2830

29-
class TransformableModel(torch.nn.Module):
30-
def __init__(self, *sizes):
31-
super().__init__()
32-
self.fcs = torch.nn.ModuleList([])
33-
self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False))
34-
for index in range(1, len(sizes) - 1):
35-
self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False))
31+
def scheme_kwargs():
32+
all_types = TransformFactory.registered_names()
33+
base = [{"type": type} for type in all_types]
34+
randomized = [{"type": type, "randomize": True} for type in all_types]
35+
return base + randomized
3636

37-
def forward(self, x):
38-
for layer in self.fcs:
39-
x = layer(x)
40-
return x
41-
42-
43-
@pytest.mark.parametrize(
44-
"scheme",
45-
[TransformScheme(type=name) for name in TransformFactory.registered_names()],
46-
)
47-
def test_memory_sharing(scheme, offload=False):
48-
# load scheme and factory
49-
scheme = TransformScheme(
50-
type="hadamard",
51-
apply=[
52-
TransformArgs(targets="Linear", location="input"),
53-
TransformArgs(targets="Linear", location="output"),
54-
],
55-
)
56-
factory = TransformFactory.from_scheme(scheme, name="")
5737

38+
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
39+
def test_memory_sharing(scheme_kwargs, offload=False):
5840
# load model (maybe with offloading)
5941
model = TransformableModel(2, 2, 4, 4, 8, 8)
6042
if offload:
6143
offloaded_dispatch(model, torch.device("cuda"))
6244

6345
# add transforms to model
64-
factory.apply_to_model(model)
46+
config = TransformConfig(
47+
config_groups={
48+
"": TransformScheme(
49+
**scheme_kwargs,
50+
apply=[
51+
TransformArgs(targets="Linear", location="input"),
52+
TransformArgs(targets="Linear", location="output"),
53+
],
54+
)
55+
}
56+
)
57+
for name, scheme in config.config_groups.items():
58+
factory = TransformFactory.from_scheme(scheme, name=name)
59+
factory.apply_to_model(model)
6560

6661
# check that memory is shared when onloaded
6762
with align_modules(model.modules()):
@@ -93,20 +88,12 @@ def test_memory_sharing(scheme, offload=False):
9388

9489
@requires_gpu
9590
@requires_accelerate()
96-
@pytest.mark.parametrize(
97-
"scheme",
98-
[TransformScheme(type=name) for name in TransformFactory.registered_names()],
99-
)
100-
def test_memory_sharing_offload(scheme):
101-
test_memory_sharing(scheme, offload=True)
91+
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
92+
def test_memory_sharing_offload(scheme_kwargs):
93+
test_memory_sharing(scheme_kwargs, offload=True)
10294

10395

104-
@pytest.mark.parametrize(
105-
"scheme",
106-
[
107-
TransformScheme(type=name, requires_grad=True)
108-
for name in TransformFactory.registered_names()
109-
],
110-
)
111-
def test_memory_sharing_training(scheme):
112-
test_memory_sharing(scheme, offload=False)
96+
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
97+
def test_memory_sharing_training(scheme_kwargs):
98+
scheme_kwargs["requires_grad"] = True
99+
test_memory_sharing(scheme_kwargs, offload=False)

0 commit comments

Comments
 (0)