Skip to content

Commit 5cb9a44

Browse files
committed
populate _dynamic_tied_weights_keys
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 3c55003 commit 5cb9a44

File tree

3 files changed

+41
-3
lines changed

3 files changed

+41
-3
lines changed

src/compressed_tensors/transform/factory/base.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414

1515
from abc import ABC, abstractmethod
16-
from typing import Optional
16+
from collections import defaultdict
17+
from typing import List, Optional, Tuple
1718

1819
import torch
1920
import torch.nn.utils.parametrize as P
@@ -47,10 +48,13 @@ class TransformFactory(RegistryMixin, ABC):
4748
:param seed: random seed used to transform weight randomization
4849
"""
4950

51+
transforms: List["TransformBase"]
52+
5053
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
5154
self.name = name
5255
self.scheme = scheme
5356
self.generator = torch.Generator()
57+
self.transforms = list()
5458
if seed is not None:
5559
self.generator.manual_seed(seed)
5660

@@ -89,6 +93,8 @@ def apply_to_model(self, model: Module):
8993
if is_target(name, module, arg.targets, arg.ignore):
9094
self._apply_to_module(module, arg)
9195

96+
self._update_tied_weights()
97+
9298
def _apply_to_module(self, module: Module, args: TransformArgs):
9399
"""
94100
Create transforms and apply them to the module
@@ -143,6 +149,28 @@ def output_hook(_, _input, output):
143149
# (1) even in the `weight` cases, this submodule attachment is needed in order
144150
# to support saving in the frozen state
145151

152+
def _update_tied_weights(self):
153+
"""
154+
Populate the `_dynamic_tied_weights_keys` attribute of transforms,
155+
which is used by transformers to detect and remove shared pointers
156+
during saving
157+
"""
158+
# avoid issues with this method being called twice
159+
for transform in self.transforms:
160+
transform._dynamic_tied_weights_keys = list()
161+
162+
# map from data_ptrs to keys
163+
ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list)
164+
for transform in self.transforms:
165+
for name, param in transform.named_parameters(recurse=False):
166+
ptr_to_keys[param.data_ptr()].append((transform, name))
167+
168+
# populate `_dynamic_tied_weights_keys` if there is more than one key
169+
for shared_keys in ptr_to_keys.values():
170+
if len(shared_keys) > 1:
171+
for transform, name in shared_keys:
172+
transform._dynamic_tied_weights_keys.append(name)
173+
146174

147175
class TransformBase(Module, ABC):
148176
"""
@@ -151,6 +179,11 @@ class TransformBase(Module, ABC):
151179

152180
args: TransformArgs
153181
weight: Parameter
182+
_dynamic_tied_weights_keys: List[str]
183+
184+
def __init__(self):
185+
super().__init__()
186+
self._dynamic_tied_weights_keys = list()
154187

155188
@abstractmethod
156189
def forward(self, value: Tensor) -> Tensor:

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@ def create_transform(self, module: Module, args: TransformArgs):
5959

6060
weight = self.weights.get(size, dtype, device, construct_device=exec_device)
6161
perm = self.perms[weight] if self.scheme.randomize else None
62-
return HadamardTransform(weight, perm, args)
62+
63+
transform = HadamardTransform(weight, perm, args)
64+
self.transforms.append(transform)
65+
return transform
6366

6467
def _create_weight(
6568
self,

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def create_transform(self, module: Module, args: TransformArgs):
5959
if args.inverse:
6060
weight = self.inverses[weight]
6161

62-
return RandomMatrixTransform(weight, args)
62+
transform = RandomMatrixTransform(weight, args)
63+
self.transforms.append(transform)
64+
return transform
6365

6466
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
6567
data = torch.rand(

0 commit comments

Comments
 (0)