Skip to content

Commit 19195be

Browse files
committed
populate _dynamic_tied_weights_keys
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 438bc13 commit 19195be

File tree

3 files changed

+41
-3
lines changed

3 files changed

+41
-3
lines changed

src/compressed_tensors/transform/factory/base.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414

1515
from abc import ABC, abstractmethod
16-
from typing import Optional
16+
from collections import defaultdict
17+
from typing import List, Optional, Tuple
1718

1819
import torch
1920
import torch.nn.utils.parametrize as P
@@ -48,10 +49,13 @@ class TransformFactory(RegistryMixin, ABC):
4849
:param seed: random seed used to transform weight randomization
4950
"""
5051

52+
transforms: List["TransformBase"]
53+
5154
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
5255
self.name = name
5356
self.scheme = scheme
5457
self.generator = torch.Generator()
58+
self.transforms = list()
5559
if seed is not None:
5660
self.generator.manual_seed(seed)
5761

@@ -90,6 +94,8 @@ def apply_to_model(self, model: Module):
9094
if is_target(name, module, arg.targets, arg.ignore):
9195
self._apply_to_module(module, arg)
9296

97+
self._update_tied_weights()
98+
9399
def _apply_to_module(self, module: Module, args: TransformArgs):
94100
"""
95101
Create transforms and apply them to the module
@@ -145,6 +151,28 @@ def output_hook(_, _input, output):
145151
else:
146152
raise NotImplementedError()
147153

154+
def _update_tied_weights(self):
155+
"""
156+
Populate the `_dynamic_tied_weights_keys` attribute of transforms,
157+
which is used by transformers to detect and remove shared pointers
158+
during saving
159+
"""
160+
# avoid issues with this method being called twice
161+
for transform in self.transforms:
162+
transform._dynamic_tied_weights_keys = list()
163+
164+
# map from data_ptrs to keys
165+
ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list)
166+
for transform in self.transforms:
167+
for name, param in transform.named_parameters(recurse=False):
168+
ptr_to_keys[param.data_ptr()].append((transform, name))
169+
170+
# populate `_dynamic_tied_weights_keys` if there is more than one key
171+
for shared_keys in ptr_to_keys.values():
172+
if len(shared_keys) > 1:
173+
for transform, name in shared_keys:
174+
transform._dynamic_tied_weights_keys.append(name)
175+
148176

149177
class TransformBase(Module, ABC):
150178
"""
@@ -153,6 +181,11 @@ class TransformBase(Module, ABC):
153181

154182
args: TransformArgs
155183
weight: Parameter
184+
_dynamic_tied_weights_keys: List[str]
185+
186+
def __init__(self):
187+
super().__init__()
188+
self._dynamic_tied_weights_keys = list()
156189

157190
@abstractmethod
158191
def forward(self, value: Tensor) -> Tensor:

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ def create_transform(self, module: Module, args: TransformArgs):
6060
factory_kwargs = {"construct_device": exec_device}
6161
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
6262
perm = self.perms[weight] if self.scheme.randomize else None
63-
return HadamardTransform(weight, perm, args)
63+
64+
transform = HadamardTransform(weight, perm, args)
65+
self.transforms.append(transform)
66+
return transform
6467

6568
def _create_weight(
6669
self,

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def create_transform(self, module: Module, args: TransformArgs):
5959
if args.inverse:
6060
weight = self.inverses[weight]
6161

62-
return RandomMatrixTransform(weight, args)
62+
transform = RandomMatrixTransform(weight, args)
63+
self.transforms.append(transform)
64+
return transform
6365

6466
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
6567
# TODO: verify that weight is invertible (has non-zero determinant)

0 commit comments

Comments
 (0)