Skip to content

Commit 348b306

Browse files
authored
Add pre/post processing to base_custom_evaluator.py
1 parent 15c89fc commit 348b306

File tree

1 file changed

+33
-5
lines changed

1 file changed

+33
-5
lines changed

tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/base_custom_evaluator.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,20 @@
1818
from ...progress_reporters import ProgressReporter
1919
from ..quantization_model_evaluator import create_dataset_attributes
2020
from ...launcher import create_launcher
21-
21+
from ...postprocessor import PostprocessingExecutor
22+
from ...preprocessor import PreprocessingExecutor
2223

2324
# base class for custom evaluators
2425
class BaseCustomEvaluator(BaseEvaluator):
25-
def __init__(self, dataset_config, launcher, orig_config):
26+
def __init__(self, dataset_config, launcher, orig_config,
27+
preprocessor=None, postprocessor=None):
2628
self.dataset_config = dataset_config
2729
self.dataset = None
30+
self.input_feeder = None
31+
self.adapter = None
2832
self.preprocessing_executor = None
29-
self.preprocessor = None
30-
self.postprocessor = None
33+
self.preprocessor = preprocessor
34+
self.postprocessor = postprocessor
3135
self.metric_executor = None
3236
self.launcher = launcher
3337
self._metrics_results = []
@@ -47,6 +51,30 @@ def get_dataset_and_launcher_info(config):
4751
launcher = create_launcher(launcher_config, delayed_model_loading=True)
4852
return dataset_config, launcher, launcher_config
4953

54+
@classmethod
55+
def get_evaluator_init_info(cls, model_config, delayed_annotation_loading=False):
56+
launcher_config = model_config['launchers'][0]
57+
datasets = model_config['datasets']
58+
dataset_config = datasets[0]
59+
dataset_name = dataset_config['name']
60+
61+
runtime_framework = launcher_config['framework']
62+
enable_runtime_preprocessing = False
63+
if runtime_framework in ['dlsdk', 'openvino']:
64+
enable_runtime_preprocessing = dataset_config.get('_ie_preprocessing', False)
65+
preprocessor = PreprocessingExecutor(
66+
dataset_config.get('preprocessing'), dataset_name,
67+
enable_runtime_preprocessing=enable_runtime_preprocessing, runtime_framework=runtime_framework
68+
)
69+
70+
if launcher_config['framework'] == 'dlsdk' and 'device' not in launcher_config:
71+
launcher_config['device'] = 'CPU'
72+
launcher = create_launcher(launcher_config, delayed_model_loading=True)
73+
dataset_metadata = {}
74+
postprocessor = PostprocessingExecutor(dataset_config.get('postprocessing'), dataset_name, dataset_metadata)
75+
76+
return datasets, launcher, preprocessor, postprocessor
77+
5078
def process_dataset(self, subset=None, num_images=None, check_progress=False, dataset_tag='',
5179
output_callback=None, allow_pairwise_subset=False, dump_prediction_to_annotation=False,
5280
calculate_metrics=True, **kwargs):
@@ -164,7 +192,7 @@ def register_metric(self, metric_config):
164192
elif isinstance(metric_config, dict):
165193
self.metric_executor.register_metric(metric_config)
166194
else:
167-
raise ValueError('Unsupported metric configuration type {}'.format(type(metric_config)))
195+
raise ValueError(f'Unsupported metric configuration type {type(metric_config)}')
168196

169197
def get_metrics_attributes(self):
170198
if not self.metric_executor:

0 commit comments

Comments
 (0)