Skip to content

Arm backend: Adjust pooling input when not divisible by stride #11854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
ReplaceScalarWithTensorArgPassTOSAMI,
)
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa
from .size_adjust_input_pass import SizeAdjustInputPass # noqa
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
from .replace_inf_values_pass import ReplaceInfValues # noqa # usort: skip
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
ReplaceScalarWithTensorArgPassTOSAMI,
RetraceFoldedDtypesPass,
ScalarsToAttributePass,
SizeAdjustConv2DPass,
SizeAdjustInputPass,
UnsqueezeBeforeRepeatPass,
UnsqueezeScalarPlaceholdersPass,
)
Expand Down Expand Up @@ -125,13 +125,13 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul

self.add_pass(DecomposeGroupedConv())
self.add_pass(RemoveClonePass())
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
self.add_pass(DecomposeSumPass())
self.add_pass(Conv1dUnsqueezePass())
self.add_pass(DecomposeMaxPool2DPass())
self.add_pass(SizeAdjustInputPass())
self.add_pass(DecomposeSelectPass())
self.add_pass(ConvertSqueezesToViewPass())

Expand Down Expand Up @@ -187,13 +187,13 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul

self.add_pass(DecomposeGroupedConv())
self.add_pass(RemoveClonePass())
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
self.add_pass(DecomposeSumPass())
self.add_pass(Conv1dUnsqueezePass())
self.add_pass(DecomposeMaxPool2DPass())
self.add_pass(SizeAdjustInputPass())
self.add_pass(DecomposeSelectPass())
self.add_pass(ConvertSqueezesToViewPass())

Expand Down
Original file line number Diff line number Diff line change
@@ -1,35 +1,149 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import cast
from typing import cast, TypeAlias

import torch.fx
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

Slices: TypeAlias = list[tuple[int, int, int]]

def conv_remainder(input_length, pad, dilation, weight, stride):
conv2d_op = exir_ops.edge.aten.convolution.default
max_pooling_op = exir_ops.edge.aten.max_pool2d.default
avg_pooling_op = exir_ops.edge.aten.avg_pool2d.default
slice_op = exir_ops.edge.aten.slice_copy.Tensor

valid_operators = [conv2d_op, max_pooling_op, avg_pooling_op]


def conv_remainder(input_length, pad, dilation, weight, stride) -> int:
"""
Returns the remainder of input_length; given the padding, dilation, stride,
and kernel size.
"""
return (input_length + 2 * pad - dilation * (weight - 1) - 1) % stride


class SizeAdjustConv2DPass(ExportPass):
def pooling_remainder(input_size, pad, kernel_size, stride) -> int:
"""
Returns the remainder of input_length; given the padding, stride, and
kernel size.
"""
return (input_size + 2 * pad - kernel_size) % stride


def get_slices_conv2d(conv_node: torch.fx.Node) -> Slices:
slices = []

input_node, weight, _, stride_hw, pad_hw, dilation_hw, _, _, _ = conv_node.args
weight_shape = cast(torch.fx.Node, weight).meta["val"].shape
input_shape = cast(torch.fx.Node, input_node).meta["val"].shape

for stride, pad, dilation, dim in zip(
cast(list, stride_hw),
cast(list, pad_hw),
cast(list, dilation_hw),
(2, 3),
):
remainder = conv_remainder(
input_shape[dim], pad, dilation, weight_shape[dim], stride
)
if remainder > pad:
adjustment = remainder - pad
args = (dim, 0, input_shape[dim] - adjustment)
slices.append(args)

return slices


def get_slices_pooling(pooling_node: torch.fx.Node) -> Slices:
slices = []

input_node = pooling_node.args[0]
kernel_size = pooling_node.args[1]
stride = pooling_node.args[2]
padding = pooling_node.args[3] if len(pooling_node.args) >= 4 else [0, 0]

# For the loop below, padding must be a list
if isinstance(padding, int):
padding = [padding, padding]

input_shape = cast(torch.fx.Node, input_node).meta["val"].shape

