@@ -189,19 +189,21 @@ def decompress(
189
189
path_to_model_or_tensors , names_to_scheme
190
190
)
191
191
192
- def decompress_from_state_dict (
192
+ def _decompress_from_path (
193
193
self ,
194
- state_dict : Dict [str , torch . Tensor ],
194
+ path_to_model : Union [ str , Path , Dict [str , Any ] ],
195
195
names_to_scheme : Dict [str , QuantizationScheme ],
196
- ) -> Generator [Tuple [str , Dict [str , torch .Tensor ]], None , None ]:
197
- weight_mappings = get_nested_mappings_from_state_dict (
198
- state_dict , self .compression_param_names
196
+ device : str ,
197
+ ):
198
+ weight_mappings = get_nested_weight_mappings (
199
+ path_to_model , self .compression_param_names
199
200
)
200
201
for module_path in weight_mappings .keys ():
201
202
weight_data = {}
202
- for param_name , param_value in weight_mappings [module_path ].items ():
203
- weight_data [param_name ] = param_value
204
-
203
+ for param_name , safe_path in weight_mappings [module_path ].items ():
204
+ full_name = merge_names (module_path , param_name )
205
+ with safe_open (safe_path , framework = "pt" , device = device ) as f :
206
+ weight_data [param_name ] = f .get_tensor (full_name )
205
207
if "weight_scale" in weight_data :
206
208
quant_args = names_to_scheme [module_path ].weights
207
209
decompressed = self .decompress_weight (
@@ -210,21 +212,19 @@ def decompress_from_state_dict(
210
212
weight_data ["weight" ] = decompressed
211
213
yield module_path , weight_data
212
214
213
- def _decompress_from_path (
215
+ def decompress_from_state_dict (
214
216
self ,
215
- path_to_model : Union [ str , Path , Dict [str , Any ] ],
217
+ state_dict : Dict [str , torch . Tensor ],
216
218
names_to_scheme : Dict [str , QuantizationScheme ],
217
- device : str ,
218
- ):
219
- weight_mappings = get_nested_weight_mappings (
220
- path_to_model , self .compression_param_names
219
+ ) -> Generator [Tuple [str , Dict [str , torch .Tensor ]], None , None ]:
220
+ weight_mappings = get_nested_mappings_from_state_dict (
221
+ state_dict , self .compression_param_names
221
222
)
222
223
for module_path in weight_mappings .keys ():
223
224
weight_data = {}
224
- for param_name , safe_path in weight_mappings [module_path ].items ():
225
- full_name = merge_names (module_path , param_name )
226
- with safe_open (safe_path , framework = "pt" , device = device ) as f :
227
- weight_data [param_name ] = f .get_tensor (full_name )
225
+ for param_name , param_value in weight_mappings [module_path ].items ():
226
+ weight_data [param_name ] = param_value
227
+
228
228
if "weight_scale" in weight_data :
229
229
quant_args = names_to_scheme [module_path ].weights
230
230
decompressed = self .decompress_weight (
0 commit comments