From 0be4680be8cf312bbd65ac4c142b4dd4aabd3ace Mon Sep 17 00:00:00 2001 From: Dipika Date: Sun, 23 Mar 2025 23:05:22 +0000 Subject: [PATCH 1/5] update --- .../quantization/lifecycle/apply.py | 3 ++ src/compressed_tensors/transforms/base.py | 11 ++-- src/compressed_tensors/transforms/hadamard.py | 48 +++++++++++++---- .../transforms/hadamard_utils.py | 30 +++++++++-- .../transforms/matrix_multiply.py | 2 +- .../transforms/random_hadamard.py | 52 ++++++++++++++++--- src/compressed_tensors/transforms/temp.py | 39 ++++++++++++++ 7 files changed, 157 insertions(+), 28 deletions(-) create mode 100644 src/compressed_tensors/transforms/temp.py diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index c2c5a704..a7a8a601 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -199,6 +199,7 @@ def process_transforms_config( dtype=dtype, **transform_creation_args, ) + transform.register_to_module( name=transform_name, module=submodule ) @@ -217,6 +218,8 @@ def process_transforms_config( else: transform_data = TransformData(data=OrderedDict(data)) submodule.transform_data = transform_data + breakpoint() + # 10358 for now mib; 1/3 of memory if caching/sharing parameter data return model diff --git a/src/compressed_tensors/transforms/base.py b/src/compressed_tensors/transforms/base.py index 56adcb62..88e41c86 100644 --- a/src/compressed_tensors/transforms/base.py +++ b/src/compressed_tensors/transforms/base.py @@ -33,11 +33,8 @@ class Transforms(RegistryMixin): def __init__( self, transform: torch.Tensor, - learnable: Optional[bool] = True, - device: Optional[Union[str, torch.device]] = "cuda", - dtype: Optional[torch.dtype] = torch.bfloat16, + learnable: Optional[bool] = False, ): - self.learnable = learnable """ Base class for setting up transforms. The registry creates transforms as parameters which can be attached to modules. @@ -62,10 +59,10 @@ def __init__( :param transform: transform (e.g. torch.Tensor, scalar) to be applied """ - if self.learnable: - self.transform = torch.nn.Parameter(transform.to(dtype).to(device)) + if learnable: + self.transform = torch.nn.Parameter(transform) else: - self.transform = torch.nn.Buffer(transform.to(dtype).to(device)) + self.transform = torch.nn.Buffer(transform) # register to class for easy offloading, serialization, deserialization def register_to_module(self, name: str, module: torch.nn.Module): diff --git a/src/compressed_tensors/transforms/hadamard.py b/src/compressed_tensors/transforms/hadamard.py index ef0e27a4..54f523e1 100644 --- a/src/compressed_tensors/transforms/hadamard.py +++ b/src/compressed_tensors/transforms/hadamard.py @@ -16,7 +16,10 @@ import torch from compressed_tensors.transforms import Transforms -from compressed_tensors.transforms.hadamard_utils import deterministic_hadamard_matrix +from compressed_tensors.transforms.hadamard_utils import ( + SingletonHadamardRegistry, + deterministic_hadamard_matrix, +) from compressed_tensors.transforms.utils import apply_matrix_transform @@ -28,9 +31,9 @@ def __init__( empty: Optional[bool] = False, device: Optional[Union[str, torch.device]] = "cuda", dtype: Optional[torch.dtype] = torch.bfloat16, - *args, - **kwargs, + learnable: Optional[bool] = False, ): + """ Produces a hadamard matrix with dims (size, size), with values -1 and 1, and the property HH.T == nI i.e the transformation @@ -46,13 +49,38 @@ def __init__( :param dtype: type to cast the rotation matrix to """ - if not empty: - # TODO: this is deterministic; we should just serialize the size - transform = torch.Tensor(deterministic_hadamard_matrix(size=size)) + self.learnable = learnable + self.hadamard_registry = SingletonHadamardRegistry() + self.size = size + self.dtype = dtype + + if empty: + # If saved, would have a different lifecycle (would be loaded and not be + # the same parameter, for now) + # Would take more memory + transform = torch.empty((size, size)).to(dtype).to(device) else: - transform = torch.empty((size, size)) + transform = self.fetch().to(device) + + super().__init__(transform=transform, learnable=learnable) + + # if not learnable, save parameter + if not self.learnable and size not in self.hadamard_registry._data: + self.hadamard_registry.set_hadamard(size, self.transform) + + def fetch(self): + # TODO: this is deterministic; we should just serialize the size + transform = self.hadamard_registry.get_hadamard(self.size) + if transform is None: + transform = torch.Tensor(deterministic_hadamard_matrix(size=self.size)).to( + self.dtype + ) + + # if learnable, save actual data, not parameter + if self.learnable: + self.hadamard_registry.set_hadamard(self.size, transform) - super().__init__(transform=transform, dtype=dtype, device=device) + return transform def apply( self, @@ -61,7 +89,7 @@ def apply( first: bool = True, ) -> torch.Tensor: return apply_matrix_transform( - transform=self.transform, + transform=self.transform.to(input_tensor.device), input_tensor=input_tensor, transpose=transpose, first=first, @@ -87,7 +115,7 @@ def inverse_apply( # need to normalize before sending back return ( apply_matrix_transform( - transform=self.transform, + transform=self.transform.to(input_tensor.device), input_tensor=input_tensor, transpose=transpose, first=first, diff --git a/src/compressed_tensors/transforms/hadamard_utils.py b/src/compressed_tensors/transforms/hadamard_utils.py index 2cbd74d8..f8f04632 100644 --- a/src/compressed_tensors/transforms/hadamard_utils.py +++ b/src/compressed_tensors/transforms/hadamard_utils.py @@ -18,7 +18,29 @@ import torch -__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"] +__all__ = [ + "random_hadamard_matrix", + "deterministic_hadamard_matrix", + "SingletonHadamardRegistry", +] + + +class SingletonHadamardRegistry: + _instance = None + + def __new__(cls): + # Check if the instance already exists + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._data = {} # Initialize the data storage + return cls._instance + + def set_hadamard(self, key, value): + self._data[key] = value + + def get_hadamard(self, key): + return self._data.get(key, None) + # adapted from: # https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py @@ -59,6 +81,7 @@ def deterministic_hadamard_matrix(size: int): # https://github.com/Dao-AILab/fast-hadamard-transform/tree/master +# ToDo: should no longer be random, call something else --> different generation type than scipy? def random_hadamard_matrix(size: int) -> torch.Tensor: """ Produces a randomly generated Hadamard matrix. @@ -73,7 +96,8 @@ def random_hadamard_matrix(size: int) -> torch.Tensor: # the matrix generated to be reproducible # Benefits: support other shapes / non powers of 2, support randomization - Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) + # Q = torch.randint(low=1, high=2, size=(size,)).to(torch.float64) + Q = torch.ones(size).to(torch.float64) Q = Q * 2 - 1 Q = torch.diag(Q) return _matmul_hadU(Q) @@ -129,7 +153,7 @@ def _matmul_hadU(X, transpose=False): input = hadK.view(1, K, K).to(input) @ input # normalize - return input.view(X.shape) / torch.tensor(n).sqrt() + return input.view(X.shape) def _is_pow2(n): diff --git a/src/compressed_tensors/transforms/matrix_multiply.py b/src/compressed_tensors/transforms/matrix_multiply.py index a06d61f2..6fb06667 100644 --- a/src/compressed_tensors/transforms/matrix_multiply.py +++ b/src/compressed_tensors/transforms/matrix_multiply.py @@ -17,7 +17,7 @@ from compressed_tensors.transforms.utils import apply_matrix_transform -# TODO: fix loading +# TODO: fix loading + add generic matrix registry? @Transforms.register("matrix-mul") class MatrixMultiply(Transforms): def apply( diff --git a/src/compressed_tensors/transforms/random_hadamard.py b/src/compressed_tensors/transforms/random_hadamard.py index 162269c5..aac116d4 100644 --- a/src/compressed_tensors/transforms/random_hadamard.py +++ b/src/compressed_tensors/transforms/random_hadamard.py @@ -12,14 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import Optional, Union import torch from compressed_tensors.transforms import Transforms -from compressed_tensors.transforms.hadamard_utils import random_hadamard_matrix +from compressed_tensors.transforms.hadamard_utils import ( + SingletonHadamardRegistry, + random_hadamard_matrix, +) from compressed_tensors.transforms.utils import apply_matrix_transform +# TODO: allow randomness for both potentially, separate by generation type +# this will make randomness a creation arg instead @Transforms.register("random-hadamard") class RandomHadamard(Transforms): def __init__( @@ -28,6 +34,7 @@ def __init__( empty: Optional[bool] = False, device: Optional[Union[str, torch.device]] = "cuda", dtype: Optional[torch.dtype] = torch.bfloat16, + learnable: Optional[bool] = False, ): """ Produces a randomly generated matrix with dims (size, size), with values @@ -52,13 +59,44 @@ def __init__( we will not have to store the entire matrix. Will need to consider accuracy implications. """ + self.learnable = learnable + self.size = size + self.normalized_size = math.sqrt(self.size) + self.dtype = dtype + self.device = device + # TODO: potentially lives outside of the registry + # And caching is controlled by llmcompressor + self.hadamard_registry = SingletonHadamardRegistry() - if not empty: - transform = random_hadamard_matrix(size=size) + self.permutation = ( + (torch.randint(low=0, high=2, size=(self.size,)).to(torch.float64) * 2 - 1) + .to(self.dtype) + .to(self.device) + ) + + if empty: + # If saved, would have a different lifecycle (would be loaded and registered + # Would take more memory + transform = torch.empty((size, size)).to(dtype) else: - transform = torch.empty((size, size)) + transform = self.fetch() + + super().__init__(transform=transform, learnable=self.learnable) + + # not learnable, cache parameter + if not self.learnable and size not in self.hadamard_registry._data: + self.hadamard_registry.set_hadamard(self.size, self.transform) + + def fetch(self): + deterministic_had = self.hadamard_registry.get_hadamard(self.size) + if deterministic_had is None: + deterministic_had = random_hadamard_matrix(size=self.size).to(self.dtype) + # learnable, cache data + if self.learnable: + self.hadamard_registry.set_hadamard(self.size, deterministic_had) - super().__init__(transform=transform, device=device, dtype=dtype) + deterministic_had = deterministic_had.to(self.device) + return (deterministic_had * self.permutation) / self.normalized_size def apply( self, @@ -67,7 +105,7 @@ def apply( first: bool = True, ) -> torch.Tensor: return apply_matrix_transform( - transform=self.transform, + transform=self.transform.to(input_tensor.device), input_tensor=input_tensor, transpose=transpose, first=first, @@ -92,7 +130,7 @@ def inverse_apply( transpose = not transpose return apply_matrix_transform( - transform=self.transform, + transform=self.transform.to(input_tensor.device), input_tensor=input_tensor, transpose=transpose, first=first, diff --git a/src/compressed_tensors/transforms/temp.py b/src/compressed_tensors/transforms/temp.py new file mode 100644 index 00000000..1ed7643e --- /dev/null +++ b/src/compressed_tensors/transforms/temp.py @@ -0,0 +1,39 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.transforms.hadamard_utils import ( + SingletonHadamardRegistry, + random_hadamard_matrix, +) + + +size = 2048 +dtype = torch.bfloat16 +hadamard_registry = SingletonHadamardRegistry() +deterministic_had = hadamard_registry.get_hadamard(size) +# fetch the deterministic had from the registry, if already precomputed +if deterministic_had is None: + deterministic_had = random_hadamard_matrix(size=size).to(dtype) + hadamard_registry.set_hadamard(size, deterministic_had) + +out = random_hadamard_matrix(size) +Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) +Q = Q * 2 - 1 + +breakpoint() +new_out = out * Q +new_out = new_out / torch.tensor(size).sqrt() +assert torch.equal(torch.eye(size), torch.round(new_out @ new_out.T)) +breakpoint() From 097caf068aebd7b1692b295cb65c1b6593bdaf8b Mon Sep 17 00:00:00 2001 From: Dipika Date: Mon, 24 Mar 2025 00:00:15 +0000 Subject: [PATCH 2/5] update random --- .../transforms/random_hadamard.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/compressed_tensors/transforms/random_hadamard.py b/src/compressed_tensors/transforms/random_hadamard.py index aac116d4..f4b1f68f 100644 --- a/src/compressed_tensors/transforms/random_hadamard.py +++ b/src/compressed_tensors/transforms/random_hadamard.py @@ -68,6 +68,8 @@ def __init__( # And caching is controlled by llmcompressor self.hadamard_registry = SingletonHadamardRegistry() + # TODO: need to register randomness as well + # Are they training just this and the actual hadamard is a buffer? self.permutation = ( (torch.randint(low=0, high=2, size=(self.size,)).to(torch.float64) * 2 - 1) .to(self.dtype) @@ -79,7 +81,7 @@ def __init__( # Would take more memory transform = torch.empty((size, size)).to(dtype) else: - transform = self.fetch() + transform = self.fetch().to(device) super().__init__(transform=transform, learnable=self.learnable) @@ -88,15 +90,15 @@ def __init__( self.hadamard_registry.set_hadamard(self.size, self.transform) def fetch(self): - deterministic_had = self.hadamard_registry.get_hadamard(self.size) - if deterministic_had is None: - deterministic_had = random_hadamard_matrix(size=self.size).to(self.dtype) + transform = self.hadamard_registry.get_hadamard(self.size) + if transform is None: + transform = random_hadamard_matrix(size=self.size).to(self.dtype) # learnable, cache data if self.learnable: - self.hadamard_registry.set_hadamard(self.size, deterministic_had) + self.hadamard_registry.set_hadamard(self.size, transform) - deterministic_had = deterministic_had.to(self.device) - return (deterministic_had * self.permutation) / self.normalized_size + return transform + # return (deterministic_had * self.permutation) / self.normalized_size def apply( self, @@ -105,7 +107,8 @@ def apply( first: bool = True, ) -> torch.Tensor: return apply_matrix_transform( - transform=self.transform.to(input_tensor.device), + transform=(self.permutation * self.transform.to(input_tensor.device)) + / self.normalized_size, input_tensor=input_tensor, transpose=transpose, first=first, @@ -130,7 +133,8 @@ def inverse_apply( transpose = not transpose return apply_matrix_transform( - transform=self.transform.to(input_tensor.device), + transform=(self.permutation * self.transform.to(input_tensor.device)) + / self.normalized_size, input_tensor=input_tensor, transpose=transpose, first=first, From b7f5c832cd35e8bdc439144073d422e8b4a030a5 Mon Sep 17 00:00:00 2001 From: Dipika Date: Mon, 24 Mar 2025 22:34:24 +0000 Subject: [PATCH 3/5] update --- .../quantization/lifecycle/apply.py | 3 +- src/compressed_tensors/transforms/base.py | 36 ++++++++----- src/compressed_tensors/transforms/hadamard.py | 39 +++++--------- .../transforms/hadamard_utils.py | 18 ------- .../transforms/matrix_multiply.py | 43 ++++++++++++++- .../transforms/random_hadamard.py | 52 +++++++------------ .../transforms/transform_scheme.py | 9 ++-- src/compressed_tensors/transforms/utils.py | 22 +++++++- 8 files changed, 125 insertions(+), 97 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index a7a8a601..c84b94de 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -180,6 +180,7 @@ def process_transforms_config( # only support weight parameters for now, assume one value in # module targets transform_name = f"{module_targets[0]}_transform_{idx}" + perm_name = f"{module_targets[0]}_perm_{idx}" # create an empty tensor OR create a new transform dtype = getattr(submodule, module_targets[0]).dtype @@ -201,7 +202,7 @@ def process_transforms_config( ) transform.register_to_module( - name=transform_name, module=submodule + name=transform_name, module=submodule, perm_name=perm_name ) # add relevant transform data to the submodule as well diff --git a/src/compressed_tensors/transforms/base.py b/src/compressed_tensors/transforms/base.py index 88e41c86..f0052095 100644 --- a/src/compressed_tensors/transforms/base.py +++ b/src/compressed_tensors/transforms/base.py @@ -31,9 +31,7 @@ class Transforms(RegistryMixin): def __init__( - self, - transform: torch.Tensor, - learnable: Optional[bool] = False, + self, transform: torch.Tensor, permutation: Optional[torch.Tensor] = None ): """ Base class for setting up transforms. The registry creates transforms @@ -59,34 +57,44 @@ def __init__( :param transform: transform (e.g. torch.Tensor, scalar) to be applied """ - if learnable: - self.transform = torch.nn.Parameter(transform) - else: - self.transform = torch.nn.Buffer(transform) + self.register_permutation = False + self.transform = torch.nn.Buffer(transform) + self.permutation = ( + torch.nn.Buffer(permutation) if permutation is not None else None + ) # register to class for easy offloading, serialization, deserialization - def register_to_module(self, name: str, module: torch.nn.Module): - if self.learnable: - register_offload_parameter(module, name, self.transform) - else: - # TODO: have to verify serialization/offloading - module.register_buffer(name, self.transform) + # TODO: Manage lifecycle of permutation and transform buffers + def register_to_module(self, name: str, module: torch.nn.Module, perm_name: str): + module.register_buffer(name, self.transform) + if self.permutation is not None: + module.register_buffer(perm_name, self.permutation) def update_transform( self, data: torch.Tensor, + permutation_data: Optional[torch.Tensor] = None, module: Optional[torch.nn.Module] = None, name: Optional[str] = None, + permutation_name: Optional[str] = None, ): if module is None: self.transform.data.copy_(data) + if self.permutation is not None and permutation_data is not None: + self.permutation.data.copy_(permutation_data) + else: # If updating the module parameter data, assumes this is also the transform # data if name is None: - raise ValueError("Name and module are required to update parma data") + raise ValueError( + "Name and module are required to update transform data" + ) update_parameter_data(module, data, name) + if self.permutation is not None and permutation_data is not None: + update_parameter_data(module, permutation_data, permutation_name) + def apply(self, input_tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Apply the transform to the module diff --git a/src/compressed_tensors/transforms/hadamard.py b/src/compressed_tensors/transforms/hadamard.py index 54f523e1..6c91d177 100644 --- a/src/compressed_tensors/transforms/hadamard.py +++ b/src/compressed_tensors/transforms/hadamard.py @@ -16,11 +16,11 @@ import torch from compressed_tensors.transforms import Transforms -from compressed_tensors.transforms.hadamard_utils import ( - SingletonHadamardRegistry, - deterministic_hadamard_matrix, +from compressed_tensors.transforms.hadamard_utils import deterministic_hadamard_matrix +from compressed_tensors.transforms.utils import ( + SingletonMatrixRegistry, + apply_matrix_transform, ) -from compressed_tensors.transforms.utils import apply_matrix_transform @Transforms.register("hadamard") @@ -31,7 +31,6 @@ def __init__( empty: Optional[bool] = False, device: Optional[Union[str, torch.device]] = "cuda", dtype: Optional[torch.dtype] = torch.bfloat16, - learnable: Optional[bool] = False, ): """ @@ -49,37 +48,27 @@ def __init__( :param dtype: type to cast the rotation matrix to """ - self.learnable = learnable - self.hadamard_registry = SingletonHadamardRegistry() + self.matrix_registry = SingletonMatrixRegistry() self.size = size - self.dtype = dtype if empty: # If saved, would have a different lifecycle (would be loaded and not be # the same parameter, for now) # Would take more memory - transform = torch.empty((size, size)).to(dtype).to(device) + transform = torch.empty((size, size)).to(dtype) else: - transform = self.fetch().to(device) + transform = self.fetch().to(dtype).to(device) - super().__init__(transform=transform, learnable=learnable) + super().__init__(transform=transform) - # if not learnable, save parameter - if not self.learnable and size not in self.hadamard_registry._data: - self.hadamard_registry.set_hadamard(size, self.transform) + if not self.matrix_registry.contains(size): + self.matrix_registry.set_matrix(size, self.transform) def fetch(self): # TODO: this is deterministic; we should just serialize the size - transform = self.hadamard_registry.get_hadamard(self.size) + transform = self.matrix_registry.get_matrix(self.size) if transform is None: - transform = torch.Tensor(deterministic_hadamard_matrix(size=self.size)).to( - self.dtype - ) - - # if learnable, save actual data, not parameter - if self.learnable: - self.hadamard_registry.set_hadamard(self.size, transform) - + transform = torch.Tensor(deterministic_hadamard_matrix(size=self.size)) return transform def apply( @@ -89,7 +78,7 @@ def apply( first: bool = True, ) -> torch.Tensor: return apply_matrix_transform( - transform=self.transform.to(input_tensor.device), + transform=self.transform, input_tensor=input_tensor, transpose=transpose, first=first, @@ -115,7 +104,7 @@ def inverse_apply( # need to normalize before sending back return ( apply_matrix_transform( - transform=self.transform.to(input_tensor.device), + transform=self.transform, input_tensor=input_tensor, transpose=transpose, first=first, diff --git a/src/compressed_tensors/transforms/hadamard_utils.py b/src/compressed_tensors/transforms/hadamard_utils.py index f8f04632..de9364cd 100644 --- a/src/compressed_tensors/transforms/hadamard_utils.py +++ b/src/compressed_tensors/transforms/hadamard_utils.py @@ -24,24 +24,6 @@ "SingletonHadamardRegistry", ] - -class SingletonHadamardRegistry: - _instance = None - - def __new__(cls): - # Check if the instance already exists - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._data = {} # Initialize the data storage - return cls._instance - - def set_hadamard(self, key, value): - self._data[key] = value - - def get_hadamard(self, key): - return self._data.get(key, None) - - # adapted from: # https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py def deterministic_hadamard_matrix(size: int): diff --git a/src/compressed_tensors/transforms/matrix_multiply.py b/src/compressed_tensors/transforms/matrix_multiply.py index 6fb06667..599cc11a 100644 --- a/src/compressed_tensors/transforms/matrix_multiply.py +++ b/src/compressed_tensors/transforms/matrix_multiply.py @@ -12,14 +12,55 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Union + import torch from compressed_tensors.transforms import Transforms -from compressed_tensors.transforms.utils import apply_matrix_transform +from compressed_tensors.transforms.utils import ( + SingletonMatrixRegistry, + apply_matrix_transform, +) # TODO: fix loading + add generic matrix registry? @Transforms.register("matrix-mul") class MatrixMultiply(Transforms): + def __init__( + self, + name: str, + transform_data: torch.Tensor, + size: Optional[int] = None, + empty: Optional[bool] = False, + device: Optional[Union[str, torch.device]] = "cpu", + dtype: Optional[torch.dtype] = torch.bfloat16, + ): + + if empty and size is None: + raise ValueError( + "size is required when setting up empty transforms for deserialization " + ) + + # name required to either pull a cached matrix or cache a matrix itself + # will assume each name corresponds to a unique matrix + self.name = name + self.matrix_registry = SingletonMatrixRegistry() + + if empty: + transform = torch.empty((size, size)).to(dtype) + else: + transform = self.fetch().to(dtype).to(device) + + super().__init__(transform=transform) + + if not self.matrix_registry.contains(self.name): + self.matrix_registry.set_matrix(self.name, self.transform) + + def fetch(self): + transform = self.matrix_registry.get_matrix(self.name) + if transform is None: + transform = transform_data + return transform + def apply( self, input_tensor: torch.Tensor, diff --git a/src/compressed_tensors/transforms/random_hadamard.py b/src/compressed_tensors/transforms/random_hadamard.py index f4b1f68f..5a349a31 100644 --- a/src/compressed_tensors/transforms/random_hadamard.py +++ b/src/compressed_tensors/transforms/random_hadamard.py @@ -17,11 +17,11 @@ import torch from compressed_tensors.transforms import Transforms -from compressed_tensors.transforms.hadamard_utils import ( - SingletonHadamardRegistry, - random_hadamard_matrix, +from compressed_tensors.transforms.hadamard_utils import random_hadamard_matrix +from compressed_tensors.transforms.utils import ( + SingletonMatrixRegistry, + apply_matrix_transform, ) -from compressed_tensors.transforms.utils import apply_matrix_transform # TODO: allow randomness for both potentially, separate by generation type @@ -34,7 +34,6 @@ def __init__( empty: Optional[bool] = False, device: Optional[Union[str, torch.device]] = "cuda", dtype: Optional[torch.dtype] = torch.bfloat16, - learnable: Optional[bool] = False, ): """ Produces a randomly generated matrix with dims (size, size), with values @@ -59,46 +58,35 @@ def __init__( we will not have to store the entire matrix. Will need to consider accuracy implications. """ - self.learnable = learnable self.size = size self.normalized_size = math.sqrt(self.size) - self.dtype = dtype - self.device = device # TODO: potentially lives outside of the registry # And caching is controlled by llmcompressor - self.hadamard_registry = SingletonHadamardRegistry() - - # TODO: need to register randomness as well - # Are they training just this and the actual hadamard is a buffer? - self.permutation = ( - (torch.randint(low=0, high=2, size=(self.size,)).to(torch.float64) * 2 - 1) - .to(self.dtype) - .to(self.device) - ) + self.matrix_registry = SingletonMatrixRegistry() if empty: # If saved, would have a different lifecycle (would be loaded and registered # Would take more memory transform = torch.empty((size, size)).to(dtype) + permutation = torch.empty((size)).to(dtype) else: - transform = self.fetch().to(device) + transform = self.fetch().to(dtype).to(device) + permutation = ( + (torch.randint(low=0, high=2, size=(self.size,)) * 2 - 1) + .to(dtype) + .to(device) + ) - super().__init__(transform=transform, learnable=self.learnable) + super().__init__(transform=transform, permutation=permutation) - # not learnable, cache parameter - if not self.learnable and size not in self.hadamard_registry._data: - self.hadamard_registry.set_hadamard(self.size, self.transform) + if not self.matrix_registry.contains(size): + self.matrix_registry.set_matrix(self.size, self.transform) def fetch(self): - transform = self.hadamard_registry.get_hadamard(self.size) + transform = self.matrix_registry.get_matrix(self.size) if transform is None: - transform = random_hadamard_matrix(size=self.size).to(self.dtype) - # learnable, cache data - if self.learnable: - self.hadamard_registry.set_hadamard(self.size, transform) - + transform = random_hadamard_matrix(size=self.size) return transform - # return (deterministic_had * self.permutation) / self.normalized_size def apply( self, @@ -107,8 +95,7 @@ def apply( first: bool = True, ) -> torch.Tensor: return apply_matrix_transform( - transform=(self.permutation * self.transform.to(input_tensor.device)) - / self.normalized_size, + transform=(self.transform * self.permutation) / self.normalized_size, input_tensor=input_tensor, transpose=transpose, first=first, @@ -133,8 +120,7 @@ def inverse_apply( transpose = not transpose return apply_matrix_transform( - transform=(self.permutation * self.transform.to(input_tensor.device)) - / self.normalized_size, + transform=(self.transform * self.permutation) / self.normalized_size, input_tensor=input_tensor, transpose=transpose, first=first, diff --git a/src/compressed_tensors/transforms/transform_scheme.py b/src/compressed_tensors/transforms/transform_scheme.py index f1770cc4..012e4128 100644 --- a/src/compressed_tensors/transforms/transform_scheme.py +++ b/src/compressed_tensors/transforms/transform_scheme.py @@ -30,16 +30,17 @@ class TransformationScheme(BaseModel): :param groups: includes TransformationArgs containing the information about the layers that should be targeted by the specified transform. By providing a list, users have the ability to apply the same transform type (with the same set - of arguments) to different layers. + of arguments) to different layers. No :param transform_creation_args: arguments needed to initialize the transform, if any - :param global_transform: whether an identical transform is applied to all the - TransformationArgs in the groups list """ + # TODO: maybe we don't need "global" + # If it's the same transform (but different call_args) in the list? + # Come back to this + transform_type: str groups: List[TransformationArgs] - global_transform: bool = False transform_creation_args: Optional[Dict[str, Any]] = None @field_validator("transform_type", mode="before") diff --git a/src/compressed_tensors/transforms/utils.py b/src/compressed_tensors/transforms/utils.py index 997c91f1..83364155 100644 --- a/src/compressed_tensors/transforms/utils.py +++ b/src/compressed_tensors/transforms/utils.py @@ -15,7 +15,27 @@ import torch -__all__ = ["apply_matrix_transform"] +__all__ = ["apply_matrix_transform", "SingletonMatrixRegistry"] + + +class SingletonMatrixRegistry: + _instance = None + + def __new__(cls): + # Check if the instance already exists + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._data = {} # Initialize the data storage + return cls._instance + + def set_matrix(self, key, value): + self._data[key] = value + + def get_matrix(self, key): + return self._data.get(key, None) + + def contains(self, key): + return key in self._data def apply_matrix_transform( From 9ff4eda66dcb28bc2b59629211f80d57313797d3 Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 26 Mar 2025 21:52:37 +0000 Subject: [PATCH 4/5] clean-up --- .../quantization/lifecycle/apply.py | 22 ++++++---- src/compressed_tensors/transforms/base.py | 40 +++++++++++-------- src/compressed_tensors/transforms/hadamard.py | 11 +++-- .../transforms/matrix_multiply.py | 22 ++++++---- .../transforms/random_hadamard.py | 17 +++++--- 5 files changed, 70 insertions(+), 42 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index c84b94de..789522ed 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -62,6 +62,7 @@ ] from compressed_tensors.quantization.utils.helpers import is_module_quantized +from compressed_tensors.utils import has_offloaded_params from compressed_tensors.utils.safetensors_load import ( get_quantization_state_dict, get_weight_mappings, @@ -88,11 +89,14 @@ def load_transforms(model: Module, model_name_or_path: str): if transform_data: for transform_name, transform_values in transform_data.data.items(): full_name = f"{name}.{transform_name}" - transform_data = state_dict.get(full_name, None) + full_per_name = full_name.replace("transform", "perm") + dict_data = state_dict.get(full_name, None) + permutation_data = state_dict.get(full_per_name, None) transform = transform_values.get("transform") - transform.register_to_module(name=transform_name, module=submodule) + transform.update_device(module=submodule) + transform.register_to_module(module=submodule) transform.update_transform( - module=submodule, data=transform_data, name=transform_name + module=submodule, data=dict_data, permutation_data=permutation_data ) @@ -180,7 +184,7 @@ def process_transforms_config( # only support weight parameters for now, assume one value in # module targets transform_name = f"{module_targets[0]}_transform_{idx}" - perm_name = f"{module_targets[0]}_perm_{idx}" + permutation_name = f"{module_targets[0]}_perm_{idx}" # create an empty tensor OR create a new transform dtype = getattr(submodule, module_targets[0]).dtype @@ -192,18 +196,21 @@ def process_transforms_config( transform_type, dtype=dtype, empty=True, + transform_name=transform_name, + permutation_name=permutation_name, **transform_creation_args, ) else: transform = Transforms.load_from_registry( transform_type, dtype=dtype, + transform_name=transform_name, + permutation_name=permutation_name, + device=next(submodule.parameters()).device, **transform_creation_args, ) - transform.register_to_module( - name=transform_name, module=submodule, perm_name=perm_name - ) + transform.register_to_module(module=submodule) # add relevant transform data to the submodule as well data = { @@ -219,7 +226,6 @@ def process_transforms_config( else: transform_data = TransformData(data=OrderedDict(data)) submodule.transform_data = transform_data - breakpoint() # 10358 for now mib; 1/3 of memory if caching/sharing parameter data return model diff --git a/src/compressed_tensors/transforms/base.py b/src/compressed_tensors/transforms/base.py index f0052095..715c0625 100644 --- a/src/compressed_tensors/transforms/base.py +++ b/src/compressed_tensors/transforms/base.py @@ -31,7 +31,11 @@ class Transforms(RegistryMixin): def __init__( - self, transform: torch.Tensor, permutation: Optional[torch.Tensor] = None + self, + transform: torch.Tensor, + transform_name: str, + permutation: Optional[torch.Tensor] = None, + permutation_name: Optional[str] = None, ): """ Base class for setting up transforms. The registry creates transforms @@ -57,26 +61,34 @@ def __init__( :param transform: transform (e.g. torch.Tensor, scalar) to be applied """ - self.register_permutation = False - self.transform = torch.nn.Buffer(transform) + self.transform = torch.nn.Parameter(transform, requires_grad=False) + self.transform_name = transform_name self.permutation = ( - torch.nn.Buffer(permutation) if permutation is not None else None + torch.nn.Parameter(permutation, requires_grad=False) + if permutation is not None + else None ) + self.permutation_name = permutation_name + + def update_device(self, module: torch.nn.Module): + # Helper function required for deserialization + module_device = next(module.parameters()).device + self.transform.data = self.transform.data.to(module_device) + if self.permutation is not None: + self.permutation.data = self.permutation.data.to(module_device) # register to class for easy offloading, serialization, deserialization # TODO: Manage lifecycle of permutation and transform buffers - def register_to_module(self, name: str, module: torch.nn.Module, perm_name: str): - module.register_buffer(name, self.transform) + def register_to_module(self, module: torch.nn.Module): + register_offload_parameter(module, self.transform_name, self.transform) if self.permutation is not None: - module.register_buffer(perm_name, self.permutation) + register_offload_parameter(module, self.permutation_name, self.permutation) def update_transform( self, data: torch.Tensor, permutation_data: Optional[torch.Tensor] = None, module: Optional[torch.nn.Module] = None, - name: Optional[str] = None, - permutation_name: Optional[str] = None, ): if module is None: self.transform.data.copy_(data) @@ -85,15 +97,11 @@ def update_transform( else: # If updating the module parameter data, assumes this is also the transform - # data - if name is None: - raise ValueError( - "Name and module are required to update transform data" - ) - update_parameter_data(module, data, name) + # is already registered/shared data + update_parameter_data(module, data, self.transform_name) if self.permutation is not None and permutation_data is not None: - update_parameter_data(module, permutation_data, permutation_name) + update_parameter_data(module, permutation_data, self.permutation_name) def apply(self, input_tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ diff --git a/src/compressed_tensors/transforms/hadamard.py b/src/compressed_tensors/transforms/hadamard.py index 6c91d177..7f280e55 100644 --- a/src/compressed_tensors/transforms/hadamard.py +++ b/src/compressed_tensors/transforms/hadamard.py @@ -28,11 +28,13 @@ class Hadamard(Transforms): def __init__( self, size: int, + transform_name: str, empty: Optional[bool] = False, - device: Optional[Union[str, torch.device]] = "cuda", + device: Optional[Union[str, torch.device]] = "cpu", dtype: Optional[torch.dtype] = torch.bfloat16, + *args, + **kwargs, ): - """ Produces a hadamard matrix with dims (size, size), with values -1 and 1, and the property HH.T == nI i.e the transformation @@ -52,14 +54,11 @@ def __init__( self.size = size if empty: - # If saved, would have a different lifecycle (would be loaded and not be - # the same parameter, for now) - # Would take more memory transform = torch.empty((size, size)).to(dtype) else: transform = self.fetch().to(dtype).to(device) - super().__init__(transform=transform) + super().__init__(transform=transform, transform_name=transform_name) if not self.matrix_registry.contains(size): self.matrix_registry.set_matrix(size, self.transform) diff --git a/src/compressed_tensors/transforms/matrix_multiply.py b/src/compressed_tensors/transforms/matrix_multiply.py index 599cc11a..5216fd6e 100644 --- a/src/compressed_tensors/transforms/matrix_multiply.py +++ b/src/compressed_tensors/transforms/matrix_multiply.py @@ -22,22 +22,29 @@ ) -# TODO: fix loading + add generic matrix registry? @Transforms.register("matrix-mul") class MatrixMultiply(Transforms): def __init__( self, name: str, - transform_data: torch.Tensor, + transform_name: str, + transform_data: Optional[torch.Tensor] = None, size: Optional[int] = None, empty: Optional[bool] = False, device: Optional[Union[str, torch.device]] = "cpu", dtype: Optional[torch.dtype] = torch.bfloat16, + *args, + **kwargs, ): if empty and size is None: raise ValueError( - "size is required when setting up empty transforms for deserialization " + "size is required when setting up parameters for deserialization " + ) + + if not empty and transform_data is None: + raise ValueError( + "transform_data is required when initializing matrix-multiply transforms" ) # name required to either pull a cached matrix or cache a matrix itself @@ -45,20 +52,21 @@ def __init__( self.name = name self.matrix_registry = SingletonMatrixRegistry() + # Can we get rid of the size for deserialization? if empty: transform = torch.empty((size, size)).to(dtype) else: - transform = self.fetch().to(dtype).to(device) + transform = self.fetch(transform_data).to(dtype).to(device) - super().__init__(transform=transform) + super().__init__(transform=transform, transform_name=tranform_name) if not self.matrix_registry.contains(self.name): self.matrix_registry.set_matrix(self.name, self.transform) - def fetch(self): + def fetch(self, transform_data: torch.Tensor): transform = self.matrix_registry.get_matrix(self.name) if transform is None: - transform = transform_data + return transform_data return transform def apply( diff --git a/src/compressed_tensors/transforms/random_hadamard.py b/src/compressed_tensors/transforms/random_hadamard.py index 5a349a31..20e52fda 100644 --- a/src/compressed_tensors/transforms/random_hadamard.py +++ b/src/compressed_tensors/transforms/random_hadamard.py @@ -31,9 +31,13 @@ class RandomHadamard(Transforms): def __init__( self, size: int, + transform_name: str, + permutation_name: str, empty: Optional[bool] = False, - device: Optional[Union[str, torch.device]] = "cuda", + device: Optional[Union[str, torch.device]] = "cpu", dtype: Optional[torch.dtype] = torch.bfloat16, + *args, + **kwargs, ): """ Produces a randomly generated matrix with dims (size, size), with values @@ -65,10 +69,8 @@ def __init__( self.matrix_registry = SingletonMatrixRegistry() if empty: - # If saved, would have a different lifecycle (would be loaded and registered - # Would take more memory transform = torch.empty((size, size)).to(dtype) - permutation = torch.empty((size)).to(dtype) + permutation = torch.empty((size)).to(dtype).to(device) else: transform = self.fetch().to(dtype).to(device) permutation = ( @@ -77,7 +79,12 @@ def __init__( .to(device) ) - super().__init__(transform=transform, permutation=permutation) + super().__init__( + transform=transform, + permutation=permutation, + transform_name=transform_name, + permutation_name=permutation_name, + ) if not self.matrix_registry.contains(size): self.matrix_registry.set_matrix(self.size, self.transform) From ff50ef3fe4e949bed1abc5e2fae7e36c77ed79b6 Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 26 Mar 2025 22:21:47 +0000 Subject: [PATCH 5/5] fix permutation weight loading --- src/compressed_tensors/quantization/lifecycle/apply.py | 2 +- src/compressed_tensors/transforms/transform_scheme.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 789522ed..255754df 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -79,7 +79,7 @@ def load_transforms(model: Module, model_name_or_path: str): state_dict = {} for weight_name, safe_path in weight_mappings.items(): - if "transform" in weight_name: + if "transform" in weight_name or "_perm_" in weight_name: with safe_open(safe_path, framework="pt", device="cpu") as f: state_dict[weight_name] = f.get_tensor(weight_name) diff --git a/src/compressed_tensors/transforms/transform_scheme.py b/src/compressed_tensors/transforms/transform_scheme.py index 012e4128..74d300bb 100644 --- a/src/compressed_tensors/transforms/transform_scheme.py +++ b/src/compressed_tensors/transforms/transform_scheme.py @@ -31,16 +31,14 @@ class TransformationScheme(BaseModel): layers that should be targeted by the specified transform. By providing a list, users have the ability to apply the same transform type (with the same set of arguments) to different layers. No + :param shared: if an identical transform is to be used for all the groups :param transform_creation_args: arguments needed to initialize the transform, if any """ - # TODO: maybe we don't need "global" - # If it's the same transform (but different call_args) in the list? - # Come back to this - transform_type: str groups: List[TransformationArgs] + shared: bool = False transform_creation_args: Optional[Dict[str, Any]] = None @field_validator("transform_type", mode="before")