From edd5cbe7c2c809767101774b8a4816cea0076bf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20J=C3=B3=C5=BAwiak?= Date: Tue, 3 Sep 2024 15:52:09 +0200 Subject: [PATCH] [PTDT-2372] Added support for MMC tasks annotations --- .../data/annotation_types/__init__.py | 4 +- .../labelbox/data/annotation_types/label.py | 3 +- .../src/labelbox/data/annotation_types/mmc.py | 44 +++++++++++++ .../data/serialization/ndjson/converter.py | 4 +- .../data/serialization/ndjson/label.py | 8 ++- .../labelbox/data/serialization/ndjson/mmc.py | 42 +++++++++++++ .../tests/data/assets/ndjson/mmc_import.json | 61 +++++++++++++++++++ .../data/serialization/ndjson/test_mmc.py | 27 ++++++++ 8 files changed, 189 insertions(+), 4 deletions(-) create mode 100644 libs/labelbox/src/labelbox/data/annotation_types/mmc.py create mode 100644 libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py create mode 100644 libs/labelbox/tests/data/assets/ndjson/mmc_import.json create mode 100644 libs/labelbox/tests/data/serialization/ndjson/test_mmc.py diff --git a/libs/labelbox/src/labelbox/data/annotation_types/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/__init__.py index 85b7ae1af..3d0442218 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/__init__.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/__init__.py @@ -63,4 +63,6 @@ from .data.tiled_image import TileLayer from .llm_prompt_response.prompt import PromptText -from .llm_prompt_response.prompt import PromptClassificationAnnotation \ No newline at end of file +from .llm_prompt_response.prompt import PromptClassificationAnnotation + +from .mmc import MessageInfo, OrderedMessageInfo, MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask, MessageEvaluationTaskAnnotation diff --git a/libs/labelbox/src/labelbox/data/annotation_types/label.py b/libs/labelbox/src/labelbox/data/annotation_types/label.py index c7a0cb7b8..1ab4889f6 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/label.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/label.py @@ -18,6 +18,7 @@ from .types import Cuid from .video import VideoClassificationAnnotation from .video import VideoObjectAnnotation, VideoMaskAnnotation +from .mmc import MessageEvaluationTaskAnnotation from ..ontology import get_feature_schema_lookup DataType = Union[VideoData, ImageData, TextData, TiledImageData, AudioData, @@ -51,7 +52,7 @@ class Label(pydantic_compat.BaseModel): annotations: List[Union[ClassificationAnnotation, ObjectAnnotation, VideoMaskAnnotation, ScalarMetric, ConfusionMatrixMetric, RelationshipAnnotation, - PromptClassificationAnnotation]] = [] + PromptClassificationAnnotation, MessageEvaluationTaskAnnotation]] = [] extra: Dict[str, Any] = {} is_benchmark_reference: Optional[bool] = False diff --git a/libs/labelbox/src/labelbox/data/annotation_types/mmc.py b/libs/labelbox/src/labelbox/data/annotation_types/mmc.py new file mode 100644 index 000000000..29b33c62d --- /dev/null +++ b/libs/labelbox/src/labelbox/data/annotation_types/mmc.py @@ -0,0 +1,44 @@ +from abc import ABC +from typing import ClassVar, List, Union + +from labelbox import pydantic_compat +from labelbox.utils import _CamelCaseMixin +from labelbox.data.annotation_types.annotation import BaseAnnotation + + +class MessageInfo(_CamelCaseMixin): + message_id: str + model_config_name: str + + +class OrderedMessageInfo(MessageInfo): + order: int + + +class _BaseMessageEvaluationTask(_CamelCaseMixin, ABC): + format: ClassVar[str] + parent_message_id: str + + +class MessageSingleSelectionTask(_BaseMessageEvaluationTask, MessageInfo): + format: ClassVar[str] = "message-single-selection" + + +class MessageMultiSelectionTask(_BaseMessageEvaluationTask): + format: ClassVar[str] = "message-multi-selection" + selected_messages: List[MessageInfo] + + +class MessageRankingTask(_BaseMessageEvaluationTask): + format: ClassVar[str] = "message-ranking" + ranked_messages: List[OrderedMessageInfo] + + @pydantic_compat.validator("ranked_messages") + def _validate_ranked_messages(cls, v: List[OrderedMessageInfo]): + if not {msg.order for msg in v} == set(range(1, len(v) + 1)): + raise ValueError("Messages must be ordered by unique and consecutive natural numbers starting from 1") + return v + + +class MessageEvaluationTaskAnnotation(BaseAnnotation): + value: Union[MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask] diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py index 2ffeb9727..07b1b59c0 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py @@ -14,6 +14,7 @@ from ...annotation_types.collection import LabelCollection, LabelGenerator from ...annotation_types.relationship import RelationshipAnnotation +from ...annotation_types.mmc import MessageEvaluationTaskAnnotation from .label import NDLabel logger = logging.getLogger(__name__) @@ -71,8 +72,9 @@ def serialize( ScalarMetric, ConfusionMatrixMetric, RelationshipAnnotation, + MessageEvaluationTaskAnnotation, ]] = [] - # First pass to get all RelatiohnshipAnnotaitons + # First pass to get all RelationshipAnnotaitons # and update the UUIDs of the source and target annotations for annotation in label.annotations: if isinstance(annotation, RelationshipAnnotation): diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index 9d34c451b..29b239196 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -18,17 +18,19 @@ from ...annotation_types.classification import Dropdown from ...annotation_types.metrics import ScalarMetric, ConfusionMatrixMetric from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation +from ...annotation_types.mmc import MessageEvaluationTaskAnnotation from .metric import NDScalarMetric, NDMetricAnnotation, NDConfusionMatrixMetric from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass, NDPromptClassification, NDPromptClassificationType, NDPromptText from .objects import NDObject, NDObjectType, NDSegments, NDDicomSegments, NDVideoMasks, NDDicomMasks +from .mmc import NDMessageTask from .relationship import NDRelationship from .base import DataRow AnnotationType = Union[NDObjectType, NDClassificationType, NDPromptClassificationType, NDConfusionMatrixMetric, NDScalarMetric, NDDicomSegments, NDSegments, NDDicomMasks, NDVideoMasks, NDRelationship, - NDPromptText] + NDPromptText, NDMessageTask] class NDLabel(pydantic_compat.BaseModel): @@ -126,6 +128,8 @@ def _generate_annotations( elif isinstance(ndjson_annotation, NDPromptClassificationType): annotation = NDPromptClassification.to_common(ndjson_annotation) annotations.append(annotation) + elif isinstance(ndjson_annotation, NDMessageTask): + annotations.append(ndjson_annotation.to_common()) else: raise TypeError( f"Unsupported annotation. {type(ndjson_annotation)}") @@ -277,6 +281,8 @@ def _create_non_video_annotations(cls, label: Label): yield NDRelationship.from_common(annotation, label.data) elif isinstance(annotation, PromptClassificationAnnotation): yield NDPromptClassification.from_common(annotation, label.data) + elif isinstance(annotation, MessageEvaluationTaskAnnotation): + yield NDMessageTask.from_common(annotation, label.data) else: raise TypeError( f"Unable to convert object to MAL format. `{type(getattr(annotation, 'value',annotation))}`" diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py new file mode 100644 index 000000000..e7af6924c --- /dev/null +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py @@ -0,0 +1,42 @@ +from typing import Any, Dict, List, Optional, Union + +from labelbox.utils import _CamelCaseMixin + +from .base import DataRow, NDAnnotation +from ...annotation_types.types import Cuid +from ...annotation_types.mmc import MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask, MessageEvaluationTaskAnnotation + + +class MessageTaskData(_CamelCaseMixin): + format: str + data: Union[MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask] + + +class NDMessageTask(NDAnnotation): + + message_evaluation_task: MessageTaskData + + def to_common(self) -> MessageEvaluationTaskAnnotation: + return MessageEvaluationTaskAnnotation( + name=self.name, + feature_schema_id=self.schema_id, + value=self.message_evaluation_task.data, + extra={"uuid": self.uuid}, + ) + + @classmethod + def from_common( + cls, + annotation: MessageEvaluationTaskAnnotation, + data: Any#Union[ImageData, TextData], + ) -> "NDMessageTask": + return cls( + uuid=str(annotation._uuid), + name=annotation.name, + schema_id=annotation.feature_schema_id, + data_row=DataRow(id=data.uid, global_key=data.global_key), + message_evaluation_task=MessageTaskData( + format=annotation.value.format, + data=annotation.value + ) + ) diff --git a/libs/labelbox/tests/data/assets/ndjson/mmc_import.json b/libs/labelbox/tests/data/assets/ndjson/mmc_import.json new file mode 100644 index 000000000..5053dc9da --- /dev/null +++ b/libs/labelbox/tests/data/assets/ndjson/mmc_import.json @@ -0,0 +1,61 @@ +[ + { + "dataRow": { + "id": "cnjencjencjfencvj" + }, + "uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72", + "name": "single-selection", + "messageEvaluationTask": { + "format": "message-single-selection", + "data": { + "messageId": "clxfzocbm00083b6v8vczsept", + "parentMessageId": "clxfznjb800073b6v43ppx9ca", + "modelConfigName": "GPT 5" + } + } + }, + { + "dataRow": { + "id": "cfcerfvergerfefj" + }, + "uuid": "gferf3a57-597e-48cb-8d8d-a8526fefe72", + "name": "multi-selection", + "messageEvaluationTask": { + "format": "message-multi-selection", + "data": { + "parentMessageId": "clxfznjb800073b6v43ppx9ca", + "selectedMessages": [ + { + "messageId": "clxfzocbm00083b6v8vczsept", + "modelConfigName": "GPT 5" + } + ] + } + } + }, + { + "dataRow": { + "id": "cwefgtrgrthveferfferffr" + }, + "uuid": "hybe3a57-5gt7e-48tgrb-8d8d-a852dswqde72", + "name": "ranking", + "messageEvaluationTask": { + "format": "message-ranking", + "data": { + "parentMessageId": "clxfznjb800073b6v43ppx9ca", + "rankedMessages": [ + { + "messageId": "clxfzocbm00083b6v8vczsept", + "modelConfigName": "GPT 4 with temperature 0.7", + "order": 1 + }, + { + "messageId": "clxfzocbm00093b6vx4ndisub", + "modelConfigName": "GPT 5", + "order": 2 + } + ] + } + } + } +] \ No newline at end of file diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py b/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py new file mode 100644 index 000000000..54202cccc --- /dev/null +++ b/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py @@ -0,0 +1,27 @@ +import json + +import pytest + +from labelbox.data.serialization import NDJsonConverter +from labelbox.pydantic_compat import ValidationError + + +def test_message_task_annotation_serialization(): + with open('tests/data/assets/ndjson/mmc_import.json', 'r') as file: + data = json.load(file) + + deserialized = list(NDJsonConverter.deserialize(data)) + reserialized = list(NDJsonConverter.serialize(deserialized)) + + assert data == reserialized + + +def test_mesage_ranking_task_wrong_order_serialization(): + with open('tests/data/assets/ndjson/mmc_import.json', 'r') as file: + data = json.load(file) + + some_ranking_task = next(task for task in data if task["messageEvaluationTask"]["format"] == "message-ranking") + some_ranking_task["messageEvaluationTask"]["data"]["rankedMessages"][0]["order"] = 3 + + with pytest.raises(ValidationError): + list(NDJsonConverter.deserialize([some_ranking_task]))