24
24
get_nested_weight_mappings ,
25
25
merge_names ,
26
26
)
27
+ from compressed_tensors .utils .safetensors_load import match_param_name
27
28
from safetensors import safe_open
28
29
from torch import Tensor
29
30
from tqdm import tqdm
@@ -223,9 +224,7 @@ def decompress_from_state_dict(
223
224
state_dict , self .compression_param_names
224
225
)
225
226
for module_path in weight_mappings .keys ():
226
- weight_data = {}
227
- for param_name , param_value in weight_mappings [module_path ].items ():
228
- weight_data [param_name ] = param_value
227
+ weight_data = weight_mappings [module_path ].copy ()
229
228
230
229
if "weight_scale" in weight_data :
231
230
quant_args = names_to_scheme [module_path ].weights
@@ -234,3 +233,31 @@ def decompress_from_state_dict(
234
233
)
235
234
weight_data ["weight" ] = decompressed
236
235
yield module_path , weight_data
236
+
237
+ def decompress_module_from_state_dict (
238
+ self ,
239
+ prefix : str ,
240
+ state_dict : Dict [str , torch .Tensor ],
241
+ scheme : QuantizationScheme ,
242
+ ) -> Dict [str , torch .Tensor ]:
243
+ """
244
+ Only used by in-memory decompression pathways to decompress the parameters of
245
+ one module
246
+
247
+ :param prefix: prefix of state_dict, typically the path to the module
248
+ :param state_dict: state dict containing module parameter values
249
+ :param scheme: quantization scheme of module to decompress
250
+ :return: state dict with weight decompressed if applicable
251
+ """
252
+ state_dict = {
253
+ key .removeprefix (f"{ prefix } ." ): value for key , value in state_dict .items ()
254
+ }
255
+
256
+ if "weight_scale" in state_dict :
257
+ state_dict ["weight" ] = self .decompress_weight (
258
+ compressed_data = state_dict , quantization_args = scheme .weights
259
+ )
260
+
261
+ state_dict = {f"{ prefix } .{ key } " : value for key , value in state_dict .items ()}
262
+
263
+ return state_dict
0 commit comments