Skip to content

Commit 2482190

Browse files
authored
Fix deep set (#1101)
1 parent 57e5905 commit 2482190

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

examples/onnxrt/object_detection/onnx_model_zoo/DUC/quantization/ptq/DUC.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ evaluation:
3333
configs: # optional. if not specified, use all cores in 1 socket.
3434
cores_per_instance: 4
3535
num_of_instance: 1
36-
accuracy:
3736

3837
tuning:
3938
accuracy_criterion:

neural_compressor/conf/dotdict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def deep_set(dictionary, keys, value):
4747
"""
4848
keys = keys.split('.')
4949
for key in keys[:-1]:
50-
dictionary = dictionary.setdefault(key, {})
50+
dictionary = dictionary.setdefault(key, DotDict())
5151
dictionary[keys[-1]] = value
5252

5353
class DotDict(dict):

test/config/test_config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,24 @@ def test_data_type(self):
843843
transform_cfg = cfg['quantization']['calibration']['dataloader']['transform']['BilinearImagenet']
844844
self.assertTrue(isinstance(transform_cfg['mean_value'], list))
845845

846+
def test_deep_set(self):
847+
from neural_compressor.conf.dotdict import DotDict, deep_set
848+
cfg = {'evaluation': {'accuracy': {}}}
849+
dot_cfg = DotDict(cfg)
850+
deep_set(dot_cfg, 'evaluation.accuracy.metric', 'iou')
851+
deep_set(dot_cfg, 'evaluation.accuracy.multi_metrics.weight', [0.1, 0,9])
852+
deep_set(dot_cfg, 'evaluation.accuracy.multi_metrics.mAP.anno_path', 'anno_path_test')
853+
self.assertTrue(dot_cfg.evaluation == dot_cfg['evaluation'])
854+
self.assertTrue(dot_cfg.evaluation.accuracy == dot_cfg['evaluation']['accuracy'])
855+
self.assertTrue(dot_cfg.evaluation.accuracy.metric == dot_cfg['evaluation']['accuracy']['metric'])
856+
self.assertTrue(dot_cfg.evaluation.accuracy.multi_metrics == dot_cfg['evaluation']['accuracy']['multi_metrics'])
857+
self.assertTrue(dot_cfg.evaluation.accuracy.multi_metrics.weight == [0.1, 0,9])
858+
self.assertTrue(dot_cfg.evaluation.accuracy.multi_metrics.mAP.anno_path == 'anno_path_test')
859+
multi_metrics1 = dot_cfg.evaluation.accuracy.multi_metrics
860+
multi_metrics2 = dot_cfg['evaluation']['accuracy']['multi_metrics']
861+
self.assertTrue(multi_metrics1 == multi_metrics2)
862+
self.assertTrue(list(multi_metrics1.keys()) == ['weight', 'mAP'])
863+
846864

847865
if __name__ == "__main__":
848866
unittest.main()

0 commit comments

Comments
 (0)