Skip to content

Commit d2a9905

Browse files
authored
Refine config check & add TF framework check (#1106)
1 parent e213cc5 commit d2a9905

File tree

4 files changed

+23
-4
lines changed

4 files changed

+23
-4
lines changed

neural_compressor/conf/config.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,8 +1242,12 @@ def _read_cfg(self, cfg_fname):
12421242
with open(cfg_fname, 'w') as f:
12431243
f.write(content)
12441244

1245-
return validated_cfg
1246-
1245+
return validated_cfg
1246+
except FileNotFoundError as f:
1247+
logger.error("{}.".format(f))
1248+
raise RuntimeError(
1249+
"The yaml file is not exist. Please check the file name or path."
1250+
)
12471251
except Exception as e:
12481252
logger.error("{}.".format(e))
12491253
raise RuntimeError(

neural_compressor/experimental/component.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,15 @@ def _init_with_conf(self):
7979
if self.cfg.model.framework != 'NA':
8080
self.framework = self.cfg.model.framework.lower()
8181
set_backend(self.framework)
82+
if self.framework == 'tensorflow' or self.framework == 'inteltensorflow':
83+
try:
84+
import tensorflow as tf
85+
except Exception as e:
86+
logger.error("{}.".format(e))
87+
raise RuntimeError(
88+
"The TensorFlow framework is not correctly installed. Please check your environment"
89+
)
90+
8291

8392
def pre_process(self):
8493
""" Initialize the dataloader and train/eval functions from yaml config.

neural_compressor/model/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ def _is_mxnet(model):
179179
fwk_name = handler(model)
180180
if fwk_name != 'NA':
181181
break
182-
assert fwk_name != 'NA', 'Framework is not detected correctly from model format.'
182+
assert fwk_name != 'NA', 'Framework is not detected correctly from model format. This could be \
183+
caused by unsupported model or inappropriate framework installation.'
183184

184185
return fwk_name
185186

test/config/test_config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,12 @@ def test_data_type(self):
843843
transform_cfg = cfg['quantization']['calibration']['dataloader']['transform']['BilinearImagenet']
844844
self.assertTrue(isinstance(transform_cfg['mean_value'], list))
845845

846+
def test_yaml_detection(self):
847+
try:
848+
cfg = conf.Conf('not_exist.yaml').usr_cfg
849+
except:
850+
pass
851+
846852
def test_deep_set(self):
847853
from neural_compressor.conf.dotdict import DotDict, deep_set
848854
cfg = {'evaluation': {'accuracy': {}}}
@@ -860,7 +866,6 @@ def test_deep_set(self):
860866
multi_metrics2 = dot_cfg['evaluation']['accuracy']['multi_metrics']
861867
self.assertTrue(multi_metrics1 == multi_metrics2)
862868
self.assertTrue(list(multi_metrics1.keys()) == ['weight', 'mAP'])
863-
864869

865870
if __name__ == "__main__":
866871
unittest.main()

0 commit comments

Comments
 (0)