Skip to content

Commit 05c4735

Browse files
committed
update
1 parent adda166 commit 05c4735

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,12 @@ def process_transforms_config(
194194
dtype=dtype,
195195
**transform_creation_args,
196196
)
197-
setattr(submodule, transform_name, transform)
197+
transform.register_to_module(name=transform_name, module=submodule)
198198

199199
# add relevant transform data to the submodule as well
200200
data = {
201201
transform_name: {
202-
"type": transform_type,
202+
"transform": transform,
203203
"call_args": transform_arg.call_args,
204204
}
205205
}

src/compressed_tensors/transforms/apply.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,10 @@ def apply_transforms_to_parameter(
4040
"""
4141

4242
for transform_name, transform_values in transform_data.data.items():
43-
transform = getattr(module, transform_name)
44-
apply = Transforms.fetch_apply(transform_values.get("type"))
43+
transform = transform_values.get("transform")
4544
call_args = transform_values.get("call_args")
46-
transformed_param_data = apply(
47-
input_tensor=module_parameter, transform=transform, **call_args
45+
transformed_param_data = transform.apply(
46+
input_tensor=module_parameter, **call_args
4847
)
4948
module_parameter.data.copy_(transformed_param_data)
5049

@@ -67,10 +66,9 @@ def apply_inverse_transforms_to_parameter(
6766
"""
6867

6968
for transform_name, transform_values in reversed(transform_data.data.items()):
70-
transform = getattr(module, transform_name)
71-
inverse_apply = Transforms.fetch_inverse_apply(transform_values.get("type"))
69+
transform = transform_values.get("transform")
7270
call_args = transform_values.get("call_args")
73-
transformed_param_data = inverse_apply(
74-
input_tensor=module_parameter, transform=transform, **call_args
71+
transformed_param_data = transform.inverse_apply(
72+
input_tensor=module_parameter, **call_args
7573
)
7674
module_parameter.data.copy_(transformed_param_data)

0 commit comments

Comments
 (0)