Skip to content

Commit 58cb352

Browse files
committed
Allow pattern replacement to ignore literals
Summary: This is necessary because sometimes the patterns found have literals include tuple of ints kind of literals. This values shouldnt be used for pattern matching since often they are based on consts derived from example inputs. THis is not exactly a safe thing to do in general so by default it is turned off Test Plan: Subsequent diff adds a pattern that relies on this Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 05225ec Pull Request resolved: #2519
1 parent b281af7 commit 58cb352

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

torchao/quantization/pt2e/reference_representation_rewrite.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch._higher_order_ops.out_dtype import out_dtype
1515
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1616
from torch.fx import GraphModule
17-
from torch.fx.subgraph_rewriter import replace_pattern
17+
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
1818

1919
from torchao.quantization.pt2e.export_utils import WrapperModule
2020
from torchao.quantization.pt2e.utils import (
@@ -627,6 +627,7 @@ class _RewriteInfo:
627627
# post transformation on the exported pattern and replacement GraphModule
628628
pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
629629
replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
630+
ignore_literals: bool = False
630631

631632

632633
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
@@ -830,6 +831,12 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
830831
replacement = replacement_post_trans(replacement)
831832
pattern.recompile() # type: ignore[attr-defined]
832833
replacement.recompile() # type: ignore[attr-defined]
833-
replace_pattern(model, pattern, replacement)
834+
replace_pattern_with_filters(
835+
model,
836+
pattern,
837+
replacement,
838+
match_filters=None,
839+
ignore_literals=rewrite_info.ignore_literals,
840+
) # type: ignore[arg-type]
834841

835842
return model

0 commit comments

Comments
 (0)