Skip to content

Commit 0ca12ce

Browse files
committed
Merge branch 'main' into inference_cost_breakdown
2 parents 4dd2000 + c5bd87f commit 0ca12ce

15 files changed

+1288
-51
lines changed

src/qonnx/core/modelwrapper.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -346,16 +346,19 @@ def find_producer(self, tensor_name):
346346
return x
347347
return None
348348

349-
def find_upstream(self, tensor_name, finder_fxn):
349+
def find_upstream(self, tensor_name, finder_fxn, keep_if_not_found=False):
350350
"""Follow the producer chain upstream, calling finder_fxn on each upstream
351351
node until it returns True or there are no nodes left. Returns the list
352-
of nodes visited, or None if finder_fxn did not return True."""
352+
of nodes visited, or None if finder_fxn did not return True. If
353+
keep_if_not_found is specified, returns the list of nodes visited, even
354+
if finder_fxn never returned True, i.e., if the search terminated at an
355+
input or initializer."""
353356
visit_list = []
354357
current_tensor = tensor_name
355358
while True:
356359
current_producer = self.find_producer(current_tensor)
357360
if current_producer is None:
358-
return []
361+
return visit_list if keep_if_not_found else []
359362
else:
360363
found = finder_fxn(current_producer)
361364
visit_list.append(current_producer)
@@ -364,7 +367,7 @@ def find_upstream(self, tensor_name, finder_fxn):
364367
elif len(current_producer.input) > 0:
365368
current_tensor = current_producer.input[0]
366369
else:
367-
return None
370+
return visit_list if keep_if_not_found else None
368371

369372
def find_consumer(self, tensor_name):
370373
"""Finds and returns the node that consumes the tensor with given name.
@@ -532,7 +535,7 @@ def get_non_finn_nodes(self):
532535
return list(filter(lambda x: not util.is_finn_op(x.domain), self.graph.node))
533536

534537
def get_node_index(self, node):
535-
"""Returns current index of given node."""
538+
"""Returns current index of given node, or None if not found."""
536539
n_ind = 0
537540
try:
538541
for n in self.graph.node:
@@ -541,6 +544,17 @@ def get_node_index(self, node):
541544
n_ind += 1
542545
except ValueError:
543546
return None
547+
return None
548+
549+
def get_node_from_name(self, node_name):
550+
"""Returns the node with the specified name, or None if not found."""
551+
try:
552+
for node in self.graph.node:
553+
if node.name == node_name:
554+
return node
555+
except ValueError:
556+
return None
557+
return None
544558

