Skip to content

Commit 9d589f3

Browse files
Fix graph_optimization bf16 convert issue (#1306)
1 parent 11f3dbf commit 9d589f3

File tree

4 files changed

+12
-12
lines changed

4 files changed

+12
-12
lines changed

neural_compressor/experimental/graph_optimization.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,14 @@ def __call__(self):
126126
self._model.output_tensor_names = cfg.model.outputs
127127
self._model.input_tensor_names = cfg.model.inputs
128128
self._model.workspace_path = cfg.tuning.workspace.path
129-
else:
130-
logger.warning("Only TensorFlow graph optimization is supported at current stage.")
131-
sys.exit(0)
129+
130+
if 'bf16' in self._precisions or \
131+
(cfg.mixed_precision and 'bf16' in cfg.mixed_precision.precisions) or \
132+
(cfg.graph_optimization and 'bf16' in cfg.graph_optimization.precisions):
133+
os.environ['MIX_PRECISION_TEST'] = '1'
134+
else:
135+
logger.warning("Only TensorFlow graph optimization is supported at current stage.")
136+
sys.exit(0)
132137

133138
# when eval_func is set, will be directly used and eval_dataloader can be None
134139
if self._eval_func is None:

neural_compressor/experimental/mixed_precision.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ def __call__(self):
131131
self._model.input_tensor_names = cfg.model.inputs if \
132132
not self._input else self._input
133133
self._model.workspace_path = cfg.tuning.workspace.path
134+
if 'bf16' in self._precisions or \
135+
(cfg.mixed_precision and 'bf16' in cfg.mixed_precision.precisions) or \
136+
(cfg.graph_optimization and 'bf16' in cfg.graph_optimization.precisions):
137+
os.environ['MIX_PRECISION_TEST'] = '1'
134138

135139
# when eval_func is set, will be directly used and eval_dataloader can be None
136140
if self._eval_func is None:

test/graph_optimization/test_graph_optimization.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,10 @@ class TestGraphOptimizationOnNonBF16Host(unittest.TestCase):
177177
@classmethod
178178
def setUpClass(self):
179179
build_fake_yaml()
180-
if CpuInfo().bf16:
181-
os.environ['MIX_PRECISION_TEST'] = '1'
182180

183181
@classmethod
184182
def tearDownClass(self):
185183
os.remove('fake_yaml.yaml')
186-
if CpuInfo().bf16:
187-
del os.environ['MIX_PRECISION_TEST']
188184

189185
@disable_random()
190186
@unittest.skipIf(tf.__version__ < "2.0", "does not support on 1.15up3")
@@ -239,7 +235,6 @@ class TestGraphOptimization(unittest.TestCase):
239235
@classmethod
240236
def setUpClass(self):
241237
os.environ['FORCE_BF16'] = '1'
242-
os.environ['MIX_PRECISION_TEST'] = '1'
243238
build_fake_yaml()
244239
build_fake_yaml_2()
245240
build_fake_yaml_3()
@@ -250,7 +245,6 @@ def setUpClass(self):
250245
@classmethod
251246
def tearDownClass(self):
252247
del os.environ['FORCE_BF16']
253-
del os.environ['MIX_PRECISION_TEST']
254248
os.remove('fake_yaml.yaml')
255249
os.remove('fake_yaml_2.yaml')
256250
os.remove('fake_yaml_3.yaml')
@@ -968,7 +962,6 @@ def test_graph_optimization_without_yaml_with_precisions(self):
968962
@disable_random()
969963
def test_graph_optimization_fp32_only_with_force_bf16(self):
970964
os.environ['FORCE_BF16'] = '1'
971-
os.environ['MIX_PRECISION_TEST'] = '1'
972965
x = tf.compat.v1.placeholder(tf.float32, [1, 56, 56, 16], name="input")
973966
top_relu = tf.nn.relu(x)
974967
paddings = tf.constant([[0, 0], [1, 1], [1, 1], [0, 0]])

test/mixed_precision/test_mixed_precision.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,6 @@ class TestMixedPrecision(unittest.TestCase):
293293
def setUpClass(self):
294294
os.environ['FORCE_FP16'] = '1'
295295
os.environ['FORCE_BF16'] = '1'
296-
os.environ['MIX_PRECISION_TEST'] = '1'
297296
self.onnx_model = build_matmul_model()
298297
self.matmul_dataset = MatmulDataset()
299298
self.tf_model = build_tf_graph()
@@ -304,7 +303,6 @@ def setUpClass(self):
304303
def tearDownClass(self):
305304
del os.environ['FORCE_FP16']
306305
del os.environ['FORCE_BF16']
307-
del os.environ['MIX_PRECISION_TEST']
308306
shutil.rmtree("./saved", ignore_errors=True)
309307
shutil.rmtree("./nc_workspace", ignore_errors=True)
310308
os.remove("test.yaml")

0 commit comments

Comments
 (0)