Skip to content

Arm backend: Match fp32->int32 cast between pytorch and TOSA's CAST #12243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa
from .decompose_sum_pass import DecomposeSumPass # noqa
from .decompose_var_pass import DecomposeVarPass # noqa
from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
FoldAndAnnotateQParamsPass,
QuantizeOperatorArguments,
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
DecomposeSqrtPass,
DecomposeSumPass,
DecomposeVarPass,
DecorateFp32toInt32CastingPass,
FoldAndAnnotateQParamsPass,
FuseBatchnorm2DPass,
FuseConstantArgsPass,
Expand Down Expand Up @@ -200,6 +201,9 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeAdaptiveAvgPool2dPass())
self.add_pass(DecomposeAvgPool2d())
self.add_pass(
DecorateFp32toInt32CastingPass()
) # Require that no new fp32->int32 is introduced after this pass
self.add_pass(ComputeConstantOpsAOT(exported_program))

self.add_pass(DecomposeGroupedConv())
Expand Down
78 changes: 78 additions & 0 deletions backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
from executorch.exir.dialects._ops import ops as exir_ops


def _get_decorated_ops(op):
if op in DecorateFp32toInt32CastingPass.targets:
return (
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.ge.Tensor,
exir_ops.edge.aten.floor.default,
exir_ops.edge.aten.ceil.default,
exir_ops.edge.aten.where.self,
)
else:
raise RuntimeError(f"Can't get decorated ops for op {op}")


class DecorateFp32toInt32CastingPass(ArmPass):
"""
To lower pytorch fp32 -> int32 casting to TOSA,
we need to transform the value with Ceil, Floor, and Where.
Before:
output = to_copy(x, dtype=torch.int32)
After:
%zero = full((1,), 0.0, dtype=torch.float32)
is_non_negative = x >= %zero
floor_x = floor(x)
ceil_x = ceil(x)
decorated_x = where(is_non_negative, floor_x, ceil_x)
output = to_copy(decorated_x, dtype=torch.int32)
"""

targets = [
exir_ops.edge.aten._to_copy.default,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
]

def call_operator(self, op, args, kwargs, meta):
if op not in self.targets:
return super().call_operator(op, args, kwargs, meta)

input = get_node_arg(args, 0)
input_dtype = input.node.meta["val"].dtype
output_dtype = meta["val"].dtype

if not (input_dtype == torch.float32 and output_dtype == torch.int32):
return super().call_operator(op, args, kwargs, meta)

op_full, op_ge, op_floor, op_ceil, op_where = _get_decorated_ops(op)

zero = super().call_operator(
op_full,
args=((1,) * len(meta["val"].size()), 0.0),
kwargs={"dtype": torch.float32},
meta=meta,
updated=True,
)

is_non_negative = super().call_operator(
op_ge, (input, zero), {}, meta, updated=True
)
floor_x = super().call_operator(op_floor, (input,), {}, meta, updated=True)
ceil_x = super().call_operator(op_ceil, (input,), {}, meta, updated=True)
decorated_x = super().call_operator(
op_where, (is_non_negative, floor_x, ceil_x), {}, meta, updated=True
)

return super().call_operator(op, (decorated_x,), kwargs, meta, updated=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from executorch.backends.arm.test import common

from executorch.backends.arm.test.tester.test_pipeline import (
OpNotSupportedPipeline,
TosaPipelineMI,
)

input_t1 = Tuple[torch.Tensor] # Input x


class FP32ToINT32Casting(torch.nn.Module):
def __init__(self, target_dtype):
super().__init__()
self.target_dtype = target_dtype

def forward(self, x: torch.Tensor):
return x.to(self.target_dtype)


test_data_fp32_input = {
"fp32_input_rank1": lambda: (
torch.rand((4), dtype=torch.float32),
torch.int32,
),
"fp32_input_rank2": lambda: (
torch.rand((3, 4), dtype=torch.float32),
torch.int32,
),
"fp32_input_rank3": lambda: (
torch.rand((2, 3, 4), dtype=torch.float32),
torch.int32,
),
"fp32_input_rank4": lambda: (
torch.rand((1, 2, 3, 4), dtype=torch.float32),
torch.int32,
),
}


@common.parametrize("test_data", test_data_fp32_input)
def test_decorate_fp32_to_int32_casting_tosa_MI(test_data: Tuple):
test_tensor, target_dtype = test_data()
module = FP32ToINT32Casting(target_dtype)

pipeline = TosaPipelineMI[input_t1](
module,
(test_tensor,),
aten_op=[],
exir_op=[],
)
pipeline.run()


@common.parametrize("test_data", test_data_fp32_input)
def test_decorate_fp32_to_int32_casting_tosa_BI(test_data: Tuple):
"""
Casting operation involving floating-point dtypes will be rejected in BI/INT profile.
Therefore, the DecorateFp32toInt32CastingPass is not required in this profile.
Add a BI test to ensure that such casting is rejected as expected.
"""
test_tensor, target_dtype = test_data()
module = FP32ToINT32Casting(target_dtype)

pipeline = OpNotSupportedPipeline[input_t1](
module,
(test_tensor,),
{
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1
},
quantize=True,
)
pipeline.run()
Loading