Skip to content

Commit 97047c0

Browse files
authored
Arm backend: Add decomposition pass and test for asin (#12241)
Add decomposition pass and test for asin Signed-off-by: Emma Kujala <emma.kujala@arm.com>
1 parent 487da8c commit 97047c0

File tree

7 files changed

+287
-0
lines changed

7 files changed

+287
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .convert_to_clamp import ConvertToClampPass # noqa
2525
from .decompose_acosh_pass import DecomposeAcoshPass # noqa
2626
from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa
27+
from .decompose_asin_pass import DecomposeAsinPass # noqa
2728
from .decompose_atan_pass import DecomposeAtanPass # noqa
2829
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa
2930
from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ConvertToClampPass,
3030
DecomposeAcoshPass,
3131
DecomposeAdaptiveAvgPool2dPass,
32+
DecomposeAsinPass,
3233
DecomposeAtanPass,
3334
DecomposeAvgPool2d,
3435
DecomposeBatchNormNoStatsPass,
@@ -158,6 +159,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
158159
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
159160
self.add_pass(DecomposeRoundPass())
160161
self.add_pass(DecomposeAcoshPass())
162+
self.add_pass(DecomposeAsinPass())
161163
self.add_pass(DecomposeSqrtPass())
162164
self.add_pass(DecomposeAtanPass())
163165
self.add_pass(ConvertIntPowToMuls())
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import logging
9+
from math import pi
10+
11+
import torch
12+
13+
from executorch.backends.arm._passes import ArmPass
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
16+
# For MI case
17+
edge_asin_op = (exir_ops.edge.aten.asin.default,)
18+
19+
20+
def get_asin_decomposition(op) -> tuple:
21+
if op in edge_asin_op:
22+
return (
23+
exir_ops.edge.aten.mul.Tensor,
24+
exir_ops.edge.aten.add.Tensor,
25+
exir_ops.edge.aten.mul.Scalar,
26+
exir_ops.edge.aten.sqrt.default,
27+
exir_ops.edge.aten.abs.default,
28+
exir_ops.edge.aten.sub.Scalar,
29+
exir_ops.edge.aten.div.Tensor,
30+
exir_ops.edge.aten.gt.Scalar,
31+
exir_ops.edge.aten.lt.Scalar,
32+
exir_ops.edge.aten.sub.Tensor,
33+
exir_ops.edge.aten.full_like.default,
34+
exir_ops.edge.aten.where.self,
35+
exir_ops.edge.aten.neg.default,
36+
)
37+
38+
raise RuntimeError(f"Can't get asin decomposition for op {op}")
39+
40+
41+
class DecomposeAsinPass(ArmPass):
42+
"""
43+
This pass decomposes asin into a rational approximation for small values
44+
and a transformed rational approximation for large values.
45+
Example:
46+
y = asin(x)
47+
Becomes:
48+
if abs(x) < 0.5:
49+
y = x + P(x^2) / Q(x^2)
50+
else:
51+
y = π/2 - 2 * (s + s^3 * Q(z) / P(z))
52+
where P and Q are polynomials defined in the function.
53+
"""
54+
55+
def _build_polynomial(
56+
self, coefficients: list[float], variable: torch.Tensor, meta: dict[str, str]
57+
) -> torch.Tensor:
58+
"""
59+
Helper function to build polynomial from coefficients and variable.
60+
"""
61+
full_like_op, add_op, mul_op_scalar, mul_op = (
62+
exir_ops.edge.aten.full_like.default,
63+
exir_ops.edge.aten.add.Tensor,
64+
exir_ops.edge.aten.mul.Scalar,
65+
exir_ops.edge.aten.mul.Tensor,
66+
)
67+
result = super().call_operator(
68+
full_like_op, (variable, coefficients[0]), {}, meta, True
69+
)
70+
for coeff in coefficients[1:]:
71+
result = super().call_operator(
72+
add_op,
73+
(
74+
result,
75+
super().call_operator(
76+
mul_op_scalar, (variable, coeff), {}, meta, True
77+
),
78+
),
79+
{},
80+
meta,
81+
)
82+
variable = super().call_operator(
83+
mul_op, (variable, variable), {}, meta, True
84+
)
85+
return result
86+
87+
def call_operator(self, op, args, kwargs, meta):
88+
logging.info(
89+
f"Approximating asin. This may introduce small numerical errors. For details, see {__file__}."
90+
)
91+
if op not in edge_asin_op:
92+
return super().call_operator(op, args, kwargs, meta)
93+
94+
x = args[0]
95+
half = 0.5
96+
one = 1.0
97+
neg_half = -0.5
98+
two = 2.0
99+
pi_over_2 = pi / 2.0
100+
zero = 0.0
101+
neg_one = -1.0
102+
103+
(
104+
mul_op,
105+
add_op,
106+
mul_op_scalar,
107+
sqrt_op,
108+
abs_op,
109+
sub_op_scalar,
110+
div_op,
111+
gt_op,
112+
lt_op,
113+
sub_op,
114+
full_like_op,
115+
where_op,
116+
neg_op,
117+
) = get_asin_decomposition(op)
118+
119+
# Coefficients for the rational approximation, calculated with the Minimax (Remez) method
120+
p_coefficients = [
121+
1.6666667163e-01,
122+
-3.2556581497e-01,
123+
2.0121252537e-01,
124+
-4.0055535734e-02,
125+
7.9153501429e-04,
126+
]
127+
128+
q_coefficients = [1.0, -2.4033949375e00, 2.0209457874e00, -6.8828397989e-01]
129+
130+
x_abs = super().call_operator(abs_op, (x,), {}, meta, True)
131+
132+
# Step 1: compute asin_small - rational approximation for [0,0.5]
133+
134+
y = super().call_operator(mul_op, (x_abs, x_abs), {}, meta, True)
135+
x3 = super().call_operator(mul_op, (x_abs, y), {}, meta, True)
136+
137+
P = self._build_polynomial(p_coefficients, x_abs, meta)
138+
Q = self._build_polynomial(q_coefficients, x_abs, meta)
139+
numer = super().call_operator(mul_op, (x3, P), {}, meta, True)
140+
r_small = super().call_operator(div_op, (numer, Q), {}, meta, True)
141+
asin_small = super().call_operator(add_op, (x_abs, r_small), {}, meta, True)
142+
143+
# Step 2: Compute the transformed approximation for large values
144+
# Calculate z = -0.5 * (|x| - 1)
145+
tmp_ones = super().call_operator(full_like_op, (x_abs, one), {}, meta, True)
146+
tmp = super().call_operator(sub_op, (x_abs, tmp_ones), {}, meta, True)
147+
z = super().call_operator(mul_op_scalar, (tmp, neg_half), {}, meta, True)
148+
149+
# Calculate s-terms
150+
s = super().call_operator(sqrt_op, (z,), {}, meta, True)
151+
s2 = super().call_operator(mul_op, (s, s), {}, meta, True)
152+
s3 = super().call_operator(mul_op, (s2, s), {}, meta, True)
153+
154+
Pz = self._build_polynomial(p_coefficients, z, meta)
155+
Qz = self._build_polynomial(q_coefficients, z, meta)
156+
157+
numer = super().call_operator(mul_op, (s3, Pz), {}, meta, True)
158+
# Calculate r_large = P(z) / Q(z)
159+
r_large = super().call_operator(div_op, (numer, Qz), {}, meta, True)
160+
161+
# Calculate asin_large = pi/2 - 2 * (s + s^3 * Q(z) / P(z))
162+
t1 = super().call_operator(add_op, (s, r_large), {}, meta, True)
163+
t2 = super().call_operator(mul_op_scalar, (t1, two), {}, meta, True)
164+
diff = super().call_operator(sub_op_scalar, (t2, pi_over_2), {}, meta, True)
165+
tmp_neg_ones = super().call_operator(
166+
full_like_op, (diff, neg_one), {}, meta, True
167+
)
168+
asin_large = super().call_operator(mul_op, (diff, tmp_neg_ones), {}, meta, True)
169+
170+
# Combine branches
171+
is_large = super().call_operator(gt_op, (x_abs, half), {}, meta, True)
172+
asin_unsigned = super().call_operator(
173+
where_op,
174+
(
175+
is_large,
176+
asin_large,
177+
asin_small,
178+
),
179+
{},
180+
meta,
181+
True,
182+
)
183+
184+
# Handle x < 0
185+
is_neg = super().call_operator(lt_op, (x, zero), {}, meta, True)
186+
# Compute -asin_unsigned
187+
negated_asin = super().call_operator(neg_op, (asin_unsigned,), {}, meta, True)
188+
# Combine branches for signed asin
189+
asin_signed = super().call_operator(
190+
where_op,
191+
(
192+
is_neg,
193+
negated_asin,
194+
asin_unsigned,
195+
),
196+
{},
197+
meta,
198+
True,
199+
)
200+
201+
return asin_signed

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class TableOps:
5656
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
5757
exir_ops.edge.aten.sinh.default: torch.sinh,
5858
exir_ops.edge.aten.acosh.default: torch.acosh,
59+
exir_ops.edge.aten.asin.default: torch.asin,
5960
}
6061

