@@ -195,33 +195,33 @@ def _decompress_from_path(
195195 weight_mappings = get_nested_weight_mappings (
196196 path_to_model , self .compression_param_names
197197 )
198- for weight_name in weight_mappings .keys ():
198+ for module_path in weight_mappings .keys ():
199199 weight_data = {}
200- for param_name , safe_path in weight_mappings [weight_name ].items ():
201- full_name = merge_names (weight_name , param_name )
200+ for param_name , safe_path in weight_mappings [module_path ].items ():
201+ full_name = merge_names (module_path , param_name )
202202 with safe_open (safe_path , framework = "pt" , device = device ) as f :
203203 weight_data [param_name ] = f .get_tensor (full_name )
204204 if "weight_scale" in weight_data :
205- quant_args = names_to_scheme [weight_name ].weights
205+ quant_args = names_to_scheme [module_path ].weights
206206 decompressed = self .decompress_weight (
207207 compressed_data = weight_data , quantization_args = quant_args
208208 )
209209 weight_data ["weight" ] = decompressed
210- yield weight_name , weight_data
210+ yield module_path , weight_data
211211
212212 def _decompress_from_state_dict (self , state_dict , names_to_scheme ):
213213 weight_mappings = get_nested_mappings_from_state_dict (
214214 state_dict , self .compression_param_names
215215 )
216- for weight_name in weight_mappings .keys ():
216+ for module_path in weight_mappings .keys ():
217217 weight_data = {}
218- for param_name , param_value in weight_mappings [weight_name ].items ():
218+ for param_name , param_value in weight_mappings [module_path ].items ():
219219 weight_data [param_name ] = param_value
220220
221221 if "weight_scale" in weight_data :
222- quant_args = names_to_scheme [weight_name ]
222+ quant_args = names_to_scheme [module_path ]
223223 decompressed = self .decompress_weight (
224224 compressed_data = weight_data , quantization_args = quant_args
225225 )
226226 weight_data ["weight" ] = decompressed
227- yield weight_name , weight_data
227+ yield module_path , weight_data
0 commit comments