Skip to content

[Transform] apply_transform_config #348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 88 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from 84 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
d8a10ec
add utilities
kylesayrs May 30, 2025
d2af054
add tests
kylesayrs May 30, 2025
e32d5b5
add additional tests
kylesayrs May 30, 2025
9d0518b
add utils and tests
kylesayrs May 30, 2025
8c5a2d9
Implement transform factories
kylesayrs May 30, 2025
809e367
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs May 30, 2025
8d613b3
add permutations
kylesayrs May 31, 2025
57d171a
add delete_offload_module
kylesayrs May 31, 2025
d77bcef
Merge branch 'kylesayrs/transform-accelerate-utilities' into kylesayr…
kylesayrs May 31, 2025
ab73b43
Merge branch 'kylesayrs/transform-accelerate-utilities' into kylesayr…
kylesayrs May 31, 2025
4b55733
Merge branch 'kylesayrs/transform_factory' into kylesayrs/transform_p…
kylesayrs May 31, 2025
aa7d21b
key inverses by weight
kylesayrs May 31, 2025
6901e02
fix tests
kylesayrs May 31, 2025
47ae9fe
standardize random hadamard
kylesayrs May 31, 2025
34f1343
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs May 31, 2025
1039100
prepend input hooks
kylesayrs May 31, 2025
5677553
Merge remote-tracking branch 'origin' into kylesayrs/transform_utils
kylesayrs Jun 5, 2025
68ec14e
apply sqrt division first
kylesayrs Jun 5, 2025
a62418a
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs Jun 5, 2025
b117523
use divided hadamards
kylesayrs Jun 5, 2025
a46f754
fix typo
kylesayrs Jun 5, 2025
cb1cb52
add random option
kylesayrs Jun 5, 2025
7c02bb2
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs Jun 5, 2025
02af1e9
use random seeds, rename matrix multiply
kylesayrs Jun 5, 2025
f45f3e9
add deterministic generation to random matrix
kylesayrs Jun 5, 2025
7a7abdf
fix perm math
kylesayrs Jun 5, 2025
6e52894
update docstrings
kylesayrs Jun 5, 2025
7230933
update docstrings
kylesayrs Jun 5, 2025
f74fe3e
Merge branch 'kylesayrs/transform_factory' into kylesayrs/transform_p…
kylesayrs Jun 5, 2025
92ddea9
cleanup
kylesayrs Jun 5, 2025
779956f
cleanup 2
kylesayrs Jun 5, 2025
fbd2939
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs Jun 5, 2025
dd72b6a
make seed optional
kylesayrs Jun 5, 2025
4ae491d
Merge branch 'kylesayrs/transform_factory' into kylesayrs/transform_p…
kylesayrs Jun 5, 2025
da19b0f
remove iterable check and missing return value
kylesayrs Jun 9, 2025
7ab17ce
Merge branch 'main' into kylesayrs/transform_permutations
kylesayrs Jun 10, 2025
33df50f
Merge remote-tracking branch 'origin' into kylesayrs/transform_permut…
kylesayrs Jun 10, 2025
6e1ec39
Remove unrelated changes
kylesayrs Jun 10, 2025
938e702
simplify code
kylesayrs Jun 10, 2025
27bc0b3
implement apply, use in tests
kylesayrs Jun 10, 2025
a27db62
use hadamards database file
kylesayrs Jun 11, 2025
ce63955
try manifest
kylesayrs Jun 11, 2025
7ae5863
try setup, update hadamards list
kylesayrs Jun 11, 2025
67675c3
fix setup
kylesayrs Jun 11, 2025
f061db9
add docstrings, cleanup
kylesayrs Jun 11, 2025
4a84ce1
fix setup, thank you @dbarbuzzi
kylesayrs Jun 11, 2025
cde1066
remove numpy, add tests
kylesayrs Jun 11, 2025
1ba6195
solidify dtype, add gpu tests
kylesayrs Jun 11, 2025
c373345
fix docstring
kylesayrs Jun 11, 2025
fbaf47a
add device option
kylesayrs Jun 11, 2025
5a887f4
construct on execution device, cache on offload device
kylesayrs Jun 11, 2025
310fe6d
save construction device changes for later
kylesayrs Jun 11, 2025
b715329
construct on execution device, cache on offload device
kylesayrs Jun 11, 2025
249323c
cite nja sloane
kylesayrs Jun 11, 2025
1823af4
Merge branch 'kylesayrs/extend-hadamard', remote-tracking branch 'ori…
kylesayrs Jun 11, 2025
94a0bf5
Merge remote-tracking branch 'origin' into kylesayrs/extend-hadamard
kylesayrs Jun 11, 2025
cf066e0
Merge branch 'kylesayrs/extend-hadamard' into kylesayrs/transform_con…
kylesayrs Jun 11, 2025
c1a4a34
remove dreg
kylesayrs Jun 11, 2025
5807ee1
put on device via safe_open
kylesayrs Jun 11, 2025
ccb88ed
nits and docstrings
kylesayrs Jun 12, 2025
feba695
update docstring
kylesayrs Jun 12, 2025
c8f6b53
Merge branch 'kylesayrs/extend-hadamard' into kylesayrs/transform_con…
kylesayrs Jun 12, 2025
e7f08e1
Merge branch 'kylesayrs/transform_construct_cache_device' into kylesa…
kylesayrs Jun 12, 2025
75b9307
Merge remote-tracking branch 'origin' into kylesayrs/transform_permut…
kylesayrs Jun 12, 2025
b6a0dd4
Merge remote-tracking branch 'origin' into kylesayrs/transform_constr…
kylesayrs Jun 13, 2025
955f2f5
Merge
kylesayrs Jun 23, 2025
226f367
merge with construct: construct in float32
kylesayrs Jun 23, 2025
9745acb
Merge remote-tracking branch 'origin' into kylesayrs/transform_apply
kylesayrs Jun 23, 2025
fd3390a
construct with same dtype, constructing on fp32 found no difference
kylesayrs Jun 23, 2025
3c55003
Merge branch 'kylesayrs/transform_construct_cache_device' into kylesa…
kylesayrs Jun 23, 2025
ad29c15
remove unnecessary imports
kylesayrs Jun 23, 2025
85f40b5
bugfixes (#375)
brian-dellabetta Jul 2, 2025
500af9b
use factory_kwargs
kylesayrs Jul 7, 2025
8e36540
add frozen dict to deps
kylesayrs Jul 7, 2025
48653ec
Merge remote-tracking branch 'origin' into kylesayrs/transform_permut…
kylesayrs Jul 7, 2025
56df0f7
fix style
kylesayrs Jul 7, 2025
a251569
merge
kylesayrs Jul 7, 2025
cb5a32b
Merge remote-tracking branch 'origin' into kylesayrs/transform_apply
kylesayrs Jul 7, 2025
06e0346
Merge branch 'kylesayrs/transform_permutations' into kylesayrs/transf…
kylesayrs Jul 7, 2025
0a4fea5
Merge branch 'kylesayrs/transform_construct_cache_device' into kylesa…
kylesayrs Jul 7, 2025
49740c6
use delete_offload_module
kylesayrs Jul 7, 2025
7dc182b
Merge remote-tracking branch 'origin' into kylesayrs/transform_constr…
kylesayrs Jul 7, 2025
80db2ce
Merge branch 'kylesayrs/transform_construct_cache_device' into kylesa…
kylesayrs Jul 7, 2025
e06bbad
add docstrign
kylesayrs Jul 7, 2025
438bc13
Merge remote-tracking branch 'origin' into kylesayrs/transform_apply
kylesayrs Jul 7, 2025
fd77ecc
use parametrize
kylesayrs Jul 8, 2025
bbf9533
remove random from tests
kylesayrs Jul 8, 2025
853ffcf
Merge remote-tracking branch 'origin' into kylesayrs/transform_apply
kylesayrs Jul 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _setup_packages() -> List:
)

