Skip to content

Commit cca21bb

Browse files
authored
Enable transpose_b=true in new matmul API (#1129)
1 parent 61c72eb commit cca21bb

File tree

5 files changed

+24
-22
lines changed

5 files changed

+24
-22
lines changed

neural_compressor/adaptor/tf_utils/graph_rewriter/int8/fuse_matmul_requantize.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,6 @@ def __init__(self, model, device='cpu'):
276276

277277
def do_transformation(self):
278278
fuse_pattern = [["_QuantizedMatMul"], ['Requantize'], ['Dequantize'], ('Softmax',)]
279-
280279
target_nodes = self.graph_analyzer.query_fusion_pattern_nodes(fuse_pattern)
281280
for i in target_nodes:
282281
quantized_node_name = i[0]
@@ -301,6 +300,10 @@ def do_transformation(self):
301300
new_node.attr["T1"].CopyFrom(quantized_node.attr['T1'])
302301
if 'T2' in quantized_node.attr:
303302
new_node.attr["T2"].CopyFrom(quantized_node.attr['T2'])
303+
if 'transpose_b' in quantized_node.attr:
304+
new_node.attr["transpose_b"].CopyFrom(quantized_node.attr['transpose_b'])
305+
if 'transpose_a' in quantized_node.attr:
306+
new_node.attr["transpose_a"].CopyFrom(quantized_node.attr['transpose_a'])
304307
if 'Tbias' in quantized_node.attr:
305308
new_node.attr["Tbias"].CopyFrom(quantized_node.attr['Tbias'])
306309
if 'fused_ops' in quantized_node.attr:
@@ -358,7 +361,6 @@ def do_transformation(self):
358361
[["_QuantizedMatMul"], ['Requantize']])
359362
if len(target_nodes) == 0:
360363
break
361-
362364
i = target_nodes[0]
363365
quantized_node_name = i[0]
364366
quantized_node = self.graph_info[quantized_node_name].node
@@ -381,6 +383,10 @@ def do_transformation(self):
381383
new_node.input.append(requested_output_min_name)
382384
new_node.input.append(requested_output_max_name)
383385

386+
if 'transpose_b' in quantized_node.attr:
387+
new_node.attr["transpose_b"].CopyFrom(quantized_node.attr['transpose_b'])
388+
if 'transpose_a' in quantized_node.attr:
389+
new_node.attr["transpose_a"].CopyFrom(quantized_node.attr['transpose_a'])
384390
if 'T1' in quantized_node.attr:
385391
new_node.attr["T1"].CopyFrom(quantized_node.attr['T1'])
386392
if 'T2' in quantized_node.attr:

neural_compressor/adaptor/tf_utils/graph_rewriter/qdq/insert_qdq_pattern.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,10 +423,9 @@ def _ignore_insert_qdq_pattern(self, matched_node_name):
423423
return True
424424

425425
#TODO Remove below two lines once the TF enabled the QuantizedMatMul while
426-
# transpose_a/transpose_a could be set to True.
426+
# transpose_a could be set to True.
427427
if self.graph_info[matched_node_name].node.op == "MatMul":
428-
if self.graph_info[matched_node_name].node.attr["transpose_a"].b == True or \
429-
self.graph_info[matched_node_name].node.attr["transpose_b"].b == True:
428+
if self.graph_info[matched_node_name].node.attr["transpose_a"].b == True:
430429
return True
431430

432431
return False

neural_compressor/adaptor/tf_utils/quantize_graph/qdq/fuse_qdq_matmul.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -187,17 +187,15 @@ def apply_matmul_biasadd_fusion(self, match_node_name):
187187
weights_max_name = weights_name[2]
188188

189189
weight_node = self.node_name_mapping[helper.node_name_from_input(weights_name[0])].node
190-
191190
# FIXME We only quantize the MatMul op which second input node type is const. This is a
192191
# workaround for RNN model like LTSM.
193192
if weight_node.op != 'Const':
194193
self.output_graph = self.input_graph
195194
return []
196195