for kernel_length, stride_length, pad_size, dim in zip(
cast(list, kernel_size),
cast(list, stride),
cast(list, padding),
(2, 3),
):
remainder = pooling_remainder(
input_shape[dim], pad_size, kernel_length, stride_length
)
if remainder > pad_size:
adjustment = remainder - pad_size
args = (dim, 0, input_shape[dim] - adjustment)
slices.append(args)

return slices


def get_slices(node: torch.fx.Node) -> Slices:
"""
Returns the remainder of input_length; given graph Node.
"""
if node.target == conv2d_op:
return get_slices_conv2d(node)
elif node.target == max_pooling_op or node.target == avg_pooling_op:
return get_slices_pooling(node)
else:
raise ValueError(f"Unsupported node target, was expecting {valid_operators}")


def is_valid_operator(node: torch.fx.Node) -> bool:
if node.target == conv2d_op:
return True
elif node.target == max_pooling_op:
dilation = node.args[4] if len(node.args) >= 5 else 1
ceil_mode = node.args[5] if len(node.args) >= 6 else False

# Dilation should be handled first by DecomposeMaxPool2DPass
if isinstance(dilation, int):
if dilation > 1:
raise ValueError(
"Expected max_pool2d with dilation = 1, has DecomposeMaxPool2DPass been run?"
)
else:
dilation = cast(list, dilation)
if dilation[0] > 1 or dilation[1] > 1:
raise ValueError(
"Expected max_pool2d with dilation = [1, 1], has DecomposeMaxPool2DPass been run?"
)

# If using ceil mode for rounding, the input does not need adjusting
return not ceil_mode
elif node.target == avg_pooling_op:
ceil_mode = node.args[4] if len(node.args) >= 5 else False
count_include_pad = node.args[5] if len(node.args) >= 6 else True
divisor_override = node.args[6] if len(node.args) >= 7 else None

return not ceil_mode and not count_include_pad and divisor_override is None

return False


class SizeAdjustInputPass(ExportPass):
"""
Adjust the convolution input size to match the kernel size, padding, stride,
and dilation parameters. Pytorch allows the input and kernel shape to not
"match", in which case the remaining rows/columns are truncated. However,
matching the size is a requirement in the TOSA specification. In case the
input and kernel shape do not match, the following is done to meet the
specification:
Adjusts the input size to Conv2D and Pooling operators. PyTorch allows
the input and kernel shape to not "match", in which case the remaining
rows/columns are truncated. However, matching the size is a requirement
in the TOSA specification. In case the input and kernel shape do not
match, the following is performed to meet the specification:

1) The padding is truncated (done in the node visitor)
2) (if neccessary) The input is truncated (done in this pass)."
Expand Down Expand Up @@ -71,52 +185,33 @@ class SizeAdjustConv2DPass(ExportPass):
input.
"""

conv2d_op = exir_ops.edge.aten.convolution.default
slice_op = exir_ops.edge.aten.slice_copy.Tensor

def call(self, graph_module: torch.fx.GraphModule):
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
modified_graph = False
for node in graph.nodes:
if node.op != "call_function":
continue
if node.target != self.conv2d_op:
if not is_valid_operator(node):
continue

conv_node = cast(torch.fx.Node, node)
input_node, weight, _, stride_hw, pad_hw, dilation_hw, _, _, _ = (
conv_node.args
)
weight_shape = cast(torch.fx.Node, weight).meta["val"].shape
input_shape = cast(torch.fx.Node, input_node).meta["val"].shape

slice_args = []
for stride, pad, dilation, dim in zip(
cast(list, stride_hw),
cast(list, pad_hw),
cast(list, dilation_hw),
(2, 3),
):
remainder = conv_remainder(
input_shape[dim], pad, dilation, weight_shape[dim], stride
)
if remainder > pad:
adjustment = remainder - pad
args = (dim, 0, input_shape[dim] - adjustment)
slice_args.append(args)
target_node = cast(torch.fx.Node, node)
slice_args = get_slices(target_node)

if len(slice_args) == 0:
continue

