diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 440938fd49..db6c4faeb4 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -39,6 +39,7 @@ from .decompose_round_pass import DecomposeRoundPass # noqa from .decompose_select import DecomposeSelectPass # noqa from .decompose_silu_pass import DecomposeSiluPass # noqa +from .decompose_sinh_pass import DecomposeSinhPass # noqa from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa from .decompose_sqrt_pass import DecomposeSqrtPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 07a4416cd7..199c24c248 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -42,6 +42,7 @@ DecomposeRoundPass, DecomposeSelectPass, DecomposeSiluPass, + DecomposeSinhPass, DecomposeSoftmaxPass, DecomposeSoftmaxUnstablePass, DecomposeSqrtPass, @@ -151,6 +152,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeSqrtPass()) self.add_pass(ConvertIntPowToMuls()) self.add_pass(CastBoolToInt8Pass()) + self.add_pass(DecomposeSinhPass()) self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI()) self.add_pass(DecomposeEmbeddingPass()) self.add_pass(FuseQuantizedActivationPass()) diff --git a/backends/arm/_passes/decompose_sinh_pass.py b/backends/arm/_passes/decompose_sinh_pass.py new file mode 100644 index 0000000000..7192eb9bf7 --- /dev/null +++ b/backends/arm/_passes/decompose_sinh_pass.py @@ -0,0 +1,53 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops + + +# For MI case +edge_sinh = exir_ops.edge.aten.sinh.default + + +class DecomposeSinhPass(ArmPass): + """ + A decomposition pass that decomposes Sinh operations into a + combination of supported TOSA-equivalent operations (MI). + + Supported input ops: + - exir_ops.edge.aten.sinh.default + + These are decomposed into exponentials, negation, subtraction, + and scalar multiplication. + """ + + def call_operator(self, op, args, kwargs, meta): + if op is not edge_sinh: + return super().call_operator(op, args, kwargs, meta) + + x = args + + sub_op, exp_op, neg_op, mul_op = ( + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.exp.default, + exir_ops.edge.aten.neg.default, + exir_ops.edge.aten.mul.Scalar, + ) + + # Exponential 1 + exp1 = super().call_operator(exp_op, x, {}, meta, updated=True) + + # Exponential 2 + neg_x = super().call_operator(neg_op, x, {}, meta, updated=True) + exp2 = super().call_operator(exp_op, (neg_x,), {}, meta, updated=True) + + # Subtraction + sub = super().call_operator(sub_op, (exp1, exp2), {}, meta, updated=True) + + # Multiplication + out = super().call_operator(mul_op, (sub, 0.5), {}, meta, updated=True) + + return out diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 402ed0253c..c579fcb030 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -53,6 +53,7 @@ class TableOps: exir_ops.edge.aten.tanh.default: torch.tanh, exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid, exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish, + exir_ops.edge.aten.sinh.default: torch.sinh, } # Targets that must be treated explicitly diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 7a893acaf8..639df53610 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -243,6 +243,7 @@ def is_node_supported( torch.ops.aten.scalar_tensor.default, exir_ops.edge.aten.gelu.default, exir_ops.edge.aten.alias_copy.default, + exir_ops.edge.aten.sinh.default, ] return supported diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 83a648c7d8..c6415c6377 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -213,6 +213,7 @@ def _match_pattern( torch.ops.aten.full_like.default, torch.ops.aten.pow.Tensor_Scalar, torch.ops.aten.gelu.default, + torch.ops.aten.sinh.default, ] _one_to_one_shared_input_qspec = [ diff --git a/backends/arm/test/ops/test_sinh.py b/backends/arm/test/ops/test_sinh.py new file mode 100644 index 0000000000..fd6cbf2b65 --- /dev/null +++ b/backends/arm/test/ops/test_sinh.py @@ -0,0 +1,78 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +aten_op = "torch.ops.aten.sinh.default" +exir_op = "executorch_exir_dialects_edge__ops_aten__sinh_default" + + +input_t1 = Tuple[torch.Tensor] # Input x + +test_data_suite = { + # (test_name, test_data) + "zeros": torch.zeros(10, 10, 10), + "zeros_alt_shape": torch.zeros(10, 3, 5), + "ones": torch.ones(10, 10, 10), + "rand": torch.rand(10, 10) - 0.5, + "rand_alt_shape": torch.rand(10, 3, 5) - 0.5, + "randn_pos": torch.randn(10) + 10, + "randn_neg": torch.randn(10) - 10, + "ramp": torch.arange(-16, 16, 0.2), + "large": 100 * torch.ones(1, 1), + "small": 0.000001 * torch.ones(1, 1), +} + + +class Sinh(torch.nn.Module): + + def forward(self, x: torch.Tensor): + return torch.sinh(x) + + +@common.parametrize("test_data", test_data_suite) +def test_sinh_tosa_MI(test_data: Tuple): + pipeline = TosaPipelineMI[input_t1]( + Sinh(), + (test_data,), + aten_op, + exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_sinh_tosa_BI(test_data: Tuple): + pipeline = TosaPipelineBI[input_t1]( + Sinh(), (test_data,), aten_op=aten_op, exir_op=exir_op + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 +@common.parametrize("test_data", test_data_suite) +def test_sinh_u55_BI(test_data: Tuple): + pipeline = EthosU55PipelineBI[input_t1]( + Sinh(), (test_data,), aten_ops=aten_op, exir_ops=exir_op + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_data", test_data_suite) +def test_sinh_u85_BI(test_data: Tuple): + pipeline = EthosU85PipelineBI[input_t1]( + Sinh(), (test_data,), aten_ops=aten_op, exir_ops=exir_op + ) + pipeline.run()