197196
#TODO Remove below two lines once the TF enabled the QuantizedMatMul while
198-
# transpose_a/transpose_a could be set to True.
199-
if matched_node.node.attr["transpose_a"].b == True or \
200-
matched_node.node.attr["transpose_b"].b == True:
197+
# transpose_a could be set to True.
198+
if matched_node.node.attr["transpose_a"].b == True:
201199
self.output_graph = self.input_graph
202200
return []
203201

@@ -582,9 +580,8 @@ def _is_match_matmul(self, patterns, qdq_inserted=False):
582580

583581
if cur_node.op == "MatMul":
584582
#TODO Remove below two lines once the TF enabled the QuantizedMatMul while
585-
# transpose_a/transpose_a could be set to True.
586-
if cur_node.attr["transpose_a"].b == True or \
587-
cur_node.attr["transpose_b"].b == True:
583+
# transpose_a could be set to True.
584+
if cur_node.attr["transpose_a"].b == True:
588585
continue
589586

590587
weights_content = tensor_util.MakeNdarray(weight_node.attr['value'].tensor)

neural_compressor/adaptor/tf_utils/quantize_graph/quantize_graph_matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,10 @@ def apply_matmul_biasadd_fusion(self, match_node_name):
150150
self.output_graph = self.input_graph
151151
return []
152152

153-
#TODO Remove below two lines once the TF enabled the QuantizedMatMul while
154-
# transpose_a/transpose_a could be set to True.
153+
#TODO Remove below two lines once the TF enabled the old QuantizedMatMul while
154+
# transpose_a/transpose_b could be set to True.
155155
if matched_node.node.attr["transpose_a"].b == True or \
156-
matched_node.node.attr["transpose_b"].b == True:
156+
matched_node.node.attr["transpose_b"].b == True:
157157
self.output_graph = self.input_graph
158158
return []
159159

test/tfnewapi/test_tensorflow_graph_qdq_matmul_fusion.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,15 @@ def test_matmul_biasadd_requantize_dequantize_last_fusion(self):
176176
self.assertEqual(found_quantized_matmul, True)
177177

178178
@disable_random()
179-
def test_disable_matmul_fusion(self):
179+
def test_matmul_fusion_with_transpose_b_true(self):
180180
g = tf.Graph()
181181
with g.as_default():
182182

183183
x_data = np.array([[0.1, 0.2], [0.2, 0.3]])
184184
y_data = np.array([[1, 2], [3, 4]], dtype=np.float)
185185
x = tf.placeholder(tf.float32, shape=[2, 2], name='x')
186186
y = tf.constant(y_data, dtype=tf.float32, shape=[2, 2])
187-
z = tf.matmul(x, y, name='no_quant_matmul')
187+
z = tf.matmul(x, y, name='no_quant_matmul', transpose_b=True)
188188
z = tf.nn.relu6(z, name='op_to_store')
189189
found_quantized_matmul = False
190190

@@ -201,21 +201,21 @@ def test_disable_matmul_fusion(self):
201201
output_graph = quantizer.fit()
202202

203203
for i in output_graph.graph_def.node:
204-
if i.op == '_QuantizedMatMul' and i.name == 'op_to_store':
204+
if i.op == '_QuantizedMatMul':
205205
found_quantized_matmul = True
206206
break
207-
self.assertEqual(found_quantized_matmul, False)
208-
207+
self.assertEqual(found_quantized_matmul, True)
208+
209209
@disable_random()
210-
def test_disable_matmul_fusion_with_transpose_b_true(self):
210+
def test_disable_matmul_fusion(self):
211211
g = tf.Graph()
212212
with g.as_default():
213213

214214
x_data = np.array([[0.1, 0.2], [0.2, 0.3]])
215215
y_data = np.array([[1, 2], [3, 4]], dtype=np.float)
216216
x = tf.placeholder(tf.float32, shape=[2, 2], name='x')
217217
y = tf.constant(y_data, dtype=tf.float32, shape=[2, 2])
218-
z = tf.matmul(x, y, name='no_quant_matmul', transpose_b=True)
218+
z = tf.matmul(x, y, name='no_quant_matmul')
219219
z = tf.nn.relu6(z, name='op_to_store')
220220
found_quantized_matmul = False
221221

0 commit comments

Comments
 (0)