Skip to content

Commit 4b81ac7

Browse files
authored
[Transform] Factory classes with shared memory and offloading (#316)
* add utilities Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add additional tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add utils and tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * Implement transform factories Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add delete_offload_module Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * key inverses by weight Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * standardize random hadamard Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * prepend input hooks Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * apply sqrt division first Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use divided hadamards Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix typo Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add random option Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use random seeds, rename matrix multiply Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add deterministic generation to random matrix Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstrings Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstrings Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * make seed optional Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent cef3b60 commit 4b81ac7

File tree

8 files changed

+613
-0
lines changed

8 files changed

+613
-0
lines changed

src/compressed_tensors/transform/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,8 @@
1818
from .transform_args import *
1919
from .transform_scheme import *
2020
from .transform_config import *
21+
22+
from .factory.base import *
23+
from .factory.hadamard import *
24+
from .factory.matrix_multiply import *
25+
from .factory.random_hadamard import *
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from abc import ABC, abstractmethod
16+
from typing import Optional
17+
18+
import torch
19+
import torch.nn.utils.parametrize as P
20+
from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils
21+
from compressed_tensors.registry.registry import RegistryMixin, T
22+
from compressed_tensors.transform import (
23+
TransformArgs,
24+
TransformLocation,
25+
TransformScheme,
26+
)
27+
from compressed_tensors.utils import (
28+
align_module_device,
29+
has_offloaded_params,
30+
patch_attr,
31+
register_offload_module,
32+
update_offload_parameter,
33+
)
34+
from torch import Tensor
35+
from torch.nn import Module, Parameter
36+
37+
38+
__all__ = ["TransformFactory", "TransformBase"]
39+
40+
41+
class TransformFactory(RegistryMixin, ABC):
42+
"""
43+
Abstract factory base used to create and apply transforms to a model
44+
45+
:param name: name associated with transform scheme
46+
:param scheme: transform scheme which defines how transforms should be created
47+
:param seed: random seed used to transform weight randomization
48+
"""
49+
50+
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
51+
self.name = name
52+
self.scheme = scheme
53+
self.generator = torch.Generator()
54+
if seed is not None:
55+
self.generator.manual_seed(seed)
56+
57+
@classmethod
58+
def from_scheme(cls: type[T], scheme: TransformScheme, **kwargs) -> T:
59+
"""
60+
Create a transform factory from a scheme
61+
62+
:param scheme: defines how transforms should be created
63+
:param kwargs: TransformFactory constructor arguments
64+
:return: subclass of `TransformFactory` corresponding to the scheme type
65+
"""
66+
constructor = cls.get_value_from_registry(name=scheme.type)
67+
return constructor(scheme=scheme, **kwargs)
68+
69+
@abstractmethod
70+
def create_transform(self, module: Module, args: TransformArgs) -> "TransformBase":
71+
"""
72+
Abstract method which defines how a transform should be created. May utilize
73+
caching to maximize shared memory
74+
75+
:param module: parent module that transform will be applied to
76+
:param args: defines how the transform will be applied to the module
77+
:return: instance of TransformBase
78+
"""
79+
raise NotImplementedError()
80+
81+
def apply_to_model(self, model: Module):
82+
"""
83+
Create transforms and apply them to the model
84+
85+
:param model: module to apply transforms to
86+
"""
87+
for arg in self.scheme.apply:
88+
for name, module in list(model.named_modules()):
89+
if is_target(name, module, arg.targets, arg.ignore):
90+
self._apply_to_module(module, arg)
91+
92+
def _apply_to_module(self, module: Module, args: TransformArgs):
93+
"""
94+
Create transforms and apply them to the module
95+
96+
:param module: target module to apply transforms to
97+
:param args: defines how the transform will be applied to the target module
98+
"""
99+
# create transform as submodule
100+
transform_name = f"{self.name}_{args.location}"
101+
transform = self.create_transform(module, args)
102+
register_offload_module(module, transform_name, transform) # (1)
103+
104+
# register input transformation hook
105+
if args.location == TransformLocation.INPUT:
106+
107+
def input_hook(_, args):
108+
input = args[0]
109+
return transform(input)
110+
111+
module.register_forward_pre_hook(input_hook, prepend=True)
112+
113+
# eagerly apply transformation to weight
114+
elif args.location in (
115+
TransformLocation.WEIGHT_INPUT,
116+
TransformLocation.WEIGHT_OUTPUT,
117+
):
118+
assert isinstance(module, torch.nn.Linear)
119+
assert module.bias is None
120+
121+
with torch.no_grad(), align_module_device(module):
122+
update_offload_parameter(module, "weight", transform(module.weight))
123+
124+
if self.scheme.requires_grad:
125+
# for training, the weight changes with every forward pass
126+
# so we can leverage parametrization to propagate the gradient
127+
if has_offloaded_params(module):
128+
raise ValueError("Offloaded training is not supported")
129+
P.register_parametrization(module, "weight", transform)
130+
131+
# register output transformation hook
132+
elif args.location == TransformLocation.OUTPUT:
133+
134+
def output_hook(_, _input, output):
135+
return transform(output)
136+
137+
module.register_forward_hook(output_hook)
138+
139+
# other locations such as q_attn and k_attn have not been implemented
140+
else:
141+
raise NotImplementedError()
142+
143+
# (1) even in the `weight` cases, this submodule attachment is needed in order
144+
# to support saving in the frozen state
145+
146+
147+
class TransformBase(Module, ABC):
148+
"""
149+
Represents the application of a transform accord to TransformArgs
150+
"""
151+
152+
args: TransformArgs
153+
weight: Parameter
154+
155+
@abstractmethod
156+
def forward(self, value: Tensor) -> Tensor:
157+
raise NotImplementedError()
158+
159+
def right_inverse(self, value: Tensor) -> Tensor:
160+
with patch_attr(self.args, "inverse", not self.args.inverse):
161+
return self.forward(value)
162+
163+
def __repr__(self):
164+
return f"{self.__class__.__name__}(inverse={self.args.inverse})"
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import torch
18+
from compressed_tensors.transform import TransformArgs, TransformScheme
19+
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
20+
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
21+
from compressed_tensors.transform.utils.utils import (
22+
apply_transform_weight,
23+
get_matrix_size,
24+
)
25+
from compressed_tensors.utils import get_offloaded_device
26+
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
27+
from torch import Tensor, device, dtype
28+
from torch.nn import Linear, Module, Parameter
29+
30+
31+
@TransformFactory.register("hadamard")
32+
class HadamardFactory(TransformFactory):
33+
"""
34+
Factory used to apply hadamard transforms to a model
35+
36+
:param name: name associated with transform scheme
37+
:param scheme: transform scheme which defines how transforms should be created
38+
:param seed: random seed used to transform weight randomization
39+
"""
40+
41+
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
42+
super().__init__(name, scheme, seed)
43+
self.weights = ParameterizedDefaultDict(self._create_weight)
44+
45+
def create_transform(self, module: Module, args: TransformArgs):
46+
"""
47+
Create a HadamardTransform for applying to a module. Transforms with the same
48+
size, dtype, and device are cached
49+
50+
:param module: parent module that transform will be applied to
51+
:param args: defines how the transform will be applied to the module
52+
"""
53+
assert isinstance(module, Linear)
54+
size = get_matrix_size(module, args.location)
55+
dtype = module.weight.dtype
56+
device = get_offloaded_device(module)
57+
58+
weight = self.weights[size, dtype, device]
59+
return HadamardTransform(weight, args)
60+
61+
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
62+
data = deterministic_hadamard_matrix(size)
63+
data = data.to(dtype=dtype, device=device)
64+
return Parameter(data, requires_grad=self.scheme.requires_grad)
65+
66+
67+
class HadamardTransform(TransformBase):
68+
def __init__(self, weight: Parameter, args: TransformArgs):
69+
super().__init__()
70+
self.weight = weight
71+
self.args = args
72+
73+
def forward(self, value: Tensor) -> Tensor:
74+
if not self.args.inverse:
75+
weight = self.weight
76+
else:
77+
weight = self.weight.T
78+
79+
return apply_transform_weight(weight, value, self.args.location)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import torch
18+
from compressed_tensors.transform import TransformArgs, TransformScheme
19+
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
20+
from compressed_tensors.transform.utils.utils import (
21+
apply_transform_weight,
22+
get_matrix_size,
23+
)
24+
from compressed_tensors.utils import get_offloaded_device
25+
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
26+
from torch import Tensor, device, dtype
27+
from torch.nn import Linear, Module, Parameter
28+
29+
30+
@TransformFactory.register("random-matrix")
31+
class RandomMatrixFactory(TransformFactory):
32+
"""
33+
Factory used to apply random matrix transforms to a model
34+
35+
:param name: name associated with transform scheme
36+
:param scheme: transform scheme which defines how transforms should be created
37+
:param seed: random seed used to transform weight randomization
38+
"""
39+
40+
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
41+
super().__init__(name, scheme, seed)
42+
self.weights = ParameterizedDefaultDict(self._create_weight)
43+
self.inverses = ParameterizedDefaultDict(self._create_inverse)
44+
45+
def create_transform(self, module: Module, args: TransformArgs):
46+
"""
47+
Create a RandomMatrixTransform for applying to a module. Transforms with the
48+
same size, dtype, and device are cached
49+
50+
:param module: parent module that transform will be applied to
51+
:param args: defines how the transform will be applied to the module
52+
"""
53+
assert isinstance(module, Linear)
54+
size = get_matrix_size(module, args.location)
55+
dtype = module.weight.dtype
56+
device = get_offloaded_device(module)
57+
58+
weight = self.weights[size, dtype, device]
59+
if args.inverse:
60+
weight = self.inverses[weight]
61+
62+
return RandomMatrixTransform(weight, args)
63+
64+
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
65+
data = torch.rand(
66+
(size, size), generator=self.generator, dtype=dtype, device=device
67+
)
68+
return Parameter(data, requires_grad=self.scheme.requires_grad)
69+
70+
def _create_inverse(self, weight: Parameter) -> Parameter:
71+
data = high_precision_invert(weight.data)
72+
return Parameter(data, requires_grad=False)
73+
74+
75+
class RandomMatrixTransform(TransformBase):
76+
def __init__(self, weight: Tensor, args: TransformArgs):
77+
super().__init__()
78+
self.weight = weight # is an inverse if args.inverse
79+
self.args = args
80+
81+
def forward(self, value: Tensor) -> Parameter:
82+
return apply_transform_weight(self.weight, value, self.args.location)
83+
84+
def right_inverse(self, value: Tensor) -> Tensor:
85+
inverse = high_precision_invert(self.weight)
86+
return apply_transform_weight(inverse, value, self.args.location)
87+
88+
89+
def high_precision_invert(weight: Tensor) -> Tensor:
90+
return torch.linalg.inv(weight.to(torch.float32)).to(weight.dtype)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from compressed_tensors.transform import HadamardFactory, TransformFactory
16+
from compressed_tensors.transform.utils.hadamard import random_hadamard_matrix
17+
from torch import device, dtype
18+
from torch.nn import Parameter
19+
20+
21+
@TransformFactory.register("random-hadamard")
22+
class RandomHadamardFactory(HadamardFactory):
23+
"""
24+
Factory used to apply random hadamard transforms to a model
25+
26+
:param name: name associated with transform scheme
27+
:param scheme: transform scheme which defines how transforms should be created
28+
:param seed: random seed used to transform weight randomization
29+
"""
30+
31+
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
32+
data = random_hadamard_matrix(size, self.generator)
33+
data = data.to(dtype=dtype, device=device)
34+
return Parameter(data, requires_grad=self.scheme.requires_grad)

0 commit comments

Comments
 (0)