Skip to content

[Transform] Implement multi-headed transforms #383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 99 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
d8a10ec
add utilities
kylesayrs May 30, 2025
d2af054
add tests
kylesayrs May 30, 2025
e32d5b5
add additional tests
kylesayrs May 30, 2025
9d0518b
add utils and tests
kylesayrs May 30, 2025
8c5a2d9
Implement transform factories
kylesayrs May 30, 2025
809e367
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs May 30, 2025
8d613b3
add permutations
kylesayrs May 31, 2025
57d171a
add delete_offload_module
kylesayrs May 31, 2025
d77bcef
Merge branch 'kylesayrs/transform-accelerate-utilities' into kylesayr…
kylesayrs May 31, 2025
ab73b43
Merge branch 'kylesayrs/transform-accelerate-utilities' into kylesayr…
kylesayrs May 31, 2025
4b55733
Merge branch 'kylesayrs/transform_factory' into kylesayrs/transform_p…
kylesayrs May 31, 2025
aa7d21b
key inverses by weight
kylesayrs May 31, 2025
6901e02
fix tests
kylesayrs May 31, 2025
47ae9fe
standardize random hadamard
kylesayrs May 31, 2025
34f1343
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs May 31, 2025
1039100
prepend input hooks
kylesayrs May 31, 2025
5677553
Merge remote-tracking branch 'origin' into kylesayrs/transform_utils
kylesayrs Jun 5, 2025
68ec14e
apply sqrt division first
kylesayrs Jun 5, 2025
a62418a
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs Jun 5, 2025
b117523
use divided hadamards
kylesayrs Jun 5, 2025
a46f754
fix typo
kylesayrs Jun 5, 2025
cb1cb52
add random option
kylesayrs Jun 5, 2025
7c02bb2
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs Jun 5, 2025
02af1e9
use random seeds, rename matrix multiply
kylesayrs Jun 5, 2025
f45f3e9
add deterministic generation to random matrix
kylesayrs Jun 5, 2025
7a7abdf
fix perm math
kylesayrs Jun 5, 2025
6e52894
update docstrings
kylesayrs Jun 5, 2025
7230933
update docstrings
kylesayrs Jun 5, 2025
f74fe3e
Merge branch 'kylesayrs/transform_factory' into kylesayrs/transform_p…
kylesayrs Jun 5, 2025
92ddea9
cleanup
kylesayrs Jun 5, 2025
779956f
cleanup 2
kylesayrs Jun 5, 2025
fbd2939
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs Jun 5, 2025
dd72b6a
make seed optional
kylesayrs Jun 5, 2025
4ae491d
Merge branch 'kylesayrs/transform_factory' into kylesayrs/transform_p…
kylesayrs Jun 5, 2025
da19b0f
remove iterable check and missing return value
kylesayrs Jun 9, 2025
7ab17ce
Merge branch 'main' into kylesayrs/transform_permutations
kylesayrs Jun 10, 2025
33df50f
Merge remote-tracking branch 'origin' into kylesayrs/transform_permut…
kylesayrs Jun 10, 2025
6e1ec39
Remove unrelated changes
kylesayrs Jun 10, 2025
938e702
simplify code
kylesayrs Jun 10, 2025
27bc0b3
implement apply, use in tests
kylesayrs Jun 10, 2025
a27db62
use hadamards database file
kylesayrs Jun 11, 2025
ce63955
try manifest
kylesayrs Jun 11, 2025
7ae5863
try setup, update hadamards list
kylesayrs Jun 11, 2025
67675c3
fix setup
kylesayrs Jun 11, 2025
f061db9
add docstrings, cleanup
kylesayrs Jun 11, 2025
4a84ce1
fix setup, thank you @dbarbuzzi
kylesayrs Jun 11, 2025
cde1066
remove numpy, add tests
kylesayrs Jun 11, 2025
1ba6195
solidify dtype, add gpu tests
kylesayrs Jun 11, 2025
c373345
fix docstring
kylesayrs Jun 11, 2025
fbaf47a
add device option
kylesayrs Jun 11, 2025
5a887f4
construct on execution device, cache on offload device
kylesayrs Jun 11, 2025
310fe6d
save construction device changes for later
kylesayrs Jun 11, 2025
b715329
construct on execution device, cache on offload device
kylesayrs Jun 11, 2025
249323c
cite nja sloane
kylesayrs Jun 11, 2025
1823af4
Merge branch 'kylesayrs/extend-hadamard', remote-tracking branch 'ori…
kylesayrs Jun 11, 2025
94a0bf5
Merge remote-tracking branch 'origin' into kylesayrs/extend-hadamard
kylesayrs Jun 11, 2025
cf066e0
Merge branch 'kylesayrs/extend-hadamard' into kylesayrs/transform_con…
kylesayrs Jun 11, 2025
c1a4a34
remove dreg
kylesayrs Jun 11, 2025
5807ee1
put on device via safe_open
kylesayrs Jun 11, 2025
ccb88ed
nits and docstrings
kylesayrs Jun 12, 2025
feba695
update docstring
kylesayrs Jun 12, 2025
c8f6b53
Merge branch 'kylesayrs/extend-hadamard' into kylesayrs/transform_con…
kylesayrs Jun 12, 2025
e7f08e1
Merge branch 'kylesayrs/transform_construct_cache_device' into kylesa…
kylesayrs Jun 12, 2025
75b9307
Merge remote-tracking branch 'origin' into kylesayrs/transform_permut…
kylesayrs Jun 12, 2025
b6a0dd4
Merge remote-tracking branch 'origin' into kylesayrs/transform_constr…
kylesayrs Jun 13, 2025
955f2f5
Merge
kylesayrs Jun 23, 2025
226f367
merge with construct: construct in float32
kylesayrs Jun 23, 2025
9745acb
Merge remote-tracking branch 'origin' into kylesayrs/transform_apply
kylesayrs Jun 23, 2025
fd3390a
construct with same dtype, constructing on fp32 found no difference
kylesayrs Jun 23, 2025
3c55003
Merge branch 'kylesayrs/transform_construct_cache_device' into kylesa…
kylesayrs Jun 23, 2025
ad29c15
remove unnecessary imports
kylesayrs Jun 23, 2025
85f40b5
bugfixes (#375)
brian-dellabetta Jul 2, 2025
500af9b
use factory_kwargs
kylesayrs Jul 7, 2025
8e36540
add frozen dict to deps
kylesayrs Jul 7, 2025
48653ec
Merge remote-tracking branch 'origin' into kylesayrs/transform_permut…
kylesayrs Jul 7, 2025
56df0f7
fix style
kylesayrs Jul 7, 2025
a251569
merge
kylesayrs Jul 7, 2025
cb5a32b
Merge remote-tracking branch 'origin' into kylesayrs/transform_apply
kylesayrs Jul 7, 2025
06e0346
Merge branch 'kylesayrs/transform_permutations' into kylesayrs/transf…
kylesayrs Jul 7, 2025
0a4fea5
Merge branch 'kylesayrs/transform_construct_cache_device' into kylesa…
kylesayrs Jul 7, 2025
49740c6
use delete_offload_module
kylesayrs Jul 7, 2025
7dc182b
Merge remote-tracking branch 'origin' into kylesayrs/transform_constr…
kylesayrs Jul 7, 2025
80db2ce
Merge branch 'kylesayrs/transform_construct_cache_device' into kylesa…
kylesayrs Jul 7, 2025
e06bbad
add docstrign
kylesayrs Jul 7, 2025
438bc13
Merge remote-tracking branch 'origin' into kylesayrs/transform_apply
kylesayrs Jul 7, 2025
fd77ecc
use parametrize
kylesayrs Jul 8, 2025
bbf9533
remove random from tests
kylesayrs Jul 8, 2025
853ffcf
Merge remote-tracking branch 'origin' into kylesayrs/transform_apply
kylesayrs Jul 9, 2025
492218a
implement num_heads
kylesayrs Jul 8, 2025
cf606e7
implement head dim
kylesayrs Jul 8, 2025
116b9f9
add more tests
kylesayrs Jul 8, 2025
f220fb9
clean up reshaping
kylesayrs Jul 9, 2025
9039eb5
code cleanup and simplification
kylesayrs Jul 9, 2025
ac1eece
undo dtype changes
kylesayrs Jul 9, 2025
7334234
simplify tests
kylesayrs Jul 9, 2025
d1b4f83
rename function
kylesayrs Jul 9, 2025
97f237e
add docstring
kylesayrs Jul 9, 2025
df74532
Merge remote-tracking branch 'origin' into kylesayrs/transform-attent…
kylesayrs Jul 10, 2025
ed7e20b
Merge remote-tracking branch 'origin' into kylesayrs/transform-attent…
kylesayrs Jul 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions src/compressed_tensors/transform/factory/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from compressed_tensors.transform import TransformArgs, TransformScheme
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
from compressed_tensors.transform.utils.utils import (
from compressed_tensors.transform.utils.matrix import (
apply_transform_weight,
get_matrix_size,
get_transform_size,
)
from compressed_tensors.utils import get_execution_device, get_offloaded_device
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
Expand Down Expand Up @@ -52,15 +52,15 @@ def create_transform(self, module: Module, args: TransformArgs):
:param args: defines how the transform will be applied to the module
"""
assert isinstance(module, Linear)
size = get_matrix_size(module, args.location)
size = get_transform_size(module, args.location, self.scheme.head_dim)
dtype = module.weight.dtype
device = get_offloaded_device(module)
exec_device = get_execution_device(module)

factory_kwargs = {"construct_device": exec_device}
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
perm = self.perms[weight] if self.scheme.randomize else None
return HadamardTransform(weight, perm, args)
return HadamardTransform(weight, perm, args, type(module))

def _create_weight(
self,
Expand All @@ -81,12 +81,17 @@ def _create_permutation(self, weight: Parameter) -> Parameter:

class HadamardTransform(TransformBase):
def __init__(
self, weight: Parameter, perm: Union[Parameter, None], args: TransformArgs
self,
weight: Parameter,
perm: Optional[Parameter],
args: TransformArgs,
module_type: type[torch.nn.Module],
):
super().__init__()
self.weight = weight
self.perm = perm
self.args = args
self.module_type = module_type

def forward(self, value: Tensor) -> Tensor:
weight = self.weight
Expand All @@ -97,4 +102,6 @@ def forward(self, value: Tensor) -> Tensor:
if self.args.inverse:
weight = weight.T

return apply_transform_weight(weight, value, self.args.location)
return apply_transform_weight(
weight, value, self.args.location, self.module_type
)
24 changes: 17 additions & 7 deletions src/compressed_tensors/transform/factory/matrix_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import torch
from compressed_tensors.transform import TransformArgs, TransformScheme
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
from compressed_tensors.transform.utils.utils import (
from compressed_tensors.transform.utils.matrix import (
apply_transform_weight,
get_matrix_size,
get_transform_size,
)
from compressed_tensors.utils import get_offloaded_device
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
Expand Down Expand Up @@ -51,15 +51,15 @@ def create_transform(self, module: Module, args: TransformArgs):
:param args: defines how the transform will be applied to the module
"""
assert isinstance(module, Linear)
size = get_matrix_size(module, args.location)
size = get_transform_size(module, args.location, self.scheme.head_dim)
dtype = module.weight.dtype
device = get_offloaded_device(module)

weight = self.weights[size, dtype, device]
if args.inverse:
weight = self.inverses[weight]

return RandomMatrixTransform(weight, args)
return RandomMatrixTransform(weight, args, type(module))

def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
# TODO: verify that weight is invertible (has non-zero determinant)
Expand All @@ -74,17 +74,27 @@ def _create_inverse(self, weight: Parameter) -> Parameter:


class RandomMatrixTransform(TransformBase):
def __init__(self, weight: Tensor, args: TransformArgs):
def __init__(
self,
weight: Tensor,
args: TransformArgs,
module_type: type[torch.nn.Module],
):
super().__init__()
self.weight = weight # is an inverse if args.inverse
self.args = args
self.module_type = module_type

def forward(self, value: Tensor) -> Parameter:
return apply_transform_weight(self.weight, value, self.args.location)
return apply_transform_weight(
self.weight, value, self.args.location, self.module_type
)

def right_inverse(self, value: Tensor) -> Tensor:
inverse = high_precision_invert(self.weight)
return apply_transform_weight(inverse, value, self.args.location)
return apply_transform_weight(
inverse, value, self.args.location, self.module_type
)


def high_precision_invert(weight: Tensor) -> Tensor:
Expand Down
3 changes: 2 additions & 1 deletion src/compressed_tensors/transform/transform_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import List, Optional

from compressed_tensors.transform import TransformArgs
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -40,3 +40,4 @@ class TransformScheme(BaseModel):
apply: List[TransformArgs] = Field(default_factory=list)
randomize: bool = Field(default=False)
requires_grad: bool = Field(default=False)
head_dim: Optional[int] = Field(default=None)
147 changes: 147 additions & 0 deletions src/compressed_tensors/transform/utils/matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Tuple

import torch
from compressed_tensors.transform import TransformLocation


__all__ = ["get_transform_size", "apply_transform_weight"]


def get_transform_size(
module: torch.nn.Module,
location: TransformLocation,
head_dim: Optional[int] = None,
) -> int:
"""
Determine the size of a transform matrix given its location on the module

:param module: module that matrix will be applied to
:param location: location on module
:param head_dim: size of head when transform is applied to mha
:return: size of matrix
"""
if isinstance(module, torch.nn.Linear):
if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
size = module.in_features
else:
size = module.out_features
else:
raise NotImplementedError(f"Transforms on {type(module)} are not supported")

if head_dim is not None:
if size % head_dim != 0:
raise ValueError(
f"{head_dim} must divide {size} for {type(module)} at {location}"
)

size = head_dim

return size


def apply_transform_weight(
weight: torch.Tensor,
value: torch.Tensor,
location: TransformLocation,
module_type: type[torch.nn.Module],
) -> torch.Tensor:
"""
:param weight: transform weight to apply
:param value: value to apply weight to
:param location: determines how weight should be applied
:param model_type: result of type(module), passed in to determine application of
weight transform
:return: value after weight has been applied
"""
fn, axis = _get_transform_method(module_type, location)

assert weight.shape[0] == weight.shape[1]
head_dim = weight.shape[0]
num_heads = value.shape[axis] // head_dim

value = value.unflatten(axis, (num_heads, head_dim))
value = fn(weight, value)
value = value.flatten(axis - 1, axis)

return value


def _get_transform_method(
module_type: type[torch.nn.Module],
location: TransformLocation,
) -> Tuple[Callable[[torch.Tensor, torch.Tensor], torch.Tensor], int]:
"""
Using the transform location, determine how to apply the transform weight to the
given value wrt linear weights. For more info on input and output transforms,
see `TransformLocation`

The following explains how weights should be applied to values according to location

let x be input activation
W be weight,
yh, xh, Wh be transformed output, input, weight

note that
y = (x W.T) // torch.nn.Linear

Choose values for yh, xh, and Wh which incorporate matrix transforms

let V, Vi be transform matrices on input side
U, Ui be transform matrices on output side

pick xh = (x V)
Wh = (U.T W Vi.T)
yh = (y U)

The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh

(xh) (Wh).T = (x V) (U.T W Vi.T).T
= (x V) (Vi W.T U) // transpose matrix product identity
= (x W.T) U
= y U
= yh

:param weight: transform weight to apply
:param value: value to apply weight to
:param location: determines how weight should be applied
:return: value after transform weight has been applied
"""
fn = axis = None

if module_type == torch.nn.Linear:
if location == TransformLocation.INPUT:
fn = lambda weight, value: value @ weight
axis = -1

elif location == TransformLocation.WEIGHT_INPUT:
fn = lambda weight, value: value @ weight.T
axis = -1

elif location == TransformLocation.WEIGHT_OUTPUT:
fn = lambda weight, value: weight.T @ value
axis = -2

elif location == TransformLocation.OUTPUT:
fn = lambda weight, value: value @ weight
axis = -1

if fn is None:
raise NotImplementedError(
f"Applying transforms to {module_type} {location} is not supported"
)

return fn, axis
91 changes: 0 additions & 91 deletions src/compressed_tensors/transform/utils/utils.py

This file was deleted.

Loading