@@ -234,11 +234,11 @@ def get_nested_weight_mappings(
234
234
for key , file_location in weight_mappings .items ():
235
235
matched = False
236
236
for param_name in params_to_nest :
237
- dense_param = match_param_name (key , param_name )
238
- if dense_param :
239
- if dense_param not in nested_weight_mappings :
240
- nested_weight_mappings [dense_param ] = {}
241
- nested_weight_mappings [dense_param ][param_name ] = file_location
237
+ module_path = match_param_name (key , param_name )
238
+ if module_path :
239
+ if module_path not in nested_weight_mappings :
240
+ nested_weight_mappings [module_path ] = {}
241
+ nested_weight_mappings [module_path ][param_name ] = file_location
242
242
matched = True
243
243
if return_unmatched_params and not matched :
244
244
unmatched_params [key ] = file_location
@@ -271,11 +271,11 @@ def get_nested_mappings_from_state_dict(
271
271
nested_weight_mappings = {}
272
272
for key in state_dict .keys ():
273
273
for param_name in params_to_nest :
274
- dense_param = match_param_name (key , param_name )
275
- if dense_param :
276
- if dense_param not in nested_weight_mappings :
277
- nested_weight_mappings [dense_param ] = {}
278
- nested_weight_mappings [dense_param ][param_name ] = state_dict [key ]
274
+ module_path = match_param_name (key , param_name )
275
+ if module_path :
276
+ if module_path not in nested_weight_mappings :
277
+ nested_weight_mappings [module_path ] = {}
278
+ nested_weight_mappings [module_path ][param_name ] = state_dict [key ]
279
279
return nested_weight_mappings
280
280
281
281
0 commit comments