Skip to content

Commit 27bc0b3

Browse files
committed
implement apply, use in tests
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 938e702 commit 27bc0b3

File tree

5 files changed

+132
-82
lines changed

5 files changed

+132
-82
lines changed

src/compressed_tensors/transform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@
2323
from .factory.hadamard import *
2424
from .factory.matrix_multiply import *
2525
from .factory.random_hadamard import *
26+
from .apply import *
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from compressed_tensors.transform import TransformConfig, TransformFactory
17+
18+
19+
__all__ = ["apply_transform_config"]
20+
21+
22+
def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
23+
for name, scheme in config.config_groups.items():
24+
factory = TransformFactory.from_scheme(scheme, name=name)
25+
factory.apply_to_model(model)

tests/test_transform/conftest.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
import torch
17+
from compressed_tensors.transform import TransformArgs
18+
19+
20+
class TransformableModel(torch.nn.Module):
21+
def __init__(self, *sizes):
22+
super().__init__()
23+
self.fcs = torch.nn.ModuleList([])
24+
self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False))
25+
for index in range(1, len(sizes) - 1):
26+
self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False))
27+
28+
def forward(self, x):
29+
for layer in self.fcs:
30+
x = layer(x)
31+
return x
32+
33+
34+
@pytest.fixture(scope="function")
35+
def model_apply():
36+
model = TransformableModel(2, 4, 8, 16, 32, 64)
37+
apply = [
38+
# weight output -> input
39+
TransformArgs(targets="fcs.0", location="weight_output"),
40+
TransformArgs(targets="fcs.1", location="input", inverse=True),
41+
# output -> weight input
42+
TransformArgs(targets="fcs.1", location="output"),
43+
TransformArgs(targets="fcs.2", location="weight_input", inverse=True),
44+
# output -> input
45+
TransformArgs(targets="fcs.2", location="output"),
46+
TransformArgs(targets="fcs.3", location="input", inverse=True),
47+
# weight output -> weight input
48+
TransformArgs(targets="fcs.3", location="weight_output"),
49+
TransformArgs(targets="fcs.4", location="weight_input", inverse=True),
50+
]
51+
52+
return model, apply

tests/test_transform/factory/test_correctness.py

Lines changed: 27 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,27 @@
1616
import torch
1717
from compressed_tensors.transform import (
1818
TransformArgs,
19+
TransformConfig,
1920
TransformFactory,
2021
TransformScheme,
22+
apply_transform_config,
2123
)
2224
from compressed_tensors.utils import force_cpu_offload
2325
from tests.testing_utils import requires_accelerate, requires_gpu
2426

2527

26-
def all_schemes():
28+
def scheme_kwargs():
2729
all_types = TransformFactory.registered_names()
28-
base = [TransformScheme(type=type) for type in all_types]
29-
randomized = [TransformScheme(type=type, randomize=True) for type in all_types]
30+
base = [{"type": type} for type in all_types]
31+
randomized = [{"type": type, "randomize": True} for type in all_types]
3032
return base + randomized
3133

3234

33-
class TransformableModel(torch.nn.Module):
34-
def __init__(self, *sizes):
35-
super().__init__()
36-
self.fcs = torch.nn.ModuleList([])
37-
self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False))
38-
for index in range(1, len(sizes) - 1):
39-
self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False))
40-
41-
def forward(self, x):
42-
for layer in self.fcs:
43-
x = layer(x)
44-
return x
45-
46-
47-
@pytest.mark.parametrize("scheme", all_schemes())
48-
def test_correctness_linear(scheme):
35+
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
36+
def test_correctness_linear(scheme_kwargs):
4937
size = (4, 8)
5038
module = torch.nn.Linear(*size, bias=True)
39+
scheme = TransformScheme(**scheme_kwargs)
5140
factory = TransformFactory.from_scheme(scheme, name="")
5241

