@@ -84,11 +84,16 @@ def load_transforms(model: Module, model_name_or_path: str):
8484
8585 for name , submodule in iter_named_leaf_modules (model ):
8686 transform_data = getattr (submodule , "transform_data" , None )
87+
8788 if transform_data :
88- for transform_name , transform_data in transform_data .data .items ():
89+ for transform_name , transform_values in transform_data .data .items ():
8990 full_name = f"{ name } .{ transform_name } "
9091 transform_data = state_dict .get (full_name , None )
91- update_parameter_data (submodule , transform_data , transform_name )
92+ transform = transform_values .get ("transform" )
93+ transform .register_to_module (name = transform_name , module = submodule )
94+ transform .update_transform (
95+ module = submodule , data = transform_data , name = transform_name
96+ )
9297
9398
9499def load_pretrained_quantization (model : Module , model_name_or_path : str ):
@@ -194,7 +199,9 @@ def process_transforms_config(
194199 dtype = dtype ,
195200 ** transform_creation_args ,
196201 )
197- transform .register_to_module (name = transform_name , module = submodule )
202+ transform .register_to_module (
203+ name = transform_name , module = submodule
204+ )
198205
199206 # add relevant transform data to the submodule as well
200207 data = {
0 commit comments