@@ -60,13 +60,6 @@ def calculate_matvec_accumulator_extremum(matrix: np.ndarray, vec_min, vec_max):
60
60
return (min_values , max_values )
61
61
62
62
63
- def propagate_range (node , model , range_dict ):
64
- iname = node .input [0 ]
65
- node_irange = range_dict [iname ]
66
- for oname in node .output :
67
- range_dict [oname ] = node_irange
68
-
69
-
70
63
def calc_gemm_range (node , model , range_dict ):
71
64
alpha = get_by_name (node .attribute , "alpha" ).f
72
65
beta = get_by_name (node .attribute , "beta" ).f
@@ -172,10 +165,49 @@ def calc_conv_range(node, model, range_dict):
172
165
range_dict [oname ] = ret
173
166
174
167
168
+ def calc_convtranspose_range (node , model , range_dict ):
169
+ iname = node .input [0 ]
170
+ wname = node .input [1 ]
171
+ assert len (node .input ) == 2 , "Found unsupported ConvTranspose with bias"
172
+ oname = node .output [0 ]
173
+ irange = range_dict [iname ]
174
+ imin , imax = irange
175
+ weights = model .get_initializer (wname )
176
+ assert weights is not None , "Uninitialized ConvTranspose weights"
177
+ groups = get_by_name (node .attribute , "group" )
178
+ if groups is None :
179
+ # default to dense convs
180
+ groups = 1
181
+ else :
182
+ groups = groups .i
183
+ assert groups == 1 , "Only dense (non-grouped) ConvTranspose is supported"
184
+ # do weight reshaping to treat Conv similar to MatMul
185
+ # (mh, mw) = (ofm, (ifm x k0 x k1 x ...))
186
+ conv_ofm = weights .shape [1 ]
187
+ conv_ifm = weights .shape [0 ]
188
+ weights = weights .transpose (1 , 0 , 2 , 3 ).reshape (conv_ofm , - 1 )
189
+ k_total = weights .shape [1 ] // conv_ifm
190
+ if type (imin ) is np .ndarray :
191
+ imin_rep = np .repeat (imin , k_total )
192
+ imax_rep = np .repeat (imax , k_total )
193
+ else :
194
+ imin_rep = imin
195
+ imax_rep = imax
196
+ dw_ret_min = []
197
+ dw_ret_max = []
198
+ for i in range (conv_ofm ):
199
+ w_slice = weights [i , :].reshape (1 , - 1 )
200
+ dw_ret = calculate_matvec_accumulator_extremum (w_slice , imin_rep , imax_rep )
201
+ dw_ret_min .append (dw_ret [0 ].item ())
202
+ dw_ret_max .append (dw_ret [1 ].item ())
203
+ ret = (np .asarray (dw_ret_min ), np .asarray (dw_ret_max ))
204
+ range_dict [oname ] = ret
205
+
206
+
175
207
def get_minmax_prototype_tensors (irange , ishp , inp_vi , i_channel_axis = 1 ):
176
208
proto_min = valueinfo_to_tensor (inp_vi )
177
209
proto_max = valueinfo_to_tensor (inp_vi )
178
- if type (irange [0 ]) in [float , int , np .float32 , np .float64 , np .uint8 , np .int8 ]:
210
+ if type (irange [0 ]) in [float , int , np .float16 , np . float32 , np .float64 , np .uint8 , np .int8 ]:
179
211
imin , imax = irange
180
212
proto_min [...] = imin
181
213
proto_max [...] = imax
@@ -211,25 +243,34 @@ def calc_monotonic_range(node, model, range_dict, i_channel_axis=1):
211
243
inp_vi = model .get_tensor_valueinfo (inp )
212
244
proto_vectors .append (get_minmax_prototype_tensors (irange , ishp , inp_vi , i_channel_axis ))
213
245
# process all combinations of prototype vectors for dynamic inputs
214
- running_min = None
215
- running_max = None
246
+ running_min = [ None for i in range ( len ( node . output ))]
247
+ running_max = [ None for i in range ( len ( node . output ))]
216
248
# create context for single-node execution
217
249
ctx = {x : model .get_initializer (x ) for x in node .input }
218
- ctx [oname ] = valueinfo_to_tensor (model .get_tensor_valueinfo (oname ))
250
+ for oname in node .output :
251
+ ctx [oname ] = valueinfo_to_tensor (model .get_tensor_valueinfo (oname ))
252
+ # assume all outputs are homogenous wrt data layout (e.g. channel axis
253
+ # always lives in the same position)
219
254
axes_to_min = [i for i in range (ctx [oname ].ndim )]
220
255
axes_to_min .remove (i_channel_axis )
221
256
axes_to_min = tuple (axes_to_min )
222
257
for inps in itertools .product (* proto_vectors ):
223
258
for i in range (n_dyn_inp ):
224
259
ctx [dyn_inps [i ]] = inps [i ]
225
260
execute_node (node , ctx , model .graph , opset_version = opset_version )
226
- # grab new output and update running min/max
227
- out = ctx [oname ]
228
- chanwise_min = out .min (axis = axes_to_min ).flatten ()
229
- chanwise_max = out .max (axis = axes_to_min ).flatten ()
230
- running_min = np .minimum (chanwise_min , running_min ).flatten () if running_min is not None else chanwise_min
231
- running_max = np .maximum (chanwise_max , running_max ).flatten () if running_max is not None else chanwise_max
232
- range_dict [oname ] = (running_min , running_max )
261
+ for oind , oname in enumerate (node .output ):
262
+ # grab new output and update running min/max
263
+ out = ctx [oname ]
264
+ chanwise_min = out .min (axis = axes_to_min ).flatten ()
265
+ chanwise_max = out .max (axis = axes_to_min ).flatten ()
266
+ running_min [oind ] = (
267
+ np .minimum (chanwise_min , running_min [oind ]).flatten () if running_min [oind ] is not None else chanwise_min
268
+ )
269
+ running_max [oind ] = (
270
+ np .maximum (chanwise_max , running_max [oind ]).flatten () if running_max [oind ] is not None else chanwise_max
271
+ )
272
+ for oind , oname in enumerate (node .output ):
273
+ range_dict [oname ] = (running_min [oind ], running_max [oind ])
233
274
234
275
235
276
def calc_range_outdtype (node , model , range_dict ):
@@ -240,12 +281,13 @@ def calc_range_outdtype(node, model, range_dict):
240
281
241
282
242
283
optype_to_range_calc = {
243
- "Transpose" : propagate_range ,
284
+ "Transpose" : calc_monotonic_range ,
244
285
"MatMul" : calc_matmul_range ,
245
286
"Conv" : calc_conv_range ,
287
+ "ConvTranspose" : calc_convtranspose_range ,
246
288
"QuantMaxNorm" : calc_range_outdtype ,
247
- "Flatten" : propagate_range ,
248
- "Reshape" : propagate_range ,
289
+ "Flatten" : calc_monotonic_range ,
290
+ "Reshape" : calc_monotonic_range ,
249
291
"Quant" : calc_monotonic_range ,
250
292
"BipolarQuant" : calc_monotonic_range ,
251
293
"Mul" : calc_monotonic_range ,
@@ -254,7 +296,7 @@ def calc_range_outdtype(node, model, range_dict):
254
296
"Add" : calc_monotonic_range ,
255
297
"BatchNormalization" : calc_monotonic_range ,
256
298
"Relu" : calc_monotonic_range ,
257
- "Pad" : propagate_range ,
299
+ "Pad" : calc_monotonic_range ,
258
300
"AveragePool" : calc_monotonic_range ,
259
301
"Trunc" : calc_range_outdtype ,
260
302
"MaxPool" : calc_monotonic_range ,
@@ -267,6 +309,7 @@ def calc_range_outdtype(node, model, range_dict):
267
309
"Clip" : calc_monotonic_range ,
268
310
"Sigmoid" : calc_monotonic_range ,
269
311
"Concat" : calc_monotonic_range ,
312
+ "Split" : calc_monotonic_range ,
270
313
}
271
314
272
315
@@ -320,8 +363,12 @@ def range_analysis(
320
363
range_min = None
321
364
range_max = None
322
365
else :
323
- irange = irange .split ("," )
324
- range_min , range_max = float (irange [0 ]), float (irange [1 ])
366
+ irange = eval (irange )
367
+ range_min , range_max = irange
368
+ if isinstance (range_min , list ):
369
+ range_min = np .asarray (range_min , dtype = np .float32 )
370
+ if isinstance (range_max , list ):
371
+ range_max = np .asarray (range_max , dtype = np .float32 )
325
372
elif isinstance (irange , tuple ):
326
373
range_min , range_max = irange
327
374
else :
@@ -350,9 +397,8 @@ def range_analysis(
350
397
for node in model .graph .node :
351
398
dyn_inputs = [x for x in node .input if is_dyn_input (x , model )]
352
399
inprange_ok = all ([x in range_dict .keys () for x in dyn_inputs ])
353
- outcount_ok = len (node .output ) == 1
354
400
op_ok = node .op_type in optype_to_range_calc .keys ()
355
- if inprange_ok and op_ok and outcount_ok :
401
+ if inprange_ok and op_ok :
356
402
range_calc_fxn = optype_to_range_calc [node .op_type ]
357
403
range_calc_fxn (node , model , range_dict )
358
404
out_range = range_dict [node .output [0 ]]
0 commit comments