Skip to content

Commit 3fd9386

Browse files
authored
Merge pull request #97 from fastmachinelearning/feature/tensor_stats
Range analysis improvements and better input shape override
2 parents 813128f + 3a33770 commit 3fd9386

File tree

3 files changed

+96
-40
lines changed

3 files changed

+96
-40
lines changed

src/qonnx/transformation/extract_conv_bias.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,29 +27,29 @@
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828

2929
import warnings
30-
from onnx import TensorProto, helper
30+
from onnx import helper
3131

3232
from qonnx.transformation.base import Transformation
3333

3434

3535
class ExtractBiasFromConv(Transformation):
3636
"""
37-
Extracts the (optional) Bias from a Conv node and inserts it behind the
38-
Conv node as an Add node.
37+
Extracts the (optional) Bias from a Conv(Transpose) node and inserts it behind the
38+
Conv(Transpose) node as an Add node.
3939
"""
4040

4141
def apply(self, model):
4242
graph = model.graph
4343
node_ind = 0
4444
for n in graph.node:
4545
node_ind += 1
46-
if n.op_type == "Conv":
46+
if n.op_type in ["Conv", "ConvTranspose"]:
4747
# Check if the node has a bias input
4848
if len(n.input) > 2:
4949
# Extract bias
5050
bias = model.get_initializer(n.input[2])
5151
if bias is None:
52-
warnings.warn(f"Could not extract bias from Conv node {n}")
52+
warnings.warn(f"Could not extract bias from node {n}")
5353
continue
5454

5555
# Insert bias as Add node behind the Conv node
@@ -65,7 +65,7 @@ def apply(self, model):
6565

6666
act_add_tensor = helper.make_tensor_value_info(
6767
model.make_new_valueinfo_name(),
68-
TensorProto.FLOAT,
68+
model.get_tensor_valueinfo(n.output[0]).type.tensor_type.elem_type,
6969
out_shape,
7070
)
7171
graph.value_info.append(act_add_tensor)

src/qonnx/util/cleanup.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit
4444

4545

