@@ -84,11 +84,16 @@ def load_transforms(model: Module, model_name_or_path: str):
84
84
85
85
for name , submodule in iter_named_leaf_modules (model ):
86
86
transform_data = getattr (submodule , "transform_data" , None )
87
+
87
88
if transform_data :
88
- for transform_name , transform_data in transform_data .data .items ():
89
+ for transform_name , transform_values in transform_data .data .items ():
89
90
full_name = f"{ name } .{ transform_name } "
90
91
transform_data = state_dict .get (full_name , None )
91
- update_parameter_data (submodule , transform_data , transform_name )
92
+ transform = transform_values .get ("transform" )
93
+ transform .register_to_module (name = transform_name , module = submodule )
94
+ transform .update_transform (
95
+ module = submodule , data = transform_data , name = transform_name
96
+ )
92
97
93
98
94
99
def load_pretrained_quantization (model : Module , model_name_or_path : str ):
@@ -194,7 +199,9 @@ def process_transforms_config(
194
199
dtype = dtype ,
195
200
** transform_creation_args ,
196
201
)
197
- transform .register_to_module (name = transform_name , module = submodule )
202
+ transform .register_to_module (
203
+ name = transform_name , module = submodule
204
+ )
198
205
199
206
# add relevant transform data to the submodule as well
200
207
data = {
0 commit comments