Skip to content

Commit a6a23ed

Browse files
authored
Merge pull request #163 from fastmachinelearning/fix/chanlast_forking_transpose
Don't remove opposing transposes if first one is forking
2 parents cf640b9 + cf7c56e commit a6a23ed

File tree

3 files changed

+59
-1
lines changed

3 files changed

+59
-1
lines changed
18.8 KB
Binary file not shown.

src/qonnx/transformation/channels_last.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,10 @@ def apply(self, model):
301301
ndim = len(input_shape)
302302
if list(to_channels_first_args(ndim)) == perm_1:
303303
successor_nodes = model.find_direct_successors(n)
304-
if successor_nodes is None:
304+
# skip if:
305+
# - this Transpose has no successors (nothing to do)
306+
# - this Transpose output is forking (cannot remove)
307+
if successor_nodes is None or len(successor_nodes) > 1:
305308
continue
306309
successor_node = successor_nodes[0]
307310

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) 2024 Advanced Micro Devices, Inc.
2+
# All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice, this
8+
# list of conditions and the following disclaimer.
9+
#
10+
# * Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# * Neither the name of qonnx nor the names of its
15+
# contributors may be used to endorse or promote products derived from
16+
# this software without specific prior written permission.
17+
#
18+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
29+
import numpy as np
30+
from pkgutil import get_data
31+
32+
import qonnx.core.onnx_exec as oxe
33+
from qonnx.core.modelwrapper import ModelWrapper
34+
from qonnx.transformation.channels_last import ConvertToChannelsLastAndClean
35+
from qonnx.util.basic import gen_finn_dt_tensor
36+
37+
38+
def test_channelslast_residual():
39+
raw_m = get_data("qonnx.data", "onnx/residual_block_clean.onnx")
40+
model = ModelWrapper(raw_m)
41+
iname = model.graph.input[0].name
42+
idt = model.get_tensor_datatype(iname)
43+
ishape = model.get_tensor_shape(iname)
44+
idict = {iname: gen_finn_dt_tensor(idt, ishape)}
45+
oname = model.graph.output[0].name
46+
expected_out = oxe.execute_onnx(model, idict)[oname]
47+
model = model.transform(ConvertToChannelsLastAndClean(make_input_channels_last=False))
48+
expected_ops = ["Transpose", "Conv", "Conv", "Relu", "Conv", "Relu", "Add", "MaxPool", "Transpose"]
49+
ops = [x.op_type for x in model.graph.node]
50+
assert ops == expected_ops, "Did not found expected op sequence after lowering and channels-last"
51+
for node in model.graph.node:
52+
if node.op_type in ["Conv", "MaxPool"]:
53+
assert node.domain == "qonnx.custom_op.channels_last"
54+
out = oxe.execute_onnx(model, idict)[oname]
55+
assert np.isclose(expected_out, out, atol=1e-4).all()

0 commit comments

Comments
 (0)