Skip to content

[WIP] Add standalone batch norm support via depthwise conv conversion. #11844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/xnnpack/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from executorch.backends.xnnpack._passes.conv1d_unsqueeze_pass import (
Conv1dUnsqueezePass,
)
from executorch.backends.xnnpack._passes.convert_batch_norm_to_depthwise_conv import (
ConvertBatchNormToDepthwiseConvPass,
)
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
from executorch.backends.xnnpack._passes.convert_to_sdpa import ConvertToSDPAPass
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
Expand Down Expand Up @@ -64,6 +67,7 @@ def __init__(
ConvertToSDPAPass,
ConstPropPass,
FuseBatchNormWithConvPass,
ConvertBatchNormToDepthwiseConvPass,
FuseActivationPass,
DecomposeConcatenate,
RemoveGetItemPass,
Expand Down
273 changes: 273 additions & 0 deletions backends/xnnpack/_passes/convert_batch_norm_to_depthwise_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import operator
from typing import Optional

import torch
from executorch.backends.transforms.utils import (
create_constant_placeholder,
delete_constant_placeholder,
)
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
from executorch.backends.xnnpack.utils.utils import (
get_param_tensor,
get_tensor_name,
is_param_node,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult
from torch.export.graph_signature import InputKind


class ConvertBatchNormToDepthwiseConvPass(XNNPACKPass):
"""
Converts standalone batch norm operations to depthwise convolutions.
This allows XNNPACK to handle batch norm operations that cannot be fused
with preceding convolutions.

BatchNorm formula: y = (x - mean) / sqrt(var + eps) * weight + bias
This can be represented as a 1x1 depthwise convolution with:
- conv_weight = weight / sqrt(var + eps)
- conv_bias = bias - mean * weight / sqrt(var + eps)
"""

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
constant_placeholders_to_delete = set()
nodes_to_convert = []

# First pass: identify standalone batch norm nodes
for node in graph.nodes:
if (
node.target != exir_ops.edge.aten._native_batch_norm_legit_no_training.default
and node.target != exir_ops.edge.aten.native_batch_norm.default
):
continue

# Check if this batch norm can be fused with a preceding conv
# If so, skip it - the fusion pass will handle it
if self._can_be_fused_with_conv(node):
continue

# Check if this is a valid standalone batch norm to convert
if self._can_convert_to_depthwise_conv(node):
nodes_to_convert.append(node)

# Second pass: convert the identified nodes
for bn_node in nodes_to_convert:
conv_node = self._convert_batch_norm_to_depthwise_conv(
graph_module, bn_node, constant_placeholders_to_delete
)
if conv_node is not None:
# Replace all uses of batch norm getitem(0) with the conv node
for user in list(bn_node.users):
if user.target == operator.getitem and user.args[1] == 0:
user.replace_all_uses_with(conv_node)
graph.erase_node(user)

# Remove the batch norm node
graph.erase_node(bn_node)

# Clean up unused constant placeholders
if constant_placeholders_to_delete:
graph_module.graph.eliminate_dead_code()
for node in constant_placeholders_to_delete:
if node is not None and len(node.users) == 0:
delete_constant_placeholder(self.exported_program, node)

graph_module.recompile()
# Regenerate metadata and shape information
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)

def _can_be_fused_with_conv(self, bn_node: torch.fx.Node) -> bool:
"""Check if this batch norm can be fused with a preceding convolution."""
# Import here to avoid circular dependency
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
FuseBatchNormWithConvPass,
)

input_node = bn_node.all_input_nodes[0]

# Check if input is a conv with single user (this batch norm)
if (
input_node.target == exir_ops.edge.aten.convolution.default
and len(input_node.users) == 1
):
return FuseBatchNormWithConvPass.can_fuse(
input_node, bn_node, self.exported_program
)

return False

def _can_convert_to_depthwise_conv(self, bn_node: torch.fx.Node) -> bool:
"""Check if this batch norm can be converted to depthwise conv."""

# All users must be getitem ops accessing the first element (output tensor)
for user in bn_node.users:
if user.target != operator.getitem or user.args[1] != 0:
return False

# Check that we have the required parameters
if len(bn_node.args) < 5:
return False

# Weight, bias, running_mean, running_var must be parameters
param_nodes = bn_node.args[1:5] # weight, bias, running_mean, running_var

