Skip to content

Commit 57e5905

Browse files
_QuantizedConv2D and _QuantizedDepthwiseConv2D APIs attribute update (#1120)
1 parent c09e888 commit 57e5905

File tree

3 files changed

+47
-47
lines changed

3 files changed

+47
-47
lines changed

neural_compressor/adaptor/tf_utils/graph_rewriter/int8/fuse_conv_requantize.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def do_transformation(self):
218218

219219
if self.new_api and new_node.op in ('_QuantizedConv2D', '_QuantizedDepthwiseConv2D'):
220220
input_data_type = dtypes.qint8 if new_node.attr["Tinput"].type == dtypes.qint8 else dtypes.quint8
221-
Helper.set_attr_type_list(new_node, 'input_types', [
221+
Helper.set_attr_type_list(new_node, 'Thost_inputs', [
222222
input_data_type.as_datatype_enum,
223223
dtypes.qint8.as_datatype_enum,
224224
dtypes.float32.as_datatype_enum if new_node.attr["Tbias"].type == dtypes.float32 \
@@ -232,12 +232,12 @@ def do_transformation(self):
232232
])
233233

234234
if quantized_node_op not in ('_QuantizedConv2D', '_QuantizedDepthwiseConv2D'):
235-
Helper.set_attr_type_list(new_node, 'out_types', self.output_types)
235+
Helper.set_attr_type_list(new_node, 'Thost_outputs', self.output_types)
236236
new_node.attr["Tsummand"].CopyFrom(attr_value_pb2.AttrValue(type=self.output_types[0]))
237237
else:
238238
if str(quantized_node.attr['fused_ops'].list.s) == str([b"BiasAdd", b"_FusedHardSwish"]):
239239
self.fused_ops= [b"BiasAdd", b"_FusedHardSwish", b"Requantize"]
240-
Helper.set_attr_type_list(new_node, 'out_types', [
240+
Helper.set_attr_type_list(new_node, 'Thost_outputs', [
241241
requantize_node.attr['out_type'].type,
242242
dtypes.float32.as_datatype_enum,
243243
dtypes.float32.as_datatype_enum ])
@@ -247,7 +247,7 @@ def do_transformation(self):
247247
dtype_map_dict[requantize_node.attr['out_type'].type])
248248
elif str(quantized_node.attr['fused_ops'].list.s) == str([b"BiasAdd", b"_FusedSwish"]):
249249
self.fused_ops= [b"BiasAdd", b"_FusedSwish", b"Requantize"]
250-
Helper.set_attr_type_list(new_node, 'out_types', [
250+
Helper.set_attr_type_list(new_node, 'Thost_outputs', [
251251
requantize_node.attr['out_type'].type,
252252
dtypes.float32.as_datatype_enum,
253253
dtypes.float32.as_datatype_enum ])
@@ -257,7 +257,7 @@ def do_transformation(self):
257257
dtype_map_dict[requantize_node.attr['out_type'].type])
258258
elif str(quantized_node.attr['fused_ops'].list.s) == str([b"BiasAdd", b"Relu"]):
259259
self.fused_ops= [b"BiasAdd", b"Relu", b"Requantize"]
260-
Helper.set_attr_type_list(new_node, 'out_types', [
260+
Helper.set_attr_type_list(new_node, 'Thost_outputs', [
261261
requantize_node.attr['out_type'].type,
262262
dtypes.float32.as_datatype_enum,
263263
dtypes.float32.as_datatype_enum ])
@@ -267,7 +267,7 @@ def do_transformation(self):
267267
dtype_map_dict[requantize_node.attr['out_type'].type])
268268
elif str(quantized_node.attr['fused_ops'].list.s) == str([b"BiasAdd", b"LeakyRelu"]):
269269
self.fused_ops= [b"BiasAdd", b"LeakyRelu", b"Requantize"]
270-
Helper.set_attr_type_list(new_node, 'out_types', [
270+
Helper.set_attr_type_list(new_node, 'Thost_outputs', [
271271
requantize_node.attr['out_type'].type,
272272
dtypes.float32.as_datatype_enum,
273273
dtypes.float32.as_datatype_enum ])
@@ -277,7 +277,7 @@ def do_transformation(self):
277277
dtype_map_dict[requantize_node.attr['out_type'].type])
278278
elif str(quantized_node.attr['fused_ops'].list.s) == str([b"BiasAdd", b"Elu"]):
279279
self.fused_ops= [b"BiasAdd", b"Elu", b"Requantize"]
280-
Helper.set_attr_type_list(new_node, 'out_types', [
280+
Helper.set_attr_type_list(new_node, 'Thost_outputs', [
281281
requantize_node.attr['out_type'].type,
282282
dtypes.float32.as_datatype_enum,
283283
dtypes.float32.as_datatype_enum ])
@@ -287,7 +287,7 @@ def do_transformation(self):
287287
dtype_map_dict[requantize_node.attr['out_type'].type])
288288
elif str(quantized_node.attr['fused_ops'].list.s) == str([b"BiasAdd"]):
289289
self.fused_ops= [b"BiasAdd", b"Requantize"]
290-
Helper.set_attr_type_list(new_node, 'out_types', [
290+
Helper.set_attr_type_list(new_node, 'Thost_outputs', [
291291
requantize_node.attr['out_type'].type,
292292
dtypes.float32.as_datatype_enum,
293293
dtypes.float32.as_datatype_enum ])
@@ -296,7 +296,7 @@ def do_transformation(self):
296296
Helper.set_attr_dtype(new_node, "Tsummand", \
297297
dtype_map_dict[requantize_node.attr['out_type'].type])
298298
elif len(quantized_node.attr['fused_ops'].list.s) == 0:
299-
Helper.set_attr_type_list(new_node, 'input_types', [
299+
Helper.set_attr_type_list(new_node, 'Thost_inputs', [
300300
input_data_type.as_datatype_enum,
301301
dtypes.qint8.as_datatype_enum,
302302
#dtypes.float32.as_datatype_enum if new_node.attr["Tbias"].type == dtypes.float32 \
@@ -309,7 +309,7 @@ def do_transformation(self):
309309
dtypes.float32.as_datatype_enum,
310310
])
311311
self.fused_ops= [b"Requantize"]
312-
Helper.set_attr_type_list(new_node, 'out_types', [
312+
Helper.set_attr_type_list(new_node, 'Thost_outputs', [
313313
requantize_node.attr['out_type'].type,
314314
dtypes.float32.as_datatype_enum,
315315
dtypes.float32.as_datatype_enum ])
@@ -326,7 +326,7 @@ def do_transformation(self):
326326
input_data_type = dtypes.qint8 if new_node.attr["Tinput"].type == dtypes.qint8 else dtypes.quint8
327327
if len(quantized_node.attr['fused_ops'].list.s) == 0:
328328
Helper.set_attr_string_list(new_node, 'fused_ops', [ b'Requantize'])
329-
Helper.set_attr_type_list(new_node, 'input_types', [
329+
Helper.set_attr_type_list(new_node, 'Thost_inputs', [
330330
input_data_type.as_datatype_enum,
331331
dtypes.qint8.as_datatype_enum,
332332
dtypes.float32.as_datatype_enum,
@@ -336,7 +336,7 @@ def do_transformation(self):
336336
dtypes.float32.as_datatype_enum,
337337
dtypes.float32.as_datatype_enum,
338338
])
339-
Helper.set_attr_type_list(new_node, 'out_types', [
339+
Helper.set_attr_type_list(new_node, 'Thost_outputs', [
340340
requantize_node.attr['out_type'].type,
341341
dtypes.float32.as_datatype_enum,
342342
dtypes.float32.as_datatype_enum ])
@@ -346,7 +346,7 @@ def do_transformation(self):
346346
dtype_map_dict[requantize_node.attr['out_type'].type])
347347
elif str(quantized_node.attr['fused_ops'].list.s) == str([b"BiasAdd"]):
348348
Helper.set_attr_string_list(new_node, 'fused_ops', [b'BiasAdd', b'Requantize'])
349-
Helper.set_attr_type_list(new_node, 'input_types', [
349+
Helper.set_attr_type_list(new_node, 'Thost_inputs', [
350350
input_data_type.as_datatype_enum,
351351
dtypes.qint8.as_datatype_enum,
352352
dtypes.float32.as_datatype_enum if new_node.attr["Tbias"].type == dtypes.float32 else \
@@ -358,7 +358,7 @@ def do_transformation(self):
358358
dtypes.float32.as_datatype_enum,
359359
dtypes.float32.as_datatype_enum,
360360
])
361-
Helper.set_attr_type_list(new_node, 'out_types', [
361+
Helper.set_attr_type_list(new_node, 'Thost_outputs', [
362362
requantize_node.attr['out_type'].type,
363363
dtypes.float32.as_datatype_enum,
364364
dtypes.float32.as_datatype_enum ])
@@ -368,7 +368,7 @@ def do_transformation(self):
368368
dtype_map_dict[requantize_node.attr['out_type'].type])
369369
elif str(quantized_node.attr['fused_ops'].list.s) == str([b"BiasAdd", b"Relu"]):
370370
Helper.set_attr_string_list(new_node, 'fused_ops', [b'BiasAdd', b'Relu', b'Requantize'])
371-
Helper.set_attr_type_list(new_node, 'input_types', [
371+
Helper.set_attr_type_list(new_node, 'Thost_inputs', [
372372
input_data_type.as_datatype_enum,
373373
dtypes.qint8.as_datatype_enum,
374374
dtypes.float32.as_datatype_enum if new_node.attr["Tbias"].type == dtypes.float32 else \
@@ -380,7 +380,7 @@ def do_transformation(self):
380380
dtypes.float32.as_datatype_enum,
381381
dtypes.float32.as_datatype_enum,
382382
])
383-
Helper.set_attr_type_list(new_node, 'out_types', [
383+
Helper.set_attr_type_list(new_node, 'Thost_outputs', [
384384
requantize_node.attr['out_type'].type,
385385
dtypes.float32.as_datatype_enum,
386386
dtypes.float32.as_datatype_enum ])
@@ -497,7 +497,7 @@ def do_transformation(self):
497497
new_node.ClearField('input')
498498
new_node.input.extend(new_input)
499499
input_data_type = dtypes.qint8 if new_node.attr["Tinput"].type == dtypes.qint8 else dtypes.quint8
500-
Helper.set_attr_type_list(new_node, 'input_types', [
500+
Helper.set_attr_type_list(new_node, 'Thost_inputs', [
501501
input_data_type.as_datatype_enum,
502502
dtypes.qint8.as_datatype_enum,
503503
dtypes.float32.as_datatype_enum,
@@ -515,23 +515,23 @@ def do_transformation(self):
515515
else dtypes.qint8)
516516
if str(quantized_node.attr['fused_ops'].list.s) == str([b'BiasAdd', b'Sum', b'Relu']):
517517
self.fused_ops = [b'BiasAdd', b'Sum', b'Relu', b'Requantize']
518-
Helper.set_attr_type_list(new_node, 'out_types', [
518+
Helper.set_attr_type_list(new_node, 'Thost_outputs', [
519519
requantize_node.attr['out_type'].type,
520520
dtypes.float32.as_datatype_enum,
521521
dtypes.float32.as_datatype_enum ])
522522
Helper.set_attr_dtype(new_node, "out_type", \
523523
dtype_map_dict[requantize_node.attr['out_type'].type])
524524
elif str(quantized_node.attr['fused_ops'].list.s) == str([b'BiasAdd', b'Relu', b'Sum']):
525525
self.fused_ops = [b'BiasAdd', b'Relu', b'Sum', b'Requantize']
526-
Helper.set_attr_type_list(new_node, 'out_types', [
526+
Helper.set_attr_type_list(new_node, 'Thost_outputs', [
527527
requantize_node.attr['out_type'].type,
528528
dtypes.float32.as_datatype_enum,
529529
dtypes.float32.as_datatype_enum ])
530530
Helper.set_attr_dtype(new_node, "out_type", \
531531
dtype_map_dict[requantize_node.attr['out_type'].type])
532532
elif str(quantized_node.attr['fused_ops'].list.s) == str([b'BiasAdd', b'LeakyRelu', b'Sum']):
533533
self.fused_ops = [b'BiasAdd', b'LeakyRelu', b'Sum', b'Requantize']
534-
Helper.set_attr_type_list(new_node, 'out_types', [
534+
Helper.set_attr_type_list(new_node, 'Thost_outputs', [
535535
requantize_node.attr['out_type'].type,
536536
dtypes.float32.as_datatype_enum,
537537
dtypes.float32.as_datatype_enum ])
@@ -543,7 +543,7 @@ def do_transformation(self):
543543

544544
elif str(quantized_node.attr['fused_ops'].list.s) == str([b'BiasAdd', b'Sum']):
545545
self.fused_ops = [b'BiasAdd', b'Sum', b'Requantize']
546-
Helper.set_attr_type_list(new_node, 'out_types', [
546+
Helper.set_attr_type_list(new_node, 'Thost_outputs', [
547547
requantize_node.attr['out_type'].type,
548548
dtypes.float32.as_datatype_enum,
549549
dtypes.float32.as_datatype_enum ])

0 commit comments

Comments
 (0)