|
14 | 14 | from torch._higher_order_ops.out_dtype import out_dtype
|
15 | 15 | from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
16 | 16 | from torch.fx import GraphModule
|
17 |
| -from torch.fx.subgraph_rewriter import replace_pattern |
| 17 | +from torch.fx.subgraph_rewriter import replace_pattern_with_filters |
18 | 18 |
|
19 | 19 | from torchao.quantization.pt2e.export_utils import WrapperModule
|
20 | 20 | from torchao.quantization.pt2e.utils import (
|
@@ -627,6 +627,7 @@ class _RewriteInfo:
|
627 | 627 | # post transformation on the exported pattern and replacement GraphModule
|
628 | 628 | pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
|
629 | 629 | replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
|
| 630 | + ignore_literals: bool = False |
630 | 631 |
|
631 | 632 |
|
632 | 633 | def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
@@ -830,6 +831,12 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
830 | 831 | replacement = replacement_post_trans(replacement)
|
831 | 832 | pattern.recompile() # type: ignore[attr-defined]
|
832 | 833 | replacement.recompile() # type: ignore[attr-defined]
|
833 |
| - replace_pattern(model, pattern, replacement) |
| 834 | + replace_pattern_with_filters( |
| 835 | + model, |
| 836 | + pattern, |
| 837 | + replacement, |
| 838 | + match_filters=None, |
| 839 | + ignore_literals=rewrite_info.ignore_literals, |
| 840 | + ) # type: ignore[arg-type] |
834 | 841 |
|
835 | 842 | return model
|
0 commit comments