parent_node = node.args[0]
with graph_module.graph.inserting_before(node):
last_node = cast(torch.fx.Node, input_node)
last_node = cast(torch.fx.Node, parent_node)
for args in slice_args:
slice_node = create_node(graph, self.slice_op, (last_node,) + args)
slice_node = create_node(graph, slice_op, (last_node,) + args)
last_node = slice_node
conv_node.replace_input_with(cast(torch.fx.Node, input_node), last_node)
node.replace_input_with(cast(torch.fx.Node, parent_node), last_node)
modified_graph = True

if modified_graph:
graph_module = super().call(graph_module).graph_module
graph.eliminate_dead_code()
graph_module.recompile()

return PassResult(graph_module, True)
1 change: 0 additions & 1 deletion backends/arm/test/models/test_nn_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def forward(self, *args):
"test_data",
module_tests,
xfails={
"max_pool1d": "ValueError: Invalid TOSA graph",
"affine_grid": "Int64 input. Partition handling fails since arange int64 output is split between 2 partitions.",
},
)
Expand Down
22 changes: 19 additions & 3 deletions backends/arm/test/ops/test_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ def forward(self, *args, **kwargs):
AvgPool2d((4, 6), (1, 2), (2, 3)),
(torch.rand(1, 16, 50, 32),),
),
"non_divisible_window": lambda: (
"non_divisible_window_adjust_padding": lambda: (
AvgPool2d(3, 2, 1, count_include_pad=False),
(torch.rand(1, 16, 112, 112),),
),
"non_divisible_window_height": lambda: (
"non_divisible_window_adjust_padding_height": lambda: (
AvgPool2d(3, (2, 1), 1),
(torch.rand(1, 16, 56, 56),),
),
"non_divisible_window_width": lambda: (
"non_divisible_window_adjust_padding_width": lambda: (
AvgPool2d(3, (1, 2), 1, count_include_pad=False),
(torch.rand(1, 16, 56, 56),),
),
Expand Down Expand Up @@ -91,6 +91,22 @@ def forward(self, *args, **kwargs):
AvgPool2d(3, 2, 1, True, True, divisor_override=2),
(torch.rand(1, 1, 14, 14),),
),
"non_divisible_no_padding": lambda: (
AvgPool2d(3, 2, 0),
(torch.rand(1, 16, 56, 56),),
),
"non_divibile_window_adjust_padding+input": lambda: (
AvgPool2d(3, 3, 1, count_include_pad=False),
(torch.rand(1, 16, 54, 54),),
),
"non_divibile_window_height_adjust_padding+input": lambda: (
AvgPool2d(3, (3, 1), 1),
(torch.rand(1, 16, 54, 54),),
),
"non_divibile_window_width_adjust_padding+input": lambda: (
AvgPool2d(3, (1, 3), 1, count_include_pad=False),
(torch.rand(1, 16, 54, 54),),
),
}


Expand Down
25 changes: 25 additions & 0 deletions backends/arm/test/ops/test_max_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,31 @@
torch.rand(1, 16, 56, 56),
[3, (1, 2), 1, 1, True],
),
"non_divisible_window_adjust_padding": lambda: (
torch.rand(1, 16, 112, 112),
[3, 2, 1],
),
"non_divisible_window_height_adjust_padding": lambda: (
torch.rand(1, 16, 56, 56),
[3, (2, 1), 1],
),
"non_divisible_window_width_adjust_padding": lambda: (
torch.rand(1, 16, 56, 56),
[3, (1, 2), 1],
),
"non_divisble_no_padding": lambda: (torch.rand(1, 16, 56, 56), [3, 2, 0]),
"non_divisible_window_adjust_padding+input": lambda: (
torch.rand(1, 16, 54, 54),
[3, 3, 1],
),
"non_divisible_window_height_adjust_padding+input": lambda: (
torch.rand(1, 16, 54, 54),
[3, (3, 1), 1],
),
"non_divisible_window_width_adjust_padding+input": lambda: (
torch.rand(1, 16, 54, 54),
[3, (1, 3), 1],
),
}

test_data_suite_mult_batches = {
Expand Down
Loading