Skip to content

Arm backend: Add sinh decomposition pass and test #11848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .decompose_round_pass import DecomposeRoundPass # noqa
from .decompose_select import DecomposeSelectPass # noqa
from .decompose_silu_pass import DecomposeSiluPass # noqa
from .decompose_sinh_pass import DecomposeSinhPass # noqa
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
DecomposeRoundPass,
DecomposeSelectPass,
DecomposeSiluPass,
DecomposeSinhPass,
DecomposeSoftmaxPass,
DecomposeSoftmaxUnstablePass,
DecomposeSqrtPass,
Expand Down Expand Up @@ -146,6 +147,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(DecomposeRoundPass())
self.add_pass(DecomposeSqrtPass())
self.add_pass(ConvertIntPowToMuls())
self.add_pass(DecomposeSinhPass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
self.add_pass(DecomposeEmbeddingPass())
self.add_pass(FuseQuantizedActivationPass())
Expand Down
53 changes: 53 additions & 0 deletions backends/arm/_passes/decompose_sinh_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops


# For MI case
edge_sinh = exir_ops.edge.aten.sinh.default


class DecomposeSinhPass(ArmPass):
"""
A decomposition pass that decomposes Sinh operations into a
combination of supported TOSA-equivalent operations (MI).

Supported input ops:
- exir_ops.edge.aten.sinh.default

These are decomposed into exponentials, negation, subtraction,
and scalar multiplication.
"""

def call_operator(self, op, args, kwargs, meta):
if op is not edge_sinh:
return super().call_operator(op, args, kwargs, meta)

x = args

sub_op, exp_op, neg_op, mul_op = (
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.neg.default,
exir_ops.edge.aten.mul.Scalar,
)

# Exponential 1
exp1 = super().call_operator(exp_op, x, {}, meta, updated=True)

# Exponential 2
neg_x = super().call_operator(neg_op, x, {}, meta, updated=True)
exp2 = super().call_operator(exp_op, (neg_x,), {}, meta, updated=True)

# Subtraction
sub = super().call_operator(sub_op, (exp1, exp2), {}, meta, updated=True)

# Multiplication
out = super().call_operator(mul_op, (sub, 0.5), {}, meta, updated=True)

return out
1 change: 1 addition & 0 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class TableOps:
exir_ops.edge.aten.tanh.default: torch.tanh,
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
exir_ops.edge.aten.sinh.default: torch.sinh,
}

# Targets that must be treated explicitly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def is_node_supported(
torch.ops.aten.scalar_tensor.default,
exir_ops.edge.aten.gelu.default,
exir_ops.edge.aten.alias_copy.default,
exir_ops.edge.aten.sinh.default,
]

return supported
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def _match_pattern(
torch.ops.aten.full_like.default,
torch.ops.aten.pow.Tensor_Scalar,
torch.ops.aten.gelu.default,
torch.ops.aten.sinh.default,
]

_one_to_one_shared_input_qspec = [
Expand Down
78 changes: 78 additions & 0 deletions backends/arm/test/ops/test_sinh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineBI,
EthosU85PipelineBI,
TosaPipelineBI,
TosaPipelineMI,
)

aten_op = "torch.ops.aten.sinh.default"
exir_op = "executorch_exir_dialects_edge__ops_aten__sinh_default"


input_t1 = Tuple[torch.Tensor] # Input x

test_data_suite = {
# (test_name, test_data)
"zeros": torch.zeros(10, 10, 10),
"zeros_alt_shape": torch.zeros(10, 3, 5),
"ones": torch.ones(10, 10, 10),
"rand": torch.rand(10, 10) - 0.5,
"rand_alt_shape": torch.rand(10, 3, 5) - 0.5,
"randn_pos": torch.randn(10) + 10,
"randn_neg": torch.randn(10) - 10,
"ramp": torch.arange(-16, 16, 0.2),
"large": 100 * torch.ones(1, 1),
"small": 0.000001 * torch.ones(1, 1),
}


class Sinh(torch.nn.Module):

def forward(self, x: torch.Tensor):
return torch.sinh(x)


@common.parametrize("test_data", test_data_suite)
def test_sinh_tosa_MI(test_data: Tuple):
pipeline = TosaPipelineMI[input_t1](
Sinh(),
(test_data,),
aten_op,
exir_op,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
def test_sinh_tosa_BI(test_data: Tuple):
pipeline = TosaPipelineBI[input_t1](
Sinh(), (test_data,), aten_op=aten_op, exir_op=exir_op
)
pipeline.run()


@common.XfailIfNoCorstone300
@common.parametrize("test_data", test_data_suite)
def test_sinh_u55_BI(test_data: Tuple):
pipeline = EthosU55PipelineBI[input_t1](
Sinh(), (test_data,), aten_ops=aten_op, exir_ops=exir_op
)
pipeline.run()


@common.XfailIfNoCorstone320
@common.parametrize("test_data", test_data_suite)
def test_sinh_u85_BI(test_data: Tuple):
pipeline = EthosU85PipelineBI[input_t1](
Sinh(), (test_data,), aten_ops=aten_op, exir_ops=exir_op
)
pipeline.run()
Loading