for param_node in param_nodes:
if not isinstance(param_node, torch.fx.Node):
return False
if not is_param_node(self.exported_program, param_node):
return False

return True

def _convert_batch_norm_to_depthwise_conv(
self,
graph_module: torch.fx.GraphModule,
bn_node: torch.fx.Node,
constant_placeholders_to_delete: set,
) -> Optional[torch.fx.Node]:
"""Convert a batch norm node to a depthwise convolution."""

# Extract batch norm parameters
input_tensor = bn_node.args[0]

# Cast args to Node types for parameter access
bn_weight_node = bn_node.args[1] if isinstance(bn_node.args[1], torch.fx.Node) else None
bn_bias_node = bn_node.args[2] if isinstance(bn_node.args[2], torch.fx.Node) else None
running_mean_node = bn_node.args[3] if isinstance(bn_node.args[3], torch.fx.Node) else None
running_var_node = bn_node.args[4] if isinstance(bn_node.args[4], torch.fx.Node) else None

if any(node is None for node in [bn_weight_node, bn_bias_node, running_mean_node, running_var_node]):
return None

# These are guaranteed to be non-None now
assert bn_weight_node is not None
assert bn_bias_node is not None
assert running_mean_node is not None
assert running_var_node is not None

bn_weight = get_param_tensor(self.exported_program, bn_weight_node)
bn_bias = get_param_tensor(self.exported_program, bn_bias_node)
running_mean = get_param_tensor(self.exported_program, running_mean_node)
running_var = get_param_tensor(self.exported_program, running_var_node)

# Get epsilon value
if str(bn_node.target).endswith("native_batch_norm.default"):
eps = bn_node.args[7] if len(bn_node.args) > 7 else 1e-5
else: # _native_batch_norm_legit_no_training
eps = bn_node.args[6] if len(bn_node.args) > 6 else 1e-5

# Ensure eps is a float
if not isinstance(eps, (int, float)):
eps = 1e-5

if any(param is None for param in [bn_weight, bn_bias, running_mean, running_var]):
return None

# Ensure all parameters are tensors
assert isinstance(bn_weight, torch.Tensor)
assert isinstance(bn_bias, torch.Tensor)
assert isinstance(running_mean, torch.Tensor)
assert isinstance(running_var, torch.Tensor)

# Calculate depthwise conv parameters
# BatchNorm: y = (x - mean) / sqrt(var + eps) * weight + bias
# Depthwise Conv: y = x * conv_weight + conv_bias
# Therefore: conv_weight = weight / sqrt(var + eps)
# conv_bias = bias - mean * weight / sqrt(var + eps)

inv_std = torch.rsqrt(running_var + eps)
conv_weight_1d = bn_weight * inv_std
conv_bias_1d = bn_bias - running_mean * conv_weight_1d

# Reshape for depthwise conv: [C] -> [C, 1, 1, 1] for 2D conv
# Assuming 4D input tensor [N, C, H, W]
num_channels = conv_weight_1d.shape[0]
conv_weight = conv_weight_1d.view(num_channels, 1, 1, 1)
conv_bias = conv_bias_1d

# Create parameter names
bn_weight_name = get_tensor_name(self.exported_program, bn_weight_node)
conv_weight_name = (bn_weight_name + "_as_depthwise_conv_weight").replace(".", "_")
conv_bias_name = (bn_weight_name + "_as_depthwise_conv_bias").replace(".", "_")

# Create new parameter nodes
graph = graph_module.graph
with graph.inserting_before(bn_node):
conv_weight_node = create_constant_placeholder(
exp_program=self.exported_program,
graph=graph,
kind=InputKind.PARAMETER,
name=conv_weight_name,
data=conv_weight,
)

conv_bias_node = create_constant_placeholder(
exp_program=self.exported_program,
graph=graph,
kind=InputKind.PARAMETER,
name=conv_bias_name,
data=conv_bias,
)

# Create depthwise convolution node
# Args: input, weight, bias, stride, padding, dilation, transposed, output_padding, groups
conv_args = (
input_tensor, # input
conv_weight_node, # weight
conv_bias_node, # bias
[1, 1], # stride
[0, 0], # padding
[1, 1], # dilation
False, # transposed
[0, 0], # output_padding
num_channels, # groups (depthwise = groups = in_channels)
)

conv_node = graph.create_node(
"call_function",
exir_ops.edge.aten.convolution.default,
args=conv_args,
)

# Mark old parameters for deletion
constant_placeholders_to_delete.update(bn_node.args[1:5])

return conv_node

@staticmethod
def can_convert_standalone_batch_norm(
bn_node: torch.fx.Node, program: ExportedProgram
) -> bool:
"""
Static method to check if a standalone batch norm can be converted.
Used by the partitioner configuration.
"""
# All users must be getitem ops accessing the first element
for user in bn_node.users:
if user.target != operator.getitem or user.args[1] != 0:
return False

# Check that we have required parameters
if len(bn_node.args) < 5:
return False

# Weight, bias, running_mean, running_var must be parameters
param_nodes = bn_node.args[1:5]

for param_node in param_nodes:
if not isinstance(param_node, torch.fx.Node):
return False
if not is_param_node(program, param_node):
return False

return True
29 changes: 16 additions & 13 deletions backends/xnnpack/partition/config/node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from typing import List, Optional

import torch
from executorch.backends.xnnpack._passes.convert_batch_norm_to_depthwise_conv import (
ConvertBatchNormToDepthwiseConvPass,
)
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
FuseBatchNormWithConvPass,
)
Expand All @@ -35,20 +38,20 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
return False

bn = node
conv = node.all_input_nodes[0]

if conv.op != "call_function":
return False

conv_name = format_target_name(conv.target.__name__) # pyre-ignore

if conv_name not in ["convolution.default"]:
why(node, f"Invalid conv target {conv_name}")
return False
input_node = node.all_input_nodes[0]

can_fuse = FuseBatchNormWithConvPass.can_fuse(conv, bn, ep)
if not can_fuse:
why(node, "BatchNorm cannot be fused with Convolution")
# First check if this can be fused with a convolution
if input_node.op == "call_function":
conv_name = format_target_name(input_node.target.__name__) # pyre-ignore
if conv_name in ["convolution.default"]:
can_fuse = FuseBatchNormWithConvPass.can_fuse(input_node, bn, ep)
if can_fuse:
return True

# If not fuseable with conv, check if it can be converted to depthwise conv
can_convert = ConvertBatchNormToDepthwiseConvPass.can_convert_standalone_batch_norm(bn, ep)
if not can_convert:
why(node, "BatchNorm cannot be fused with Convolution or converted to depthwise conv")
return False

return True
Expand Down
33 changes: 30 additions & 3 deletions backends/xnnpack/test/passes/test_batch_norm_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,36 @@ def test_q8_batch_norm_fusion(self):
.run_method_and_compare_outputs()
)

def test_fp32_standalone_batch_norm_converts_to_depthwise_conv(self):
"""
Test that standalone batch norms (i.e. batch norms that are not fused with a conv)
can be converted to depthwise convolutions and successfully partitioned and lowered.
"""

class StandaloneBN(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn = torch.nn.BatchNorm2d(2)
# Run forward to set up batch norm statistics
self.forward(torch.randn(2, 2, 4, 4) * 2 + 2)

def forward(self, x):
return self.bn(x)

(
Tester(StandaloneBN().eval(), (torch.randn(2, 2, 4, 4),))
.export()
.to_edge()
.check_count({self.bn_name: 1})
.partition()
.check_count({self.bn_name: 0}) # Should be partitioned and converted
.run_method_and_compare_outputs()
)

def test_fp32_batch_norm_no_fusion_doesnt_partition(self):
"""
We do not currently support standalone batch norms (i.e. batch norms that are
not fused with a conv). This is planned, but until implemented, this test ensures
that we do not partition the standalone batch norm and then fail to lower.
DEPRECATED: We now support standalone batch norms by converting them to depthwise conv.
This test remains for backwards compatibility but may be removed in the future.
"""

class BN(torch.nn.Module):
Expand All @@ -86,6 +111,8 @@ def __init__(self):
def forward(self, x):
return self.bn(x)

# Note: This test is now testing the old behavior where standalone batch norms
# without proper initialization may not be convertible
(
Tester(BN(), (torch.randn(2, 2, 4, 4),))
.export()
Expand Down
Loading
Loading