Skip to content

[Transforms] apply transforms to torch.nn.Embedding modules #377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 75 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 71 commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
d8a10ec
add utilities
kylesayrs May 30, 2025
d2af054
add tests
kylesayrs May 30, 2025
e32d5b5
add additional tests
kylesayrs May 30, 2025
9d0518b
add utils and tests
kylesayrs May 30, 2025
8c5a2d9
Implement transform factories
kylesayrs May 30, 2025
809e367
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs May 30, 2025
8d613b3
add permutations
kylesayrs May 31, 2025
57d171a
add delete_offload_module
kylesayrs May 31, 2025
d77bcef
Merge branch 'kylesayrs/transform-accelerate-utilities' into kylesayr…
kylesayrs May 31, 2025
ab73b43
Merge branch 'kylesayrs/transform-accelerate-utilities' into kylesayr…
kylesayrs May 31, 2025
4b55733
Merge branch 'kylesayrs/transform_factory' into kylesayrs/transform_p…
kylesayrs May 31, 2025
aa7d21b
key inverses by weight
kylesayrs May 31, 2025
6901e02
fix tests
kylesayrs May 31, 2025
47ae9fe
standardize random hadamard
kylesayrs May 31, 2025
34f1343
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs May 31, 2025
1039100
prepend input hooks
kylesayrs May 31, 2025
5677553
Merge remote-tracking branch 'origin' into kylesayrs/transform_utils
kylesayrs Jun 5, 2025
68ec14e
apply sqrt division first
kylesayrs Jun 5, 2025
a62418a
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs Jun 5, 2025
b117523
use divided hadamards
kylesayrs Jun 5, 2025
a46f754
fix typo
kylesayrs Jun 5, 2025
cb1cb52
add random option
kylesayrs Jun 5, 2025
7c02bb2
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs Jun 5, 2025
02af1e9
use random seeds, rename matrix multiply
kylesayrs Jun 5, 2025
f45f3e9
add deterministic generation to random matrix
kylesayrs Jun 5, 2025
7a7abdf
fix perm math
kylesayrs Jun 5, 2025
6e52894
update docstrings
kylesayrs Jun 5, 2025
7230933
update docstrings
kylesayrs Jun 5, 2025
f74fe3e
Merge branch 'kylesayrs/transform_factory' into kylesayrs/transform_p…
kylesayrs Jun 5, 2025
92ddea9
cleanup
kylesayrs Jun 5, 2025
779956f
cleanup 2
kylesayrs Jun 5, 2025
fbd2939
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs Jun 5, 2025
dd72b6a
make seed optional
kylesayrs Jun 5, 2025
4ae491d
Merge branch 'kylesayrs/transform_factory' into kylesayrs/transform_p…
kylesayrs Jun 5, 2025
da19b0f
remove iterable check and missing return value
kylesayrs Jun 9, 2025
7ab17ce
Merge branch 'main' into kylesayrs/transform_permutations
kylesayrs Jun 10, 2025
33df50f
Merge remote-tracking branch 'origin' into kylesayrs/transform_permut…
kylesayrs Jun 10, 2025
6e1ec39
Remove unrelated changes
kylesayrs Jun 10, 2025
938e702
simplify code
kylesayrs Jun 10, 2025
27bc0b3
implement apply, use in tests
kylesayrs Jun 10, 2025
a27db62
use hadamards database file
kylesayrs Jun 11, 2025
ce63955
try manifest
kylesayrs Jun 11, 2025
7ae5863
try setup, update hadamards list
kylesayrs Jun 11, 2025
67675c3
fix setup
kylesayrs Jun 11, 2025
f061db9
add docstrings, cleanup
kylesayrs Jun 11, 2025
4a84ce1
fix setup, thank you @dbarbuzzi
kylesayrs Jun 11, 2025
cde1066
remove numpy, add tests
kylesayrs Jun 11, 2025
1ba6195
solidify dtype, add gpu tests
kylesayrs Jun 11, 2025
c373345
fix docstring
kylesayrs Jun 11, 2025
fbaf47a
add device option
kylesayrs Jun 11, 2025
5a887f4
construct on execution device, cache on offload device
kylesayrs Jun 11, 2025
310fe6d
save construction device changes for later
kylesayrs Jun 11, 2025
b715329
construct on execution device, cache on offload device
kylesayrs Jun 11, 2025
249323c
cite nja sloane
kylesayrs Jun 11, 2025
1823af4
Merge branch 'kylesayrs/extend-hadamard', remote-tracking branch 'ori…
kylesayrs Jun 11, 2025
94a0bf5
Merge remote-tracking branch 'origin' into kylesayrs/extend-hadamard
kylesayrs Jun 11, 2025
cf066e0
Merge branch 'kylesayrs/extend-hadamard' into kylesayrs/transform_con…
kylesayrs Jun 11, 2025
c1a4a34
remove dreg
kylesayrs Jun 11, 2025
5807ee1
put on device via safe_open
kylesayrs Jun 11, 2025
ccb88ed
nits and docstrings
kylesayrs Jun 12, 2025
feba695
update docstring
kylesayrs Jun 12, 2025
c8f6b53
Merge branch 'kylesayrs/extend-hadamard' into kylesayrs/transform_con…
kylesayrs Jun 12, 2025
e7f08e1
Merge branch 'kylesayrs/transform_construct_cache_device' into kylesa…
kylesayrs Jun 12, 2025
b6a0dd4
Merge remote-tracking branch 'origin' into kylesayrs/transform_constr…
kylesayrs Jun 13, 2025
955f2f5
Merge
kylesayrs Jun 23, 2025
226f367
merge with construct: construct in float32
kylesayrs Jun 23, 2025
9745acb
Merge remote-tracking branch 'origin' into kylesayrs/transform_apply
kylesayrs Jun 23, 2025
fd3390a
construct with same dtype, constructing on fp32 found no difference
kylesayrs Jun 23, 2025
3c55003
Merge branch 'kylesayrs/transform_construct_cache_device' into kylesa…
kylesayrs Jun 23, 2025
85f40b5
bugfixes (#375)
brian-dellabetta Jul 2, 2025
9c9f4aa
utils to apply transforms to torch.nn.Embedding modules
brian-dellabetta Jul 2, 2025
76c1014
Update src/compressed_tensors/transform/utils/utils.py
brian-dellabetta Jul 5, 2025
24a0307
is_linear -> module_type
brian-dellabetta Jul 7, 2025
db3b5a4
spinquant r2 working poc
brian-dellabetta Jul 7, 2025
af7a254
head_dim -> num_heads config change
brian-dellabetta Jul 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,7 @@ def apply_quantization_config(
# list of submodules to ignore
ignored_submodules = defaultdict(list)
# mark appropriate layers for quantization by setting their quantization schemes
for name, submodule in iter_named_quantizable_modules(
model,
include_children=True,
include_attn=True,
): # child modules and attention modules
for name, submodule in model.named_modules(): # child modules and attention modules
# potentially fix module name to remove FSDP wrapper prefix
name = fix_fsdp_module_name(name)
if matches := find_name_or_class_matches(name, submodule, config.ignore):
Expand Down
1 change: 1 addition & 0 deletions src/compressed_tensors/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
from .factory.hadamard import *
from .factory.matrix_multiply import *
from .factory.random_hadamard import *
from .apply import *
25 changes: 25 additions & 0 deletions src/compressed_tensors/transform/apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from compressed_tensors.transform import TransformConfig, TransformFactory


__all__ = ["apply_transform_config"]


def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
for name, scheme in config.config_groups.items():
factory = TransformFactory.from_scheme(scheme, name=name)
factory.apply_to_model(model)
10 changes: 4 additions & 6 deletions src/compressed_tensors/transform/factory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ def _apply_to_module(self, module: Module, args: TransformArgs):
# create transform as submodule
transform_name = f"{self.name}_{args.location.value}"
transform = self.create_transform(module, args)
register_offload_module(module, transform_name, transform) # (1)

# register input transformation hook
if args.location == TransformLocation.INPUT:
register_offload_module(module, transform_name, transform)

def input_hook(_, args):
input = args[0]
Expand All @@ -115,8 +115,8 @@ def input_hook(_, args):
TransformLocation.WEIGHT_INPUT,
TransformLocation.WEIGHT_OUTPUT,
):
assert isinstance(module, torch.nn.Linear)
assert module.bias is None
assert isinstance(module, (torch.nn.Linear, torch.nn.Embedding))
assert not hasattr(module, "bias") or module.bias is None

with torch.no_grad(), align_module_device(module):
update_offload_parameter(module, "weight", transform(module.weight))
Expand All @@ -130,6 +130,7 @@ def input_hook(_, args):

# register output transformation hook
elif args.location == TransformLocation.OUTPUT:
register_offload_module(module, transform_name, transform)

def output_hook(_, _input, output):
return transform(output)
Expand All @@ -140,9 +141,6 @@ def output_hook(_, _input, output):
else:
raise NotImplementedError()

# (1) even in the `weight` cases, this submodule attachment is needed in order
# to support saving in the frozen state


class TransformBase(Module, ABC):
"""
Expand Down
54 changes: 40 additions & 14 deletions src/compressed_tensors/transform/factory/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import Optional, Union

import torch
from compressed_tensors.transform import TransformArgs, TransformScheme
Expand All @@ -22,7 +22,7 @@
apply_transform_weight,
get_matrix_size,
)
from compressed_tensors.utils import get_offloaded_device
from compressed_tensors.utils import get_execution_device, get_offloaded_device
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
from torch import Tensor, device, dtype
from torch.nn import Linear, Module, Parameter
Expand All @@ -41,6 +41,7 @@ class HadamardFactory(TransformFactory):
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
super().__init__(name, scheme, seed)
self.weights = ParameterizedDefaultDict(self._create_weight)
self.perms = ParameterizedDefaultDict(self._create_permutation)

def create_transform(self, module: Module, args: TransformArgs):
"""
Expand All @@ -50,30 +51,55 @@ def create_transform(self, module: Module, args: TransformArgs):
:param module: parent module that transform will be applied to
:param args: defines how the transform will be applied to the module
"""
assert isinstance(module, Linear)
is_linear = isinstance(module, Linear)
assert hasattr(module, "weight")
size = get_matrix_size(module, args.location)
dtype = module.weight.dtype
device = get_offloaded_device(module)
exec_device = get_execution_device(module)

weight = self.weights[size, dtype, device]
return HadamardTransform(weight, args)
weight = self.weights.get(size, dtype, device, construct_device=exec_device)
perm = self.perms[weight] if self.scheme.randomize else None
return HadamardTransform(weight, perm, args, is_linear)

def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
data = deterministic_hadamard_matrix(size, dtype, device)
data = data.to(dtype=dtype, device=device)
def _create_weight(
self,
size: int,
dtype: dtype,
device: device,
construct_device: device,
) -> Parameter:
# construct on execution device, cache on offload device
data = deterministic_hadamard_matrix(size, dtype, construct_device)
data = data.to(device=device)
return Parameter(data, requires_grad=self.scheme.requires_grad)

def _create_permutation(self, weight: Parameter) -> Parameter:
data = torch.randperm(weight.size(0), generator=self.generator)
return Parameter(data, requires_grad=False)


class HadamardTransform(TransformBase):
def __init__(self, weight: Parameter, args: TransformArgs):
def __init__(
self,
weight: Parameter,
perm: Union[Parameter, None],
args: TransformArgs,
is_linear: bool,
):
super().__init__()
self.weight = weight
self.perm = perm
self.args = args
self.is_linear = is_linear

def forward(self, value: Tensor) -> Tensor:
if not self.args.inverse:
weight = self.weight
else:
weight = self.weight.T
weight = self.weight

if self.perm is not None:
weight = weight[self.perm][:, self.perm]

if self.args.inverse:
weight = weight.T

return apply_transform_weight(weight, value, self.args.location)
return apply_transform_weight(weight, value, self.args.location, self.is_linear)
14 changes: 11 additions & 3 deletions src/compressed_tensors/transform/factory/random_hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from compressed_tensors.transform import HadamardFactory, TransformFactory
from compressed_tensors.transform.utils.hadamard import random_hadamard_matrix
from torch import device, dtype
Expand All @@ -28,7 +29,14 @@ class RandomHadamardFactory(HadamardFactory):
:param seed: random seed used to transform weight randomization
"""

def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
data = random_hadamard_matrix(size, dtype, device, self.generator)
data = data.to(dtype=dtype, device=device)
def _create_weight(
self,
size: int,
dtype: dtype,
device: device,
construct_device: device,
) -> Parameter:
# construct on execution device, cache on offload device
data = random_hadamard_matrix(size, dtype, construct_device, self.generator)
data = data.to(device=device)
return Parameter(data, requires_grad=self.scheme.requires_grad)
4 changes: 2 additions & 2 deletions src/compressed_tensors/transform/transform_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class TransformConfig(BaseModel):
inverse=True,
),
],
randomize_modules=True,
randomize=True,
),
"u": TransformScheme(
type="hadamard",
Expand All @@ -62,7 +62,7 @@ class TransformConfig(BaseModel):
targets=["Linear"], location="output", inverse=True # non-mergable
),
],
randomize_modules=True,
randomize=True,
),
}
)
Expand Down
7 changes: 3 additions & 4 deletions src/compressed_tensors/transform/transform_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,12 @@ class TransformScheme(BaseModel):
(see `Transforms.registered_names()`)
:param apply: list of TransformationArgs containing the information about the
modules that should be targeted by the specified transform
:param randomize_modules: True if unique transforms should be applied to each
unique module targeted by `apply`, otherwise reuse transform weights where
applicable
:param randomize: True if uniquely randomized transform weights should be used,
otherwise use identical transform weights where applicable
:param requires_grad: True if weights include gradients for training
"""

type: str
apply: List[TransformArgs] = Field(default_factory=list)
randomize_modules: bool = Field(default=False)
randomize: bool = Field(default=False)
requires_grad: bool = Field(default=False)
4 changes: 3 additions & 1 deletion src/compressed_tensors/transform/utils/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def deterministic_hadamard_matrix(

log2 = int(math.log2(size))
if size != 2**log2:
raise ValueError("Cannot construct deterministic hadamard of size != 2^n")
raise ValueError(
f"Cannot construct deterministic hadamard of size {size} != 2^n"
)

H = torch.tensor([[1]], dtype=dtype, device=device)

Expand Down
52 changes: 38 additions & 14 deletions src/compressed_tensors/transform/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,28 @@ def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int
:param location: location on module
:return: size of matrix
"""
assert isinstance(module, torch.nn.Linear)
if location in ("input", TransformLocation.WEIGHT_INPUT):
return module.in_features
else:
return module.out_features
if isinstance(module, torch.nn.Linear):
if location in ("input", TransformLocation.WEIGHT_INPUT):
return module.in_features
else:
return module.out_features
elif isinstance(module, torch.nn.Embedding):
if location in ("input", TransformLocation.WEIGHT_INPUT):
return module.num_embeddings
else:
return module.embedding_dim

