File tree 1 file changed +8
-2
lines changed
src/sparseml/pytorch/sparsification/quantization
1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -361,10 +361,13 @@ def _match_submodule_name_or_type(
361
361
# 2. match the submodule prefix (longest first)
362
362
submodule_match = ""
363
363
for name_or_type in names_or_types :
364
+ name_to_compare = submodule_name [:]
365
+ if name_to_compare .startswith ("module." ):
366
+ name_to_compare = name_to_compare [7 :]
364
367
if name_or_type == submodule .__class__ .__name__ :
365
368
# type match, return type name
366
369
return name_or_type
367
- if submodule_name .startswith (name_or_type ) and (
370
+ if name_to_compare .startswith (name_or_type ) and (
368
371
len (name_or_type ) > len (submodule_match )
369
372
):
370
373
# match to most specific submodule name
@@ -422,7 +425,10 @@ def _get_unmatched_types_or_names(types_or_names):
422
425
for type_or_name in types_or_names :
423
426
matched = False
424
427
for submodule_name , submodule in model .named_modules ():
425
- if submodule_name .startswith (type_or_name ) or (
428
+ name_to_compare = submodule_name [:]
429
+ if name_to_compare .startswith ("module." ):
430
+ name_to_compare = name_to_compare [7 :]
431
+ if name_to_compare .startswith (type_or_name ) or (
426
432
submodule .__class__ .__name__ == type_or_name
427
433
):
428
434
matched = True
You can’t perform that action at this time.
0 commit comments