@@ -40,11 +40,10 @@ def apply_transforms_to_parameter(
40
40
"""
41
41
42
42
for transform_name , transform_values in transform_data .data .items ():
43
- transform = getattr (module , transform_name )
44
- apply = Transforms .fetch_apply (transform_values .get ("type" ))
43
+ transform = transform_values .get ("transform" )
45
44
call_args = transform_values .get ("call_args" )
46
- transformed_param_data = apply (
47
- input_tensor = module_parameter , transform = transform , ** call_args
45
+ transformed_param_data = transform . apply (
46
+ input_tensor = module_parameter , ** call_args
48
47
)
49
48
module_parameter .data .copy_ (transformed_param_data )
50
49
@@ -67,10 +66,9 @@ def apply_inverse_transforms_to_parameter(
67
66
"""
68
67
69
68
for transform_name , transform_values in reversed (transform_data .data .items ()):
70
- transform = getattr (module , transform_name )
71
- inverse_apply = Transforms .fetch_inverse_apply (transform_values .get ("type" ))
69
+ transform = transform_values .get ("transform" )
72
70
call_args = transform_values .get ("call_args" )
73
- transformed_param_data = inverse_apply (
74
- input_tensor = module_parameter , transform = transform , ** call_args
71
+ transformed_param_data = transform . inverse_apply (
72
+ input_tensor = module_parameter , ** call_args
75
73
)
76
74
module_parameter .data .copy_ (transformed_param_data )
0 commit comments