def _setup_install_requires() -> List:
return ["torch>=1.7.0", "transformers", "pydantic>=2.0"]
return ["torch>=1.7.0", "transformers", "pydantic>=2.0", "frozendict"]

def _setup_extras() -> Dict:
return {
Expand Down
6 changes: 1 addition & 5 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,7 @@ def apply_quantization_config(
# list of submodules to ignore
ignored_submodules = defaultdict(list)
# mark appropriate layers for quantization by setting their quantization schemes
for name, submodule in iter_named_quantizable_modules(
model,
include_children=True,
include_attn=True,
): # child modules and attention modules
for name, submodule in model.named_modules(): # child modules and attention modules
# potentially fix module name to remove FSDP wrapper prefix
name = fix_fsdp_module_name(name)
if matches := find_name_or_class_matches(name, submodule, config.ignore):
Expand Down
1 change: 1 addition & 0 deletions src/compressed_tensors/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
from .factory.hadamard import *
from .factory.matrix_multiply import *
from .factory.random_hadamard import *
from .apply import *
32 changes: 32 additions & 0 deletions src/compressed_tensors/transform/apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from compressed_tensors.transform import TransformConfig, TransformFactory


__all__ = ["apply_transform_config"]


def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
"""
Apply a transform config to a model. Weight transforms are fused into weights, while
activation transforms are attached as submodules and trigger via pytorch hooks

:param model: model to apply config to
:param config: transform config to apply
"""
for name, scheme in config.config_groups.items():
factory = TransformFactory.from_scheme(scheme, name=name)
factory.apply_to_model(model)
10 changes: 6 additions & 4 deletions src/compressed_tensors/transform/factory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from compressed_tensors.utils import (
align_module_device,
delete_offload_module,
has_offloaded_params,
patch_attr,
register_offload_module,
Expand Down Expand Up @@ -99,7 +100,7 @@ def _apply_to_module(self, module: Module, args: TransformArgs):
# create transform as submodule
transform_name = f"{self.name}_{args.location.value}"
transform = self.create_transform(module, args)
register_offload_module(module, transform_name, transform) # (1)
register_offload_module(module, transform_name, transform)

# register input transformation hook
if args.location == TransformLocation.INPUT:
Expand All @@ -118,6 +119,7 @@ def input_hook(_, args):
assert isinstance(module, torch.nn.Linear)
assert module.bias is None

# fuse transform into weight
with torch.no_grad(), align_module_device(module):
update_offload_parameter(module, "weight", transform(module.weight))

Expand All @@ -128,6 +130,9 @@ def input_hook(_, args):
raise ValueError("Offloaded training is not supported")
P.register_parametrization(module, "weight", transform)

# transform is no longer needed (unfusing is not supported)
delete_offload_module(module, transform_name)

# register output transformation hook
elif args.location == TransformLocation.OUTPUT:

Expand All @@ -140,9 +145,6 @@ def output_hook(_, _input, output):
else:
raise NotImplementedError()

# (1) even in the `weight` cases, this submodule attachment is needed in order
# to support saving in the frozen state


class TransformBase(Module, ABC):
"""
Expand Down
45 changes: 33 additions & 12 deletions src/compressed_tensors/transform/factory/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import Optional, Union

import torch
from compressed_tensors.transform import TransformArgs, TransformScheme
Expand All @@ -22,7 +22,7 @@
apply_transform_weight,
get_matrix_size,
)
from compressed_tensors.utils import get_offloaded_device
from compressed_tensors.utils import get_execution_device, get_offloaded_device
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
from torch import Tensor, device, dtype
from torch.nn import Linear, Module, Parameter
Expand All @@ -41,6 +41,7 @@ class HadamardFactory(TransformFactory):
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
super().__init__(name, scheme, seed)
self.weights = ParameterizedDefaultDict(self._create_weight)
self.perms = ParameterizedDefaultDict(self._create_permutation)

def create_transform(self, module: Module, args: TransformArgs):
"""
Expand All @@ -54,26 +55,46 @@ def create_transform(self, module: Module, args: TransformArgs):
size = get_matrix_size(module, args.location)
dtype = module.weight.dtype
device = get_offloaded_device(module)
exec_device = get_execution_device(module)

weight = self.weights[size, dtype, device]
return HadamardTransform(weight, args)
factory_kwargs = {"construct_device": exec_device}
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
perm = self.perms[weight] if self.scheme.randomize else None
return HadamardTransform(weight, perm, args)

def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
data = deterministic_hadamard_matrix(size, dtype, device)
data = data.to(dtype=dtype, device=device)
def _create_weight(
self,
size: int,
dtype: dtype,
device: device,
construct_device: device,
) -> Parameter:
# construct on execution device, cache on offload device
data = deterministic_hadamard_matrix(size, dtype, construct_device)
data = data.to(device=device)
return Parameter(data, requires_grad=self.scheme.requires_grad)

def _create_permutation(self, weight: Parameter) -> Parameter:
data = torch.randperm(weight.size(0), generator=self.generator)
return Parameter(data, requires_grad=False)


class HadamardTransform(TransformBase):
def __init__(self, weight: Parameter, args: TransformArgs):
def __init__(
self, weight: Parameter, perm: Union[Parameter, None], args: TransformArgs
):
super().__init__()
self.weight = weight
self.perm = perm
self.args = args

def forward(self, value: Tensor) -> Tensor:
if not self.args.inverse:
weight = self.weight
else:
weight = self.weight.T
weight = self.weight

if self.perm is not None:
weight = weight[self.perm][:, self.perm]

if self.args.inverse:
weight = weight.T

return apply_transform_weight(weight, value, self.args.location)
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def create_transform(self, module: Module, args: TransformArgs):
return RandomMatrixTransform(weight, args)

def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
# TODO: verify that weight is invertable (has non-zero determinant)
data = torch.rand(
(size, size), generator=self.generator, dtype=dtype, device=device
)
Expand Down
13 changes: 10 additions & 3 deletions src/compressed_tensors/transform/factory/random_hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@ class RandomHadamardFactory(HadamardFactory):
:param seed: random seed used to transform weight randomization
"""

def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
data = random_hadamard_matrix(size, dtype, device, self.generator)
data = data.to(dtype=dtype, device=device)
def _create_weight(
self,
size: int,
dtype: dtype,
device: device,
construct_device: device,
) -> Parameter:
# construct on execution device, cache on offload device
data = random_hadamard_matrix(size, dtype, construct_device, self.generator)
data = data.to(device=device)
return Parameter(data, requires_grad=self.scheme.requires_grad)
4 changes: 2 additions & 2 deletions src/compressed_tensors/transform/transform_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class TransformConfig(BaseModel):
inverse=True,
),
],
randomize_modules=True,
randomize=True,
),
"u": TransformScheme(
type="hadamard",
Expand All @@ -62,7 +62,7 @@ class TransformConfig(BaseModel):
targets=["Linear"], location="output", inverse=True # non-mergable
),
],
randomize_modules=True,
randomize=True,
),
}
)
Expand Down
7 changes: 3 additions & 4 deletions src/compressed_tensors/transform/transform_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,12 @@ class TransformScheme(BaseModel):
(see `Transforms.registered_names()`)
:param apply: list of TransformationArgs containing the information about the
modules that should be targeted by the specified transform
:param randomize_modules: True if unique transforms should be applied to each
unique module targeted by `apply`, otherwise reuse transform weights where
applicable
:param randomize: True if uniquely randomized transform weights should be used,
otherwise use identical transform weights where applicable
:param requires_grad: True if weights include gradients for training
"""

type: str
apply: List[TransformArgs] = Field(default_factory=list)
randomize_modules: bool = Field(default=False)
randomize: bool = Field(default=False)
requires_grad: bool = Field(default=False)
21 changes: 17 additions & 4 deletions src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
import contextlib
import warnings
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional

import numpy
import torch
from frozendict import frozendict
from transformers import AutoConfig


Expand Down Expand Up @@ -373,11 +374,23 @@ class ParameterizedDefaultDict(dict):

def __init__(self, default_factory: Callable[[Any], Any]):
self.default_factory = default_factory
self._factory_kwargs = frozendict()

def __missing__(self, key):
def __missing__(self, key: Any) -> Any:
if isinstance(key, tuple):
value = self.default_factory(*key)
value = self.default_factory(*key, **self._factory_kwargs)
else:
value = self.default_factory(key)
value = self.default_factory(key, **self._factory_kwargs)
self[key] = value
return value

def get(self, *args, factory_kwargs: Mapping = frozendict()) -> Any:
"""
Similar to `__getitem__`, but allows passing kwargs to factory function

:param \\*args: args whose tuple will value will be treated as key
:param factory_kwargs: keyword arguments to pass to `default_factory`
:return: dictionary entry for given key
"""
with patch_attr(self, "_factory_kwargs", factory_kwargs):
return self[args]
5 changes: 2 additions & 3 deletions tests/test_transform/factory/test_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TransformConfig,
TransformFactory,
TransformScheme,
apply_transform_config,
)
from compressed_tensors.utils import offloaded_dispatch
from tests.testing_utils import requires_accelerate, requires_gpu
Expand Down Expand Up @@ -81,9 +82,7 @@ def test_correctness_model(scheme_kwargs, model_apply, offload=False):
)
}
)
for name, scheme in config.config_groups.items():
factory = TransformFactory.from_scheme(scheme, name=name)
factory.apply_to_model(model)
apply_transform_config(model, config)

