Skip to content

Commit cd1f9e2

Browse files
Remove Dequantize Cast Optimizer MIN_FIRST limitation for spr-base (#940)
Signed-off-by: Lv, Liang1 <liang1.lv@intel.com>
1 parent af0aea4 commit cd1f9e2

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

neural_compressor/adaptor/tf_utils/graph_rewriter/bf16/dequantize_cast_optimizer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ def do_transformation(self):
5656

5757
if len(dq_outputs) > 1:
5858
continue
59-
if dq_node.attr["mode"].s == b"MIN_FIRST":
60-
continue
6159
cast_node = graph_info[i[1]].node
6260
cast_outputs = graph_info[i[1]].outputs
6361
all_cast_outputs_bf16 = True

test/tfnewapi/test_tensorflow_graph_dequantize_cast_optimizer_newapi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,12 @@ def test_dequantize_cast_min_first(self):
6969
graph_def = build_fake_graphdef(set_min_first=True)
7070
converted_graph_def = DequantizeCastOptimizer(graph_def).do_transformation()
7171
hasCast = False
72+
# Remove MIN_FIRST limitation for spr-base, so the "Cast" will be removed now
7273
for i in converted_graph_def.node:
7374
if i.op == "Cast":
7475
hasCast = True
7576
break
76-
self.assertEqual(hasCast, True)
77+
self.assertEqual(hasCast, False)
7778

7879
@disable_random()
7980
def test_dequantize_cast_multiple_outputs(self):

0 commit comments

Comments
 (0)