Skip to content

Commit 70a85d4

Browse files
TFNewAPI support Quantized Matmul, BatchMatmul (#1116)
1 parent 4f04f0c commit 70a85d4

File tree

9 files changed

+358
-152
lines changed

9 files changed

+358
-152
lines changed

neural_compressor/adaptor/inteltensorflow.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,11 @@
279279
'Dequantize + MatMul + BiasAdd + Relu + QuantizeV2',
280280
'Dequantize + MatMul + BiasAdd + QuantizeV2',
281281
'Dequantize + MatMul + Relu + QuantizeV2',
282-
'Dequantize + BatchMatMulV2',
283-
'Dequantize + BatchMatMulV2 + Mul + Add',
282+
'Dequantize + BatchMatMulV2 + Mul + QuantizeV2',
283+
'Dequantize + BatchMatMulV2 + Add + QuantizeV2',
284+
'Dequantize + BatchMatMulV2 + AddV2 + QuantizeV2',
285+
'Dequantize + BatchMatMulV2 + Mul + Add + QuantizeV2',
286+
'Dequantize + BatchMatMulV2 + Mul + AddV2 + QuantizeV2',
284287
'Dequantize + Conv3D + AddV2 + AddV2 + Relu + QuantizeV2',
285288
'Dequantize + Conv3D + Add + Relu + QuantizeV2',
286289
'Dequantize + Conv3D + AddV2 + Relu + QuantizeV2',

neural_compressor/adaptor/tensorflow.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,8 @@ def _dump_model_op_stats(self, model_graphdef):
599599
int8_op_prefix_list = ['QuantizedConv2D', '_QuantizedConv3D', 'QuantizedDepthwise',
600600
'QuantizedMaxPool', 'QuantizedAvgPool',
601601
'QuantizedConcatV2', 'QuantizedMatMul',
602-
'_QuantizedFusedBatchNorm']
602+
'_QuantizedFusedBatchNorm', '_QuantizedMatMul',
603+
'_QuantizedBatchMatMul']
603604
from tensorflow.python.framework import dtypes
604605

605606
res = {}
@@ -620,6 +621,8 @@ def _dump_model_op_stats(self, model_graphdef):
620621
origin_op_type = 'FusedBatchNormV3'
621622
if origin_op_type == 'Depthwise':
622623
origin_op_type = 'DepthwiseConv2dNative'
624+
if origin_op_type == 'BatchMatMul':
625+
origin_op_type = 'BatchMatMulV2'
623626
res[origin_op_type]['INT8'] += 1
624627

625628
if i.op in fp32_op_list:

neural_compressor/adaptor/tf_utils/graph_converter.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
from .graph_rewriter.int8.fuse_conv_requantize import FuseConvRequantizeTransformer
5757
from .graph_rewriter.int8.fuse_matmul_requantize import FuseMatMulRequantizeTransformer
5858
from .graph_rewriter.int8.fuse_matmul_requantize import FuseMatMulRequantizeDequantizeTransformer
59+
from .graph_rewriter.int8.fuse_matmul_requantize import FuseMatMulRequantizeNewAPITransformer
60+
from .graph_rewriter.int8.fuse_matmul_requantize import FuseMatMulRequantizeDequantizeNewAPITransformer
5961
from .graph_rewriter.int8.scale_propagation import ScaleProPagationTransformer
6062
from .graph_rewriter.bf16.bf16_convert import BF16Convert
6163
from .graph_rewriter.int8.post_quantized_op_cse import PostCseOptimizer
@@ -547,7 +549,7 @@ def _freeze_requantization_ranges(self, additional_data=None):
547549
self.scale_info.update(requant_min_max)
548550

549551
self._tmp_graph_def = QuantizedRNNConverter(
550-
self._tmp_graph_def, self._calibration_data, self._rnn_details).do_transformation()
552+
self._tmp_graph_def, self._calibration_data, self._rnn_details, self.new_api).do_transformation()
551553

552554
if 'scale_propagation_max_pooling' in self.recipes and \
553555
self.recipes['scale_propagation_max_pooling']:
@@ -570,18 +572,18 @@ def _fuse_requantize_with_fused_quantized_node(self):
570572