6162
# Targets that must be treated explicitly

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def is_node_supported(
251251
exir_ops.edge.aten.acosh.default,
252252
exir_ops.edge.aten._adaptive_avg_pool2d.default,
253253
exir_ops.edge.aten.sign.default,
254+
exir_ops.edge.aten.asin.default,
254255
]
255256

256257
return supported

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def _match_pattern(
217217
torch.ops.aten.atan.default,
218218
torch.ops.aten.acosh.default,
219219
torch.ops.aten.sign.default,
220+
torch.ops.aten.asin.default,
220221
]
221222

222223
_one_to_one_shared_input_qspec = [

backends/arm/test/ops/test_asin.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineBI,
13+
EthosU85PipelineBI,
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
17+
18+
input_t = Tuple[torch.Tensor] # Input x
19+
aten_op = "torch.ops.aten.asin.default"
20+
21+
test_data_suite = {
22+
"zeros": lambda: torch.zeros(1, 5, 3, 2), # valid: asin(0) = 0
23+
"ones": lambda: torch.ones(10, 5, 15), # edge case: asin(1) = pi/2
24+
"neg_ones": lambda: -torch.ones(10, 5, 15), # edge case: asin(-1) = -pi/2
25+
"rand": lambda: (torch.rand(10, 10, 5) * 2) - 1, # uniform random in [-1, 1]
26+
"ramp": lambda: torch.linspace(-1.0, 1.0, steps=160), # full domain coverage
27+
"near_bounds": lambda: torch.tensor(
28+
[-0.999, -0.9, -0.5, 0.0, 0.5, 0.9, 0.999]
29+
), # precision edge values
30+
"pos_rand": lambda: torch.rand(7, 10, 2), # positive random values in [0, 1]
31+
}
32+
33+
34+
class Asin(torch.nn.Module):
35+
def forward(self, x):
36+
return torch.asin(x)
37+
38+
39+
@common.parametrize("test_data", test_data_suite)
40+
def test_asin_tosa_MI(test_data: Tuple):
41+
pipeline = TosaPipelineMI[input_t](
42+
Asin(),
43+
(test_data(),),
44+
aten_op,
45+
exir_op=[],
46+
)
47+
pipeline.run()
48+
49+
50+
@common.parametrize("test_data", test_data_suite)
51+
def test_asin_tosa_BI(test_data: Tuple):
52+
pipeline = TosaPipelineBI[input_t](
53+
Asin(),
54+
(test_data(),),
55+
aten_op=[],
56+
exir_op=[],
57+
)
58+
pipeline.run()
59+
60+
61+
@common.parametrize("test_data", test_data_suite)
62+
@common.XfailIfNoCorstone300
63+
def test_asin_u55_BI(test_data: Tuple):
64+
pipeline = EthosU55PipelineBI[input_t](
65+
Asin(),
66+
(test_data(),),
67+
aten_ops=[],
68+
)
69+
pipeline.run()
70+
71+
72+
@common.parametrize("test_data", test_data_suite)
73+
@common.XfailIfNoCorstone320
74+
def test_asin_u85_BI(test_data: Tuple):
75+
pipeline = EthosU85PipelineBI[input_t](
76+
Asin(),
77+
(test_data(),),
78+
aten_ops=[],
79+
)
80+
pipeline.run()

0 commit comments

Comments
 (0)