diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index e7e298053c6..5875e35cdd7 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -9,6 +9,7 @@ import logging from typing import cast, List, Optional +import numpy as np import torch from executorch.backends.xnnpack.partition.config.xnnpack_config import ( ConfigPrecisionType, @@ -523,6 +524,17 @@ class SubConfig(GenericNodePartitionerConfig): def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + if not self.check_common_constraints(node, ep): + return False + # No support for sub nodes with alpha != 1 + if "alpha" in node.kwargs and not np.isclose( + node.kwargs["alpha"], 1.0, atol=1e-9, rtol=1e-9 + ): + why(node, reason="Sub node doesn't support alpha != 1") + return False + return True + class BMMConfig(GenericNodePartitionerConfig): """ diff --git a/backends/xnnpack/test/ops/test_sub.py b/backends/xnnpack/test/ops/test_sub.py index 06219730ddb..de958e2186e 100644 --- a/backends/xnnpack/test/ops/test_sub.py +++ b/backends/xnnpack/test/ops/test_sub.py @@ -152,3 +152,27 @@ def forward(self, x, y): .serialize() .run_method_and_compare_outputs() ) + + class SubWithAlpha(torch.nn.Module): + def forward(self, x, y): + # node with alpha = 1.0 will be partitioned + out1 = torch.sub(x, y, alpha=1) + # node with alpha != 1.0 will not be partitioned + out2 = torch.sub(x, y, alpha=2) + return out1, out2 + + def test_add_with_alpha(self): + inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4)) + ( + Tester(self.SubWithAlpha(), inputs) + .export() + .check_count({"torch.ops.aten.sub.Tensor": 2}) + .to_edge_transform_and_lower() + # unpartitioned node + .check_count({"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1}) + # partitioned node + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + )