Skip to content

Commit b163bd9

Browse files
[Transform] apply_transform_config (#348)
* add utilities Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add additional tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add utils and tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * Implement transform factories Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add permutations Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add delete_offload_module Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * key inverses by weight Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * standardize random hadamard Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * prepend input hooks Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * apply sqrt division first Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use divided hadamards Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix typo Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add random option Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use random seeds, rename matrix multiply Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add deterministic generation to random matrix Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix perm math Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstrings Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstrings Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * cleanup Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * cleanup 2 Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * make seed optional Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove iterable check and missing return value Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * Remove unrelated changes * simplify code Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * implement apply, use in tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use hadamards database file Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * try manifest Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * try setup, update hadamards list Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix setup Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add docstrings, cleanup Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix setup, thank you @dbarbuzzi Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove numpy, add tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * solidify dtype, add gpu tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix docstring Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add device option Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * construct on execution device, cache on offload device Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * save construction device changes for later Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * construct on execution device, cache on offload device * cite nja sloane Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove dreg Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * put on device via safe_open Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * nits and docstrings Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstring Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * Merge * merge with construct: construct in float32 Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * construct with same dtype, constructing on fp32 found no difference Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove unnecessary imports Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * bugfixes (#375) Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> * use factory_kwargs Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add frozen dict to deps Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix style Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * merge Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use delete_offload_module Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add docstrign Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use parametrize Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove random from tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
1 parent e4eb3fb commit b163bd9

File tree

5 files changed

+69
-49
lines changed

5 files changed

+69
-49
lines changed

src/compressed_tensors/transform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@
2323
from .factory.hadamard import *
2424
from .factory.matrix_multiply import *
2525
from .factory.random_hadamard import *
26+
from .apply import *
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from compressed_tensors.transform import TransformConfig, TransformFactory
17+
18+
19+
__all__ = ["apply_transform_config"]
20+
21+
22+
def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
23+
"""
24+
Apply a transform config to a model. Weight transforms are fused into weights, while
25+
activation transforms are attached as submodules and trigger via pytorch hooks
26+
27+
:param model: model to apply config to
28+
:param config: transform config to apply
29+
"""
30+
for name, scheme in config.config_groups.items():
31+
factory = TransformFactory.from_scheme(scheme, name=name)
32+
factory.apply_to_model(model)

src/compressed_tensors/transform/factory/base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from compressed_tensors.utils import (
2929
align_module_device,
30+
delete_offload_module,
3031
has_offloaded_params,
3132
patch_attr,
3233
register_offload_module,
@@ -100,7 +101,7 @@ def _apply_to_module(self, module: Module, args: TransformArgs):
100101
# create transform as submodule
101102
transform_name = f"{self.name}_{args.location.value}"
102103
transform = self.create_transform(module, args)
103-
register_offload_module(module, transform_name, transform) # (1)
104+
register_offload_module(module, transform_name, transform)
104105

105106
# register input transformation hook
106107
if args.location == TransformLocation.INPUT:
@@ -119,6 +120,7 @@ def input_hook(_, args):
119120
assert isinstance(module, torch.nn.Linear)
120121
assert module.bias is None
121122

123+
# fuse transform into weight
122124
with torch.no_grad(), align_module_device(module):
123125
update_offload_parameter(module, "weight", transform(module.weight))
124126

@@ -129,6 +131,9 @@ def input_hook(_, args):
129131
raise ValueError("Offloaded training is not supported")
130132
P.register_parametrization(module, "weight", transform)
131133

134+
# transform is no longer needed (unfusing is not supported)
135+
delete_offload_module(module, transform_name)
136+
132137
# register output transformation hook
133138
elif args.location == TransformLocation.OUTPUT:
134139

@@ -141,9 +146,6 @@ def output_hook(_, _input, output):
141146
else:
142147
raise NotImplementedError()
143148

144-
# (1) even in the `weight` cases, this submodule attachment is needed in order
145-
# to support saving in the frozen state
146-
147149

148150
class TransformBase(InternalModule, ABC):
149151
"""

tests/test_transform/factory/test_correctness.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,18 @@
1919
TransformConfig,
2020
TransformFactory,
2121
TransformScheme,
22+
apply_transform_config,
2223
)
2324
from compressed_tensors.utils import offloaded_dispatch
2425
from tests.testing_utils import requires_accelerate, requires_gpu
2526

2627

27-
def scheme_kwargs():
28-
all_types = TransformFactory.registered_names()
29-
base = [{"type": type} for type in all_types]
30-
randomized = [{"type": type, "randomize": True} for type in all_types]
31-
return base + randomized
32-
33-
34-
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
35-
def test_correctness_linear(scheme_kwargs):
28+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
29+
@pytest.mark.parametrize("randomized", (True, False))
30+
def test_correctness_linear(type, randomized):
3631
size = (4, 8)
3732
module = torch.nn.Linear(*size, bias=True)
38-
scheme = TransformScheme(**scheme_kwargs)
33+
scheme = TransformScheme(type=type, randomized=randomized)
3934
factory = TransformFactory.from_scheme(scheme, name="")
4035

4136
input_tfm = factory.create_transform(
@@ -59,8 +54,9 @@ def test_correctness_linear(scheme_kwargs):
5954
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
6055

6156

62-
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
63-
def test_correctness_model(scheme_kwargs, model_apply, offload=False):
57+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
58+
@pytest.mark.parametrize("randomized", (True, False))
59+
def test_correctness_model(type, randomized, model_apply, offload=False):
6460
# load model
6561
model = model_apply[0]
6662
if offload:
@@ -75,15 +71,10 @@ def test_correctness_model(scheme_kwargs, model_apply, offload=False):
7571
# apply transforms
7672
config = TransformConfig(
7773
config_groups={
78-
"": TransformScheme(
79-
**scheme_kwargs,
80-
apply=model_apply[1],
81-
)
74+
"": TransformScheme(type=type, randomized=randomized, apply=model_apply[1])
8275
}
8376
)
84-
for name, scheme in config.config_groups.items():
85-
factory = TransformFactory.from_scheme(scheme, name=name)
86-
factory.apply_to_model(model)
77+
apply_transform_config(model, config)
8778

8879
# compare outputs
8980
output = model(input)
@@ -92,6 +83,7 @@ def test_correctness_model(scheme_kwargs, model_apply, offload=False):
9283

9384
@requires_gpu
9485
@requires_accelerate()
95-
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
96-
def test_correctness_model_offload(scheme_kwargs, model_apply):
97-
test_correctness_model(scheme_kwargs, model_apply, offload=True)
86+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
87+
@pytest.mark.parametrize("randomized", (True, False))
88+
def test_correctness_model_offload(type, randomized, model_apply):
89+
test_correctness_model(type, randomized, model_apply, offload=True)

tests/test_transform/factory/test_memory.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,18 @@
2020
TransformArgs,
2121
TransformBase,
2222
TransformConfig,
23-
TransformFactory,
2423
TransformScheme,
24+
apply_transform_config,
2525
)
2626
from compressed_tensors.utils import align_modules, offloaded_dispatch
2727
from tests.test_transform.conftest import TransformableModel
2828
from tests.testing_utils import requires_accelerate, requires_gpu
2929

3030

31-
def scheme_kwargs():
32-
all_types = TransformFactory.registered_names()
33-
base = [{"type": type} for type in all_types]
34-
randomized = [{"type": type, "randomize": True} for type in all_types]
35-
return base + randomized
36-
37-
38-
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
39-
def test_memory_sharing(scheme_kwargs, offload=False):
31+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
32+
@pytest.mark.parametrize("randomized", (True, False))
33+
@pytest.mark.parametrize("requires_grad", (True, False))
34+
def test_memory_sharing(type, randomized, requires_grad, offload=False):
4035
# load model (maybe with offloading)
4136
model = TransformableModel(2, 2, 4, 4, 8, 8)
4237
if offload:
@@ -46,17 +41,17 @@ def test_memory_sharing(scheme_kwargs, offload=False):
4641
config = TransformConfig(
4742
config_groups={
4843
"": TransformScheme(
49-
**scheme_kwargs,
44+
type=type,
45+
randomzied=randomized,
46+
requires_grad=requires_grad,
5047
apply=[
5148
TransformArgs(targets="Linear", location="input"),
5249
TransformArgs(targets="Linear", location="output"),
5350
],
5451
)
5552
}
5653
)
57-
for name, scheme in config.config_groups.items():
58-
factory = TransformFactory.from_scheme(scheme, name=name)
59-
factory.apply_to_model(model)
54+
apply_transform_config(model, config)
6055

6156
# check that memory is shared when onloaded
6257
with align_modules(model.modules()):
@@ -88,12 +83,10 @@ def test_memory_sharing(scheme_kwargs, offload=False):
8883

8984
@requires_gpu
9085
@requires_accelerate()
91-
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
92-
def test_memory_sharing_offload(scheme_kwargs):
93-
test_memory_sharing(scheme_kwargs, offload=True)
94-
95-
96-
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
97-
def test_memory_sharing_training(scheme_kwargs):
98-
scheme_kwargs["requires_grad"] = True
99-
test_memory_sharing(scheme_kwargs, offload=False)
86+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
87+
@pytest.mark.parametrize("randomized", (True, False))
88+
def test_memory_sharing_offload(
89+
type,
90+
randomized,
91+
):
92+
test_memory_sharing(type, randomized, requires_grad=False, offload=True)

0 commit comments

Comments
 (0)