Skip to content

Commit 579337d

Browse files
committed
move loading step to transformers
1 parent 05c4735 commit 579337d

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
QuantizationStatus,
4040
apply_quantization_config,
4141
load_pretrained_quantization,
42-
load_transforms,
4342
)
4443
from compressed_tensors.quantization.lifecycle import expand_target_names
4544
from compressed_tensors.quantization.quant_args import QuantizationArgs
@@ -472,8 +471,6 @@ def decompress(self, model_path: str, model: Module):
472471
)
473472
load_pretrained_quantization(model, model_path)
474473

475-
load_transforms(model, model_path)
476-
477474
model_path_or_state_dict = (
478475
model.state_dict() if sparse_decompressed else model_path
479476
)

src/compressed_tensors/quantization/lifecycle/apply.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,16 @@ def load_transforms(model: Module, model_name_or_path: str):
8484

8585
for name, submodule in iter_named_leaf_modules(model):
8686
transform_data = getattr(submodule, "transform_data", None)
87+
8788
if transform_data:
88-
for transform_name, transform_data in transform_data.data.items():
89+
for transform_name, transform_values in transform_data.data.items():
8990
full_name = f"{name}.{transform_name}"
9091
transform_data = state_dict.get(full_name, None)
91-
update_parameter_data(submodule, transform_data, transform_name)
92+
transform = transform_values.get("transform")
93+
transform.register_to_module(name=transform_name, module=submodule)
94+
transform.update_transform(
95+
module=submodule, data=transform_data, name=transform_name
96+
)
9297

9398

9499
def load_pretrained_quantization(model: Module, model_name_or_path: str):
@@ -194,7 +199,9 @@ def process_transforms_config(
194199
dtype=dtype,
195200
**transform_creation_args,
196201
)
197-
transform.register_to_module(name=transform_name, module=submodule)
202+
transform.register_to_module(
203+
name=transform_name, module=submodule
204+
)
198205

199206
# add relevant transform data to the submodule as well
200207
data = {

0 commit comments

Comments
 (0)