Skip to content

[PTDT-2372] Added support for MMC tasks annotations #1787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,6 @@
from .data.tiled_image import TileLayer

from .llm_prompt_response.prompt import PromptText
from .llm_prompt_response.prompt import PromptClassificationAnnotation
from .llm_prompt_response.prompt import PromptClassificationAnnotation

from .mmc import MessageInfo, OrderedMessageInfo, MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask, MessageEvaluationTaskAnnotation
3 changes: 2 additions & 1 deletion libs/labelbox/src/labelbox/data/annotation_types/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .types import Cuid
from .video import VideoClassificationAnnotation
from .video import VideoObjectAnnotation, VideoMaskAnnotation
from .mmc import MessageEvaluationTaskAnnotation
from ..ontology import get_feature_schema_lookup

DataType = Union[VideoData, ImageData, TextData, TiledImageData, AudioData,
Expand Down Expand Up @@ -51,7 +52,7 @@ class Label(pydantic_compat.BaseModel):
annotations: List[Union[ClassificationAnnotation, ObjectAnnotation,
VideoMaskAnnotation, ScalarMetric,
ConfusionMatrixMetric, RelationshipAnnotation,
PromptClassificationAnnotation]] = []
PromptClassificationAnnotation, MessageEvaluationTaskAnnotation]] = []
extra: Dict[str, Any] = {}
is_benchmark_reference: Optional[bool] = False

Expand Down
44 changes: 44 additions & 0 deletions libs/labelbox/src/labelbox/data/annotation_types/mmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from abc import ABC
from typing import ClassVar, List, Union

from labelbox import pydantic_compat
from labelbox.utils import _CamelCaseMixin
from labelbox.data.annotation_types.annotation import BaseAnnotation


class MessageInfo(_CamelCaseMixin):
message_id: str
model_config_name: str


class OrderedMessageInfo(MessageInfo):
order: int


class _BaseMessageEvaluationTask(_CamelCaseMixin, ABC):
format: ClassVar[str]
parent_message_id: str


class MessageSingleSelectionTask(_BaseMessageEvaluationTask, MessageInfo):
format: ClassVar[str] = "message-single-selection"


class MessageMultiSelectionTask(_BaseMessageEvaluationTask):
format: ClassVar[str] = "message-multi-selection"
selected_messages: List[MessageInfo]


class MessageRankingTask(_BaseMessageEvaluationTask):
format: ClassVar[str] = "message-ranking"
ranked_messages: List[OrderedMessageInfo]

@pydantic_compat.validator("ranked_messages")
def _validate_ranked_messages(cls, v: List[OrderedMessageInfo]):
if not {msg.order for msg in v} == set(range(1, len(v) + 1)):
raise ValueError("Messages must be ordered by unique and consecutive natural numbers starting from 1")
return v


class MessageEvaluationTaskAnnotation(BaseAnnotation):
value: Union[MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask]
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from ...annotation_types.collection import LabelCollection, LabelGenerator
from ...annotation_types.relationship import RelationshipAnnotation
from ...annotation_types.mmc import MessageEvaluationTaskAnnotation
from .label import NDLabel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -71,8 +72,9 @@ def serialize(
ScalarMetric,
ConfusionMatrixMetric,
RelationshipAnnotation,
MessageEvaluationTaskAnnotation,
]] = []
# First pass to get all RelatiohnshipAnnotaitons
# First pass to get all RelationshipAnnotaitons
# and update the UUIDs of the source and target annotations
for annotation in label.annotations:
if isinstance(annotation, RelationshipAnnotation):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@
from ...annotation_types.classification import Dropdown
from ...annotation_types.metrics import ScalarMetric, ConfusionMatrixMetric
from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation
from ...annotation_types.mmc import MessageEvaluationTaskAnnotation

from .metric import NDScalarMetric, NDMetricAnnotation, NDConfusionMatrixMetric
from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass, NDPromptClassification, NDPromptClassificationType, NDPromptText
from .objects import NDObject, NDObjectType, NDSegments, NDDicomSegments, NDVideoMasks, NDDicomMasks
from .mmc import NDMessageTask
from .relationship import NDRelationship
from .base import DataRow

AnnotationType = Union[NDObjectType, NDClassificationType, NDPromptClassificationType,
NDConfusionMatrixMetric, NDScalarMetric, NDDicomSegments,
NDSegments, NDDicomMasks, NDVideoMasks, NDRelationship,
NDPromptText]
NDPromptText, NDMessageTask]


