Skip to content

Commit 8c5a2d9

Browse files
committed
Implement transform factories
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 859fd90 commit 8c5a2d9

File tree

17 files changed

+1209
-8
lines changed

17 files changed

+1209
-8
lines changed

src/compressed_tensors/transform/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,8 @@
1818
from .transform_args import *
1919
from .transform_scheme import *
2020
from .transform_config import *
21+
22+
from .factory.base import *
23+
from .factory.hadamard import *
24+
from .factory.matrix_multiply import *
25+
from .factory.random_hadamard import *
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from abc import ABC, abstractmethod
16+
17+
import torch
18+
import torch.nn.utils.parametrize as P
19+
from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils
20+
from compressed_tensors.registry.registry import RegistryMixin, T
21+
from compressed_tensors.transform import (
22+
TransformArgs,
23+
TransformLocation,
24+
TransformScheme,
25+
)
26+
from compressed_tensors.utils import (
27+
align_module_device,
28+
has_offloaded_params,
29+
patch_attr,
30+
register_offload_module,
31+
update_offload_parameter,
32+
)
33+
from torch import Tensor
34+
from torch.nn import Module, Parameter
35+
36+
37+
__all__ = ["TransformFactory", "TransformBase"]
38+
39+
40+
class TransformFactory(RegistryMixin, ABC):
41+
"""
42+
Abstract factory base used to create and apply transforms to a model
43+
44+
:param name: name associated with transform scheme
45+
:param scheme: transform scheme which defines how transforms should be created
46+
:param seed: random seed used to transform weight randomization
47+
"""
48+
49+
def __init__(self, name: str, scheme: TransformScheme, seed: int = 42):
50+
self.name = name
51+
self.scheme = scheme
52+
self.seed = seed
53+
54+
@classmethod
55+
def from_scheme(cls: type[T], scheme: TransformScheme, **kwargs) -> T:
56+
"""
57+
Create a transform factory from a scheme
58+
59+
:param scheme: defines how transforms should be created
60+
:param kwargs: TransformFactory constructor arguments
61+
:return: subclass of `TransformFactory` corresponding to the scheme type
62+
"""
63+
constructor = cls.get_value_from_registry(name=scheme.type)
64+
return constructor(scheme=scheme, **kwargs)
65+
66+
@abstractmethod
67+
def create_transform(self, module: Module, args: TransformArgs) -> "TransformBase":
68+
"""
69+
Abstract method which defines how a transform should be created. May utilize
70+
caching to maximize shared memory
71+
72+
:param module: parent module that transform will be applied to
73+
:param args: defines how the transform will be applied to the module
74+
:return: instance of TransformBase
75+
"""
76+
raise NotImplementedError()
77+
78+
def apply_to_model(self, model: Module):
79+
"""
80+
Create transforms and apply them to the model
81+
82+
:param model: module to apply transforms to
83+
"""
84+
for arg in self.scheme.apply:
85+
for path, module in list(model.named_modules()):
86+
if is_target(path, module, arg.targets, arg.ignore):
87+
self._apply_to_module(module, arg)
88+
89+
def _apply_to_module(self, module: Module, args: TransformArgs):
90+
"""
91+
Create transforms and apply them to the module
92+
93+
:param module: target module to apply transforms to
94+
:param args: defines how the transform will be applied to the target module
95+
"""
96+
# create transform as submodule
97+
transform_name = f"{self.name}_{args.location}"
98+
transform = self.create_transform(module, args)
99+
register_offload_module(module, transform_name, transform) # (1)
100+
101+
# register input transformation hook
102+
if args.location == TransformLocation.INPUT:
103+
104+
def input_hook(_, args):
105+
input = args[0]
106+
return transform(input)
107+
108+
module.register_forward_pre_hook(input_hook)
109+
110+
# eagerly apply transformation to weight
111+
elif args.location in (
112+
TransformLocation.WEIGHT_INPUT,
113+
TransformLocation.WEIGHT_OUTPUT,
114+
):
115+
assert isinstance(module, torch.nn.Linear)
116+
assert module.bias is None
117+
118+
with torch.no_grad(), align_module_device(module):
119+
update_offload_parameter(module, "weight", transform(module.weight))
120+
121+
if self.scheme.requires_grad:
122+
# for training, the weight changes with every forward pass
123+
# so we can leverage parametrization to propagate the gradient
124+
if has_offloaded_params(module):
125+
raise ValueError("Offloaded training is not supported")
126+
P.register_parametrization(module, "weight", transform)
127+
128+
# register output transformation hook
129+
elif args.location == TransformLocation.OUTPUT:
130+
131+
def output_hook(_, _input, output):
132+
return transform(output)
133+
134+
module.register_forward_hook(output_hook)
135+
136+
# other locations such as q_attn and k_attn have not been implemented
137+
else:
138+
raise NotImplementedError()
139+
140+
# (1) even in the `weight` cases, this submodule attachment is needed in order
141+
# to support saving in the frozen state
142+
143+
144+
class TransformBase(Module, ABC):
145+
"""
146+
Represents the application of a transform accord to TransformArgs
147+
"""
148+
149+
args: TransformArgs
150+
weight: Parameter
151+
152+
@abstractmethod
153+
def forward(self, value: Tensor) -> Tensor:
154+
raise NotImplementedError()
155+
156+
def right_inverse(self, value: Tensor) -> Tensor:
157+
with patch_attr(self.args, "inverse", not self.args.inverse):
158+
return self.forward(value)
159+
160+
def __repr__(self):
161+
return f"{self.__class__.__name__}(inverse={self.args.inverse})"
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import torch
18+
from compressed_tensors.transform import TransformArgs, TransformScheme
19+
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
20+
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
21+
from compressed_tensors.transform.utils.utils import (
22+
apply_transform_weight,
23+
get_matrix_size,
24+
)
25+
from compressed_tensors.utils import get_offloaded_device
26+
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
27+
from torch import Tensor, device, dtype
28+
from torch.nn import Linear, Module, Parameter
29+
30+
31+
@TransformFactory.register("hadamard")
32+
class HadamardFactory(TransformFactory):
33+
"""
34+
Factory used to apply hadamard transforms to a model
35+
36+
:param name: name associated with transform scheme
37+
:param scheme: transform scheme which defines how transforms should be created
38+
:param seed: random seed used to transform weight randomization
39+
"""
40+
41+
def __init__(self, name: str, scheme: TransformScheme, seed: int = 42):
42+
super().__init__(name, scheme, seed)
43+
self.weights = ParameterizedDefaultDict(self._create_weight)
44+
45+
def create_transform(self, module: Module, args: TransformArgs):
46+
"""
47+
Create a HadamardTransform for applying to a module. Transforms with the same
48+
size, dtype, and device are cached
49+
50+
:param module: parent module that transform will be applied to
51+
:param args: defines how the transform will be applied to the module
52+
"""
53+
assert isinstance(module, Linear)
54+
size = get_matrix_size(module, args.location)
55+
dtype = module.weight.dtype
56+
device = get_offloaded_device(module)
57+
58+
weight = self.weights[size, dtype, device]
59+
return HadamardTransform(weight, args)
60+
61+
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
62+
data = torch.tensor(deterministic_hadamard_matrix(size)) # TODO: seed=self.seed
63+
data = data.to(dtype=dtype, device=device)
64+
return Parameter(data, requires_grad=self.scheme.requires_grad)
65+
66+
67+
class HadamardTransform(TransformBase):
68+
def __init__(self, weight: Parameter, args: TransformArgs):
69+
super().__init__()
70+
self.weight = weight
71+
self.args = args
72+
73+
def forward(self, value: Tensor) -> Tensor:
74+
if not self.args.inverse:
75+
weight = self.weight
76+
else:
77+
weight = self.weight.T / self.weight.size(0)
78+
79+
return apply_transform_weight(weight, value, self.args.location)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from compressed_tensors.transform import TransformArgs, TransformScheme
17+
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
18+
from compressed_tensors.transform.utils.utils import (
19+
apply_transform_weight,
20+
get_matrix_size,
21+
)
22+
from compressed_tensors.utils import get_offloaded_device
23+
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
24+
from torch import Tensor, device, dtype
25+
from torch.nn import Linear, Module, Parameter
26+
27+
28+
@TransformFactory.register("matrix-mul")
29+
class RandomMatrixFactory(TransformFactory):
30+
"""
31+
Factory used to apply random matrix transforms to a model
32+
33+
:param name: name associated with transform scheme
34+
:param scheme: transform scheme which defines how transforms should be created
35+
:param seed: random seed used to transform weight randomization
36+
"""
37+
38+
def __init__(self, name: str, scheme: TransformScheme, seed: int = 42):
39+
super().__init__(name, scheme, seed)
40+
self.weights = ParameterizedDefaultDict(self._create_weight)
41+
self.inverses = ParameterizedDefaultDict(self._create_inverse)
42+
43+
def create_transform(self, module: Module, args: TransformArgs):
44+
"""
45+
Create a RandomMatrixTransform for applying to a module. Transforms with the
46+
same size, dtype, and device are cached
47+
48+
:param module: parent module that transform will be applied to
49+
:param args: defines how the transform will be applied to the module
50+
"""
51+
assert isinstance(module, Linear)
52+
size = get_matrix_size(module, args.location)
53+
dtype = module.weight.dtype
54+
device = get_offloaded_device(module)
55+
56+
if not args.inverse:
57+
weight = self.weights[size, dtype, device]
58+
else:
59+
weight = self.inverses[size, dtype, device]
60+
return RandomMatrixTransform(weight, args)
61+
62+
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
63+
data = torch.rand((size, size), dtype=dtype, device=device)
64+
return Parameter(data, requires_grad=self.scheme.requires_grad)
65+
66+
def _create_inverse(self, size: int, dtype: dtype, device: device) -> Parameter:
67+
weight = self.weights[size, dtype, device]
68+
return Parameter(high_precision_invert(weight.data), requires_grad=False)
69+
70+
71+
class RandomMatrixTransform(TransformBase):
72+
def __init__(self, weight: Tensor, args: TransformArgs):
73+
super().__init__()
74+
self.weight = weight # is an inverse if args.inverse
75+
self.args = args
76+
77+
def forward(self, value: Tensor) -> Parameter:
78+
return apply_transform_weight(self.weight, value, self.args.location)
79+
80+
def right_inverse(self, value: Tensor) -> Tensor:
81+
inverse = high_precision_invert(self.weight)
82+
return apply_transform_weight(inverse, value, self.args.location)
83+
84+
85+
def high_precision_invert(weight: Tensor) -> Tensor:
86+
return torch.linalg.inv(weight.to(torch.float32)).to(weight.dtype)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from compressed_tensors.transform import HadamardFactory, TransformFactory
16+
from compressed_tensors.transform.utils.hadamard import random_hadamard_matrix
17+
from torch import device, dtype
18+
from torch.nn import Parameter
19+
20+
21+
@TransformFactory.register("random-hadamard")
22+
class RandomHadamardFactory(HadamardFactory):
23+
"""
24+
Factory used to apply random hadamard transforms to a model
25+
26+
:param name: name associated with transform scheme
27+
:param scheme: transform scheme which defines how transforms should be created
28+
:param seed: random seed used to transform weight randomization
29+
"""
30+
31+
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
32+
for key in self.weights.keys():
33+
if key[0] == size:
34+
return self.weights[key].to(dtype=dtype, device=device)
35+
36+
data = random_hadamard_matrix(size) # seed
37+
data = data.to(dtype=dtype, device=device)
38+
return Parameter(data, requires_grad=self.scheme.requires_grad)

src/compressed_tensors/transform/transform_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pydantic import BaseModel, Field, field_validator
1919

2020

21-
__all__ = ["TransformArgs"]
21+
__all__ = ["TransformLocation", "TransformArgs"]
2222

2323

2424
class TransformLocation(str, Enum):

0 commit comments

Comments
 (0)