Skip to content

Commit dd6caa3

Browse files
authored
Milestone2.1: Partition to_dim_order_copy op in XNN delegate (#12220)
### Summary This PR adds support for the `to_dim_order_copy` operation in the XNNPACK delegate partitioner, enabling direct handling of memory format conversions initiated by users via `.to(memory_format=)` calls. This enhancement significantly improves performance by producing more compressed graphs that avoid unnecessary partitioning boundaries at memory format conversion points. By delegating these operations directly to XNNPACK, we eliminate the overhead of context switching between the runtime and delegate, reducing both execution time and memory footprint. The implementation leverages XNNPACK's highly optimized memory format conversion routines, which are specifically designed for efficient tensor layout transformations on various hardware targets. ### Test plan Confirmed expected output when having user specified dim order conversions as well as appropriate partitioning. I did this by writing individual tests for the to_copy op ensuring it changes dim order and dtype when appropriate. Also added test module to confirm that the to copy nodes are partitioned and not in another partition
1 parent c2de265 commit dd6caa3

File tree

4 files changed

+164
-1
lines changed

4 files changed

+164
-1
lines changed

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
SquareRootConfig,
5151
SubConfig,
5252
TanhConfig,
53+
ToDimOrderCopyConfig,
5354
UpsampleBilinear2dConfig,
5455
)
5556
from executorch.backends.xnnpack.partition.config.node_configs import (
@@ -102,6 +103,7 @@
102103
ReciprocalSquareRootConfig,
103104
ReLUConfig,
104105
TanhConfig,
106+
ToDimOrderCopyConfig,
105107
SigmoidConfig,
106108
SliceCopyConfig,
107109
SoftmaxConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,35 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
425425
return [ConfigPrecisionType.FP32]
426426

427427

428+
class ToDimOrderCopyConfig(GenericNodePartitionerConfig):
429+
target_name = "_to_dim_order_copy.default"
430+
431+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
432+
"""
433+
Only support dim order conversion partitioning, not DType conversions
434+
"""
435+
if not self.check_common_constraints(node, ep):
436+
return False
437+
438+
# Get input node and compare dtypes
439+
input_node = get_input_node(node, 0)
440+
input_dtype = input_node.meta["val"].dtype
441+
output_dtype = node.meta["val"].dtype
442+
443+
# Return False if doing dtype conversion
444+
if input_dtype != output_dtype:
445+
why(
446+
node,
447+
reason=f"dtype conversion from {input_dtype} to {output_dtype} is not supported",
448+
)
449+
return False
450+
451+
return True
452+
453+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
454+
return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
455+
456+
428457
class MeanDimConfig(GenericNodePartitionerConfig):
429458
target_name = "mean.dim"
430459

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from executorch.backends.xnnpack.test.tester import Tester
12+
13+
14+
class TestChannelsLastTaggedReshapePass(unittest.TestCase):
15+
def setUp(self):
16+
torch._dynamo.reset()
17+
18+
def run_tester(self, module, inputs):
19+
tester = Tester(
20+
module.eval(),
21+
inputs,
22+
)
23+
tester.export().to_edge_transform_and_lower().check_not(
24+
["executorch_exir_dialects_edge__ops_aten__to_copy_default"]
25+
).to_executorch().serialize().run_method_and_compare_outputs()
26+
27+
class ChannelLastBeforeLinear(torch.nn.Module):
28+
def __init__(self):
29+
super().__init__()
30+
self.linear = torch.nn.Linear(3, 3)
31+
32+
def forward(self, x):
33+
y = x.to(memory_format=torch.channels_last)
34+
return self.linear(y)
35+
36+
ChannelLastBeforeLinearModule = ChannelLastBeforeLinear()
37+
38+
def test_channel_last_before_linear(self):
39+
self.run_tester(self.ChannelLastBeforeLinearModule, (torch.randn(1, 3, 3, 3),))
40+
41+
class ContiguousBeforeConv(torch.nn.Module):
42+
def __init__(self):
43+
super().__init__()
44+
self.conv = torch.nn.Conv2d(3, 3, 3)
45+
46+
def forward(self, x):
47+
y = x.to(memory_format=torch.contiguous_format)
48+
return self.conv(y)
49+
50+
ContiguousBeforeConvModule = ContiguousBeforeConv()
51+
52+
def test_contiguous_before_conv(self):
53+
self.run_tester(self.ContiguousBeforeConvModule, (torch.randn(1, 3, 6, 6),))
54+
55+
class DtypeAndMemoryFormatConversion(torch.nn.Module):
56+
def __init__(self):
57+
super().__init__()
58+
self.conv = torch.nn.Conv2d(3, 3, 3)
59+
60+
def forward(self, x):
61+
y = x.to(torch.float, memory_format=torch.channels_last)
62+
return self.conv(y)
63+
64+
DtypeAndMemoryFormatConversionModule = DtypeAndMemoryFormatConversion()
65+
66+
def test_dtype_and_memory_format_conversion(self):
67+
self.run_tester(
68+
self.DtypeAndMemoryFormatConversionModule,
69+
(torch.randint(0, 10, (1, 3, 6, 6), dtype=torch.int32),),
70+
)
71+
72+
class DtypeAndMemoryFormatWithLinear(torch.nn.Module):
73+
def __init__(self):
74+
super().__init__()
75+
self.linear = torch.nn.Linear(3, 3)
76+
77+
def forward(self, x):
78+
y = x.to(torch.float, memory_format=torch.channels_last)
79+
return self.linear(y)
80+
81+
DtypeAndMemoryFormatWithLinearModule = DtypeAndMemoryFormatWithLinear()
82+
83+
def test_dtype_and_memory_format_with_linear(self):
84+
self.run_tester(
85+
self.DtypeAndMemoryFormatWithLinearModule,
86+
(torch.randint(0, 10, (1, 3, 3, 3), dtype=torch.int16),),
87+
)
88+
89+
class QuantizedToCopy(torch.nn.Module):
90+
def __init__(self):
91+
super().__init__()
92+
self.conv = torch.nn.Conv2d(3, 3, 3)
93+
self.conv2 = torch.nn.Conv2d(3, 3, 3)
94+
95+
def forward(self, x):
96+
y = self.conv(x)
97+
y = y.to(memory_format=torch.contiguous_format)
98+
return self.conv2(y)
99+
100+
QuantizedToCopyModule = QuantizedToCopy()
101+
102+
def test_quantized_to_copy(self):
103+
tester = Tester(
104+
self.QuantizedToCopyModule.eval(),
105+
(torch.randn(1, 3, 9, 9),),
106+
)
107+
108+
tester.quantize().export().to_edge_transform_and_lower().check_not(
109+
[
110+
"executorch_exir_dialects_edge__ops_aten__to_copy_default",
111+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
112+
]
113+
).to_executorch().serialize().run_method_and_compare_outputs(qtol=0.01)

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def run_tester(self, module, inputs):
5454
module.eval(),
5555
inputs,
5656
)
57-
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
57+
tester.export().to_edge_transform_and_lower().check_not(
58+
["executorch_exir_dialects_edge__ops_aten__to_copy_default"]
59+
).to_executorch().serialize().run_method_and_compare_outputs()
5860

5961
class LinearConv(torch.nn.Module):
6062
def __init__(self):
@@ -179,6 +181,23 @@ def test_fp32_channels_last_tagged_reshape_pass(self):
179181
.run_method_and_compare_outputs()
180182
)
181183

184+
class LinearConvDimSwap(torch.nn.Module):
185+
def __init__(self):
186+
super().__init__()
187+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
188+
self.linear1 = torch.nn.Linear(4, 3)
189+
190+
def forward(self, x):
191+
y = self.linear1(x)
192+
y = y.to(memory_format=torch.channels_last)
193+
y = y.to(memory_format=torch.contiguous_format)
194+
return self.conv1(y)
195+
196+
LinearConvDimSwapModule = LinearConvDimSwap()
197+
198+
def test_conv_linear_dim_order_swap_partitioner(self):
199+
self.run_tester(self.LinearConvDimSwapModule, (torch.randn(1, 3, 6, 4),))
200+
182201
def test_qs8_channels_last_tagged_reshape_pass(self):
183202
for module, num_reshape in self.modules.items():
184203
(

0 commit comments

Comments
 (0)