diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index e7e298053c6..7c6d52a2a79 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -9,6 +9,7 @@ import logging from typing import cast, List, Optional +import numpy as np import torch from executorch.backends.xnnpack.partition.config.xnnpack_config import ( ConfigPrecisionType, @@ -106,6 +107,17 @@ def __init__(self, **kwargs): def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + if not self.check_common_constraints(node, ep): + return False + # No support for add nodes with alpha != 1 + if "alpha" in node.kwargs and not np.isclose( + node.kwargs["alpha"], 1.0, atol=1e-9, rtol=1e-9 + ): + why(node, reason="Add node doesn't support alpha != 1") + return False + return True + class ReLUConfig(GenericNodePartitionerConfig): target_name = "relu.default" diff --git a/backends/xnnpack/test/ops/test_add.py b/backends/xnnpack/test/ops/test_add.py index 2416879f5ce..4fbb99e6696 100644 --- a/backends/xnnpack/test/ops/test_add.py +++ b/backends/xnnpack/test/ops/test_add.py @@ -240,3 +240,27 @@ def forward(self, x, z): .serialize() .run_method_and_compare_outputs() ) + + class AddWithAlpha(torch.nn.Module): + def forward(self, x, y): + # node with alpha = 1.0 will be partitioned + out1 = torch.add(x, y, alpha=1) + # node with alpha != 1.0 will not be partitioned + out2 = torch.add(x, y, alpha=2) + return out1, out2 + + def test_add_with_alpha(self): + inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4)) + ( + Tester(self.AddWithAlpha(), inputs) + .export() + .check_count({"torch.ops.aten.add.Tensor": 2}) + .to_edge_transform_and_lower() + # unpartitioned node + .check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1}) + # partitioned node + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + )