Skip to content

Milestone2.1: Partition to_dim_order_copy op in XNN delegate #11286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,11 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
# The node requires nchw inputs
for input_node in node.all_input_nodes:
self.input_to_nchw(graph_module, input_node, node)
elif node.target == exir_ops.edge.aten._to_copy.default:
if node.kwargs["memory_format"] == torch.channels_last:
self.mark_as_nhwc_node(node)
else:
self.mark_as_nchw_node(node)
else:
# The node can have inputs in any format (but all must be the
# same format)
Expand Down
2 changes: 2 additions & 0 deletions backends/xnnpack/partition/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
SoftmaxConfig,
SquareRootConfig,
SubConfig,
ToDimOrderCopyConfig,
UpsampleBilinear2dConfig,
)
from executorch.backends.xnnpack.partition.config.node_configs import (
Expand Down Expand Up @@ -99,6 +100,7 @@
PreluConfig,
ReciprocalSquareRootConfig,
ReLUConfig,
ToDimOrderCopyConfig,
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
SigmoidConfig,
SliceCopyConfig,
Expand Down
29 changes: 29 additions & 0 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,35 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]


class ToDimOrderCopyConfig(GenericNodePartitionerConfig):
target_name = "_to_dim_order_copy.default"

def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
"""
Only support dim order conversion partitioning, not DType conversions
"""
if not self.check_common_constraints(node, ep):
return False

# Get input node and compare dtypes
input_node = get_input_node(node, 0)
input_dtype = input_node.meta["val"].dtype
output_dtype = node.meta["val"].dtype

# Return False if doing dtype conversion
if input_dtype != output_dtype:
why(
node,
reason=f"dtype conversion from {input_dtype} to {output_dtype} is not supported",
)
return False

return True

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]


class MeanDimConfig(GenericNodePartitionerConfig):
target_name = "mean.dim"

Expand Down
85 changes: 85 additions & 0 deletions backends/xnnpack/test/ops/test_to_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch

from executorch.backends.xnnpack.test.tester import Tester


class TestChannelsLastTaggedReshapePass(unittest.TestCase):
def setUp(self):
torch._dynamo.reset()

def run_tester(self, module, inputs):
tester = Tester(
module.eval(),
inputs,
)
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()

class ChannelLastBeforeLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)

def forward(self, x):
y = x.to(memory_format=torch.channels_last)
return self.linear(y)

ChannelLastBeforeLinearModule = ChannelLastBeforeLinear()

def test_channel_last_before_linear(self):
self.run_tester(self.ChannelLastBeforeLinearModule, (torch.randn(1, 3, 3, 3),))

class ContiguousBeforeConv(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)

def forward(self, x):
y = x.to(memory_format=torch.contiguous_format)
return self.conv(y)

ContiguousBeforeConvModule = ContiguousBeforeConv()

def test_contiguous_before_conv(self):
self.run_tester(self.ContiguousBeforeConvModule, (torch.randn(1, 3, 6, 6),))

class DtypeAndMemoryFormatConversion(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)

def forward(self, x):
y = x.to(torch.float, memory_format=torch.channels_last)
return self.conv(y)

DtypeAndMemoryFormatConversionModule = DtypeAndMemoryFormatConversion()

def test_dtype_and_memory_format_conversion(self):
self.run_tester(
self.DtypeAndMemoryFormatConversionModule,
(torch.randint(0, 10, (1, 3, 6, 6), dtype=torch.int32),),
)

class DtypeAndMemoryFormatWithLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)

def forward(self, x):
y = x.to(torch.float, memory_format=torch.channels_last)
return self.linear(y)

DtypeAndMemoryFormatWithLinearModule = DtypeAndMemoryFormatWithLinear()

def test_dtype_and_memory_format_with_linear(self):
self.run_tester(
self.DtypeAndMemoryFormatWithLinearModule,
(torch.randint(0, 10, (1, 3, 3, 3), dtype=torch.int16),),
)
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,23 @@ def test_fp32_channels_last_tagged_reshape_pass(self):
.run_method_and_compare_outputs()
)

class LinearConvDimSwap(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.linear1 = torch.nn.Linear(4, 3)

def forward(self, x):
y = self.linear1(x)
y = y.to(memory_format=torch.channels_last)
y = y.to(memory_format=torch.contiguous_format)
return self.conv1(y)

LinearConvDimSwapModule = LinearConvDimSwap()

def test_conv_linear_dim_order_swap_partitioner(self):
self.run_tester(self.LinearConvDimSwapModule, (torch.randn(1, 3, 6, 4),))

def test_qs8_channels_last_tagged_reshape_pass(self):
for module, num_reshape in self.modules.items():
(
Expand Down
Loading