@@ -40,11 +40,10 @@ def apply_transforms_to_parameter(
4040 """
4141
4242 for transform_name , transform_values in transform_data .data .items ():
43- transform = getattr (module , transform_name )
44- apply = Transforms .fetch_apply (transform_values .get ("type" ))
43+ transform = transform_values .get ("transform" )
4544 call_args = transform_values .get ("call_args" )
46- transformed_param_data = apply (
47- input_tensor = module_parameter , transform = transform , ** call_args
45+ transformed_param_data = transform . apply (
46+ input_tensor = module_parameter , ** call_args
4847 )
4948 module_parameter .data .copy_ (transformed_param_data )
5049
@@ -67,10 +66,9 @@ def apply_inverse_transforms_to_parameter(
6766 """
6867
6968 for transform_name , transform_values in reversed (transform_data .data .items ()):
70- transform = getattr (module , transform_name )
71- inverse_apply = Transforms .fetch_inverse_apply (transform_values .get ("type" ))
69+ transform = transform_values .get ("transform" )
7270 call_args = transform_values .get ("call_args" )
73- transformed_param_data = inverse_apply (
74- input_tensor = module_parameter , transform = transform , ** call_args
71+ transformed_param_data = transform . inverse_apply (
72+ input_tensor = module_parameter , ** call_args
7573 )
7674 module_parameter .data .copy_ (transformed_param_data )
0 commit comments