Skip to content

Commit 22e63cd

Browse files
authored
Fix for quantization modifier w/ DDP (#1594)
* Removes "module." from submodule names inserted by DDP * Renamed variables to make them consistent
1 parent 7abe53e commit 22e63cd

File tree

1 file changed

+8
-2
lines changed
  • src/sparseml/pytorch/sparsification/quantization

1 file changed

+8
-2
lines changed

src/sparseml/pytorch/sparsification/quantization/quantize.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -361,10 +361,13 @@ def _match_submodule_name_or_type(
361361
# 2. match the submodule prefix (longest first)
362362
submodule_match = ""
363363
for name_or_type in names_or_types:
364+
name_to_compare = submodule_name[:]
365+
if name_to_compare.startswith("module."):
366+
name_to_compare = name_to_compare[7:]
364367
if name_or_type == submodule.__class__.__name__:
365368
# type match, return type name
366369
return name_or_type
367-
if submodule_name.startswith(name_or_type) and (
370+
if name_to_compare.startswith(name_or_type) and (
368371
len(name_or_type) > len(submodule_match)
369372
):
370373
# match to most specific submodule name
@@ -422,7 +425,10 @@ def _get_unmatched_types_or_names(types_or_names):
422425
for type_or_name in types_or_names:
423426
matched = False
424427
for submodule_name, submodule in model.named_modules():
425-
if submodule_name.startswith(type_or_name) or (
428+
name_to_compare = submodule_name[:]
429+
if name_to_compare.startswith("module."):
430+
name_to_compare = name_to_compare[7:]
431+
if name_to_compare.startswith(type_or_name) or (
426432
submodule.__class__.__name__ == type_or_name
427433
):
428434
matched = True

0 commit comments

Comments
 (0)