From 4ef6668f116474d694dda14c1ab695d8268a2096 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 10 Jul 2025 11:33:10 -0700 Subject: [PATCH 1/2] Allow pattern replacement to ignore literals Summary: This is necessary because sometimes the patterns found have literals include tuple of ints kind of literals. This values shouldnt be used for pattern matching since often they are based on consts derived from example inputs. THis is not exactly a safe thing to do in general so by default it is turned off Test Plan: Subsequent diff adds a pattern that relies on this Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- .../quantization/pt2e/reference_representation_rewrite.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/pt2e/reference_representation_rewrite.py b/torchao/quantization/pt2e/reference_representation_rewrite.py index 6526c6044f..cf9738ddc7 100644 --- a/torchao/quantization/pt2e/reference_representation_rewrite.py +++ b/torchao/quantization/pt2e/reference_representation_rewrite.py @@ -14,7 +14,7 @@ from torch._higher_order_ops.out_dtype import out_dtype from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 from torch.fx import GraphModule -from torch.fx.subgraph_rewriter import replace_pattern +from torch.fx.subgraph_rewriter import replace_pattern_with_filters from torchao.quantization.pt2e.export_utils import WrapperModule from torchao.quantization.pt2e.utils import ( @@ -627,6 +627,7 @@ class _RewriteInfo: # post transformation on the exported pattern and replacement GraphModule pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None + ignore_literals: bool = False def reference_representation_rewrite(model: GraphModule) -> GraphModule: @@ -830,6 +831,6 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule: replacement = replacement_post_trans(replacement) pattern.recompile() # type: ignore[attr-defined] replacement.recompile() # type: ignore[attr-defined] - replace_pattern(model, pattern, replacement) + replace_pattern_with_filters(model, pattern, replacement, match_filters=None, ignore_literals=rewrite_info.ignore_literals) # type: ignore[arg-type] return model From fd04d3b54570096a1dd2748e6dc2059d25a15269 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Fri, 11 Jul 2025 07:55:47 -0700 Subject: [PATCH 2/2] Update on "Allow pattern replacement to ignore literals" Summary: This is necessary because sometimes the patterns found have literals include tuple of ints kind of literals. This values shouldnt be used for pattern matching since often they are based on consts derived from example inputs. THis is not exactly a safe thing to do in general so by default it is turned off Test Plan: Subsequent diff adds a pattern that relies on this Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- .../quantization/pt2e/reference_representation_rewrite.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchao/quantization/pt2e/reference_representation_rewrite.py b/torchao/quantization/pt2e/reference_representation_rewrite.py index cf9738ddc7..8d1875bfd9 100644 --- a/torchao/quantization/pt2e/reference_representation_rewrite.py +++ b/torchao/quantization/pt2e/reference_representation_rewrite.py @@ -831,6 +831,12 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule: replacement = replacement_post_trans(replacement) pattern.recompile() # type: ignore[attr-defined] replacement.recompile() # type: ignore[attr-defined] - replace_pattern_with_filters(model, pattern, replacement, match_filters=None, ignore_literals=rewrite_info.ignore_literals) # type: ignore[arg-type] + replace_pattern_with_filters( + model, + pattern, + replacement, + match_filters=None, + ignore_literals=rewrite_info.ignore_literals, + ) # type: ignore[arg-type] return model