571573
if not self.fake_quant:
572574
# TODO Use MatMul and BatchMatMul new API
573-
#if self.qdq_enabled:
574-
# self._tmp_graph_def = FuseMatMulRequantizeNewAPITransformer(
575-
# self._tmp_graph_def).do_transformation()
576-
#
577-
# self._tmp_graph_def = FuseMatMulRequantizeDequantizeNewAPITransformer(
578-
# self._tmp_graph_def).do_transformation()
579-
#else:
580-
self._tmp_graph_def = FuseMatMulRequantizeTransformer(
581-
self._tmp_graph_def).do_transformation()
575+
if self.qdq_enabled:
576+
self._tmp_graph_def = FuseMatMulRequantizeNewAPITransformer(
577+
self._tmp_graph_def).do_transformation()
578+
579+
self._tmp_graph_def = FuseMatMulRequantizeDequantizeNewAPITransformer(
580+
self._tmp_graph_def).do_transformation()
581+
else:
582+
self._tmp_graph_def = FuseMatMulRequantizeTransformer(
583+
self._tmp_graph_def).do_transformation()
582584

583-
self._tmp_graph_def = FuseMatMulRequantizeDequantizeTransformer(
584-
self._tmp_graph_def).do_transformation()
585+
self._tmp_graph_def = FuseMatMulRequantizeDequantizeTransformer(
586+
self._tmp_graph_def).do_transformation()
585587

586588
self._tmp_graph_def = StripUnusedNodesOptimizer(
587589
self._tmp_graph_def,

neural_compressor/adaptor/tf_utils/graph_rewriter/bf16/bf16_convert.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,10 @@ def _bf16_convert(self, bf16_node_name):
129129
return
130130
else:
131131
self.converted_ops.append(bf16_node.name)
132-
132+
133133
inputs_dt, outputs_dt = self._dtype(bf16_node)
134134
inputs_dt_val, outputs_dt_val = self._dtype_val(bf16_node)
135135
allowed_dt_val = self._allowed_dtype_val(bf16_node)
136-
137136
for index, input_name in enumerate(bf16_node.input):
138137
if input_name.startswith('^'):
139138
continue
@@ -142,7 +141,6 @@ def _bf16_convert(self, bf16_node_name):
142141
input_name)]
143142
input_node = input_detail.node
144143
input_node_outputs = input_detail.outputs
145-
146144
if inputs_dt[index] in allowed_dt_val and \
147145
dtypes.bfloat16.as_datatype_enum not in allowed_dt_val[inputs_dt[index]]:
148146
continue
@@ -239,6 +237,10 @@ def _model_bf16_convert(self):
239237
if bf16_node_name not in self.cur_graph.node_name_details:
240238
self.bf16_ops.remove(bf16_node_name)
241239
continue
240+
else:
241+
if "fused_ops" in self.cur_graph.node_name_details[bf16_node_name].node.attr:
242+
self.bf16_ops.remove(bf16_node_name)
243+
continue
242244
for bf16_node_name in set(self.bf16_ops):
243245
self._bf16_convert(bf16_node_name)
244246
return self.cur_graph.dump_graph()

neural_compressor/adaptor/tf_utils/graph_rewriter/int8/fuse_matmul_requantize.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,8 @@ def do_transformation(self):
261261

262262
return self.graph_analyzer.dump_graph()
263263

