Skip to content

Commit cd0add0

Browse files
authored
Add UT for special concatv2 input case in TF newAPI (#1299)
1 parent a520746 commit cd0add0

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

test/tfnewapi/test_tensorflow_graph_qdq_concat_fusion.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,5 +194,43 @@ def test_concat_with_same_input_type(self):
194194
quantized_concat = True
195195
self.assertEqual(quantized_concat, True)
196196

197+
@disable_random()
198+
def test_concat_with_qint8_and_fp32_input_type(self):
199+
x = tf.compat.v1.placeholder(
200+
tf.float32, [1, 128, 128, 16], name="input")
201+
bias = tf.compat.v1.get_variable("bias", [16],
202+
initializer=tf.compat.v1.random_normal_initializer())
203+
204+
bias_add = tf.nn.bias_add(x, bias)
205+
206+
pool = tf.nn.avg_pool(x, ksize=1, strides=[1, 1, 1, 1], name='avgpool', padding="SAME")
207+
concat = tf.concat([bias_add, pool], 1)
208+
final_node = tf.nn.relu(concat , name='op_to_store')
209+
out_name = final_node.name.split(':')[0]
210+
with tf.compat.v1.Session() as sess:
211+
sess.run(tf.compat.v1.global_variables_initializer())
212+
output_graph_def = graph_util.convert_variables_to_constants(
213+
sess=sess,
214+
input_graph_def=sess.graph_def,
215+
output_node_names=[out_name])
216+
from neural_compressor.experimental import Quantization, common
217+
218+
quantizer = Quantization('fake_yaml.yaml')
219+
dataset = quantizer.dataset(
220+
'dummy', shape=(100, 128, 128, 16), label=True)
221+
quantizer.calib_dataloader = common.DataLoader(dataset)
222+
quantizer.eval_dataloader = common.DataLoader(dataset)
223+
quantizer.model = output_graph_def
224+
output_graph = quantizer.fit()
225+
dtype = None
226+
quantized_concat = False
227+
from tensorflow.python.framework import dtypes
228+
for i in output_graph.graph_def.node:
229+
if i.op == 'QuantizedConcatV2':
230+
dtype = dtypes.DType(i.attr['T'].type)
231+
quantized_concat = True
232+
self.assertEqual(quantized_concat, True)
233+
self.assertEqual(dtype, dtypes.qint8)
234+
197235
if __name__ == "__main__":
198236
unittest.main()

0 commit comments

Comments
 (0)