@@ -261,8 +261,8 @@ def do_transformation(self):
261
261
262
262
return self .graph_analyzer .dump_graph ()
263
263
264
- class FuseMatMulRequantizeDequantizeNewAPITransformer (GraphRewriterBase ): # pragma: no cover
265
- """Fuse _QuantizedFusedMatMul + Requantize + Dequantize into _QuantizedFusedMatMulAndDequantize .
264
+ class FuseMatMulRequantizeDequantizeNewAPITransformer (GraphRewriterBase ):
265
+ """Fuse _QuantizedMatMul + Requantize + Dequantize into _QuantizedMatMul .
266
266
"""
267
267
def __init__ (self , model , device = 'cpu' ):
268
268
super ().__init__ (model )
@@ -275,20 +275,13 @@ def __init__(self, model, device='cpu'):
275
275
self .eps = 1e-5
276
276
277
277
def do_transformation (self ):
278
- fuse_pattern = [["_QuantizedFusedMatMul " ], ['Requantize' ], ['Dequantize' ], ('Softmax' ,)]
278
+ fuse_pattern = [["_QuantizedMatMul " ], ['Requantize' ], ['Dequantize' ], ('Softmax' ,)]
279
279
280
280
target_nodes = self .graph_analyzer .query_fusion_pattern_nodes (fuse_pattern )
281
281
for i in target_nodes :
282
- # TODO Remove below checker once the TF's limitation removed.
283
- if len (i ) == 5 :
284
- continue
285
-
286
282
quantized_node_name = i [0 ]
287
283
quantized_node = self .graph_info [quantized_node_name ].node
288
284
requantize_node_name = i [1 ]
289
- requantize_node = self .graph_info [requantize_node_name ].node
290
- requested_output_min_name = requantize_node .input [3 ]
291
- requested_output_max_name = requantize_node .input [4 ]
292
285
deq_node_name = i [2 ]
293
286
294
287
quantized_node_op = i [- 1 ][0 ]
@@ -299,26 +292,30 @@ def do_transformation(self):
299
292
300
293
new_node = node_def_pb2 .NodeDef ()
301
294
302
- new_node .op = quantized_node_op + "AndDequantize"
295
+ new_node .op = quantized_node_op
303
296
new_node .name = requantize_node_name
304
297
for _ , value in enumerate (quantized_node .input ):
305
298
new_node .input .append (value )
306
299
307
- #new_node.input.append(requested_output_min_name)
308
- #new_node.input.append(requested_output_max_name)
309
300
if 'T1' in quantized_node .attr :
310
301
new_node .attr ["T1" ].CopyFrom (quantized_node .attr ['T1' ])
311
302
if 'T2' in quantized_node .attr :
312
303
new_node .attr ["T2" ].CopyFrom (quantized_node .attr ['T2' ])
313
- if 'num_args ' in quantized_node .attr :
314
- new_node .attr ["num_args " ].CopyFrom (quantized_node .attr ['num_args ' ])
304
+ if 'Tbias ' in quantized_node .attr :
305
+ new_node .attr ["Tbias " ].CopyFrom (quantized_node .attr ['Tbias ' ])
315
306
if 'fused_ops' in quantized_node .attr :
316
307
new_node .attr ["fused_ops" ].CopyFrom (quantized_node .attr ["fused_ops" ])
317
-
308
+ if 'input_quant_mode' in quantized_node .attr :
309
+ new_node .attr ["input_quant_mode" ].CopyFrom (quantized_node .attr ["input_quant_mode" ])
310
+ if 'output_quant_mode' in quantized_node .attr :
311
+ new_node .attr ["output_quant_mode" ].CopyFrom (quantized_node .attr ["output_quant_mode" ])
312
+ if 'Thost_inputs' in quantized_node .attr :
313
+ new_node .attr ["Thost_inputs" ].CopyFrom (quantized_node .attr ["Thost_inputs" ])
314
+ Helper .set_attr_type_list (new_node , 'Thost_outputs' , [dtypes .float32 .as_datatype_enum ])
315
+ Helper .set_attr_string_list (new_node , 'fused_ops' , [b'BiasAdd' , b'Dequantize' ])
318
316
top_node_name = Helper .node_name_from_input (quantized_node .input [0 ])
319
317
float32_type = dtypes .float32 .as_datatype_enum
320
- new_node .attr ["Targs" ].CopyFrom (attr_value_pb2 .AttrValue (type = float32_type ))
321
- new_node .attr ["Toutput" ].CopyFrom (attr_value_pb2 .AttrValue (type = float32_type ))
318
+ new_node .attr ["Tout" ].CopyFrom (attr_value_pb2 .AttrValue (type = float32_type ))
322
319
323
320
self .graph_analyzer .remove_node (requantize_node_name )
324
321
@@ -338,7 +335,7 @@ def do_transformation(self):
338
335
339
336
return self .graph_analyzer .dump_graph ()
340
337
341
- class FuseMatMulRequantizeNewAPITransformer (GraphRewriterBase ): # pragma: no cover
338
+ class FuseMatMulRequantizeNewAPITransformer (GraphRewriterBase ):
342
339
"""Fuse newAPI Quantized MatMul Op with the successor Requantize Op.
343
340
"""
344
341
def __init__ (self , model , device = 'cpu' ):
@@ -358,7 +355,7 @@ def do_transformation(self):
358
355
359
356
while True :
360
357
target_nodes = self .graph_analyzer .query_fusion_pattern_nodes (
361
- [["_QuantizedFusedMatMul " ], ['Requantize' ]])
358
+ [["_QuantizedMatMul " ], ['Requantize' ]])
362
359
if len (target_nodes ) == 0 :
363
360
break
364
361
@@ -377,23 +374,41 @@ def do_transformation(self):
377
374
378
375
new_node = node_def_pb2 .NodeDef ()
379
376
380
- new_node .op = quantized_node_op + "AndRequantize"
377
+ new_node .op = quantized_node_op
381
378
new_node .name = requantize_node_name
382
379
for _ , value in enumerate (quantized_node .input ):
383
380
new_node .input .append (value )
384
381
new_node .input .append (requested_output_min_name )
385
382
new_node .input .append (requested_output_max_name )
383
+
386
384
if 'T1' in quantized_node .attr :
387
385
new_node .attr ["T1" ].CopyFrom (quantized_node .attr ['T1' ])
388
386
if 'T2' in quantized_node .attr :
389
387
new_node .attr ["T2" ].CopyFrom (quantized_node .attr ['T2' ])
390
- if 'num_args' in quantized_node .attr :
391
- new_node .attr ["num_args" ].CopyFrom (quantized_node .attr ["num_args" ])
392
- if 'Targs' in quantized_node .attr :
393
- new_node .attr ["Targs" ].CopyFrom (quantized_node .attr ["Targs" ])
394
- if 'fused_ops' in quantized_node .attr :
395
- new_node .attr ["fused_ops" ].CopyFrom (quantized_node .attr ["fused_ops" ])
396
- new_node .attr ["Toutput" ].CopyFrom (attr_value_pb2 .AttrValue (type = uint8_type ))
388
+ if 'Tbias' in quantized_node .attr :
389
+ new_node .attr ["Tbias" ].CopyFrom (quantized_node .attr ["Targs" ])
390
+ if 'U' in quantized_node .attr :
391
+ new_node .attr ["U" ].CopyFrom (quantized_node .attr ["U" ])
392
+ if 'input_quant_mode' in quantized_node .attr :
393
+ new_node .attr ["input_quant_mode" ].CopyFrom (quantized_node .attr ["input_quant_mode" ])
394
+ if 'output_quant_mode' in quantized_node .attr :
395
+ new_node .attr ["output_quant_mode" ].CopyFrom (quantized_node .attr ["output_quant_mode" ])
396
+ Helper .set_attr_type_list (new_node , "Thost_inputs" , [
397
+ dtypes .quint8 .as_datatype_enum ,
398
+ dtypes .qint8 .as_datatype_enum ,
399
+ dtypes .float32 .as_datatype_enum ,
400
+ dtypes .float32 .as_datatype_enum ,
401
+ dtypes .float32 .as_datatype_enum ,
402
+ dtypes .float32 .as_datatype_enum ,
403
+ dtypes .float32 .as_datatype_enum ,
404
+ dtypes .float32 .as_datatype_enum ,
405
+ dtypes .float32 .as_datatype_enum ])
406
+ Helper .set_attr_type_list (new_node , 'Thost_outputs' , [
407
+ dtypes .quint8 .as_datatype_enum ,
408
+ dtypes .float32 .as_datatype_enum ,
409
+ dtypes .float32 .as_datatype_enum ])
410
+ Helper .set_attr_string_list (new_node , 'fused_ops' , [b'BiasAdd' , b'Relu' , b'Requantize' ])
411
+ new_node .attr ["Tout" ].CopyFrom (attr_value_pb2 .AttrValue (type = uint8_type ))
397
412
398
413
parent_node_name = Helper .node_name_from_input (quantized_node .input [0 ])
399
414
self .graph_analyzer .replace_single_node (
0 commit comments