264-
class FuseMatMulRequantizeDequantizeNewAPITransformer(GraphRewriterBase): # pragma: no cover
265-
"""Fuse _QuantizedFusedMatMul + Requantize + Dequantize into _QuantizedFusedMatMulAndDequantize.
264+
class FuseMatMulRequantizeDequantizeNewAPITransformer(GraphRewriterBase):
265+
"""Fuse _QuantizedMatMul + Requantize + Dequantize into _QuantizedMatMul.
266266
"""
267267
def __init__(self, model, device='cpu'):
268268
super().__init__(model)
@@ -275,20 +275,13 @@ def __init__(self, model, device='cpu'):
275275
self.eps = 1e-5
276276

277277
def do_transformation(self):
278-
fuse_pattern = [["_QuantizedFusedMatMul"], ['Requantize'], ['Dequantize'], ('Softmax',)]
278+
fuse_pattern = [["_QuantizedMatMul"], ['Requantize'], ['Dequantize'], ('Softmax',)]
279279

280280
target_nodes = self.graph_analyzer.query_fusion_pattern_nodes(fuse_pattern)
281281
for i in target_nodes:
282-
# TODO Remove below checker once the TF's limitation removed.
283-
if len(i) == 5:
284-
continue
285-
286282
quantized_node_name = i[0]
287283
quantized_node = self.graph_info[quantized_node_name].node
288284
requantize_node_name = i[1]
289-
requantize_node = self.graph_info[requantize_node_name].node
290-
requested_output_min_name = requantize_node.input[3]
291-
requested_output_max_name = requantize_node.input[4]
292285
deq_node_name = i[2]
293286

294287
quantized_node_op = i[-1][0]
@@ -299,26 +292,30 @@ def do_transformation(self):
299292

300293
new_node = node_def_pb2.NodeDef()
301294

302-
new_node.op = quantized_node_op + "AndDequantize"
295+
new_node.op = quantized_node_op
303296
new_node.name = requantize_node_name
304297
for _, value in enumerate(quantized_node.input):
305298
new_node.input.append(value)
306299

307-
#new_node.input.append(requested_output_min_name)
308-
#new_node.input.append(requested_output_max_name)
309300
if 'T1' in quantized_node.attr:
310301
new_node.attr["T1"].CopyFrom(quantized_node.attr['T1'])
311302
if 'T2' in quantized_node.attr:
312303
new_node.attr["T2"].CopyFrom(quantized_node.attr['T2'])
313-
if 'num_args' in quantized_node.attr:
314-
new_node.attr["num_args"].CopyFrom(quantized_node.attr['num_args'])
304+
if 'Tbias' in quantized_node.attr:
305+
new_node.attr["Tbias"].CopyFrom(quantized_node.attr['Tbias'])
315306
if 'fused_ops' in quantized_node.attr:
316307
new_node.attr["fused_ops"].CopyFrom(quantized_node.attr["fused_ops"])
317-
308+
if 'input_quant_mode' in quantized_node.attr:
309+
new_node.attr["input_quant_mode"].CopyFrom(quantized_node.attr["input_quant_mode"])
310+
if 'output_quant_mode' in quantized_node.attr:
311+
new_node.attr["output_quant_mode"].CopyFrom(quantized_node.attr["output_quant_mode"])
312+
if 'Thost_inputs' in quantized_node.attr:
313+
new_node.attr["Thost_inputs"].CopyFrom(quantized_node.attr["Thost_inputs"])
314+
Helper.set_attr_type_list(new_node, 'Thost_outputs', [dtypes.float32.as_datatype_enum])
315+
Helper.set_attr_string_list(new_node, 'fused_ops', [b'BiasAdd', b'Dequantize'])
318316
top_node_name = Helper.node_name_from_input(quantized_node.input[0])
319317
float32_type = dtypes.float32.as_datatype_enum
320-
new_node.attr["Targs"].CopyFrom(attr_value_pb2.AttrValue(type=float32_type))
321-
new_node.attr["Toutput"].CopyFrom(attr_value_pb2.AttrValue(type=float32_type))
318+
new_node.attr["Tout"].CopyFrom(attr_value_pb2.AttrValue(type=float32_type))
322319

323320
self.graph_analyzer.remove_node(requantize_node_name)
324321

@@ -338,7 +335,7 @@ def do_transformation(self):
338335

339336
return self.graph_analyzer.dump_graph()
340337

341-
class FuseMatMulRequantizeNewAPITransformer(GraphRewriterBase): # pragma: no cover
338+
class FuseMatMulRequantizeNewAPITransformer(GraphRewriterBase):
342339
"""Fuse newAPI Quantized MatMul Op with the successor Requantize Op.
343340
"""
344341
def __init__(self, model, device='cpu'):
@@ -358,7 +355,7 @@ def do_transformation(self):
358355

359356
while True:
360357
target_nodes = self.graph_analyzer.query_fusion_pattern_nodes(
361-
[["_QuantizedFusedMatMul"], ['Requantize']])
358+
[["_QuantizedMatMul"], ['Requantize']])
362359
if len(target_nodes) == 0:
363360
break
364361

@@ -377,23 +374,41 @@ def do_transformation(self):
377374

378375
new_node = node_def_pb2.NodeDef()
379376

380-
new_node.op = quantized_node_op + "AndRequantize"
377+
new_node.op = quantized_node_op
381378
new_node.name = requantize_node_name
382379
for _, value in enumerate(quantized_node.input):
383380
new_node.input.append(value)
384381
new_node.input.append(requested_output_min_name)
385382
new_node.input.append(requested_output_max_name)
383+
386384
if 'T1' in quantized_node.attr:
387385
new_node.attr["T1"].CopyFrom(quantized_node.attr['T1'])
388386
if 'T2' in quantized_node.attr:
389387
new_node.attr["T2"].CopyFrom(quantized_node.attr['T2'])
390-
if 'num_args' in quantized_node.attr:
391-
new_node.attr["num_args"].CopyFrom(quantized_node.attr["num_args"])
392-
if 'Targs' in quantized_node.attr:
393-
new_node.attr["Targs"].CopyFrom(quantized_node.attr["Targs"])
394-
if 'fused_ops' in quantized_node.attr:
395-
new_node.attr["fused_ops"].CopyFrom(quantized_node.attr["fused_ops"])
396-
new_node.attr["Toutput"].CopyFrom(attr_value_pb2.AttrValue(type=uint8_type))
388+
if 'Tbias' in quantized_node.attr:
389+
new_node.attr["Tbias"].CopyFrom(quantized_node.attr["Targs"])
390+
if 'U' in quantized_node.attr:
391+
new_node.attr["U"].CopyFrom(quantized_node.attr["U"])
392+
if 'input_quant_mode' in quantized_node.attr:
393+
new_node.attr["input_quant_mode"].CopyFrom(quantized_node.attr["input_quant_mode"])
394+
if 'output_quant_mode' in quantized_node.attr:
395+
new_node.attr["output_quant_mode"].CopyFrom(quantized_node.attr["output_quant_mode"])
396+
Helper.set_attr_type_list(new_node, "Thost_inputs", [
397+
dtypes.quint8.as_datatype_enum,
398+
dtypes.qint8.as_datatype_enum,
399+
dtypes.float32.as_datatype_enum,
400+
dtypes.float32.as_datatype_enum,
401+
dtypes.float32.as_datatype_enum,
402+
dtypes.float32.as_datatype_enum,
403+
dtypes.float32.as_datatype_enum,
404+
dtypes.float32.as_datatype_enum,
405+
dtypes.float32.as_datatype_enum])
406+
Helper.set_attr_type_list(new_node, 'Thost_outputs', [
407+
dtypes.quint8.as_datatype_enum,
408+
dtypes.float32.as_datatype_enum,
409+
dtypes.float32.as_datatype_enum])
410+
Helper.set_attr_string_list(new_node, 'fused_ops', [b'BiasAdd', b'Relu', b'Requantize'])
411+
new_node.attr["Tout"].CopyFrom(attr_value_pb2.AttrValue(type=uint8_type))
397412

398413
parent_node_name = Helper.node_name_from_input(quantized_node.input[0])
399414
self.graph_analyzer.replace_single_node(

neural_compressor/adaptor/tf_utils/graph_rewriter/int8/rnn_convert.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@
3030

3131

3232
class QuantizedRNNConverter(GraphRewriterBase):
33-
def __init__(self, model, calibration_data, rnn_details):
33+
def __init__(self, model, calibration_data, rnn_details, new_api=False):
3434
super().__init__(model)
3535
self.calibration_data = calibration_data
3636
self.rnn_details = rnn_details
37+
self.new_api = new_api
3738

3839
@dump_elapsed_time("Pass QuantizedRNNConverter")
3940
def do_transformation(self):
@@ -207,22 +208,51 @@ def do_transformation(self):
207208

208209
quantized_matmul_input.append(enter_min_node.name)
209210
quantized_matmul_input.append(enter_max_node.name)
210-
quantized_matmul_with_bias_node = Helper.create_node(
211-
'QuantizedMatMulWithBias', i[0] + '_quantized_mat_mul', quantized_matmul_input)
211+
if self.new_api:
212+
quantized_matmul_with_bias_node = Helper.create_node(
213+
'_QuantizedMatMul', i[0] + '_quantized_mat_mul', quantized_matmul_input)
214+
else:
215+
quantized_matmul_with_bias_node = Helper.create_node(
216+
'QuantizedMatMulWithBias', i[0] + '_quantized_mat_mul', quantized_matmul_input)
212217
Helper.set_attr_dtype(
213218
quantized_matmul_with_bias_node, 'T1', dtypes.quint8)
214219
Helper.set_attr_dtype(
215220
quantized_matmul_with_bias_node, 'T2', dtypes.qint8)
216221
Helper.set_attr_dtype(
217222
quantized_matmul_with_bias_node, 'Tbias', dtypes.float32)
218-
Helper.set_attr_dtype(
219-
quantized_matmul_with_bias_node, 'Toutput', dtypes.qint32)
223+
if self.new_api:
224+
Helper.set_attr_dtype(
225+
quantized_matmul_with_bias_node, 'Tout', dtypes.qint32)
226+
else:
227+
Helper.set_attr_dtype(
228+
quantized_matmul_with_bias_node, 'Toutput', dtypes.qint32)
220229
Helper.set_attr_bool(
221230
quantized_matmul_with_bias_node, 'transpose_a', False)
222231
Helper.set_attr_bool(
223232
quantized_matmul_with_bias_node, 'transpose_b', False)
224-
Helper.set_attr_string(
225-
quantized_matmul_with_bias_node, 'input_quant_mode', b"MIN_FIRST")
233+
if self.new_api:
234+
Helper.set_attr_string(
235+
quantized_matmul_with_bias_node, 'input_quant_mode', b"SCALED")
236+
Helper.set_attr_string(
237+
quantized_matmul_with_bias_node, 'output_quant_mode', b"SCALED")
238+
Helper.set_attr_string_list(quantized_matmul_with_bias_node, 'fused_ops', [b'BiasAdd'])
239+
Helper.set_attr_type_list(quantized_matmul_with_bias_node, 'Thost_inputs', [
240+
dtypes.quint8.as_datatype_enum,
241+
dtypes.qint8.as_datatype_enum,
242+
dtypes.float32.as_datatype_enum,
243+
dtypes.float32.as_datatype_enum,
244+
dtypes.float32.as_datatype_enum,
245+
dtypes.float32.as_datatype_enum,
246+
dtypes.float32.as_datatype_enum
247+
])
248+
Helper.set_attr_type_list(quantized_matmul_with_bias_node, 'Thost_outputs', [
249+
dtypes.qint32.as_datatype_enum,
250+
dtypes.float32.as_datatype_enum,
251+
dtypes.float32.as_datatype_enum])
252+
else:
253+
Helper.set_attr_string(
254+
quantized_matmul_with_bias_node, 'input_quant_mode', b"MIN_FIRST")
255+
226256
g.add_node(quantized_matmul_with_bias_node,
227257
quantize_node.name, [bias_node.name])
228258

neural_compressor/adaptor/tf_utils/graph_rewriter/qdq/insert_qdq_pattern.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,10 @@ def _insert_qdq_pattern_for_common_ops(self, original_node, is_asymmetric):
157157
if each_input_name[0] == '^':
158158
continue
159159

160-
if self.node_name_mapping[original_node.name].op == "MatMul" or \
161-
self.node_name_mapping[original_node.name].op == "BatchMatMulV2":
160+
if self.node_name_mapping[original_node.name].op == "MatMul":
162161
dtype = dtypes.quint8
162+
elif self.node_name_mapping[original_node.name].op == "BatchMatMulV2":
163+
dtype = dtypes.qint8
163164
else:
164165
input_node_name = Helper.node_name_from_input(each_input_name)
165166
if input_node_name in self.graph_info:

0 commit comments

Comments
 (0)