raise ValueError(
f"Unsupported module type {type(module)}, "
"should be either Linear or Embedding."
)


def apply_transform_weight(
weight: torch.Tensor,
transform_weight: torch.Tensor,
value: torch.Tensor,
location: TransformLocation,
is_linear: bool = True,
) -> torch.Tensor:
"""
Using the transform location, determine how to apply the transform weight to the
Expand Down Expand Up @@ -69,23 +80,36 @@ def apply_transform_weight(
= y U
= yh

:param weight: transform weight to apply
:param value: value to apply weight to
:param location: determines how weight should be applied
:return: value after transform weight has been applied
:param transform_weight: transform weight to apply
:param value: value to apply transform_weight to
:param location: determines how transform_weight should be applied
:param is_linear: if value belongs to the weights of a Linear module
This is needed because torch uses convention:
Linear(in_features,out_features) has weight shape (out_features, in_features)
But other modules (e.g. torch.nn.Embedding) don't:
Embedding(num_embeddings, embedding_dim) has weight shape
(num_embeddings, embedding_dim)
:return: value after transform_weight has been applied
"""

if location == TransformLocation.INPUT:
return value @ weight
return value @ transform_weight

elif location == TransformLocation.WEIGHT_INPUT:
return value @ weight.T
if is_linear:
return value @ transform_weight.T
else:
# TODO is this ever needed?
raise NotImplementedError()

elif location == TransformLocation.WEIGHT_OUTPUT:
return weight.T @ value
if is_linear:
return transform_weight.T @ value
else:
return value @ transform_weight

elif location == TransformLocation.OUTPUT:
return value @ weight
return value @ transform_weight

else:
raise NotImplementedError(f"{location} has not been implemented yet")
11 changes: 8 additions & 3 deletions src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,11 +373,16 @@ class ParameterizedDefaultDict(dict):

def __init__(self, default_factory: Callable[[Any], Any]):
self.default_factory = default_factory
self._kwargs = {}

def __missing__(self, key):
def __missing__(self, key: Any) -> Any:
if isinstance(key, tuple):
value = self.default_factory(*key)
value = self.default_factory(*key, **self._kwargs)
else:
value = self.default_factory(key)
value = self.default_factory(key, **self._kwargs)
self[key] = value
return value

def get(self, *args, **kwargs) -> Any:
with patch_attr(self, "_kwargs", kwargs):
return self[args]
Loading