18
18
from ...progress_reporters import ProgressReporter
19
19
from ..quantization_model_evaluator import create_dataset_attributes
20
20
from ...launcher import create_launcher
21
-
21
+ from ...postprocessor import PostprocessingExecutor
22
+ from ...preprocessor import PreprocessingExecutor
22
23
23
24
# base class for custom evaluators
24
25
class BaseCustomEvaluator (BaseEvaluator ):
25
- def __init__ (self , dataset_config , launcher , orig_config ):
26
+ def __init__ (self , dataset_config , launcher , orig_config ,
27
+ preprocessor = None , postprocessor = None ):
26
28
self .dataset_config = dataset_config
27
29
self .dataset = None
30
+ self .input_feeder = None
31
+ self .adapter = None
28
32
self .preprocessing_executor = None
29
- self .preprocessor = None
30
- self .postprocessor = None
33
+ self .preprocessor = preprocessor
34
+ self .postprocessor = postprocessor
31
35
self .metric_executor = None
32
36
self .launcher = launcher
33
37
self ._metrics_results = []
@@ -47,6 +51,30 @@ def get_dataset_and_launcher_info(config):
47
51
launcher = create_launcher (launcher_config , delayed_model_loading = True )
48
52
return dataset_config , launcher , launcher_config
49
53
54
+ @classmethod
55
+ def get_evaluator_init_info (cls , model_config , delayed_annotation_loading = False ):
56
+ launcher_config = model_config ['launchers' ][0 ]
57
+ datasets = model_config ['datasets' ]
58
+ dataset_config = datasets [0 ]
59
+ dataset_name = dataset_config ['name' ]
60
+
61
+ runtime_framework = launcher_config ['framework' ]
62
+ enable_runtime_preprocessing = False
63
+ if runtime_framework in ['dlsdk' , 'openvino' ]:
64
+ enable_runtime_preprocessing = dataset_config .get ('_ie_preprocessing' , False )
65
+ preprocessor = PreprocessingExecutor (
66
+ dataset_config .get ('preprocessing' ), dataset_name ,
67
+ enable_runtime_preprocessing = enable_runtime_preprocessing , runtime_framework = runtime_framework
68
+ )
69
+
70
+ if launcher_config ['framework' ] == 'dlsdk' and 'device' not in launcher_config :
71
+ launcher_config ['device' ] = 'CPU'
72
+ launcher = create_launcher (launcher_config , delayed_model_loading = True )
73
+ dataset_metadata = {}
74
+ postprocessor = PostprocessingExecutor (dataset_config .get ('postprocessing' ), dataset_name , dataset_metadata )
75
+
76
+ return datasets , launcher , preprocessor , postprocessor
77
+
50
78
def process_dataset (self , subset = None , num_images = None , check_progress = False , dataset_tag = '' ,
51
79
output_callback = None , allow_pairwise_subset = False , dump_prediction_to_annotation = False ,
52
80
calculate_metrics = True , ** kwargs ):
@@ -164,7 +192,7 @@ def register_metric(self, metric_config):
164
192
elif isinstance (metric_config , dict ):
165
193
self .metric_executor .register_metric (metric_config )
166
194
else :
167
- raise ValueError ('Unsupported metric configuration type {}' . format ( type (metric_config )) )
195
+ raise ValueError (f 'Unsupported metric configuration type { type (metric_config )} ' )
168
196
169
197
def get_metrics_attributes (self ):
170
198
if not self .metric_executor :
0 commit comments