Skip to content

Commit d533a87

Browse files
authored
Arm backend: Add sign decomposition pass and test (#12159)
Decomposes sign into other operators. Signed-off-by: Teo Bergkvist <teo.bergkvist@arm.com>
1 parent 5f70823 commit d533a87

File tree

6 files changed

+165
-0
lines changed

6 files changed

+165
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
4242
from .decompose_round_pass import DecomposeRoundPass # noqa
4343
from .decompose_select import DecomposeSelectPass # noqa
44+
from .decompose_sign_pass import DecomposeSignPass # noqa
4445
from .decompose_silu_pass import DecomposeSiluPass # noqa
4546
from .decompose_sinh_pass import DecomposeSinhPass # noqa
4647
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
DecomposeNotEqualPass,
4545
DecomposeRoundPass,
4646
DecomposeSelectPass,
47+
DecomposeSignPass,
4748
DecomposeSiluPass,
4849
DecomposeSinhPass,
4950
DecomposeSoftmaxPass,
@@ -158,6 +159,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
158159
self.add_pass(ConvertIntPowToMuls())
159160
self.add_pass(CastBoolToInt8Pass())
160161
self.add_pass(DecomposeSinhPass())
162+
self.add_pass(DecomposeSignPass())
161163
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
162164
self.add_pass(DecomposeEmbeddingPass())
163165
self.add_pass(FuseQuantizedActivationPass())
@@ -242,6 +244,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
242244
self.add_pass(DecomposeScaledDotProductAttention())
243245
self.add_pass(DecomposeRoundPass())
244246
self.add_pass(CastBoolToInt8Pass())
247+
self.add_pass(DecomposeSignPass())
245248
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
246249
self.add_pass(ScalarsToAttributePass())
247250
self.add_pass(DecomposeGroupNormPass())
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
8+
from executorch.backends.arm._passes import ArmPass
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
11+
12+
# For MI case
13+
edge_sign = exir_ops.edge.aten.sign.default
14+
# For BI case
15+
aten_sign = torch.ops.aten.sign.default
16+
17+
18+
def get_ops(op):
19+
"""Returns the appropriate operator functions based on the input operator."""
20+
if op == edge_sign:
21+
return (
22+
exir_ops.edge.aten.gt.Scalar,
23+
exir_ops.edge.aten.lt.Scalar,
24+
exir_ops.edge.aten.where.self,
25+
exir_ops.edge.aten.neg.default,
26+
exir_ops.edge.aten.mul.Scalar,
27+
exir_ops.edge.aten.add.Scalar,
28+
)
29+
elif op == aten_sign:
30+
return (
31+
torch.ops.aten.gt.Scalar,
32+
torch.ops.aten.lt.Scalar,
33+
torch.ops.aten.where.self,
34+
torch.ops.aten.neg.default,
35+
torch.ops.aten.mul.Scalar,
36+
torch.ops.aten.add.Scalar,
37+
)
38+
else:
39+
raise ValueError(f"Unsupported operator: {op}")
40+
41+
42+
class DecomposeSignPass(ArmPass):
43+
"""Decomposes the sign operator into a sequence of operations that are supported by the Arm backend."""
44+
45+
def call_operator(self, op, args, kwargs, meta):
46+
if op not in (edge_sign, aten_sign):
47+
return super().call_operator(op, args, kwargs, meta)
48+
49+
gt_op, lt_op, where_op, neg_op, mul_op, add_op = get_ops(op)
50+
51+
x = args[0]
52+
53+
gt_mask = super().call_operator(gt_op, (x, 0.0), {}, meta, updated=True)
54+
lt_mask = super().call_operator(lt_op, (x, 0.0), {}, meta, updated=True)
55+
56+
zeros = super().call_operator(mul_op, (x, 0.0), {}, meta, updated=True)
57+
ones = super().call_operator(add_op, (zeros, 1.0), {}, meta, updated=True)
58+
neg_ones = super().call_operator(neg_op, (ones,), {}, meta, updated=True)
59+
60+
negative_tensor = super().call_operator(
61+
where_op, (lt_mask, neg_ones, zeros), {}, meta, updated=True
62+
)
63+
positive_tensor = super().call_operator(
64+
where_op, (gt_mask, ones, zeros), {}, meta, updated=True
65+
)
66+
67+
return super().call_operator(
68+
where_op,
69+
(lt_mask, negative_tensor, positive_tensor),
70+
{},
71+
meta,
72+
updated=True,
73+
)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def is_node_supported(
249249
exir_ops.edge.aten.sinh.default,
250250
exir_ops.edge.aten.atan.default,
251251
exir_ops.edge.aten.acosh.default,
252+
exir_ops.edge.aten.sign.default,
252253
]
253254

254255
return supported

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def _match_pattern(
216216
torch.ops.aten.sinh.default,
217217
torch.ops.aten.atan.default,
218218
torch.ops.aten.acosh.default,
219+
torch.ops.aten.sign.default,
219220
]
220221

221222
_one_to_one_shared_input_qspec = [

backends/arm/test/ops/test_sign.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import pytest
9+
import torch
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineBI,
13+
EthosU85PipelineBI,
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
17+
18+
aten_op = "torch.ops.aten.sign.default"
19+
exir_op = "executorch_exir_dialects_edge__ops_aten__sign_default"
20+
21+
input_t1 = Tuple[torch.Tensor]
22+
23+
test_data_suite = {
24+
"zeros": torch.zeros(3, 5),
25+
"ones": torch.ones(4, 4),
26+
"neg_ones": -torch.ones(4, 4),
27+
"mixed_signs": torch.tensor([[-2.0, -1.0, 0.0, 1.0, 2.0]]),
28+
"positive_ramp": torch.arange(0.1, 1.1, 0.2),
29+
"negative_ramp": torch.arange(-1.0, -0.1, 0.2),
30+
"small_values": torch.tensor([-1e-7, 0.0, 1e-7]),
31+
"rand": torch.rand(10, 10) - 0.5,
32+
"rand_alt_shape": torch.rand(10, 3, 5) - 0.5,
33+
"high_magnitude": torch.tensor([-1e6, -10.0, 0.0, 10.0, 1e6]),
34+
}
35+
36+
37+
class Sign(torch.nn.Module):
38+
def forward(self, x: torch.Tensor):
39+
return torch.sign(x)
40+
41+
42+
@common.parametrize("test_data", test_data_suite)
43+
def test_sign_tosa_MI(test_data: Tuple):
44+
pipeline = TosaPipelineMI[input_t1](
45+
Sign(),
46+
(test_data,),
47+
aten_op=aten_op,
48+
exir_op=exir_op,
49+
)
50+
pipeline.run()
51+
52+
53+
@common.parametrize("test_data", test_data_suite)
54+
def test_sign_tosa_BI(test_data: Tuple):
55+
pipeline = TosaPipelineBI[input_t1](
56+
Sign(),
57+
(test_data,),
58+
aten_op=[],
59+
exir_op=exir_op,
60+
)
61+
pipeline.run()
62+
63+
64+
@common.XfailIfNoCorstone300
65+
@common.parametrize("test_data", test_data_suite)
66+
@pytest.mark.xfail(reason="where.self not supported on U55")
67+
def test_sign_u55_BI(test_data: Tuple):
68+
pipeline = EthosU55PipelineBI[input_t1](
69+
Sign(),
70+
(test_data,),
71+
aten_ops=[],
72+
exir_ops=exir_op,
73+
)
74+
pipeline.run()
75+
76+
77+
@common.XfailIfNoCorstone320
78+
@common.parametrize("test_data", test_data_suite)
79+
def test_sign_u85_BI(test_data: Tuple):
80+
pipeline = EthosU85PipelineBI[input_t1](
81+
Sign(),
82+
(test_data,),
83+
aten_ops=[],
84+
exir_ops=exir_op,
85+
)
86+
pipeline.run()

0 commit comments

Comments
 (0)