@@ -99,10 +99,10 @@ def _apply_to_module(self, module: Module, args: TransformArgs):
99
99
# create transform as submodule
100
100
transform_name = f"{ self .name } _{ args .location .value } "
101
101
transform = self .create_transform (module , args )
102
- register_offload_module (module , transform_name , transform ) # (1)
103
102
104
103
# register input transformation hook
105
104
if args .location == TransformLocation .INPUT :
105
+ register_offload_module (module , transform_name , transform )
106
106
107
107
def input_hook (_ , args ):
108
108
input = args [0 ]
@@ -130,6 +130,7 @@ def input_hook(_, args):
130
130
131
131
# register output transformation hook
132
132
elif args .location == TransformLocation .OUTPUT :
133
+ register_offload_module (module , transform_name , transform )
133
134
134
135
def output_hook (_ , _input , output ):
135
136
return transform (output )
@@ -140,9 +141,6 @@ def output_hook(_, _input, output):
140
141
else :
141
142
raise NotImplementedError ()
142
143
143
- # (1) even in the `weight` cases, this submodule attachment is needed in order
144
- # to support saving in the frozen state
145
-
146
144
147
145
class TransformBase (Module , ABC ):
148
146
"""
0 commit comments