Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions detector/yolov5_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
import cv2
import numpy as np
import yolov5
from learning_loop_node.data_classes import BoxDetection, ImageMetadata, ImagesMetadata, PointDetection
from learning_loop_node.data_classes import (
BoxDetection,
ImageMetadata,
ImagesMetadata,
PointDetection,
)
from learning_loop_node.detector.detector_logic import DetectorLogic
from learning_loop_node.enums import CategoryType

Expand Down Expand Up @@ -81,11 +86,19 @@ def clip_point(x: float, y: float, img_width: int, img_height: int) -> Tuple[flo
y = min(max(0, y), img_height)
return x, y

def evaluate(self, image: bytes) -> ImageMetadata:
def evaluate(self,
image: bytes,
tags: List[str],
source: Optional[str] = None,
creation_date: Optional[str] = None) -> ImageMetadata:
assert self.yolov5 is not None, 'init() must be executed first. Maybe loading the engine failed?!'
assert self.model_info is not None, 'model_info must be set before calling evaluate()'

image_metadata = ImageMetadata()
image_metadata.tags = tags
image_metadata.source = source
image_metadata.created = creation_date

try:
t = time.time()
cv_image = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
Expand Down Expand Up @@ -128,7 +141,10 @@ def evaluate(self, image: bytes) -> ImageMetadata:
self.log.exception('inference failed')
return image_metadata

def batch_evaluate(self, images: List[bytes]) -> ImagesMetadata:
def batch_evaluate(self, images: List[bytes],
tags: List[str],
source: Optional[str] = None,
creation_date: Optional[str] = None) -> ImagesMetadata:
raise NotImplementedError('batch_evaluate is not implemented for Yolov5Detector')

def _create_engine(self, resolution: int, cat_count: int, model_variant: Optional[str], wts_file: str) -> str:
Expand Down