545559
def get_tensor_layout(self, tensor_name):
546560
"""Returns the data layout annotation of tensor with given name.
Binary file not shown.
Binary file not shown.
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright (c) 2023 Advanced Micro Devices, Inc.
2+
# All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice, this
8+
# list of conditions and the following disclaimer.
9+
#
10+
# * Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# * Neither the name of qonnx nor the names of its
15+
# contributors may be used to endorse or promote products derived from
16+
# this software without specific prior written permission.
17+
#
18+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
29+
import numpy as np
30+
from onnx import TensorProto, helper
31+
32+
from qonnx.core.modelwrapper import ModelWrapper
33+
from qonnx.transformation.base import Transformation
34+
from qonnx.transformation.general import GiveUniqueParameterTensors, SortGraph
35+
from qonnx.transformation.remove import RemoveIdentityOps
36+
37+
38+
class ExtractQuantScaleZeroPt(Transformation):
39+
"""Extract any non-identity scale and zero-point Quant inputs as
40+
separate Div/Mul (for scale) and Add/Sub (for zeropoint" nodes,
41+
preceding and following the Quant node."""
42+
43+
def apply(self, model: ModelWrapper):
44+
graph = model.graph
45+
for node in graph.node:
46+
if node.op_type == "Quant":
47+
quant_node = node
48+
input_nm, scale_nm, zeropt_nm, _ = node.input
49+
scale_t = model.get_initializer(scale_nm)
50+
zeropt_t = model.get_initializer(zeropt_nm)
51+
ishp = model.get_tensor_shape(input_nm)
52+
extract_scale = False
53+
extract_zeropt = False
54+
if scale_t is not None and (scale_t != 1).any():
55+
extract_scale = True
56+
if zeropt_t is not None and (zeropt_t != 0).any():
57+
extract_zeropt = True
58+
if (not extract_scale) and (not extract_zeropt):
59+
continue
60+
running_input = input_nm
61+
if extract_scale:
62+
# create new Div node that divides the input
63+
# by the scale
64+
inp_scaled_nm = model.make_new_valueinfo_name()
65+
inp_scaled = helper.make_tensor_value_info(
66+
inp_scaled_nm,
67+
TensorProto.FLOAT,
68+
ishp,
69+
)
70+
graph.value_info.append(inp_scaled)
71+
inp_scale_node = helper.make_node("Div", [running_input, scale_nm], [inp_scaled_nm])
72+
graph.node.append(inp_scale_node)
73+
# create new Mul node
74+
# remove scale from Quant node
75+
new_scale_nm = model.make_new_valueinfo_name()
76+
model.set_initializer(new_scale_nm, np.asarray(1.0, dtype=np.float32))
77+
quant_node.input[1] = new_scale_nm
78+
running_input = inp_scaled_nm
79+
if extract_zeropt:
80+
# create new Add node that adds the zeropoint to
81+
# the scaled input
82+
inp_zeropt_nm = model.make_new_valueinfo_name()
83+
inp_zeropt = helper.make_tensor_value_info(
84+
inp_zeropt_nm,
85+
TensorProto.FLOAT,
86+
ishp,
87+
)
88+
graph.value_info.append(inp_zeropt)
89+
inp_zeropt_node = helper.make_node("Add", [running_input, zeropt_nm], [inp_zeropt_nm])
90+
graph.node.append(inp_zeropt_node)
91+
# remove zeropt from Quant node
92+
new_zeropt_nm = model.make_new_valueinfo_name()
93+
model.set_initializer(new_zeropt_nm, np.asarray(0.0, dtype=np.float32))
94+
quant_node.input[2] = new_zeropt_nm
95+
running_input = inp_zeropt_nm
96+
# rewire node input to any newly created Div/Add nodes
97+
quant_node.input[0] = running_input
98+
last_node = quant_node
99+
final_output = quant_node.output[0]
100+
if extract_zeropt:
101+
# create new Sub node that subtracts the zeropoint from
102+
# the output
103+
out_zeropt_nm = model.make_new_valueinfo_name()
104+
out_zeropt = helper.make_tensor_value_info(
105+
out_zeropt_nm,
106+
TensorProto.FLOAT,
107+
ishp,
108+
)
109+
graph.value_info.append(out_zeropt)
110+
out_zeropt_node = helper.make_node("Sub", [out_zeropt_nm, zeropt_nm], [final_output])
111+
last_node.output[0] = out_zeropt_nm
112+
graph.node.append(out_zeropt_node)
113+
# important: when tracking a pointer to newly added nodes,
114+
# ensure the item from the container is used, and not the
115+
# make_node result -- those are different objects
116+
# e.g. if we use last_node = out_zeropt_node below,
117+
# this will point to the wrong object and cause bugs later
118+
last_node = graph.node[-1]
119+
if extract_scale:
120+
# create new Mul node that applies the output scale
121+
out_scale_nm = model.make_new_valueinfo_name()
122+
out_scale = helper.make_tensor_value_info(
123+
out_scale_nm,
124+
TensorProto.FLOAT,
125+
ishp,
126+
)
127+
last_node.output[0] = out_scale_nm
128+
graph.value_info.append(out_scale)
129+
out_scale_node = helper.make_node("Mul", [out_scale_nm, scale_nm], [final_output])
130+
graph.node.append(out_scale_node)
131+
132+
if extract_scale or extract_zeropt:
133+
# since we used append() for new nodes, need to call
134+
# SortGraph to ensure correct (topological) order
135+
model = model.transform(SortGraph())
136+
# Remove potential unity multiplications from alpha and beta attributes
137+
model = model.transform(RemoveIdentityOps())
138+
# Ensure unique parameter tensors
139+
model = model.transform(GiveUniqueParameterTensors())
140+
return model, True
141+
142+
return model, False

src/qonnx/transformation/lower_convs_to_matmul.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,7 @@
3232

3333
from qonnx.transformation.base import Transformation
3434
from qonnx.transformation.extract_conv_bias import ExtractBiasFromConv
35-
from qonnx.util.basic import get_by_name
36-
37-
38-
def _auto_pad_to_explicit_padding(autopad_str, idim_h, idim_w, k_h, k_w, stride_h, stride_w, n_dims):
39-
pad_total_h = (stride_h - 1) * idim_h - stride_h + k_h
40-
pad_total_w = (stride_w - 1) * idim_w - stride_w + k_w
41-
pad_half_small_h = int((pad_total_h / 2))
42-
pad_half_small_w = int((pad_total_w / 2))
43-
pad_half_large_h = pad_total_h - pad_half_small_h
44-
pad_half_large_w = pad_total_w - pad_half_small_w
45-
if autopad_str == "VALID":
46-
return [0 for i in range(2 * n_dims)]
47-
elif autopad_str == "SAME_UPPER":
48-
return [pad_half_small_h, pad_half_small_w, pad_half_large_h, pad_half_large_w]
49-
elif autopad_str == "SAME_LOWER":
50-
return [pad_half_large_h, pad_half_large_w, pad_half_small_h, pad_half_small_w]
51-
else:
52-
raise Exception("Unsupported auto_pad: " + autopad_str)
35+
from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name
5336

5437

5538
class LowerConvsToMatMul(Transformation):
@@ -100,7 +83,7 @@ def apply(self, model):
10083
# use specified padding
10184
pad = get_by_name(n.attribute, "pads").ints
10285
else:
103-
pad = _auto_pad_to_explicit_padding(
86+
pad = auto_pad_to_explicit_padding(
10487
auto_pad,
10588
ifm_dim_h,
10689
ifm_dim_w,

0 commit comments

Comments
 (0)