Skip to content

Commit 69572ae

Browse files
authored
Add support for accuracy validation for yolox-s model (#3964)
* Added support for pytorch yoloxs model to accuracy_checker. * Added support for Geti version of yolox-s model. * Minor update of the error message in python_launcher. * Refactored YoloxsAdapter for accuracy_checker. * Merged YoloxsAdapter processing methods into one. * Enable loading model checkpoint from url. * Fixed accuracy checker checkpoint handling in pytorch_launcher to pass unit tests. * Added check for empty labels to Yoloxs adapter. * Added config params for output layer names to accuracy_checker tool.
1 parent a14ddee commit 69572ae

File tree

7 files changed

+111
-7
lines changed

7 files changed

+111
-7
lines changed

tools/accuracy_checker/accuracy_checker/adapters/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ AccuracyChecker supports following set of adapters:
8989
* `output_name` - name of output layer (Optional).
9090
* `threshold` - minimal objectness score value for valid detections (Optional, default 0.001).
9191
* `num` - num parameter from DarkNet configuration file (Optional, default 5).
92+
* `yoloxs` - converting output of YOLOX model to `DetectionPrediction` representation.
93+
* `score_threshold` - minimal accepted score for valid detections (Optional, default 0.001).
94+
* `box_format_xywh` - enabling additional preprocessing when box format is xywh (default `False`).
95+
* `boxes_out` - name of output layer with boxes(Optional, default `boxes`).
96+
* `labels_out` - name of output layer with labels(Optional, default `labels`).
9297
* `yolo_v8_detection` - converting output of YOLO v8 family pretrained for object detection to `DetectionPrediction`.
9398
* `conf_threshold` - minimal confidence for filtering valid detections (Optional, default 0.25).
9499
* `multi_label` - allow to use multiple labels for the same box coordinates (Optional, default True).

tools/accuracy_checker/accuracy_checker/adapters/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
YoloV5Adapter,
8686
YolorAdapter,
8787
YoloxAdapter,
88+
YoloxsAdapter,
8889
YolofAdapter,
8990
# for adapter registration, it should be imported and added to __all__ list
9091
YoloV8DetectionAdapter
@@ -184,6 +185,7 @@
184185
'YoloV5Adapter',
185186
'YolorAdapter',
186187
'YoloxAdapter',
188+
'YoloxsAdapter',
187189
'YolofAdapter',
188190
'YoloV8DetectionAdapter',
189191

tools/accuracy_checker/accuracy_checker/adapters/yolo.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,74 @@ def xywh2xyxy(x):
798798
return y
799799

800800

801+
class YoloxsAdapter(Adapter):
802+
__provider__ = 'yoloxs'
803+
prediction_types = (DetectionPrediction, )
804+
805+
@classmethod
806+
def parameters(cls):
807+
parameters = super().parameters()
808+
parameters.update({
809+
'score_threshold': NumberField(value_type=float, optional=True, min_value=0, default=0.001,
810+
description="Minimal accepted score value for valid detections."),
811+
'box_format_xywh': BoolField(optional=True, default=False,
812+
description="Indicates that box output format is xywh."),
813+
'boxes_out': StringField(optional=True, default='boxes', description="Boxes output layer name."),
814+
'labels_out': StringField(optional=True, default='labels', description="Labels output layer name."),
815+
})
816+
return parameters
817+
818+
def configure(self):
819+
self.score_threshold = self.get_value_from_config('score_threshold')
820+
self.box_format_xywh = self.get_value_from_config('box_format_xywh')
821+
self.boxes_out = self.get_value_from_config('boxes_out')
822+
self.labels_out = self.get_value_from_config('labels_out')
823+
824+
def process(self, raw, identifiers, frame_meta):
825+
result = []
826+
raw_outputs = self._extract_predictions(raw, frame_meta)
827+
828+
num_classes = 0
829+
x_mins, y_mins, x_maxs, y_maxs = [], [], [], []
830+
831+
for identifier, meta in zip(identifiers, frame_meta):
832+
if len(self.additional_output_mapping) > 0:
833+
boxes = np.array(raw_outputs[self.additional_output_mapping[self.boxes_out]]).squeeze()
834+
labels = np.array(raw_outputs[self.additional_output_mapping[self.labels_out]]).squeeze()
835+
if not labels.shape:
836+
result.append(DetectionPrediction(identifier, [], [], [], [], [], [], meta))
837+
continue
838+
scores = boxes[:, 4]
839+
boxes = boxes[:, :4]
840+
else:
841+
output = np.array(raw_outputs[self.output_blob])
842+
num_classes = output.shape[1] - 5
843+
labels = np.argmax(output[:, 5: 5 + num_classes], axis=1)
844+
scores = output[:, 4]
845+
boxes = output[:, :4]
846+
847+
if num_classes > 0:
848+
class_max_confidences = np.max(output[:, 5: 5 + num_classes], axis=1)
849+
scores *= class_max_confidences
850+
851+
if self.box_format_xywh:
852+
boxes = xywh2xyxy(boxes)
853+
854+
mask = scores > self.score_threshold
855+
scores = scores[mask]
856+
labels = labels[mask]
857+
boxes = boxes[mask]
858+
859+
image_resize_ratio = meta['scale_x']
860+
if boxes.size > 0 and image_resize_ratio > 0:
861+
x_mins, y_mins, x_maxs, y_maxs = boxes.T / image_resize_ratio
862+
863+
result.append(DetectionPrediction(
864+
identifier, labels, scores, x_mins, y_mins, x_maxs, y_maxs, meta
865+
))
866+
return result
867+
868+
801869
class YoloV8DetectionAdapter(Adapter):
802870
"""
803871
class adapter for yolov8, yolov8 support multiple tasks, this class for object detection.

tools/accuracy_checker/accuracy_checker/annotation_converters/ms_coco.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,10 @@ def _create_representations(
206206
image_full_path = self.images_dir / image[1]
207207
if not check_file_existence(image_full_path):
208208
content_errors.append('{}: does not exist'.format(image_full_path))
209-
detection_annotation = DetectionAnnotation(image[1], image_labels, xmins, ymins, xmaxs, ymaxs)
210-
detection_annotation.metadata['iscrowd'] = is_crowd
211-
detection_annotations.append(detection_annotation)
209+
if image_labels != []:
210+
detection_annotation = DetectionAnnotation(image[1], image_labels, xmins, ymins, xmaxs, ymaxs)
211+
detection_annotation.metadata['iscrowd'] = is_crowd
212+
detection_annotations.append(detection_annotation)
212213
progress_reporter.update(image_id, 1)
213214

214215
progress_reporter.finish()

tools/accuracy_checker/accuracy_checker/config/config_validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class ConfigValidator(BaseValidator):
8484
WARN_ON_EXTRA_ARGUMENT = _ExtraArgumentBehaviour.WARN
8585
ERROR_ON_EXTRA_ARGUMENT = _ExtraArgumentBehaviour.ERROR
8686
IGNORE_ON_EXTRA_ARGUMENT = _ExtraArgumentBehaviour.IGNORE
87-
acceptable_unknown_options = ['connector', '_command_line_mapping']
87+
acceptable_unknown_options = ['connector', '_command_line_mapping', 'model']
8888

8989
def __init__(self, config_uri, on_extra_argument=WARN_ON_EXTRA_ARGUMENT, fields=None, **kwargs):
9090
super().__init__(**kwargs)

tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from contextlib import contextmanager
1818
import sys
1919
import importlib
20+
import urllib
21+
import re
2022
from collections import OrderedDict
2123

2224
import numpy as np
@@ -25,6 +27,7 @@
2527

2628
MODULE_REGEX = r'(?:\w+)(?:(?:.\w+)*)'
2729
DEVICE_REGEX = r'(?P<device>cpu$|cuda)?'
30+
CHECKPOINT_URL_REGEX = r'^https?://.*\.pth(\?.*)?(#.*)?$'
2831

2932

3033
class PyTorchLauncher(Launcher):
@@ -38,6 +41,10 @@ def parameters(cls):
3841
'checkpoint': PathField(
3942
check_exists=True, is_directory=False, optional=True, description='pre-trained model checkpoint'
4043
),
44+
'checkpoint_url': StringField(
45+
optional=True, regex=CHECKPOINT_URL_REGEX, description='Url link to pre-trained model checkpoint.'
46+
),
47+
'state_key': StringField(optional=True, regex=r'\w+', description='pre-trained model checkpoint state key'),
4148
'python_path': PathField(
4249
check_exists=True, is_directory=True, optional=True,
4350
description='appendix for PYTHONPATH for making network module visible in current python environment'
@@ -47,6 +54,9 @@ def parameters(cls):
4754
key_type=str, validate_values=False, optional=True, default={},
4855
description='keyword arguments for network module'
4956
),
57+
'init_method': StringField(
58+
optional=True, regex=r'\w+', description='Method name to be called for module initialization.'
59+
),
5060
'device': StringField(default='cpu', regex=DEVICE_REGEX),
5161
'batch': NumberField(value_type=int, min_value=1, optional=True, description="Batch size.", default=1),
5262
'output_names': ListField(
@@ -79,13 +89,17 @@ def __init__(self, config_entry: dict, *args, **kwargs):
7989
module_kwargs = config_entry.get("module_kwargs", {})
8090
self.device = self.get_value_from_config('device')
8191
self.cuda = 'cuda' in self.device
92+
checkpoint = config_entry.get('checkpoint')
93+
if checkpoint is None:
94+
checkpoint = config_entry.get('checkpoint_url')
8295
self.module = self.load_module(
8396
config_entry['module'],
8497
module_args,
8598
module_kwargs,
86-
config_entry.get('checkpoint'),
99+
checkpoint,
87100
config_entry.get('state_key'),
88-
config_entry.get("python_path")
101+
config_entry.get("python_path"),
102+
config_entry.get("init_method")
89103
)
90104

91105
self._batch = self.get_value_from_config('batch')
@@ -115,14 +129,25 @@ def batch(self):
115129
def output_blob(self):
116130
return next(iter(self.output_names))
117131

118-
def load_module(self, model_cls, module_args, module_kwargs, checkpoint=None, state_key=None, python_path=None):
132+
def load_module(self, model_cls, module_args, module_kwargs, checkpoint=None, state_key=None, python_path=None,
133+
init_method=None
134+
):
119135
module_parts = model_cls.split(".")
120136
model_cls = module_parts[-1]
121137
model_path = ".".join(module_parts[:-1])
122138
with append_to_path(python_path):
123139
model_cls = importlib.import_module(model_path).__getattribute__(model_cls)
124140
module = model_cls(*module_args, **module_kwargs)
141+
if init_method is not None:
142+
if hasattr(model_cls, init_method):
143+
init_method = getattr(module, init_method)
144+
module = init_method()
145+
else:
146+
raise ValueError(f'Could not call the method {init_method} in the module {model_cls}.')
147+
125148
if checkpoint:
149+
if isinstance(checkpoint, str) and re.match(CHECKPOINT_URL_REGEX, checkpoint):
150+
checkpoint = urllib.request.urlretrieve(checkpoint)[0]
126151
checkpoint = self._torch.load(
127152
checkpoint, map_location=None if self.cuda else self._torch.device('cpu')
128153
)

tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher_readme.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@ For enabling PyTorch launcher you need to add `framework: pytorch` in launchers
77
* `device` - specifies which device will be used for infer (`cpu`, `cuda` and so on).
88
* `module`- PyTorch network module for loading.
99
* `checkpoint` - pre-trained model checkpoint (Optional).
10+
* `checkpoint_url` - url link to pre-trained model checkpoint (Optional).
11+
* `state_key` - pre-trained model checkpoint state key (Optional).
1012
* `python_path` - appendix for PYTHONPATH for making network module visible in current python environment (Optional).
1113
* `module_args` - list of positional arguments for network module (Optional).
1214
* `module_kwargs` - dictionary (`key`: `value` where `key` is argument name, `value` is argument value) which represent network module keyword arguments.
15+
* `init_method` - method name to be called for module initialization (Optional).
1316
* `adapter` - approach how raw output will be converted to representation of dataset problem, some adapters can be specific to framework. You can find detailed instruction how to use adapters [here](../adapters/README.md).
1417
* `batch` - batch size for running model (Optional, default 1).
1518
* `use_openvino_backend` - use torch.compile feature with `openvino` backend (Optional, default `False`)

0 commit comments

Comments
 (0)