Skip to content

Commit 1315388

Browse files
Arm backend: Match fp32->int32 cast between pytorch and TOSA's CAST (#12243)
- Pytorch .to() truncates values toward zeros, while TOSA's CAST rounds values to the nearest int. They will show different behaviours on fp32->int32 casting. - Fix the mismatched behaviour with the following transformation: Before: output = to_copy(x, torch.int32) After: is_non_negative = x >= 0 floor_x = floor(x) ceil_x = ceil(x) decorated_x = where(is_non_negative, floor_x, ceil_x) output = to_copy(decorated_x, torch.int32) Change-Id: I21286432eeb0a5e2f21865f3ac097051c921a9b3 cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 Signed-off-by: Yufeng Shi <yufeng.shi@arm.com>
1 parent ba19c75 commit 1315388

File tree

4 files changed

+163
-0
lines changed

4 files changed

+163
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa
5252
from .decompose_sum_pass import DecomposeSumPass # noqa
5353
from .decompose_var_pass import DecomposeVarPass # noqa
54+
from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa
5455
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
5556
FoldAndAnnotateQParamsPass,
5657
QuantizeOperatorArguments,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
DecomposeSqrtPass,
5757
DecomposeSumPass,
5858
DecomposeVarPass,
59+
DecorateFp32toInt32CastingPass,
5960
FoldAndAnnotateQParamsPass,
6061
FuseBatchnorm2DPass,
6162
FuseConstantArgsPass,
@@ -200,6 +201,9 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
200201
self.add_pass(MatchArgRanksPass(exported_program))
201202
self.add_pass(DecomposeAdaptiveAvgPool2dPass())
202203
self.add_pass(DecomposeAvgPool2d())
204+
self.add_pass(
205+
DecorateFp32toInt32CastingPass()
206+
) # Require that no new fp32->int32 is introduced after this pass
203207
self.add_pass(ComputeConstantOpsAOT(exported_program))
204208

205209
self.add_pass(DecomposeGroupedConv())
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import torch
10+
from executorch.backends.arm._passes import ArmPass
11+
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
14+
15+
def _get_decorated_ops(op):
16+
if op in DecorateFp32toInt32CastingPass.targets:
17+
return (
18+
exir_ops.edge.aten.full.default,
19+
exir_ops.edge.aten.ge.Tensor,
20+
exir_ops.edge.aten.floor.default,
21+
exir_ops.edge.aten.ceil.default,
22+
exir_ops.edge.aten.where.self,
23+
)
24+
else:
25+
raise RuntimeError(f"Can't get decorated ops for op {op}")
26+
27+
28+
class DecorateFp32toInt32CastingPass(ArmPass):
29+
"""
30+
To lower pytorch fp32 -> int32 casting to TOSA,
31+
we need to transform the value with Ceil, Floor, and Where.
32+
Before:
33+
output = to_copy(x, dtype=torch.int32)
34+
After:
35+
%zero = full((1,), 0.0, dtype=torch.float32)
36+
is_non_negative = x >= %zero
37+
floor_x = floor(x)
38+
ceil_x = ceil(x)
39+
decorated_x = where(is_non_negative, floor_x, ceil_x)
40+
output = to_copy(decorated_x, dtype=torch.int32)
41+
"""
42+
43+
targets = [
44+
exir_ops.edge.aten._to_copy.default,
45+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
46+
]
47+
48+
def call_operator(self, op, args, kwargs, meta):
49+
if op not in self.targets:
50+
return super().call_operator(op, args, kwargs, meta)
51+
52+
input = get_node_arg(args, 0)
53+
input_dtype = input.node.meta["val"].dtype
54+
output_dtype = meta["val"].dtype
55+
56+
if not (input_dtype == torch.float32 and output_dtype == torch.int32):
57+
return super().call_operator(op, args, kwargs, meta)
58+
59+
op_full, op_ge, op_floor, op_ceil, op_where = _get_decorated_ops(op)
60+
61+
zero = super().call_operator(
62+
op_full,
63+
args=((1,) * len(meta["val"].size()), 0.0),
64+
kwargs={"dtype": torch.float32},
65+
meta=meta,
66+
updated=True,
67+
)
68+
69+
is_non_negative = super().call_operator(
70+
op_ge, (input, zero), {}, meta, updated=True
71+
)
72+
floor_x = super().call_operator(op_floor, (input,), {}, meta, updated=True)
73+
ceil_x = super().call_operator(op_ceil, (input,), {}, meta, updated=True)
74+
decorated_x = super().call_operator(
75+
op_where, (is_non_negative, floor_x, ceil_x), {}, meta, updated=True
76+
)
77+
78+
return super().call_operator(op, (decorated_x,), kwargs, meta, updated=True)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm.test import common
10+
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
OpNotSupportedPipeline,
13+
TosaPipelineMI,
14+
)
15+
16+
input_t1 = Tuple[torch.Tensor] # Input x
17+
18+
19+
class FP32ToINT32Casting(torch.nn.Module):
20+
def __init__(self, target_dtype):
21+
super().__init__()
22+
self.target_dtype = target_dtype
23+
24+
def forward(self, x: torch.Tensor):
25+
return x.to(self.target_dtype)
26+
27+
28+
test_data_fp32_input = {
29+
"fp32_input_rank1": lambda: (
30+
torch.rand((4), dtype=torch.float32),
31+
torch.int32,
32+
),
33+
"fp32_input_rank2": lambda: (
34+
torch.rand((3, 4), dtype=torch.float32),
35+
torch.int32,
36+
),
37+
"fp32_input_rank3": lambda: (
38+
torch.rand((2, 3, 4), dtype=torch.float32),
39+
torch.int32,
40+
),
41+
"fp32_input_rank4": lambda: (
42+
torch.rand((1, 2, 3, 4), dtype=torch.float32),
43+
torch.int32,
44+
),
45+
}
46+
47+
48+
@common.parametrize("test_data", test_data_fp32_input)
49+
def test_decorate_fp32_to_int32_casting_tosa_MI(test_data: Tuple):
50+
test_tensor, target_dtype = test_data()
51+
module = FP32ToINT32Casting(target_dtype)
52+
53+
pipeline = TosaPipelineMI[input_t1](
54+
module,
55+
(test_tensor,),
56+
aten_op=[],
57+
exir_op=[],
58+
)
59+
pipeline.run()
60+
61+
62+
@common.parametrize("test_data", test_data_fp32_input)
63+
def test_decorate_fp32_to_int32_casting_tosa_BI(test_data: Tuple):
64+
"""
65+
Casting operation involving floating-point dtypes will be rejected in BI/INT profile.
66+
Therefore, the DecorateFp32toInt32CastingPass is not required in this profile.
67+
Add a BI test to ensure that such casting is rejected as expected.
68+
"""
69+
test_tensor, target_dtype = test_data()
70+
module = FP32ToINT32Casting(target_dtype)
71+
72+
pipeline = OpNotSupportedPipeline[input_t1](
73+
module,
74+
(test_tensor,),
75+
{
76+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1
77+
},
78+
quantize=True,
79+
)
80+
pipeline.run()

0 commit comments

Comments
 (0)