Skip to content

Commit b009f47

Browse files
committed
ensure serializable
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 5a95fd2 commit b009f47

File tree

7 files changed

+77
-12
lines changed

7 files changed

+77
-12
lines changed

src/compressed_tensors/transform/factory/base.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,17 @@ def _apply_to_module(self, module: Module, args: TransformArgs):
103103
:param module: target module to apply transforms to
104104
:param args: defines how the transform will be applied to the target module
105105
"""
106+
if has_offloaded_params(module):
107+
if module._hf_hook.place_submodules:
108+
raise NotImplementedError(
109+
"Applying transforms to offloaded submodules with "
110+
"`place_submodules=True` is not supported"
111+
)
112+
106113
# create transform as submodule
107114
transform_name = f"{self.name}_{args.location.value}"
108115
transform = self.create_transform(module, args)
116+
self.transforms.append(transform)
109117
register_offload_module(module, transform_name, transform)
110118

111119
# register input transformation hook
@@ -136,8 +144,9 @@ def input_hook(_, args):
136144
raise ValueError("Offloaded training is not supported")
137145
P.register_parametrization(module, "weight", transform)
138146

139-
# transform is no longer needed (unfusing is not supported)
140-
delete_offload_module(module, transform_name)
147+
else:
148+
# transform is no longer needed (unfusing is not supported)
149+
delete_offload_module(module, transform_name)
141150

142151
# register output transformation hook
143152
elif args.location == TransformLocation.OUTPUT:
@@ -165,13 +174,20 @@ def _update_tied_weights(self):
165174
ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list)
166175
for transform in self.transforms:
167176
for name, param in transform.named_parameters(recurse=False):
177+
# NOTE: previously asserted that parent._hf_hook.place_submodules=False
178+
if has_offloaded_params(transform):
179+
param = transform._hf_hook.weights_map[name]
168180
ptr_to_keys[param.data_ptr()].append((transform, name))
169181

170182
# populate `_dynamic_tied_weights_keys` if there is more than one key
183+
# and ensure that they share tensors
171184
for shared_keys in ptr_to_keys.values():
172185
if len(shared_keys) > 1:
186+
tensor = getattr(shared_keys[0][0], shared_keys[0][1])
187+
173188
for transform, name in shared_keys:
174189
transform._dynamic_tied_weights_keys.append(name)
190+
setattr(transform, name, tensor)
175191

176192

177193
class TransformBase(Module, ABC):

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,7 @@ def create_transform(self, module: Module, args: TransformArgs):
6161
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
6262
perm = self.perms[weight] if self.scheme.randomize else None
6363

64-
transform = HadamardTransform(weight, perm, args)
65-
self.transforms.append(transform)
66-
return transform
64+
return HadamardTransform(weight, perm, args)
6765

6866
def _create_weight(
6967
self,

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,7 @@ def create_transform(self, module: Module, args: TransformArgs):
5959
if args.inverse:
6060
weight = self.inverses[weight]
6161

62-
transform = RandomMatrixTransform(weight, args)
63-
self.transforms.append(transform)
64-
return transform
62+
return RandomMatrixTransform(weight, args)
6563

6664
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
6765
# TODO: verify that weight is invertible (has non-zero determinant)
@@ -72,6 +70,7 @@ def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
7270

7371
def _create_inverse(self, weight: Parameter) -> Parameter:
7472
data = high_precision_invert(weight.data)
73+
data = data.contiguous() # ensure proper serialization
7574
return Parameter(data, requires_grad=False)
7675

7776

tests/test_transform/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414

1515
import pytest
1616
import torch
17-
from compressed_tensors.transform import TransformArgs
17+
from compressed_tensors.transform import TransformArgs, TransformFactory
18+
from transformers import PretrainedConfig, PreTrainedModel
1819

1920

20-
class TransformableModel(torch.nn.Module):
21+
class TransformableModel(PreTrainedModel):
2122
def __init__(self, *sizes):
22-
super().__init__()
23+
super().__init__(config=PretrainedConfig())
2324
self.fcs = torch.nn.ModuleList(
2425
[
2526
torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)

tests/test_transform/factory/test_correctness.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
apply_transform_config,
2323
)
2424
from compressed_tensors.utils import offloaded_dispatch
25+
from tests.test_transform.conftest import scheme_kwargs
2526
from tests.testing_utils import requires_accelerate, requires_gpu
2627

2728

tests/test_transform/factory/test_memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
TransformArgs,
2121
TransformBase,
2222
TransformConfig,
23-
TransformFactory,
2423
TransformScheme,
24+
TransformFactory,
2525
apply_transform_config,
2626
)
2727
from compressed_tensors.utils import align_modules, offloaded_dispatch
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
import torch
17+
from compressed_tensors.transform import (
18+
TransformConfig,
19+
TransformScheme,
20+
apply_transform_config,
21+
)
22+
from compressed_tensors.utils import offloaded_dispatch
23+
from tests.test_transform.conftest import scheme_kwargs
24+
from tests.testing_utils import requires_accelerate, requires_gpu
25+
26+
27+
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
28+
def test_serialization(scheme_kwargs, model_apply, tmp_path, offload=False):
29+
# get model, maybe offload
30+
model, apply = model_apply
31+
if offload:
32+
offloaded_dispatch(model, torch.device("cuda"))
33+
34+
# apply transforms to model
35+
config = TransformConfig(
36+
config_groups={"": TransformScheme(**scheme_kwargs, apply=apply)}
37+
)
38+
apply_transform_config(model, config)
39+
40+
# save model
41+
model.save_pretrained(tmp_path)
42+
43+
# TODO: reload model
44+
45+
46+
@requires_gpu
47+
@requires_accelerate()
48+
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
49+
def test_serialization_offload(scheme_kwargs, model_apply, tmp_path):
50+
test_serialization(scheme_kwargs, model_apply, tmp_path, offload=True)

0 commit comments

Comments
 (0)