46-
def cleanup_model(model, preserve_qnt_ops=True, override_batchsize=None, extract_conv_bias=False):
46+
def cleanup_model(model, preserve_qnt_ops=True, override_inpsize=None, extract_conv_bias=False):
4747
"""Execute the transformations for the cleanup function on a model level.
4848
This allows the reuse of the cleanup transformations, without needing to read/write the model from/to disk.
4949
@@ -61,6 +61,19 @@ def cleanup_model(model, preserve_qnt_ops=True, override_batchsize=None, extract
6161
preserve_qnt_optypes = ["Quant", "BipolarQuant", "QuantizeLinear", "DequantizeLinear"]
6262
else:
6363
preserve_qnt_optypes = []
64+
65+
if override_inpsize is not None:
66+
if type(override_inpsize) is str:
67+
override_inpsize = eval(override_inpsize)
68+
if type(override_inpsize) is int:
69+
override_batchsize = override_inpsize
70+
model = model.transform(ChangeBatchSize(override_batchsize))
71+
elif type(override_inpsize) is tuple:
72+
override_batchsize = override_inpsize[0]
73+
model = model.transform(ChangeBatchSize(override_batchsize))
74+
iname = model.graph.input[0].name
75+
model.set_tensor_shape(iname, override_inpsize)
76+
6477
cleanup_transformations = [
6578
InferShapes(),
6679
GiveUniqueParameterTensors(),
@@ -80,27 +93,24 @@ def cleanup_model(model, preserve_qnt_ops=True, override_batchsize=None, extract
8093
model = model.transform(GiveUniqueNodeNames())
8194
model = model.transform(GiveReadableTensorNames())
8295

83-
if override_batchsize is not None:
84-
model = model.transform(ChangeBatchSize(override_batchsize))
85-
model = model.transform(InferShapes())
86-
8796
return model
8897

8998

90-
def cleanup(in_file, *, out_file=None, preserve_qnt_ops=True, override_batchsize: int = None, extract_conv_bias=False):
99+
def cleanup(in_file, *, out_file=None, preserve_qnt_ops=True, override_inpsize: str = None, extract_conv_bias=False):
91100
"""Execute a set of graph transformations to clean-up the given ONNX file.
92101
93102
:param in_file: Filename for the input ONNX model
94103
:param preserve_qnt_ops: Preserve weight quantization operators
95104
:param out_file: If set, filename for the output ONNX model. Set to in_file with _clean
96105
suffix otherwise.
97-
:param override_batchsize: If specified, override the batch size for the ONNX graph
106+
:param override_inpsize: If specified, override the input size (e.g. "(1,3,224,224)" to set all or
107+
just 1 to set batchsize to 1) for the ONNX graph
98108
:param extract_conv_bias: If specified, separate Conv bias into its own Add node
99109
"""
100110

101111
model = ModelWrapper(in_file)
102112
model = cleanup_model(
103-
model, preserve_qnt_ops=preserve_qnt_ops, override_batchsize=override_batchsize, extract_conv_bias=extract_conv_bias
113+
model, preserve_qnt_ops=preserve_qnt_ops, override_inpsize=override_inpsize, extract_conv_bias=extract_conv_bias
104114
)
105115
if out_file is None:
106116
out_file = in_file.replace(".onnx", "_clean.onnx")

src/qonnx/util/range_analysis.py

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,6 @@ def calculate_matvec_accumulator_extremum(matrix: np.ndarray, vec_min, vec_max):
6060
return (min_values, max_values)
6161

6262

63-
def propagate_range(node, model, range_dict):
64-
iname = node.input[0]
65-
node_irange = range_dict[iname]
66-
for oname in node.output:
67-
range_dict[oname] = node_irange
68-
69-
7063
def calc_gemm_range(node, model, range_dict):
7164
alpha = get_by_name(node.attribute, "alpha").f
7265
beta = get_by_name(node.attribute, "beta").f
@@ -172,10 +165,49 @@ def calc_conv_range(node, model, range_dict):
172165
range_dict[oname] = ret
173166

174167

168+
def calc_convtranspose_range(node, model, range_dict):
169+
iname = node.input[0]
170+
wname = node.input[1]
171+
assert len(node.input) == 2, "Found unsupported ConvTranspose with bias"
172+
oname = node.output[0]
173+
irange = range_dict[iname]
174+
imin, imax = irange
175+
weights = model.get_initializer(wname)
176+
assert weights is not None, "Uninitialized ConvTranspose weights"
177+
groups = get_by_name(node.attribute, "group")
178+
if groups is None:
179+
# default to dense convs
180+
groups = 1
181+
else:
182+
groups = groups.i
183+
assert groups == 1, "Only dense (non-grouped) ConvTranspose is supported"
184+
# do weight reshaping to treat Conv similar to MatMul
185+
# (mh, mw) = (ofm, (ifm x k0 x k1 x ...))
186+
conv_ofm = weights.shape[1]
187+
conv_ifm = weights.shape[0]
188+
weights = weights.transpose(1, 0, 2, 3).reshape(conv_ofm, -1)
189+
k_total = weights.shape[1] // conv_ifm
190+
if type(imin) is np.ndarray:
191+
imin_rep = np.repeat(imin, k_total)
192+
imax_rep = np.repeat(imax, k_total)
193+
else:
194+
imin_rep = imin
195+
imax_rep = imax
196+
dw_ret_min = []
197+
dw_ret_max = []
198+
for i in range(conv_ofm):
199+
w_slice = weights[i, :].reshape(1, -1)
200+
dw_ret = calculate_matvec_accumulator_extremum(w_slice, imin_rep, imax_rep)
201+
dw_ret_min.append(dw_ret[0].item())
202+
dw_ret_max.append(dw_ret[1].item())
203+
ret = (np.asarray(dw_ret_min), np.asarray(dw_ret_max))
204+
range_dict[oname] = ret
205+
206+
175207
def get_minmax_prototype_tensors(irange, ishp, inp_vi, i_channel_axis=1):
176208
proto_min = valueinfo_to_tensor(inp_vi)
177209
proto_max = valueinfo_to_tensor(inp_vi)
178-
if type(irange[0]) in [float, int, np.float32, np.float64, np.uint8, np.int8]:
210+
if type(irange[0]) in [float, int, np.float16, np.float32, np.float64, np.uint8, np.int8]:
179211
imin, imax = irange
180212
proto_min[...] = imin
181213
proto_max[...] = imax
@@ -211,25 +243,34 @@ def calc_monotonic_range(node, model, range_dict, i_channel_axis=1):
211243
inp_vi = model.get_tensor_valueinfo(inp)
212244
proto_vectors.append(get_minmax_prototype_tensors(irange, ishp, inp_vi, i_channel_axis))
213245
# process all combinations of prototype vectors for dynamic inputs
214-
running_min = None
215-
running_max = None
246+
running_min = [None for i in range(len(node.output))]
247+
running_max = [None for i in range(len(node.output))]
216248
# create context for single-node execution
217249
ctx = {x: model.get_initializer(x) for x in node.input}
218-
ctx[oname] = valueinfo_to_tensor(model.get_tensor_valueinfo(oname))
250+
for oname in node.output:
251+
ctx[oname] = valueinfo_to_tensor(model.get_tensor_valueinfo(oname))
252+
# assume all outputs are homogenous wrt data layout (e.g. channel axis
253+
# always lives in the same position)
219254
axes_to_min = [i for i in range(ctx[oname].ndim)]
220255
axes_to_min.remove(i_channel_axis)
221256
axes_to_min = tuple(axes_to_min)
222257
for inps in itertools.product(*proto_vectors):
223258
for i in range(n_dyn_inp):
224259
ctx[dyn_inps[i]] = inps[i]
225260
execute_node(node, ctx, model.graph, opset_version=opset_version)
226-
# grab new output and update running min/max
227-
out = ctx[oname]
228-
chanwise_min = out.min(axis=axes_to_min).flatten()
229-
chanwise_max = out.max(axis=axes_to_min).flatten()
230-
running_min = np.minimum(chanwise_min, running_min).flatten() if running_min is not None else chanwise_min
231-
running_max = np.maximum(chanwise_max, running_max).flatten() if running_max is not None else chanwise_max
232-
range_dict[oname] = (running_min, running_max)
261+
for oind, oname in enumerate(node.output):
262+
# grab new output and update running min/max
263+
out = ctx[oname]
264+
chanwise_min = out.min(axis=axes_to_min).flatten()
265+
chanwise_max = out.max(axis=axes_to_min).flatten()
266+
running_min[oind] = (
267+
np.minimum(chanwise_min, running_min[oind]).flatten() if running_min[oind] is not None else chanwise_min
268+
)
269+
running_max[oind] = (
270+
np.maximum(chanwise_max, running_max[oind]).flatten() if running_max[oind] is not None else chanwise_max
271+
)
272+
for oind, oname in enumerate(node.output):
273+
range_dict[oname] = (running_min[oind], running_max[oind])
233274

234275

235276
def calc_range_outdtype(node, model, range_dict):
@@ -240,12 +281,13 @@ def calc_range_outdtype(node, model, range_dict):
240281

241282

242283
optype_to_range_calc = {
243-
"Transpose": propagate_range,
284+
"Transpose": calc_monotonic_range,
244285
"MatMul": calc_matmul_range,
245286
"Conv": calc_conv_range,
287+
"ConvTranspose": calc_convtranspose_range,
246288
"QuantMaxNorm": calc_range_outdtype,
247-
"Flatten": propagate_range,
248-
"Reshape": propagate_range,
289+
"Flatten": calc_monotonic_range,
290+
"Reshape": calc_monotonic_range,
249291
"Quant": calc_monotonic_range,
250292
"BipolarQuant": calc_monotonic_range,
251293
"Mul": calc_monotonic_range,
@@ -254,7 +296,7 @@ def calc_range_outdtype(node, model, range_dict):
254296
"Add": calc_monotonic_range,
255297
"BatchNormalization": calc_monotonic_range,
256298
"Relu": calc_monotonic_range,
257-
"Pad": propagate_range,
299+
"Pad": calc_monotonic_range,
258300
"AveragePool": calc_monotonic_range,
259301
"Trunc": calc_range_outdtype,
260302
"MaxPool": calc_monotonic_range,
@@ -267,6 +309,7 @@ def calc_range_outdtype(node, model, range_dict):
267309
"Clip": calc_monotonic_range,
268310
"Sigmoid": calc_monotonic_range,
269311
"Concat": calc_monotonic_range,
312+
"Split": calc_monotonic_range,
270313
}
271314

272315

@@ -320,8 +363,12 @@ def range_analysis(
320363
range_min = None
321364
range_max = None
322365
else:
323-
irange = irange.split(",")
324-
range_min, range_max = float(irange[0]), float(irange[1])
366+
irange = eval(irange)
367+
range_min, range_max = irange
368+
if isinstance(range_min, list):
369+
range_min = np.asarray(range_min, dtype=np.float32)
370+
if isinstance(range_max, list):
371+
range_max = np.asarray(range_max, dtype=np.float32)
325372
elif isinstance(irange, tuple):
326373
range_min, range_max = irange
327374
else:
@@ -350,9 +397,8 @@ def range_analysis(
350397
for node in model.graph.node:
351398
dyn_inputs = [x for x in node.input if is_dyn_input(x, model)]
352399
inprange_ok = all([x in range_dict.keys() for x in dyn_inputs])
353-
outcount_ok = len(node.output) == 1
354400
op_ok = node.op_type in optype_to_range_calc.keys()
355-
if inprange_ok and op_ok and outcount_ok:
401+
if inprange_ok and op_ok:
356402
range_calc_fxn = optype_to_range_calc[node.op_type]
357403
range_calc_fxn(node, model, range_dict)
358404
out_range = range_dict[node.output[0]]

0 commit comments

Comments
 (0)