Skip to content

Arm backend: Add missing bias in pass #11847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@


from . import arm_pass_utils # noqa
from .arm_pass import ArmPass # noqa # usort: skip
from .add_bias_pass import AddBiasPass # noqa
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
from .arm_pass import ArmPass # noqa
from .broadcast_args_pass import BroadcastArgsPass # noqa
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
from .cast_to_int32_pass import CastToInt32Pass # noqa
Expand Down
62 changes: 62 additions & 0 deletions backends/arm/_passes/add_bias_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.transforms.utils import create_constant_placeholder

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult
from torch.export.graph_signature import InputKind


class AddBiasPass(ArmPass):
"""TOSA requires convolution nodes to have a bias input.
This pass adds a bias input to convolution nodes that do not have one.
The bias is set to zero.
"""

targeted_ops = (exir_ops.edge.aten.convolution.default,)

def call(self, graph_module):
modified = False
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target not in self.targeted_ops:
continue

if len(node.all_input_nodes) < 3:
modified = True
# bias is missing
weight_node = node.all_input_nodes[1]
output_channels = get_first_fake_tensor(weight_node).shape[0]
# add a node containging zeros
# if quantized, use int32, otherwise use float32
if (
"output_qparams" in node.meta
and len(node.meta["output_qparams"]) > 0
):
bias_data = torch.zeros(size=(output_channels,), dtype=torch.int32)
else:
bias_data = torch.zeros(
size=(output_channels,), dtype=torch.float32
)

with graph_module.graph.inserting_after(weight_node):
bias_node = create_constant_placeholder(
self.exported_program,
graph=graph_module.graph,
kind=InputKind.PARAMETER,
data=bias_data,
persistent_buffer=True,
name=f"{node.name}_bias",
)
node.update_arg(2, bias_node)

if modified:
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, modified)
3 changes: 3 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-unsafe
from executorch.backends.arm._passes import (
AddBiasPass,
AnnotateChannelsLastDimOrder,
AnnotateDecomposedMatmulPass,
BroadcastArgsPass,
Expand Down Expand Up @@ -134,6 +135,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul

self.add_pass(FuseViewCopyTransform())
self.add_pass(FuseConstantArgsPass(exported_program))
self.add_pass(AddBiasPass(exported_program))

self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
Expand Down Expand Up @@ -194,6 +196,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul

self.add_pass(FuseViewCopyTransform())
self.add_pass(FuseConstantArgsPass(exported_program))
self.add_pass(AddBiasPass(exported_program))
self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
self.add_pass(AnnotateChannelsLastDimOrder())
Expand Down
36 changes: 0 additions & 36 deletions backends/arm/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,24 +109,6 @@ def define_node(
local_bound=False,
)

# Non-bias case.
if len(node.all_input_nodes) == 2:
# Create a zero bias tensor if not presented
out_channels = weight.shape[0]
bias_name = "bias" + node.name.split("default", 1)[1]
bias_type = output.dtype
if output.dtype == ts.DType.INT8:
# Conv is quantized to int8, but the TOSA operator has
# output type int32, and the bias must be the same type
# as the TOSA output type
bias_type = ts.DType.INT32
bias = tosa_graph.addConst(
[out_channels],
bias_type,
[0] * out_channels,
name=bias_name,
)

# The output type is int32 when input type is int8.
conv2d_output_name = output.name
if output.dtype == ts.DType.INT8:
Expand Down Expand Up @@ -313,24 +295,6 @@ def define_node(
name=f"{conv2d_output_name}_weight_zp",
)

# Non-bias case.
if len(node.all_input_nodes) == 2:
# Create a zero bias tensor if not presented
out_channels = weight.shape[0]
bias_name = f"{conv2d_output_name}_bias"
bias_type = output.dtype
if output.dtype == ts.DType.INT8:
# Conv is quantized to int8, but the TOSA operator has
# output type int32, and the bias must be the same type
# as the TOSA output type
bias_type = ts.DType.INT32
bias = tosa_graph.addConst(
[out_channels],
bias_type,
[0] * out_channels,
name=bias_name,
)

# Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W)
in_channels = input.shape[1]
out_channels = weight.shape[0]
Expand Down
Loading