class NDLabel(pydantic_compat.BaseModel):
Expand Down Expand Up @@ -126,6 +128,8 @@ def _generate_annotations(
elif isinstance(ndjson_annotation, NDPromptClassificationType):
annotation = NDPromptClassification.to_common(ndjson_annotation)
annotations.append(annotation)
elif isinstance(ndjson_annotation, NDMessageTask):
annotations.append(ndjson_annotation.to_common())
else:
raise TypeError(
f"Unsupported annotation. {type(ndjson_annotation)}")
Expand Down Expand Up @@ -277,6 +281,8 @@ def _create_non_video_annotations(cls, label: Label):
yield NDRelationship.from_common(annotation, label.data)
elif isinstance(annotation, PromptClassificationAnnotation):
yield NDPromptClassification.from_common(annotation, label.data)
elif isinstance(annotation, MessageEvaluationTaskAnnotation):
yield NDMessageTask.from_common(annotation, label.data)
else:
raise TypeError(
f"Unable to convert object to MAL format. `{type(getattr(annotation, 'value',annotation))}`"
Expand Down
42 changes: 42 additions & 0 deletions libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Any, Dict, List, Optional, Union

from labelbox.utils import _CamelCaseMixin

from .base import DataRow, NDAnnotation
from ...annotation_types.types import Cuid
from ...annotation_types.mmc import MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask, MessageEvaluationTaskAnnotation


class MessageTaskData(_CamelCaseMixin):
format: str
data: Union[MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask]


class NDMessageTask(NDAnnotation):

message_evaluation_task: MessageTaskData

def to_common(self) -> MessageEvaluationTaskAnnotation:
return MessageEvaluationTaskAnnotation(
name=self.name,
feature_schema_id=self.schema_id,
value=self.message_evaluation_task.data,
extra={"uuid": self.uuid},
)

@classmethod
def from_common(
cls,
annotation: MessageEvaluationTaskAnnotation,
data: Any#Union[ImageData, TextData],
) -> "NDMessageTask":
return cls(
uuid=str(annotation._uuid),
name=annotation.name,
schema_id=annotation.feature_schema_id,
data_row=DataRow(id=data.uid, global_key=data.global_key),
message_evaluation_task=MessageTaskData(
format=annotation.value.format,
data=annotation.value
)
)
61 changes: 61 additions & 0 deletions libs/labelbox/tests/data/assets/ndjson/mmc_import.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
[
{
"dataRow": {
"id": "cnjencjencjfencvj"
},
"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72",
"name": "single-selection",
"messageEvaluationTask": {
"format": "message-single-selection",
"data": {
"messageId": "clxfzocbm00083b6v8vczsept",
"parentMessageId": "clxfznjb800073b6v43ppx9ca",
"modelConfigName": "GPT 5"
}
}
},
{
"dataRow": {
"id": "cfcerfvergerfefj"
},
"uuid": "gferf3a57-597e-48cb-8d8d-a8526fefe72",
"name": "multi-selection",
"messageEvaluationTask": {
"format": "message-multi-selection",
"data": {
"parentMessageId": "clxfznjb800073b6v43ppx9ca",
"selectedMessages": [
{
"messageId": "clxfzocbm00083b6v8vczsept",
"modelConfigName": "GPT 5"
}
]
}
}
},
{
"dataRow": {
"id": "cwefgtrgrthveferfferffr"
},
"uuid": "hybe3a57-5gt7e-48tgrb-8d8d-a852dswqde72",
"name": "ranking",
"messageEvaluationTask": {
"format": "message-ranking",
"data": {
"parentMessageId": "clxfznjb800073b6v43ppx9ca",
"rankedMessages": [
{
"messageId": "clxfzocbm00083b6v8vczsept",
"modelConfigName": "GPT 4 with temperature 0.7",
"order": 1
},
{
"messageId": "clxfzocbm00093b6vx4ndisub",
"modelConfigName": "GPT 5",
"order": 2
}
]
}
}
}
]
27 changes: 27 additions & 0 deletions libs/labelbox/tests/data/serialization/ndjson/test_mmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import json

import pytest

from labelbox.data.serialization import NDJsonConverter
from labelbox.pydantic_compat import ValidationError


def test_message_task_annotation_serialization():
with open('tests/data/assets/ndjson/mmc_import.json', 'r') as file:
data = json.load(file)

deserialized = list(NDJsonConverter.deserialize(data))
reserialized = list(NDJsonConverter.serialize(deserialized))

assert data == reserialized


def test_mesage_ranking_task_wrong_order_serialization():
with open('tests/data/assets/ndjson/mmc_import.json', 'r') as file:
data = json.load(file)

some_ranking_task = next(task for task in data if task["messageEvaluationTask"]["format"] == "message-ranking")
some_ranking_task["messageEvaluationTask"]["data"]["rankedMessages"][0]["order"] = 3

with pytest.raises(ValidationError):
list(NDJsonConverter.deserialize([some_ranking_task]))
Loading