Skip to content

Commit 492218a

Browse files
committed
implement num_heads
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 853ffcf commit 492218a

File tree

4 files changed

+53
-16
lines changed

4 files changed

+53
-16
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from compressed_tensors.transform import TransformArgs, TransformScheme
1919
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
2020
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
21-
from compressed_tensors.transform.utils.utils import (
21+
from compressed_tensors.transform.utils.matrix import (
2222
apply_transform_weight,
2323
get_matrix_size,
2424
)
@@ -52,15 +52,16 @@ def create_transform(self, module: Module, args: TransformArgs):
5252
:param args: defines how the transform will be applied to the module
5353
"""
5454
assert isinstance(module, Linear)
55-
size = get_matrix_size(module, args.location)
55+
num_heads = self.scheme.num_heads
56+
size = get_matrix_size(module, args.location, num_heads)
5657
dtype = module.weight.dtype
5758
device = get_offloaded_device(module)
5859
exec_device = get_execution_device(module)
5960

6061
factory_kwargs = {"construct_device": exec_device}
6162
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
6263
perm = self.perms[weight] if self.scheme.randomize else None
63-
return HadamardTransform(weight, perm, args)
64+
return HadamardTransform(weight, perm, args, num_heads)
6465

6566
def _create_weight(
6667
self,
@@ -81,12 +82,17 @@ def _create_permutation(self, weight: Parameter) -> Parameter:
8182

8283
class HadamardTransform(TransformBase):
8384
def __init__(
84-
self, weight: Parameter, perm: Union[Parameter, None], args: TransformArgs
85+
self,
86+
weight: Parameter,
87+
perm: Optional[Parameter],
88+
args: TransformArgs,
89+
num_heads: Optional[int],
8590
):
8691
super().__init__()
8792
self.weight = weight
8893
self.perm = perm
8994
self.args = args
95+
self.num_heads = num_heads
9096

9197
def forward(self, value: Tensor) -> Tensor:
9298
weight = self.weight
@@ -97,4 +103,4 @@ def forward(self, value: Tensor) -> Tensor:
97103
if self.args.inverse:
98104
weight = weight.T
99105

100-
return apply_transform_weight(weight, value, self.args.location)
106+
return apply_transform_weight(weight, value, self.args.location, self.num_heads)

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from compressed_tensors.transform import TransformArgs, TransformScheme
1919
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
20-
from compressed_tensors.transform.utils.utils import (
20+
from compressed_tensors.transform.utils.matrix import (
2121
apply_transform_weight,
2222
get_matrix_size,
2323
)
@@ -51,15 +51,16 @@ def create_transform(self, module: Module, args: TransformArgs):
5151
:param args: defines how the transform will be applied to the module
5252
"""
5353
assert isinstance(module, Linear)
54-
size = get_matrix_size(module, args.location)
54+
num_heads = self.scheme.num_heads
55+
size = get_matrix_size(module, args.location, num_heads)
5556
dtype = module.weight.dtype
5657
device = get_offloaded_device(module)
5758

5859
weight = self.weights[size, dtype, device]
5960
if args.inverse:
6061
weight = self.inverses[weight]
6162

62-
return RandomMatrixTransform(weight, args)
63+
return RandomMatrixTransform(weight, args, num_heads)
6364

6465
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
6566
# TODO: verify that weight is invertible (has non-zero determinant)
@@ -74,17 +75,22 @@ def _create_inverse(self, weight: Parameter) -> Parameter:
7475

7576

7677
class RandomMatrixTransform(TransformBase):
77-
def __init__(self, weight: Tensor, args: TransformArgs):
78+
def __init__(self, weight: Tensor, args: TransformArgs, num_heads: Optional[int]):
7879
super().__init__()
7980
self.weight = weight # is an inverse if args.inverse
8081
self.args = args
82+
self.num_heads = num_heads
8183

8284
def forward(self, value: Tensor) -> Parameter:
83-
return apply_transform_weight(self.weight, value, self.args.location)
85+
return apply_transform_weight(
86+
self.weight, value, self.args.location, self.num_heads
87+
)
8488

8589
def right_inverse(self, value: Tensor) -> Tensor:
8690
inverse = high_precision_invert(self.weight)
87-
return apply_transform_weight(inverse, value, self.args.location)
91+
return apply_transform_weight(
92+
inverse, value, self.args.location, self.num_heads
93+
)
8894

8995

9096
def high_precision_invert(weight: Tensor) -> Tensor:

src/compressed_tensors/transform/transform_scheme.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List
15+
from typing import List, Optional
1616

1717
from compressed_tensors.transform import TransformArgs
1818
from pydantic import BaseModel, Field
@@ -40,3 +40,4 @@ class TransformScheme(BaseModel):
4040
apply: List[TransformArgs] = Field(default_factory=list)
4141
randomize: bool = Field(default=False)
4242
requires_grad: bool = Field(default=False)
43+
num_heads: Optional[int] = Field(default=None)

src/compressed_tensors/transform/utils/utils.py renamed to src/compressed_tensors/transform/utils/matrix.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Optional
16+
1517
import torch
1618
from compressed_tensors.transform import TransformLocation
1719

1820

1921
__all__ = ["get_matrix_size", "apply_transform_weight"]
2022

2123

22-
def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int:
24+
def get_matrix_size(
25+
module: torch.nn.Module,
26+
location: TransformLocation,
27+
num_heads: Optional[int] = None,
28+
) -> int:
2329
"""
2430
Determine the size of a matrix given its location on the module
2531
@@ -28,17 +34,36 @@ def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int
2834
:return: size of matrix
2935
"""
3036
assert isinstance(module, torch.nn.Linear)
37+
3138
if location in ("input", TransformLocation.WEIGHT_INPUT):
32-
return module.in_features
39+
size = module.in_features
3340
else:
34-
return module.out_features
41+
size = module.out_features
42+
43+
if num_heads is not None:
44+
assert size % num_heads == 0
45+
size = size // num_heads
46+
47+
return size
3548

3649

3750
def apply_transform_weight(
3851
weight: torch.Tensor,
3952
value: torch.Tensor,
4053
location: TransformLocation,
54+
num_heads: Optional[int] = None,
4155
) -> torch.Tensor:
56+
if num_heads is not None:
57+
weight = weight.repeat((num_heads, num_heads))
58+
59+
return apply_transform_weight_linear(weight, value, location)
60+
61+
62+
def apply_transform_weight_linear(
63+
weight: torch.Tensor,
64+
value: torch.Tensor,
65+
location: TransformLocation,
66+
):
4267
"""
4368
Using the transform location, determine how to apply the transform weight to the
4469
given value. For more info on input and output transforms, see `TransformLocation`
@@ -74,7 +99,6 @@ def apply_transform_weight(
7499
:param location: determines how weight should be applied
75100
:return: value after transform weight has been applied
76101
"""
77-
78102
if location == TransformLocation.INPUT:
79103
return value @ weight
80104

0 commit comments

Comments
 (0)