diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 9d1e7f2e01..8c8dcfe1d7 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -5,9 +5,10 @@ from . import arm_pass_utils # noqa +from .arm_pass import ArmPass # noqa # usort: skip +from .add_bias_pass import AddBiasPass # noqa from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa -from .arm_pass import ArmPass # noqa from .broadcast_args_pass import BroadcastArgsPass # noqa from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa from .cast_to_int32_pass import CastToInt32Pass # noqa diff --git a/backends/arm/_passes/add_bias_pass.py b/backends/arm/_passes/add_bias_pass.py new file mode 100644 index 0000000000..31c0c0505c --- /dev/null +++ b/backends/arm/_passes/add_bias_pass.py @@ -0,0 +1,62 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.transforms.utils import create_constant_placeholder + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import PassResult +from torch.export.graph_signature import InputKind + + +class AddBiasPass(ArmPass): + """TOSA requires convolution nodes to have a bias input. + This pass adds a bias input to convolution nodes that do not have one. + The bias is set to zero. + """ + + targeted_ops = (exir_ops.edge.aten.convolution.default,) + + def call(self, graph_module): + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if node.target not in self.targeted_ops: + continue + + if len(node.all_input_nodes) < 3: + modified = True + # bias is missing + weight_node = node.all_input_nodes[1] + output_channels = get_first_fake_tensor(weight_node).shape[0] + # add a node containging zeros + # if quantized, use int32, otherwise use float32 + if ( + "output_qparams" in node.meta + and len(node.meta["output_qparams"]) > 0 + ): + bias_data = torch.zeros(size=(output_channels,), dtype=torch.int32) + else: + bias_data = torch.zeros( + size=(output_channels,), dtype=torch.float32 + ) + + with graph_module.graph.inserting_after(weight_node): + bias_node = create_constant_placeholder( + self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + data=bias_data, + persistent_buffer=True, + name=f"{node.name}_bias", + ) + node.update_arg(2, bias_node) + + if modified: + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index fee4fda978..d379e1e469 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -7,6 +7,7 @@ # pyre-unsafe from executorch.backends.arm._passes import ( + AddBiasPass, AnnotateChannelsLastDimOrder, AnnotateDecomposedMatmulPass, BroadcastArgsPass, @@ -134,6 +135,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(FuseViewCopyTransform()) self.add_pass(FuseConstantArgsPass(exported_program)) + self.add_pass(AddBiasPass(exported_program)) self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) @@ -194,6 +196,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(FuseViewCopyTransform()) self.add_pass(FuseConstantArgsPass(exported_program)) + self.add_pass(AddBiasPass(exported_program)) self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) self.add_pass(AnnotateChannelsLastDimOrder()) diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index 772c20a2f3..3c73e7b32c 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -109,24 +109,6 @@ def define_node( local_bound=False, ) - # Non-bias case. - if len(node.all_input_nodes) == 2: - # Create a zero bias tensor if not presented - out_channels = weight.shape[0] - bias_name = "bias" + node.name.split("default", 1)[1] - bias_type = output.dtype - if output.dtype == ts.DType.INT8: - # Conv is quantized to int8, but the TOSA operator has - # output type int32, and the bias must be the same type - # as the TOSA output type - bias_type = ts.DType.INT32 - bias = tosa_graph.addConst( - [out_channels], - bias_type, - [0] * out_channels, - name=bias_name, - ) - # The output type is int32 when input type is int8. conv2d_output_name = output.name if output.dtype == ts.DType.INT8: @@ -313,24 +295,6 @@ def define_node( name=f"{conv2d_output_name}_weight_zp", ) - # Non-bias case. - if len(node.all_input_nodes) == 2: - # Create a zero bias tensor if not presented - out_channels = weight.shape[0] - bias_name = f"{conv2d_output_name}_bias" - bias_type = output.dtype - if output.dtype == ts.DType.INT8: - # Conv is quantized to int8, but the TOSA operator has - # output type int32, and the bias must be the same type - # as the TOSA output type - bias_type = ts.DType.INT32 - bias = tosa_graph.addConst( - [out_channels], - bias_type, - [0] * out_channels, - name=bias_name, - ) - # Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W) in_channels = input.shape[1] out_channels = weight.shape[0]