# compare outputs
output = model(input)
Expand Down
5 changes: 2 additions & 3 deletions tests/test_transform/factory/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TransformConfig,
TransformFactory,
TransformScheme,
apply_transform_config,
)
from compressed_tensors.utils import align_modules, offloaded_dispatch
from tests.test_transform.conftest import TransformableModel
Expand Down Expand Up @@ -54,9 +55,7 @@ def test_memory_sharing(scheme_kwargs, offload=False):
)
}
)
for name, scheme in config.config_groups.items():
factory = TransformFactory.from_scheme(scheme, name=name)
factory.apply_to_model(model)
apply_transform_config(model, config)

# check that memory is shared when onloaded
with align_modules(model.modules()):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_transform/test_transform_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_basic_scheme():
type="hadamard",
apply=[basic_args],
)
assert not scheme.randomize_modules
assert not scheme.randomize
assert scheme.type == "hadamard"
assert len(scheme.apply) == 1
assert isinstance(scheme.apply[0], TransformArgs)
Expand All @@ -43,10 +43,10 @@ def test_multiple_groups_global():
scheme = TransformScheme(
type="hadamard",
apply=[embedding_args, linear_args],
randomize_modules=True,
randomize=True,
)

assert scheme.randomize_modules
assert scheme.randomize
assert scheme.type == "hadamard"
assert len(scheme.apply) == 2
assert isinstance(scheme.apply[0], TransformArgs)
Expand All @@ -69,6 +69,6 @@ def test_multiple_groups():
apply=apply,
)

assert not scheme.randomize_modules
assert not scheme.randomize
assert scheme.type == "hadamard"
assert len(scheme.apply) == 20