26
26
)
27
27
from compressed_tensors .utils import (
28
28
align_module_device ,
29
+ delete_offload_module ,
29
30
has_offloaded_params ,
30
31
patch_attr ,
31
32
register_offload_module ,
@@ -99,10 +100,10 @@ def _apply_to_module(self, module: Module, args: TransformArgs):
99
100
# create transform as submodule
100
101
transform_name = f"{ self .name } _{ args .location .value } "
101
102
transform = self .create_transform (module , args )
103
+ register_offload_module (module , transform_name , transform )
102
104
103
105
# register input transformation hook
104
106
if args .location == TransformLocation .INPUT :
105
- register_offload_module (module , transform_name , transform )
106
107
107
108
def input_hook (_ , args ):
108
109
input = args [0 ]
@@ -118,6 +119,7 @@ def input_hook(_, args):
118
119
assert isinstance (module , torch .nn .Linear )
119
120
assert module .bias is None
120
121
122
+ # fuse transform into weight
121
123
with torch .no_grad (), align_module_device (module ):
122
124
update_offload_parameter (module , "weight" , transform (module .weight ))
123
125
@@ -128,9 +130,11 @@ def input_hook(_, args):
128
130
raise ValueError ("Offloaded training is not supported" )
129
131
P .register_parametrization (module , "weight" , transform )
130
132
133
+ # transform is no longer needed (unfusing is not supported)
134
+ delete_offload_module (module , transform_name )
135
+
131
136
# register output transformation hook
132
137
elif args .location == TransformLocation .OUTPUT :
133
- register_offload_module (module , transform_name , transform )
134
138
135
139
def output_hook (_ , _input , output ):
136
140
return transform (output )
0 commit comments