Skip to content

Commit 180226b

Browse files
[Transform] Implement multi-headed transforms (#383)
* add utilities Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add additional tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add utils and tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * Implement transform factories Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add permutations Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add delete_offload_module Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * key inverses by weight Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * standardize random hadamard Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * prepend input hooks Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * apply sqrt division first Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use divided hadamards Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix typo Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add random option Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use random seeds, rename matrix multiply Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add deterministic generation to random matrix Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix perm math Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstrings Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstrings Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * cleanup Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * cleanup 2 Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * make seed optional Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove iterable check and missing return value Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * Remove unrelated changes * simplify code Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * implement apply, use in tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use hadamards database file Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * try manifest Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * try setup, update hadamards list Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix setup Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add docstrings, cleanup Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix setup, thank you @dbarbuzzi Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove numpy, add tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * solidify dtype, add gpu tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix docstring Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add device option Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * construct on execution device, cache on offload device Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * save construction device changes for later Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * construct on execution device, cache on offload device * cite nja sloane Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove dreg Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * put on device via safe_open Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * nits and docstrings Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstring Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * Merge * merge with construct: construct in float32 Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * construct with same dtype, constructing on fp32 found no difference Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove unnecessary imports Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * bugfixes (#375) Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> * use factory_kwargs Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add frozen dict to deps Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix style Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * merge Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use delete_offload_module Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add docstrign Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use parametrize Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove random from tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * implement num_heads Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * implement head dim Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add more tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * clean up reshaping Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * code cleanup and simplification Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * undo dtype changes Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * simplify tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * rename function Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add docstring Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * refactor lambdas to _multihead_matmul function Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> * multihead_matmul bugfix Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> * support embeddings (#385) Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * more unit test parameterizations Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com> Co-authored-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent ecbe770 commit 180226b

File tree

8 files changed

+363
-119
lines changed

8 files changed

+363
-119
lines changed

src/compressed_tensors/transform/factory/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,8 @@ def input_hook(_, args):
117117
TransformLocation.WEIGHT_INPUT,
118118
TransformLocation.WEIGHT_OUTPUT,
119119
):
120-
assert isinstance(module, torch.nn.Linear)
121-
assert module.bias is None
122-
123120
# fuse transform into weight
121+
assert hasattr(module, "weight")
124122
with torch.no_grad(), align_module_device(module):
125123
update_offload_parameter(module, "weight", transform(module.weight))
126124

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from compressed_tensors.transform import TransformArgs, TransformScheme
2020
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
2121
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
22-
from compressed_tensors.transform.utils.utils import (
22+
from compressed_tensors.transform.utils.matrix import (
2323
apply_transform_weight,
24-
get_matrix_size,
24+
get_transform_size,
2525
)
2626
from compressed_tensors.utils import get_execution_device, get_offloaded_device
2727
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
@@ -52,16 +52,16 @@ def create_transform(self, module: Module, args: TransformArgs):
5252
:param module: parent module that transform will be applied to
5353
:param args: defines how the transform will be applied to the module
5454
"""
55-
assert isinstance(module, Linear)
56-
size = get_matrix_size(module, args.location)
55+
assert hasattr(module, "weight")
56+
size = get_transform_size(module, args.location, self.scheme.head_dim)
5757
dtype = module.weight.dtype
5858
device = get_offloaded_device(module)
5959
exec_device = get_execution_device(module)
6060

6161
factory_kwargs = {"construct_device": exec_device}
6262
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
6363
perm = self.perms[weight] if self.scheme.randomize else None
64-
return HadamardTransform(weight, perm, args)
64+
return HadamardTransform(weight, perm, args, type(module))
6565

6666
def _create_weight(
6767
self,
@@ -82,12 +82,17 @@ def _create_permutation(self, weight: Parameter) -> Parameter:
8282

8383
class HadamardTransform(TransformBase):
8484
def __init__(
85-
self, weight: Parameter, perm: Union[Parameter, None], args: TransformArgs
85+
self,
86+
weight: Parameter,
87+
perm: Optional[Parameter],
88+
args: TransformArgs,
89+
module_type: type[torch.nn.Module],
8690
):
8791
super().__init__()
8892
self.weight = weight
8993
self.perm = perm
9094
self.args = args
95+
self.module_type = module_type
9196
self._scale = math.sqrt(weight.size(0))
9297

9398
def forward(self, value: Tensor) -> Tensor:
@@ -98,5 +103,7 @@ def forward(self, value: Tensor) -> Tensor:
98103

99104
if self.args.inverse:
100105
weight = weight.T
101-
102-
return apply_transform_weight(weight, value, self.args.location) / self._scale
106+
107+
return apply_transform_weight(
108+
weight, value, self.args.location, self.module_type
109+
) / self._scale

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
import torch
1818
from compressed_tensors.transform import TransformArgs, TransformScheme
1919
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
20-
from compressed_tensors.transform.utils.utils import (
20+
from compressed_tensors.transform.utils.matrix import (
2121
apply_transform_weight,
22-
get_matrix_size,
22+
get_transform_size,
2323
)
2424
from compressed_tensors.utils import get_offloaded_device
2525
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
@@ -50,16 +50,16 @@ def create_transform(self, module: Module, args: TransformArgs):
5050
:param module: parent module that transform will be applied to
5151
:param args: defines how the transform will be applied to the module
5252
"""
53-
assert isinstance(module, Linear)
54-
size = get_matrix_size(module, args.location)
53+
assert hasattr(module, "weight")
54+
size = get_transform_size(module, args.location, self.scheme.head_dim)
5555
dtype = module.weight.dtype
5656
device = get_offloaded_device(module)
5757

5858
weight = self.weights[size, dtype, device]
5959
if args.inverse:
6060
weight = self.inverses[weight]
6161

62-
return RandomMatrixTransform(weight, args)
62+
return RandomMatrixTransform(weight, args, type(module))
6363

6464
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
6565
# TODO: verify that weight is invertible (has non-zero determinant)
@@ -74,17 +74,27 @@ def _create_inverse(self, weight: Parameter) -> Parameter:
7474

7575

7676
class RandomMatrixTransform(TransformBase):
77-
def __init__(self, weight: Tensor, args: TransformArgs):
77+
def __init__(
78+
self,
79+
weight: Tensor,
80+
args: TransformArgs,
81+
module_type: type[torch.nn.Module],
82+
):
7883
super().__init__()
7984
self.weight = weight # is an inverse if args.inverse
8085
self.args = args
86+
self.module_type = module_type
8187

8288
def forward(self, value: Tensor) -> Parameter:
83-
return apply_transform_weight(self.weight, value, self.args.location)
89+
return apply_transform_weight(
90+
self.weight, value, self.args.location, self.module_type
91+
)
8492

8593
def right_inverse(self, value: Tensor) -> Tensor:
8694
inverse = high_precision_invert(self.weight)
87-
return apply_transform_weight(inverse, value, self.args.location)
95+
return apply_transform_weight(
96+
inverse, value, self.args.location, self.module_type
97+
)
8898

8999

90100
def high_precision_invert(weight: Tensor) -> Tensor:

src/compressed_tensors/transform/transform_scheme.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List
15+
from typing import List, Optional
1616

1717
from compressed_tensors.transform import TransformArgs
1818
from pydantic import BaseModel, Field
@@ -40,3 +40,4 @@ class TransformScheme(BaseModel):
4040
apply: List[TransformArgs] = Field(default_factory=list)
4141
randomize: bool = Field(default=False)
4242
requires_grad: bool = Field(default=False)
43+
head_dim: Optional[int] = Field(default=None)
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Callable, Optional, Tuple
16+
17+
import torch
18+
from compressed_tensors.transform import TransformLocation
19+
20+
21+
__all__ = ["get_transform_size", "apply_transform_weight"]
22+
23+
24+
def get_transform_size(
25+
module: torch.nn.Module,
26+
location: TransformLocation,
27+
head_dim: Optional[int] = None,
28+
) -> int:
29+
"""
30+
Determine the size of a transform matrix given its location on the module
31+
32+
:param module: module that matrix will be applied to
33+
:param location: location on module
34+
:param head_dim: size of head when transform is applied to mha
35+
:return: size of matrix
36+
"""
37+
if isinstance(module, torch.nn.Linear):
38+
if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
39+
size = module.in_features
40+
else:
41+
size = module.out_features
42+
elif isinstance(module, torch.nn.Embedding):
43+
if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
44+
size = module.num_embeddings
45+
else:
46+
size = module.embedding_dim
47+
else:
48+
raise NotImplementedError(f"Transforms on {type(module)} are not supported")
49+
50+
if head_dim is not None:
51+
if size % head_dim != 0:
52+
raise ValueError(
53+
f"{head_dim} must divide {size} for {type(module)} at {location}"
54+
)
55+
56+
size = head_dim
57+
58+
return size
59+
60+
61+
def apply_transform_weight(
62+
transform_weight: torch.Tensor,
63+
value: torch.Tensor,
64+
location: TransformLocation,
65+
module_type: type[torch.nn.Module],
66+
) -> torch.Tensor:
67+
"""
68+
Using the transform location, apply the transform_weight to the
69+
given value wrt linear weights. For more info on input and output transforms,
70+
see `TransformLocation`
71+
72+
The following explains how weights should be applied to values according to location
73+
74+
let x be input activation
75+
W be weight,
76+
yh, xh, Wh be transformed output, input, weight
77+
78+
note that
79+
y = (x W.T) // torch.nn.Linear
80+
81+
Choose values for yh, xh, and Wh which incorporate matrix transforms
82+
83+
let V, Vi be transform matrices on input side
84+
U, Ui be transform matrices on output side
85+
86+
pick xh = (x V)
87+
Wh = (U.T W Vi.T)
88+
yh = (y U)
89+
90+
The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh
91+
92+
(xh) (Wh).T = (x V) (U.T W Vi.T).T
93+
= (x V) (Vi W.T U) // transpose matrix product identity
94+
= (x W.T) U
95+
= y U
96+
= yh
97+
98+
:param transform_weight: transform weight to apply
99+
:param value: value to apply transform_weight to
100+
:param location: determines how weight should be applied
101+
:param model_type: result of type(module), passed in to determine application of
102+
weight transform
103+
:return: value after transform_weight has been applied
104+
"""
105+
106+
assert transform_weight.shape[0] == transform_weight.shape[1]
107+
108+
if module_type == torch.nn.Linear:
109+
if location == TransformLocation.INPUT:
110+
return _multihead_matmul(value, transform_weight)
111+
112+
elif location == TransformLocation.WEIGHT_INPUT:
113+
# equivalent to (transform_weight @ value.T).T
114+
return _multihead_matmul(value, transform_weight.T)
115+
116+
elif location == TransformLocation.WEIGHT_OUTPUT:
117+
# equivalent to (value.T @ transform_weight).T
118+
return _multihead_matmul(transform_weight.T, value)
119+
120+
elif location == TransformLocation.OUTPUT:
121+
return _multihead_matmul(value, transform_weight)
122+
123+
# similar derivation to torch.nn.Linear, but `y = (x W)`
124+
elif module_type == torch.nn.Embedding:
125+
if location == TransformLocation.INPUT:
126+
return _multihead_matmul(value, transform_weight)
127+
128+
elif location == TransformLocation.WEIGHT_INPUT:
129+
return _multihead_matmul(
130+
transform_weight,
131+
value,
132+
)
133+
134+
elif location == TransformLocation.WEIGHT_OUTPUT:
135+
return _multihead_matmul(value, transform_weight)
136+
137+
elif location == TransformLocation.OUTPUT:
138+
return _multihead_matmul(value, transform_weight)
139+
140+
raise NotImplementedError(
141+
f"Applying transforms to {module_type} {location} is not supported"
142+
)
143+
144+
145+
def _multihead_matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
146+
"""
147+
Performs A @ B for last two dims of two matrices A and B that possibly
148+
have different shapes, as is the case in multi-headed dimension. If
149+
shapes are different, this is equivalent to converting the last two dims
150+
of the smaller matrix into a block-diagonal matrix with the same shape as
151+
the last two dims of the larger matrix.
152+
153+
E.g. if A is half the size of B, this function will perform
154+
[[A ] @ B
155+
[ A]]
156+
157+
If B is a third of the size of A, this function will perform
158+
A @ [[B ]
159+
[ B ]
160+
[ B]]
161+
162+
This function will error out if the shapes are not evenly divisble
163+
164+
:param A: left-hand tensor
165+
:param B: right-hand tensor
166+
:return: result
167+
"""
168+
if A.shape[-1] > B.shape[-2]:
169+
head_dim = B.shape[-2]
170+
num_heads = A.shape[-1] // head_dim
171+
A = A.unflatten(-1, (num_heads, head_dim))
172+
return (A @ B).flatten(-2, -1)
173+
elif A.shape[-1] < B.shape[-2]:
174+
head_dim = A.shape[-1]
175+
num_heads = B.shape[-2] // head_dim
176+
B = B.unflatten(-2, (num_heads, head_dim))
177+
return (A @ B).flatten(-3, -2)
178+
else:
179+
return A @ B

0 commit comments

Comments
 (0)