Skip to content

Commit 851b29b

Browse files
Arm backend: Remove upsample ops from ops_to_not_decompose on U55 (#12236)
The upsample operators get legalized to the RESCALE Tosa operator, which is not supported on U55. By having the operator in ops_to_not_decompose in the Tosa partitioner the operator is both blocked from being delegated to the NPU and blocked from being run on CPU. Therefore, the upsample operators are removed from the ops_to_not_decompose list on U55 in the partitioner, allowing them to fall back to CPU. Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com>
1 parent 0c6e348 commit 851b29b

File tree

3 files changed

+130
-2
lines changed

3 files changed

+130
-2
lines changed

backends/arm/test/ops/test_upsample_bilinear2d.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010

1111
from executorch.backends.arm.test.tester.test_pipeline import (
1212
EthosU85PipelineBI,
13+
OpNotSupportedPipeline,
1314
TosaPipelineBI,
1415
TosaPipelineMI,
1516
)
1617

1718
aten_op = "torch.ops.aten.upsample_bilinear2d.vec"
19+
exir_op = "executorch_exir_dialects_edge__ops_aten_upsample_bilinear2d_vec"
1820
input_t1 = Tuple[torch.Tensor] # Input x
1921

2022
test_data_suite_tosa = {
@@ -57,6 +59,10 @@
5759
"rand_one_and_half_size": (torch.rand(2, 4, 8, 3), (12, 4), None, False),
5860
}
5961

62+
test_data_u55 = {
63+
"rand_double_size": (torch.rand(2, 4, 8, 3), (16, 6), None, True),
64+
}
65+
6066

6167
class UpsamplingBilinear2d(torch.nn.Module):
6268
def __init__(
@@ -189,6 +195,60 @@ def test_upsample_bilinear2d_vec_tosa_BI_Upsample(
189195
pipeline.run()
190196

191197

198+
@common.parametrize("test_data", test_data_u55)
199+
@common.XfailIfNoCorstone300
200+
def test_upsample_bilinear2d_vec_U55_BI_Upsample_not_delegated(
201+
test_data: torch.Tensor,
202+
):
203+
test_data, size, scale_factor, compare_outputs = test_data
204+
pipeline = OpNotSupportedPipeline[input_t1](
205+
Upsample(size, scale_factor),
206+
(test_data,),
207+
{exir_op: 1},
208+
n_expected_delegates=0,
209+
quantize=True,
210+
u55_subset=True,
211+
)
212+
213+
pipeline.run()
214+
215+
216+
@common.parametrize("test_data", test_data_u55)
217+
@common.XfailIfNoCorstone300
218+
def test_upsample_bilinear2d_vec_U55_BI_Interpolate_not_delegated(
219+
test_data: torch.Tensor,
220+
):
221+
test_data, size, scale_factor, compare_outputs = test_data
222+
pipeline = OpNotSupportedPipeline[input_t1](
223+
Interpolate(size, scale_factor),
224+
(test_data,),
225+
{exir_op: 1},
226+
n_expected_delegates=0,
227+
quantize=True,
228+
u55_subset=True,
229+
)
230+
231+
pipeline.run()
232+
233+
234+
@common.parametrize("test_data", test_data_u55)
235+
@common.XfailIfNoCorstone300
236+
def test_upsample_bilinear2d_vec_U55_BI_UpsamplingBilinear2d_not_delegated(
237+
test_data: torch.Tensor,
238+
):
239+
test_data, size, scale_factor, compare_outputs = test_data
240+
pipeline = OpNotSupportedPipeline[input_t1](
241+
UpsamplingBilinear2d(size, scale_factor),
242+
(test_data,),
243+
{exir_op: 1},
244+
n_expected_delegates=0,
245+
quantize=True,
246+
u55_subset=True,
247+
)
248+
249+
pipeline.run()
250+
251+
192252
@common.parametrize("test_data", test_data_suite_Uxx)
193253
@common.XfailIfNoCorstone320
194254
def test_upsample_bilinear2d_vec_U85_BI_Upsample(test_data: input_t1):

backends/arm/test/ops/test_upsample_nearest2d.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
from executorch.backends.arm.test import common
1010

1111
from executorch.backends.arm.test.tester.test_pipeline import (
12+
OpNotSupportedPipeline,
1213
TosaPipelineBI,
1314
TosaPipelineMI,
1415
)
1516

1617
aten_op = "torch.ops.aten.upsample_nearest2d.vec"
18+
exir_op = "executorch_exir_dialects_edge__ops_aten_upsample_nearest2d_vec"
1719
input_t1 = Tuple[torch.Tensor] # Input x
1820

1921
test_data_suite = {
@@ -40,6 +42,10 @@
4042
"rand_one_and_half_size": lambda: (torch.rand(2, 4, 8, 3), (12, 4), None, False),
4143
}
4244

45+
test_data_u55 = {
46+
"rand_double_size": lambda: (torch.rand(2, 4, 8, 3), (16, 6), None, True),
47+
}
48+
4349
test_data_suite_dynamic = {
4450
# (test_name, test_data, size, scale_factor, compare_outputs)
4551
"rand_double_scale": lambda: (torch.rand(2, 4, 8, 3), None, 2.0, False),
@@ -170,6 +176,59 @@ def test_upsample_nearest2d_vec_tosa_BI_nearest(test_data: torch.Tensor):
170176
)
171177
if not compare_outputs:
172178
pipeline.pop_stage(-1)
179+
pipeline.run()
180+
181+
182+
@common.parametrize("test_data", test_data_u55)
183+
@common.XfailIfNoCorstone300
184+
def test_upsample_nearest2d_vec_U55_BI_Upsample_not_delegated(
185+
test_data: torch.Tensor,
186+
):
187+
test_data, size, scale_factor, compare_outputs = test_data()
188+
pipeline = OpNotSupportedPipeline[input_t1](
189+
Upsample(size, scale_factor),
190+
(test_data,),
191+
{exir_op: 1},
192+
n_expected_delegates=0,
193+
quantize=True,
194+
u55_subset=True,
195+
)
196+
197+
pipeline.run()
198+
199+
200+
@common.parametrize("test_data", test_data_u55)
201+
@common.XfailIfNoCorstone300
202+
def test_upsample_nearest2d_vec_U55_BI_Interpolate_not_delegated(
203+
test_data: torch.Tensor,
204+
):
205+
test_data, size, scale_factor, compare_outputs = test_data()
206+
pipeline = OpNotSupportedPipeline[input_t1](
207+
Interpolate(size, scale_factor),
208+
(test_data,),
209+
{exir_op: 1},
210+
n_expected_delegates=0,
211+
quantize=True,
212+
u55_subset=True,
213+
)
214+
215+
pipeline.run()
216+
217+
218+
@common.parametrize("test_data", test_data_u55)
219+
@common.XfailIfNoCorstone300
220+
def test_upsample_nearest2d_vec_U55_BI_UpsamplingBilinear2d_not_delegated(
221+
test_data: torch.Tensor,
222+
):
223+
test_data, size, scale_factor, compare_outputs = test_data()
224+
pipeline = OpNotSupportedPipeline[input_t1](
225+
UpsamplingNearest2d(size, scale_factor),
226+
(test_data,),
227+
{exir_op: 1},
228+
n_expected_delegates=0,
229+
quantize=True,
230+
u55_subset=True,
231+
)
173232

174233
pipeline.run()
175234

@@ -327,4 +386,5 @@ def test_upsample_nearest2d_dynamic_BI_upsample(test_data: torch.Tensor):
327386
)
328387
if not compare_outputs:
329388
pipeline.pop_stage(-1)
389+
330390
pipeline.run()

backends/arm/tosa_partitioner.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,18 @@ def filter_fn(node: torch.fx.Node) -> bool:
174174

175175
ops_to_not_decompose = [
176176
torch.ops.aten.linear.default,
177-
torch.ops.aten.upsample_bilinear2d.vec,
178-
torch.ops.aten.upsample_nearest2d.vec,
179177
torch.ops.aten.eye.default,
180178
torch.ops.aten.linspace.default,
181179
] + ops_to_not_decompose_if_quant_op
182180

181+
tosa_spec = get_tosa_spec(self.delegation_spec.compile_specs)
182+
if not tosa_spec.is_U55_subset:
183+
# Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d
184+
# and upsample_nearest2d decompose into that it will not be possible to
185+
# delegate those operators on U55. If we have said here to not decompose
186+
# them there will be an error saying the operator was not decomposed. It
187+
# will not be possible for it to end up on either CPU or NPU.
188+
ops_to_not_decompose.append(torch.ops.aten.upsample_nearest2d.vec)
189+
ops_to_not_decompose.append(torch.ops.aten.upsample_bilinear2d.vec)
190+
183191
return (ops_to_not_decompose, filter_fn)

0 commit comments

Comments
 (0)