5342
input_tfm = factory.create_transform(
@@ -71,44 +60,37 @@ def test_correctness_linear(scheme):
7160
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
7261

7362

74-
@pytest.mark.parametrize("scheme", all_schemes())
75-
def test_correctness_model(scheme, offload=False):
63+
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
64+
def test_correctness_model(scheme_kwargs, model_apply, offload=False):
7665
# load model
77-
model = TransformableModel(2, 4, 8, 16, 32, 64)
66+
model = model_apply[0]
7867
if offload:
7968
model = force_cpu_offload(model, torch.device("cuda"))
8069

81-
# create factory
82-
scheme.apply = [
83-
# weight output -> input
84-
TransformArgs(targets="fcs.0", location="weight_output"),
85-
TransformArgs(targets="fcs.1", location="input", inverse=True),
86-
# output -> weight input
87-
TransformArgs(targets="fcs.1", location="output"),
88-
TransformArgs(targets="fcs.2", location="weight_input", inverse=True),
89-
# output -> input
90-
TransformArgs(targets="fcs.2", location="output"),
91-
TransformArgs(targets="fcs.3", location="input", inverse=True),
92-
# weight output -> weight input
93-
TransformArgs(targets="fcs.3", location="weight_output"),
94-
TransformArgs(targets="fcs.4", location="weight_input", inverse=True),
95-
]
96-
factory = TransformFactory.from_scheme(scheme, name="")
97-
98-
# create inputs
70+
# get output
9971
input = torch.rand((17, model.fcs[0].in_features))
10072
if offload:
10173
input = input.to(torch.device("cuda"))
74+
true_output = model(input)
75+
76+
# apply transforms
77+
config = TransformConfig(
78+
config_groups={
79+
"": TransformScheme(
80+
**scheme_kwargs,
81+
apply=model_apply[1],
82+
)
83+
}
84+
)
85+
apply_transform_config(model, config)
10286

10387
# compare outputs
104-
true_output = model(input)
105-
factory.apply_to_model(model)
10688
output = model(input)
10789
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
10890

10991

11092
@requires_gpu
11193
@requires_accelerate()
112-
@pytest.mark.parametrize("scheme", all_schemes())
113-
def test_correctness_model_offload(scheme):
114-
test_correctness_model(scheme, offload=True)
94+
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
95+
def test_correctness_model_offload(scheme_kwargs, model_apply):
96+
test_correctness_model(scheme_kwargs, model_apply, offload=True)

tests/test_transform/factory/test_memory.py

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,53 +19,43 @@
1919
from compressed_tensors.transform import (
2020
TransformArgs,
2121
TransformBase,
22+
TransformConfig,
2223
TransformFactory,
2324
TransformScheme,
25+
apply_transform_config,
2426
)
2527
from compressed_tensors.utils import align_modules, force_cpu_offload
28+
from tests.test_transform.conftest import TransformableModel
2629
from tests.testing_utils import requires_accelerate, requires_gpu
2730

2831

29-
def all_schemes():
32+
def scheme_kwargs():
3033
all_types = TransformFactory.registered_names()
31-
base = [TransformScheme(type=type) for type in all_types]
32-
randomized = [TransformScheme(type=type, randomize=True) for type in all_types]
34+
base = [{"type": type} for type in all_types]
35+
randomized = [{"type": type, "randomize": True} for type in all_types]
3336
return base + randomized
3437

3538

36-
class TransformableModel(torch.nn.Module):
37-
def __init__(self, *sizes):
38-
super().__init__()
39-
self.fcs = torch.nn.ModuleList([])
40-
self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False))
41-
for index in range(1, len(sizes) - 1):
42-
self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False))
43-
44-
def forward(self, x):
45-
for layer in self.fcs:
46-
x = layer(x)
47-
return x
48-
49-
50-
@pytest.mark.parametrize("scheme", all_schemes())
51-
def test_memory_sharing(scheme, offload=False):
52-
# load scheme and factory
53-
scheme = TransformScheme(
54-
type="hadamard",
55-
apply=[
56-
TransformArgs(targets="Linear", location="input"),
57-
TransformArgs(targets="Linear", location="output"),
58-
],
59-
)
60-
factory = TransformFactory.from_scheme(scheme, name="")
61-
39+
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
40+
def test_memory_sharing(scheme_kwargs, offload=False):
6241
# load model (maybe with offloading)
6342
model = TransformableModel(2, 2, 4, 4, 8, 8)
6443
if offload:
6544
force_cpu_offload(model, torch.device("cuda"))
6645

6746
# add transforms to model
68-
factory.apply_to_model(model)
47+
config = TransformConfig(
48+
config_groups={
49+
"": TransformScheme(
50+
**scheme_kwargs,
51+
apply=[
52+
TransformArgs(targets="Linear", location="input"),
53+
TransformArgs(targets="Linear", location="output"),
54+
],
55+
)
56+
}
57+
)
58+
apply_transform_config(model, config)
6959

7060
# check that memory is shared when onloaded
7161
with align_modules(model.modules()):
@@ -97,12 +87,12 @@ def test_memory_sharing(scheme, offload=False):
9787

9888
@requires_gpu
9989
@requires_accelerate()
100-
@pytest.mark.parametrize("scheme", all_schemes())
101-
def test_memory_sharing_offload(scheme):
102-
test_memory_sharing(scheme, offload=True)
90+
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
91+
def test_memory_sharing_offload(scheme_kwargs):
92+
test_memory_sharing(scheme_kwargs, offload=True)
10393

10494

105-
@pytest.mark.parametrize("scheme", all_schemes())
106-
def test_memory_sharing_training(scheme):
107-
scheme.requires_grad = True
108-
test_memory_sharing(scheme, offload=False)
95+
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
96+
def test_memory_sharing_training(scheme_kwargs):
97+
scheme_kwargs["requires_grad"] = True
98+
test_memory_sharing(scheme_kwargs, offload=False)

0 commit comments

Comments
 (0)