diff --git a/torchao/quantization/pt2e/reference_representation_rewrite.py b/torchao/quantization/pt2e/reference_representation_rewrite.py index 6526c6044f..8d1875bfd9 100644 --- a/torchao/quantization/pt2e/reference_representation_rewrite.py +++ b/torchao/quantization/pt2e/reference_representation_rewrite.py @@ -14,7 +14,7 @@ from torch._higher_order_ops.out_dtype import out_dtype from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 from torch.fx import GraphModule -from torch.fx.subgraph_rewriter import replace_pattern +from torch.fx.subgraph_rewriter import replace_pattern_with_filters from torchao.quantization.pt2e.export_utils import WrapperModule from torchao.quantization.pt2e.utils import ( @@ -627,6 +627,7 @@ class _RewriteInfo: # post transformation on the exported pattern and replacement GraphModule pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None + ignore_literals: bool = False def reference_representation_rewrite(model: GraphModule) -> GraphModule: @@ -830,6 +831,12 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule: replacement = replacement_post_trans(replacement) pattern.recompile() # type: ignore[attr-defined] replacement.recompile() # type: ignore[attr-defined] - replace_pattern(model, pattern, replacement) + replace_pattern_with_filters( + model, + pattern, + replacement, + match_filters=None, + ignore_literals=rewrite_info.ignore_literals, + ) # type: ignore[arg-type] return model