From 1c29cd2564b3e4cb20c94f329ae14b6391bfc0ef Mon Sep 17 00:00:00 2001 From: Yufeng Shi Date: Mon, 23 Jun 2025 17:14:26 +0100 Subject: [PATCH] Arm backend: Match fp32->int32 cast between pytorch and TOSA's CAST - Pytorch .to() truncates values toward zeros, while TOSA's CAST rounds values to the nearest int. They will show different behaviours on fp32->int32 casting. - Fix the mismatched behaviour with the following transformation: Before: output = to_copy(x, torch.int32) After: is_non_negative = x >= 0 floor_x = floor(x) ceil_x = ceil(x) decorated_x = where(is_non_negative, floor_x, ceil_x) output = to_copy(decorated_x, torch.int32) Change-Id: I21286432eeb0a5e2f21865f3ac097051c921a9b3 Signed-off-by: Yufeng Shi --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 4 + .../decorate_fp32_to_int32_casting_pass.py | 78 ++++++++++++++++++ ...est_decorate_fp32_to_int32_casting_pass.py | 80 +++++++++++++++++++ 4 files changed, 163 insertions(+) create mode 100644 backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py create mode 100644 backends/arm/test/passes/test_decorate_fp32_to_int32_casting_pass.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index f9efa898331..b15aaa9bfc7 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -48,6 +48,7 @@ from .decompose_sqrt_pass import DecomposeSqrtPass # noqa from .decompose_sum_pass import DecomposeSumPass # noqa from .decompose_var_pass import DecomposeVarPass # noqa +from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa from .fold_qdq_with_annotated_qparams_pass import ( # noqa FoldAndAnnotateQParamsPass, QuantizeOperatorArguments, diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index f4a8af27ff8..80d712575a8 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -51,6 +51,7 @@ DecomposeSqrtPass, DecomposeSumPass, DecomposeVarPass, + DecorateFp32toInt32CastingPass, FoldAndAnnotateQParamsPass, FuseBatchnorm2DPass, FuseConstantArgsPass, @@ -191,6 +192,9 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(MatchArgRanksPass(exported_program)) self.add_pass(DecomposeAvgPool2d()) + self.add_pass( + DecorateFp32toInt32CastingPass() + ) # Require that no new fp32->int32 is introduced after this pass self.add_pass(ComputeConstantOpsAOT(exported_program)) self.add_pass(DecomposeGroupedConv()) diff --git a/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py b/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py new file mode 100644 index 00000000000..d6f7ac2ceac --- /dev/null +++ b/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py @@ -0,0 +1,78 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import get_node_arg +from executorch.exir.dialects._ops import ops as exir_ops + + +def _get_decorated_ops(op): + if op in DecorateFp32toInt32CastingPass.targets: + return ( + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.floor.default, + exir_ops.edge.aten.ceil.default, + exir_ops.edge.aten.where.self, + ) + else: + raise RuntimeError(f"Can't get decorated ops for op {op}") + + +class DecorateFp32toInt32CastingPass(ArmPass): + """ + To lower pytorch fp32 -> int32 casting to TOSA, + we need to transform the value with Ceil, Floor, and Where. + Before: + output = to_copy(x, dtype=torch.int32) + After: + %zero = full((1,), 0.0, dtype=torch.float32) + is_non_negative = x >= %zero + floor_x = floor(x) + ceil_x = ceil(x) + decorated_x = where(is_non_negative, floor_x, ceil_x) + output = to_copy(decorated_x, dtype=torch.int32) + """ + + targets = [ + exir_ops.edge.aten._to_copy.default, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + ] + + def call_operator(self, op, args, kwargs, meta): + if op not in self.targets: + return super().call_operator(op, args, kwargs, meta) + + input = get_node_arg(args, 0) + input_dtype = input.node.meta["val"].dtype + output_dtype = meta["val"].dtype + + if not (input_dtype == torch.float32 and output_dtype == torch.int32): + return super().call_operator(op, args, kwargs, meta) + + op_full, op_ge, op_floor, op_ceil, op_where = _get_decorated_ops(op) + + zero = super().call_operator( + op_full, + args=((1,) * len(meta["val"].size()), 0.0), + kwargs={"dtype": torch.float32}, + meta=meta, + updated=True, + ) + + is_non_negative = super().call_operator( + op_ge, (input, zero), {}, meta, updated=True + ) + floor_x = super().call_operator(op_floor, (input,), {}, meta, updated=True) + ceil_x = super().call_operator(op_ceil, (input,), {}, meta, updated=True) + decorated_x = super().call_operator( + op_where, (is_non_negative, floor_x, ceil_x), {}, meta, updated=True + ) + + return super().call_operator(op, (decorated_x,), kwargs, meta, updated=True) diff --git a/backends/arm/test/passes/test_decorate_fp32_to_int32_casting_pass.py b/backends/arm/test/passes/test_decorate_fp32_to_int32_casting_pass.py new file mode 100644 index 00000000000..25312b89748 --- /dev/null +++ b/backends/arm/test/passes/test_decorate_fp32_to_int32_casting_pass.py @@ -0,0 +1,80 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import ( + OpNotSupportedPipeline, + TosaPipelineMI, +) + +input_t1 = Tuple[torch.Tensor] # Input x + + +class FP32ToINT32Casting(torch.nn.Module): + def __init__(self, target_dtype): + super().__init__() + self.target_dtype = target_dtype + + def forward(self, x: torch.Tensor): + return x.to(self.target_dtype) + + +test_data_fp32_input = { + "fp32_input_rank1": lambda: ( + torch.rand((4), dtype=torch.float32), + torch.int32, + ), + "fp32_input_rank2": lambda: ( + torch.rand((3, 4), dtype=torch.float32), + torch.int32, + ), + "fp32_input_rank3": lambda: ( + torch.rand((2, 3, 4), dtype=torch.float32), + torch.int32, + ), + "fp32_input_rank4": lambda: ( + torch.rand((1, 2, 3, 4), dtype=torch.float32), + torch.int32, + ), +} + + +@common.parametrize("test_data", test_data_fp32_input) +def test_decorate_fp32_to_int32_casting_tosa_MI(test_data: Tuple): + test_tensor, target_dtype = test_data() + module = FP32ToINT32Casting(target_dtype) + + pipeline = TosaPipelineMI[input_t1]( + module, + (test_tensor,), + aten_op=[], + exir_op=[], + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_fp32_input) +def test_decorate_fp32_to_int32_casting_tosa_BI(test_data: Tuple): + """ + Casting operation involving floating-point dtypes will be rejected in BI/INT profile. + Therefore, the DecorateFp32toInt32CastingPass is not required in this profile. + Add a BI test to ensure that such casting is rejected as expected. + """ + test_tensor, target_dtype = test_data() + module = FP32ToINT32Casting(target_dtype) + + pipeline = OpNotSupportedPipeline[input_t1]( + module, + (test_tensor,), + { + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1 + }, + quantize=True, + ) + pipeline.run()