diff --git a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py index 1d824d234ee..f7b260023d4 100644 --- a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py +++ b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py @@ -395,6 +395,11 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901 # The node requires nchw inputs for input_node in node.all_input_nodes: self.input_to_nchw(graph_module, input_node, node) + elif node.target == exir_ops.edge.aten._to_copy.default: + if node.kwargs["memory_format"] == torch.channels_last: + self.mark_as_nhwc_node(node) + else: + self.mark_as_nchw_node(node) else: # The node can have inputs in any format (but all must be the # same format) diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index 553b10f60d1..2f1e4cb8c56 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -49,6 +49,7 @@ SoftmaxConfig, SquareRootConfig, SubConfig, + ToDimOrderCopyConfig, UpsampleBilinear2dConfig, ) from executorch.backends.xnnpack.partition.config.node_configs import ( @@ -99,6 +100,7 @@ PreluConfig, ReciprocalSquareRootConfig, ReLUConfig, + ToDimOrderCopyConfig, # SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails SigmoidConfig, SliceCopyConfig, diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 46922e47010..686f5093989 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -371,6 +371,35 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] +class ToDimOrderCopyConfig(GenericNodePartitionerConfig): + target_name = "_to_dim_order_copy.default" + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + """ + Only support dim order conversion partitioning, not DType conversions + """ + if not self.check_common_constraints(node, ep): + return False + + # Get input node and compare dtypes + input_node = get_input_node(node, 0) + input_dtype = input_node.meta["val"].dtype + output_dtype = node.meta["val"].dtype + + # Return False if doing dtype conversion + if input_dtype != output_dtype: + why( + node, + reason=f"dtype conversion from {input_dtype} to {output_dtype} is not supported", + ) + return False + + return True + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + class MeanDimConfig(GenericNodePartitionerConfig): target_name = "mean.dim" diff --git a/backends/xnnpack/test/ops/test_to_copy.py b/backends/xnnpack/test/ops/test_to_copy.py new file mode 100644 index 00000000000..d336cda5f7e --- /dev/null +++ b/backends/xnnpack/test/ops/test_to_copy.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +from executorch.backends.xnnpack.test.tester import Tester + + +class TestChannelsLastTaggedReshapePass(unittest.TestCase): + def setUp(self): + torch._dynamo.reset() + + def run_tester(self, module, inputs): + tester = Tester( + module.eval(), + inputs, + ) + tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs() + + class ChannelLastBeforeLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + y = x.to(memory_format=torch.channels_last) + return self.linear(y) + + ChannelLastBeforeLinearModule = ChannelLastBeforeLinear() + + def test_channel_last_before_linear(self): + self.run_tester(self.ChannelLastBeforeLinearModule, (torch.randn(1, 3, 3, 3),)) + + class ContiguousBeforeConv(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + y = x.to(memory_format=torch.contiguous_format) + return self.conv(y) + + ContiguousBeforeConvModule = ContiguousBeforeConv() + + def test_contiguous_before_conv(self): + self.run_tester(self.ContiguousBeforeConvModule, (torch.randn(1, 3, 6, 6),)) + + class DtypeAndMemoryFormatConversion(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + y = x.to(torch.float, memory_format=torch.channels_last) + return self.conv(y) + + DtypeAndMemoryFormatConversionModule = DtypeAndMemoryFormatConversion() + + def test_dtype_and_memory_format_conversion(self): + self.run_tester( + self.DtypeAndMemoryFormatConversionModule, + (torch.randint(0, 10, (1, 3, 6, 6), dtype=torch.int32),), + ) + + class DtypeAndMemoryFormatWithLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + y = x.to(torch.float, memory_format=torch.channels_last) + return self.linear(y) + + DtypeAndMemoryFormatWithLinearModule = DtypeAndMemoryFormatWithLinear() + + def test_dtype_and_memory_format_with_linear(self): + self.run_tester( + self.DtypeAndMemoryFormatWithLinearModule, + (torch.randint(0, 10, (1, 3, 3, 3), dtype=torch.int16),), + ) diff --git a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py index cfc409b4596..21295123920 100644 --- a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py +++ b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py @@ -173,6 +173,23 @@ def test_fp32_channels_last_tagged_reshape_pass(self): .run_method_and_compare_outputs() ) + class LinearConvDimSwap(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.linear1 = torch.nn.Linear(4, 3) + + def forward(self, x): + y = self.linear1(x) + y = y.to(memory_format=torch.channels_last) + y = y.to(memory_format=torch.contiguous_format) + return self.conv1(y) + + LinearConvDimSwapModule = LinearConvDimSwap() + + def test_conv_linear_dim_order_swap_partitioner(self): + self.run_tester(self.LinearConvDimSwapModule, (torch.randn(1, 3, 6, 4),)) + def test_qs8_channels_last_tagged_reshape_pass(self): for module, num_reshape in self.modules.items(): (