diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 440938fd49c..d35403992b6 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -67,7 +67,7 @@ ReplaceScalarWithTensorArgPassTOSAMI, ) from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa -from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa +from .size_adjust_input_pass import SizeAdjustInputPass # noqa from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa from .replace_inf_values_pass import ReplaceInfValues # noqa # usort: skip diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 07a4416cd74..41e76bde0b7 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -64,7 +64,7 @@ ReplaceScalarWithTensorArgPassTOSAMI, RetraceFoldedDtypesPass, ScalarsToAttributePass, - SizeAdjustConv2DPass, + SizeAdjustInputPass, UnsqueezeBeforeRepeatPass, UnsqueezeScalarPlaceholdersPass, ) @@ -125,13 +125,13 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeGroupedConv()) self.add_pass(RemoveClonePass()) - self.add_pass(SizeAdjustConv2DPass()) self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(UnsqueezeBeforeRepeatPass()) self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) self.add_pass(DecomposeSumPass()) self.add_pass(Conv1dUnsqueezePass()) self.add_pass(DecomposeMaxPool2DPass()) + self.add_pass(SizeAdjustInputPass()) self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) @@ -187,13 +187,13 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeGroupedConv()) self.add_pass(RemoveClonePass()) - self.add_pass(SizeAdjustConv2DPass()) self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(UnsqueezeBeforeRepeatPass()) self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) self.add_pass(DecomposeSumPass()) self.add_pass(Conv1dUnsqueezePass()) self.add_pass(DecomposeMaxPool2DPass()) + self.add_pass(SizeAdjustInputPass()) self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) diff --git a/backends/arm/_passes/size_adjust_conv2d_pass.py b/backends/arm/_passes/size_adjust_input_pass.py similarity index 50% rename from backends/arm/_passes/size_adjust_conv2d_pass.py rename to backends/arm/_passes/size_adjust_input_pass.py index ee811273438..e87d65c450f 100644 --- a/backends/arm/_passes/size_adjust_conv2d_pass.py +++ b/backends/arm/_passes/size_adjust_input_pass.py @@ -1,20 +1,28 @@ # Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import cast +from typing import cast, TypeAlias import torch.fx from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult +Slices: TypeAlias = list[tuple[int, int, int]] -def conv_remainder(input_length, pad, dilation, weight, stride): +conv2d_op = exir_ops.edge.aten.convolution.default +max_pooling_op = exir_ops.edge.aten.max_pool2d.default +avg_pooling_op = exir_ops.edge.aten.avg_pool2d.default +slice_op = exir_ops.edge.aten.slice_copy.Tensor + +valid_operators = [conv2d_op, max_pooling_op, avg_pooling_op] + + +def conv_remainder(input_length, pad, dilation, weight, stride) -> int: """ Returns the remainder of input_length; given the padding, dilation, stride, and kernel size. @@ -22,14 +30,120 @@ def conv_remainder(input_length, pad, dilation, weight, stride): return (input_length + 2 * pad - dilation * (weight - 1) - 1) % stride -class SizeAdjustConv2DPass(ExportPass): +def pooling_remainder(input_size, pad, kernel_size, stride) -> int: + """ + Returns the remainder of input_length; given the padding, stride, and + kernel size. + """ + return (input_size + 2 * pad - kernel_size) % stride + + +def get_slices_conv2d(conv_node: torch.fx.Node) -> Slices: + slices = [] + + input_node, weight, _, stride_hw, pad_hw, dilation_hw, _, _, _ = conv_node.args + weight_shape = cast(torch.fx.Node, weight).meta["val"].shape + input_shape = cast(torch.fx.Node, input_node).meta["val"].shape + + for stride, pad, dilation, dim in zip( + cast(list, stride_hw), + cast(list, pad_hw), + cast(list, dilation_hw), + (2, 3), + ): + remainder = conv_remainder( + input_shape[dim], pad, dilation, weight_shape[dim], stride + ) + if remainder > pad: + adjustment = remainder - pad + args = (dim, 0, input_shape[dim] - adjustment) + slices.append(args) + + return slices + + +def get_slices_pooling(pooling_node: torch.fx.Node) -> Slices: + slices = [] + + input_node = pooling_node.args[0] + kernel_size = pooling_node.args[1] + stride = pooling_node.args[2] + padding = pooling_node.args[3] if len(pooling_node.args) >= 4 else [0, 0] + + # For the loop below, padding must be a list + if isinstance(padding, int): + padding = [padding, padding] + + input_shape = cast(torch.fx.Node, input_node).meta["val"].shape + + for kernel_length, stride_length, pad_size, dim in zip( + cast(list, kernel_size), + cast(list, stride), + cast(list, padding), + (2, 3), + ): + remainder = pooling_remainder( + input_shape[dim], pad_size, kernel_length, stride_length + ) + if remainder > pad_size: + adjustment = remainder - pad_size + args = (dim, 0, input_shape[dim] - adjustment) + slices.append(args) + + return slices + + +def get_slices(node: torch.fx.Node) -> Slices: + """ + Returns the remainder of input_length; given graph Node. + """ + if node.target == conv2d_op: + return get_slices_conv2d(node) + elif node.target == max_pooling_op or node.target == avg_pooling_op: + return get_slices_pooling(node) + else: + raise ValueError(f"Unsupported node target, was expecting {valid_operators}") + + +def is_valid_operator(node: torch.fx.Node) -> bool: + if node.target == conv2d_op: + return True + elif node.target == max_pooling_op: + dilation = node.args[4] if len(node.args) >= 5 else 1 + ceil_mode = node.args[5] if len(node.args) >= 6 else False + + # Dilation should be handled first by DecomposeMaxPool2DPass + if isinstance(dilation, int): + if dilation > 1: + raise ValueError( + "Expected max_pool2d with dilation = 1, has DecomposeMaxPool2DPass been run?" + ) + else: + dilation = cast(list, dilation) + if dilation[0] > 1 or dilation[1] > 1: + raise ValueError( + "Expected max_pool2d with dilation = [1, 1], has DecomposeMaxPool2DPass been run?" + ) + + # If using ceil mode for rounding, the input does not need adjusting + return not ceil_mode + elif node.target == avg_pooling_op: + ceil_mode = node.args[4] if len(node.args) >= 5 else False + count_include_pad = node.args[5] if len(node.args) >= 6 else True + divisor_override = node.args[6] if len(node.args) >= 7 else None + + return not ceil_mode and not count_include_pad and divisor_override is None + + return False + + +class SizeAdjustInputPass(ExportPass): """ - Adjust the convolution input size to match the kernel size, padding, stride, - and dilation parameters. Pytorch allows the input and kernel shape to not - "match", in which case the remaining rows/columns are truncated. However, - matching the size is a requirement in the TOSA specification. In case the - input and kernel shape do not match, the following is done to meet the - specification: + Adjusts the input size to Conv2D and Pooling operators. PyTorch allows + the input and kernel shape to not "match", in which case the remaining + rows/columns are truncated. However, matching the size is a requirement + in the TOSA specification. In case the input and kernel shape do not + match, the following is performed to meet the specification: 1) The padding is truncated (done in the node visitor) 2) (if neccessary) The input is truncated (done in this pass)." @@ -71,52 +185,33 @@ class SizeAdjustConv2DPass(ExportPass): input. """ - conv2d_op = exir_ops.edge.aten.convolution.default - slice_op = exir_ops.edge.aten.slice_copy.Tensor - - def call(self, graph_module: torch.fx.GraphModule): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph = graph_module.graph modified_graph = False for node in graph.nodes: if node.op != "call_function": continue - if node.target != self.conv2d_op: + if not is_valid_operator(node): continue - conv_node = cast(torch.fx.Node, node) - input_node, weight, _, stride_hw, pad_hw, dilation_hw, _, _, _ = ( - conv_node.args - ) - weight_shape = cast(torch.fx.Node, weight).meta["val"].shape - input_shape = cast(torch.fx.Node, input_node).meta["val"].shape - - slice_args = [] - for stride, pad, dilation, dim in zip( - cast(list, stride_hw), - cast(list, pad_hw), - cast(list, dilation_hw), - (2, 3), - ): - remainder = conv_remainder( - input_shape[dim], pad, dilation, weight_shape[dim], stride - ) - if remainder > pad: - adjustment = remainder - pad - args = (dim, 0, input_shape[dim] - adjustment) - slice_args.append(args) + target_node = cast(torch.fx.Node, node) + slice_args = get_slices(target_node) + if len(slice_args) == 0: continue + parent_node = node.args[0] with graph_module.graph.inserting_before(node): - last_node = cast(torch.fx.Node, input_node) + last_node = cast(torch.fx.Node, parent_node) for args in slice_args: - slice_node = create_node(graph, self.slice_op, (last_node,) + args) + slice_node = create_node(graph, slice_op, (last_node,) + args) last_node = slice_node - conv_node.replace_input_with(cast(torch.fx.Node, input_node), last_node) + node.replace_input_with(cast(torch.fx.Node, parent_node), last_node) modified_graph = True if modified_graph: graph_module = super().call(graph_module).graph_module graph.eliminate_dead_code() graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/arm/test/models/test_nn_functional.py b/backends/arm/test/models/test_nn_functional.py index a802807184c..7c5c98cdcb3 100644 --- a/backends/arm/test/models/test_nn_functional.py +++ b/backends/arm/test/models/test_nn_functional.py @@ -82,7 +82,6 @@ def forward(self, *args): "test_data", module_tests, xfails={ - "max_pool1d": "ValueError: Invalid TOSA graph", "affine_grid": "Int64 input. Partition handling fails since arange int64 output is split between 2 partitions.", }, ) diff --git a/backends/arm/test/ops/test_avg_pool2d.py b/backends/arm/test/ops/test_avg_pool2d.py index e2bbfc3a8cd..d5c4561537e 100644 --- a/backends/arm/test/ops/test_avg_pool2d.py +++ b/backends/arm/test/ops/test_avg_pool2d.py @@ -51,15 +51,15 @@ def forward(self, *args, **kwargs): AvgPool2d((4, 6), (1, 2), (2, 3)), (torch.rand(1, 16, 50, 32),), ), - "non_divisible_window": lambda: ( + "non_divisible_window_adjust_padding": lambda: ( AvgPool2d(3, 2, 1, count_include_pad=False), (torch.rand(1, 16, 112, 112),), ), - "non_divisible_window_height": lambda: ( + "non_divisible_window_adjust_padding_height": lambda: ( AvgPool2d(3, (2, 1), 1), (torch.rand(1, 16, 56, 56),), ), - "non_divisible_window_width": lambda: ( + "non_divisible_window_adjust_padding_width": lambda: ( AvgPool2d(3, (1, 2), 1, count_include_pad=False), (torch.rand(1, 16, 56, 56),), ), @@ -91,6 +91,22 @@ def forward(self, *args, **kwargs): AvgPool2d(3, 2, 1, True, True, divisor_override=2), (torch.rand(1, 1, 14, 14),), ), + "non_divisible_no_padding": lambda: ( + AvgPool2d(3, 2, 0), + (torch.rand(1, 16, 56, 56),), + ), + "non_divibile_window_adjust_padding+input": lambda: ( + AvgPool2d(3, 3, 1, count_include_pad=False), + (torch.rand(1, 16, 54, 54),), + ), + "non_divibile_window_height_adjust_padding+input": lambda: ( + AvgPool2d(3, (3, 1), 1), + (torch.rand(1, 16, 54, 54),), + ), + "non_divibile_window_width_adjust_padding+input": lambda: ( + AvgPool2d(3, (1, 3), 1, count_include_pad=False), + (torch.rand(1, 16, 54, 54),), + ), } diff --git a/backends/arm/test/ops/test_max_pool.py b/backends/arm/test/ops/test_max_pool.py index 55340a565e5..b2aa263de39 100644 --- a/backends/arm/test/ops/test_max_pool.py +++ b/backends/arm/test/ops/test_max_pool.py @@ -39,6 +39,31 @@ torch.rand(1, 16, 56, 56), [3, (1, 2), 1, 1, True], ), + "non_divisible_window_adjust_padding": lambda: ( + torch.rand(1, 16, 112, 112), + [3, 2, 1], + ), + "non_divisible_window_height_adjust_padding": lambda: ( + torch.rand(1, 16, 56, 56), + [3, (2, 1), 1], + ), + "non_divisible_window_width_adjust_padding": lambda: ( + torch.rand(1, 16, 56, 56), + [3, (1, 2), 1], + ), + "non_divisble_no_padding": lambda: (torch.rand(1, 16, 56, 56), [3, 2, 0]), + "non_divisible_window_adjust_padding+input": lambda: ( + torch.rand(1, 16, 54, 54), + [3, 3, 1], + ), + "non_divisible_window_height_adjust_padding+input": lambda: ( + torch.rand(1, 16, 54, 54), + [3, (3, 1), 1], + ), + "non_divisible_window_width_adjust_padding+input": lambda: ( + torch.rand(1, 16, 54, 54), + [3, (1, 3), 1], + ), } test_data_suite_mult_batches = {