Skip to content

Commit 0c6e348

Browse files
authored
Arm backend: Add pass and test for adaptive_avg_pool2d (#12190)
Add full support and tests for adaptive_avg_pool2d Signed-off-by: Emma Kujala <emma.kujala@arm.com>
1 parent f858e0d commit 0c6e348

File tree

5 files changed

+260
-0
lines changed

5 files changed

+260
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2424
from .convert_to_clamp import ConvertToClampPass # noqa
2525
from .decompose_acosh_pass import DecomposeAcoshPass # noqa
26+
from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa
2627
from .decompose_atan_pass import DecomposeAtanPass # noqa
2728
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa
2829
from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ConvertSqueezesToViewPass,
2929
ConvertToClampPass,
3030
DecomposeAcoshPass,
31+
DecomposeAdaptiveAvgPool2dPass,
3132
DecomposeAtanPass,
3233
DecomposeAvgPool2d,
3334
DecomposeBatchNormNoStatsPass,
@@ -127,6 +128,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
127128
if self.tosa_spec.is_U55_subset:
128129
self.add_pass(BroadcastArgsPass())
129130
self.add_pass(DecomposeLinearPass())
131+
self.add_pass(DecomposeAdaptiveAvgPool2dPass())
130132
self.add_pass(DecomposeAvgPool2d())
131133
self.add_pass(ComputeConstantOpsAOT(exported_program))
132134

@@ -194,6 +196,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
194196
self.add_pass(RetraceFoldedDtypesPass())
195197
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
196198
self.add_pass(MatchArgRanksPass(exported_program))
199+
self.add_pass(DecomposeAdaptiveAvgPool2dPass())
197200
self.add_pass(DecomposeAvgPool2d())
198201
self.add_pass(ComputeConstantOpsAOT(exported_program))
199202

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from math import ceil, floor
7+
8+
import torch
9+
10+
from executorch.backends.arm._passes import ArmPass
11+
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
14+
edge_ops = (exir_ops.edge.aten._adaptive_avg_pool2d.default,)
15+
aten_ops = (torch.ops.aten.adaptive_avg_pool2d.default,)
16+
17+
18+
def _get_decomposition(op) -> tuple:
19+
if op in edge_ops:
20+
return (
21+
exir_ops.edge.aten.avg_pool2d.default,
22+
exir_ops.edge.aten.slice_copy.Tensor,
23+
exir_ops.edge.aten.cat.default,
24+
)
25+
if op in aten_ops:
26+
return (
27+
torch.ops.aten.avg_pool2d.default,
28+
torch.ops.aten.slice_copy.Tensor,
29+
torch.ops.aten.cat.default,
30+
)
31+
raise RuntimeError(f"Unable to get decomposition for op {op}")
32+
33+
34+
class DecomposeAdaptiveAvgPool2dPass(ArmPass):
35+
"""
36+
Decomposes AdaptiveAvgPool2d into AvgPool2d operations.
37+
38+
An input tensor of shape (N, C, H, W) is transformed into an output tensor
39+
of shape (N, C, output_size_h, output_size_w).
40+
41+
The output is of size output_size_h x output_size_w for any input.
42+
"""
43+
44+
def call_operator(self, op, args, kwargs, meta, updated=False):
45+
if op not in (edge_ops + aten_ops):
46+
return super().call_operator(op, args, kwargs, meta, updated)
47+
48+
avg_pool2d_op, slice_op, cat_op = _get_decomposition(op)
49+
50+
x = args[0]
51+
_, _, input_size_h, input_size_w = x.data.shape
52+
53+
(output_size_h, output_size_w) = args[1]
54+
55+
# Vela currently only allows a stride in the interval of [1,3] for AvgPool2d.
56+
# To accommodate this, the AvgPool2d op is applied to pooling regions and the results are concatenated.
57+
58+
res = []
59+
for out_i in range(output_size_h):
60+
row = []
61+
for out_j in range(output_size_w):
62+
# Calculate pooling regions
63+
start_h = floor(out_i * input_size_h / output_size_h)
64+
end_h = ceil((out_i + 1) * input_size_h / output_size_h)
65+
start_w = floor(out_j * input_size_w / output_size_w)
66+
end_w = ceil((out_j + 1) * input_size_w / output_size_w)
67+
68+
# Slice along H
69+
x_h = super().call_operator(
70+
slice_op, (x, 2, start_h, end_h), kwargs, meta, True
71+
)
72+
# Slice along W
73+
x_hw = super().call_operator(
74+
slice_op, (x_h, 3, start_w, end_w), kwargs, meta, True
75+
)
76+
77+
# Apply avg pooling with kernel size equal to the pooling region
78+
kernel_h = end_h - start_h
79+
kernel_w = end_w - start_w
80+
pool_args = (x_hw, (kernel_h, kernel_w), (1, 1), (0, 0))
81+
pooled = super().call_operator(
82+
avg_pool2d_op, pool_args, kwargs, meta, True
83+
)
84+
row.append(pooled)
85+
86+
# Concatenate row results along width (dim=3)
87+
row_tensor = super().call_operator(cat_op, (row, 3), kwargs, meta, True)
88+
res.append(row_tensor)
89+
90+
# Concatenate all rows along height (dim=2)
91+
out = super().call_operator(cat_op, (res, 2), kwargs, meta, True)
92+
return out

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def is_node_supported(
249249
exir_ops.edge.aten.sinh.default,
250250
exir_ops.edge.aten.atan.default,
251251
exir_ops.edge.aten.acosh.default,
252+
exir_ops.edge.aten._adaptive_avg_pool2d.default,
252253
exir_ops.edge.aten.sign.default,
253254
]
254255

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm.test import common
11+
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU55PipelineBI,
14+
EthosU85PipelineBI,
15+
TosaPipelineBI,
16+
TosaPipelineMI,
17+
)
18+
19+
exir_op = "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"
20+
21+
input_t = Tuple[torch.Tensor]
22+
23+
24+
class AdaptiveAvgPool2d(torch.nn.AdaptiveAvgPool2d):
25+
def forward(self, *args, **kwargs):
26+
return super().forward(*args, **kwargs)
27+
28+
29+
test_modules = {
30+
"output_bigger_than_input_1_to_3": lambda: (
31+
AdaptiveAvgPool2d((3, 3)),
32+
(torch.rand(1, 3, 1, 1),),
33+
),
34+
"output_bigger_than_input_7_to_10": lambda: (
35+
AdaptiveAvgPool2d((10, 10)),
36+
(torch.rand(1, 3, 7, 7),),
37+
),
38+
"output_1x1": lambda: (AdaptiveAvgPool2d((1, 1)), (torch.rand(1, 4, 8, 8),)),
39+
"output_2x2": lambda: (AdaptiveAvgPool2d((2, 2)), (torch.rand(1, 4, 10, 10),)),
40+
"output_4x4": lambda: (AdaptiveAvgPool2d((4, 4)), (torch.rand(1, 5, 15, 15),)),
41+
"output_2x3": lambda: (AdaptiveAvgPool2d((2, 3)), (torch.rand(1, 3, 9, 13),)),
42+
"output_h_keep": lambda: (
43+
AdaptiveAvgPool2d((2, None)),
44+
(torch.rand(1, 3, 10, 16),),
45+
),
46+
"output_w_keep": lambda: (
47+
AdaptiveAvgPool2d((None, 4)),
48+
(torch.rand(1, 3, 14, 20),),
49+
),
50+
"output_5x5": lambda: (AdaptiveAvgPool2d((5, 5)), (torch.rand(1, 3, 25, 25),)),
51+
"output_3x5": lambda: (AdaptiveAvgPool2d((3, 5)), (torch.rand(1, 3, 15, 20),)),
52+
"output_7x1": lambda: (AdaptiveAvgPool2d((7, 1)), (torch.rand(1, 3, 21, 3),)),
53+
"output_1x7": lambda: (AdaptiveAvgPool2d((1, 7)), (torch.rand(1, 3, 3, 21),)),
54+
"output_3xNone": lambda: (AdaptiveAvgPool2d((3, None)), (torch.rand(1, 3, 9, 24),)),
55+
"output_Nonex3": lambda: (AdaptiveAvgPool2d((None, 3)), (torch.rand(1, 3, 24, 9),)),
56+
"pool_h_static_w_none": lambda: (
57+
AdaptiveAvgPool2d((3, None)),
58+
(torch.rand(1, 3, 9, 17),),
59+
),
60+
"pool_h_none_w_static": lambda: (
61+
AdaptiveAvgPool2d((None, 5)),
62+
(torch.rand(1, 3, 15, 25),),
63+
),
64+
"identity_pool": lambda: (
65+
AdaptiveAvgPool2d((10, 10)),
66+
(torch.rand(1, 3, 10, 10),),
67+
),
68+
"non_divisible_5x5_from_17x17": lambda: (
69+
AdaptiveAvgPool2d((5, 5)),
70+
(torch.rand(1, 3, 17, 17),),
71+
),
72+
"pool_height_only": lambda: (
73+
AdaptiveAvgPool2d((1, 6)),
74+
(torch.rand(1, 3, 12, 6),),
75+
),
76+
"pool_width_only": lambda: (
77+
AdaptiveAvgPool2d((6, 1)),
78+
(torch.rand(1, 3, 6, 12),),
79+
),
80+
"extreme_input_large": lambda: (
81+
AdaptiveAvgPool2d((1, 1)),
82+
(torch.rand(1, 3, 128, 128),),
83+
),
84+
"single_channel_input": lambda: (
85+
AdaptiveAvgPool2d((4, 4)),
86+
(torch.rand(1, 1, 16, 16),),
87+
),
88+
"high_channel_count": lambda: (
89+
AdaptiveAvgPool2d((2, 2)),
90+
(torch.rand(1, 1024, 32, 32),),
91+
),
92+
# Common input/output sizes found in models
93+
"output_7x7_from_14x14": lambda: (
94+
AdaptiveAvgPool2d((7, 7)),
95+
(torch.rand(1, 512, 14, 14),),
96+
),
97+
"output_1x1_from_8x8": lambda: (
98+
AdaptiveAvgPool2d((1, 1)),
99+
(torch.rand(1, 2048, 8, 8),),
100+
),
101+
"output_1x1_from_19": lambda: (
102+
AdaptiveAvgPool2d((1, 1)),
103+
(torch.rand(1, 2560, 19, 19),),
104+
),
105+
"output_1x1_from_7x7": lambda: (
106+
AdaptiveAvgPool2d((1, 1)),
107+
(torch.rand(1, 1280, 7, 7),),
108+
),
109+
}
110+
111+
112+
@common.parametrize("test_module", test_modules)
113+
def test_adaptive_avg_pool2d_tosa_MI(test_module):
114+
model, input_tensor = test_module()
115+
116+
pipeline = TosaPipelineMI[input_t](
117+
model,
118+
input_tensor,
119+
aten_op=[],
120+
exir_op=exir_op,
121+
)
122+
pipeline.run()
123+
124+
125+
@common.parametrize("test_module", test_modules)
126+
def test_adaptive_avg_pool2d_tosa_BI(test_module):
127+
model, input_tensor = test_module()
128+
129+
pipeline = TosaPipelineBI[input_t](
130+
model,
131+
input_tensor,
132+
aten_op=[],
133+
exir_op=exir_op,
134+
)
135+
pipeline.run()
136+
137+
138+
@common.parametrize("test_module", test_modules)
139+
@common.XfailIfNoCorstone300
140+
def test_adaptive_avg_pool2d_u55_BI(test_module):
141+
model, input_tensor = test_module()
142+
143+
pipeline = EthosU55PipelineBI[input_t](
144+
model,
145+
input_tensor,
146+
aten_ops=[],
147+
exir_ops=exir_op,
148+
)
149+
pipeline.run()
150+
151+
152+
@common.parametrize("test_module", test_modules)
153+
@common.XfailIfNoCorstone320
154+
def test_adaptive_avg_pool2d_u85_BI(test_module):
155+
model, input_tensor = test_module()
156+
157+
pipeline = EthosU85PipelineBI[input_t](
158+
model,
159+
input_tensor,
160+
aten_ops=[],
161+
exir_ops=exir_op,
162+
)
163+
pipeline.run()

0 commit comments

Comments
 (0)