Skip to content

Commit 484738d

Browse files
committed
[PTDT-2372] Added support for MMC tasks annotations
1 parent e20a774 commit 484738d

File tree

5 files changed

+97
-14
lines changed

5 files changed

+97
-14
lines changed

libs/labelbox/src/labelbox/data/annotation_types/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,6 @@
6363
from .data.tiled_image import TileLayer
6464

6565
from .llm_prompt_response.prompt import PromptText
66-
from .llm_prompt_response.prompt import PromptClassificationAnnotation
66+
from .llm_prompt_response.prompt import PromptClassificationAnnotation
67+
68+
from .mmc import MessageInfo, OrderedMessageInfo, MessageSingleSelection, MessageMultiSelection, MessageRanking

libs/labelbox/src/labelbox/data/annotation_types/annotation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from labelbox.data.annotation_types.classification.classification import ClassificationAnnotation
99
from .ner import DocumentEntity, TextEntity, ConversationEntity
10+
from .mmc import BaseMMCAnnotation
1011

1112

1213
class ObjectAnnotation(BaseAnnotation, ConfidenceMixin, CustomMetricsMixin):
@@ -28,5 +29,5 @@ class ObjectAnnotation(BaseAnnotation, ConfidenceMixin, CustomMetricsMixin):
2829
extra (Dict[str, Any])
2930
"""
3031

31-
value: Union[TextEntity, ConversationEntity, DocumentEntity, Geometry]
32+
value: Union[TextEntity, ConversationEntity, DocumentEntity, Geometry, BaseMMCAnnotation]
3233
classifications: List[ClassificationAnnotation] = []
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from abc import ABC
2+
from typing import ClassVar, List
3+
4+
from labelbox import pydantic_compat
5+
from labelbox.utils import _CamelCaseMixin
6+
7+
8+
class MessageInfo(_CamelCaseMixin):
9+
message_id: str
10+
model_config_name: str
11+
12+
13+
class OrderedMessageInfo(MessageInfo):
14+
order: int
15+
16+
17+
class BaseMMCAnnotation(_CamelCaseMixin, ABC):
18+
format: ClassVar[str]
19+
parent_message_id: str
20+
21+
22+
class MessageSingleSelection(BaseMMCAnnotation, MessageInfo):
23+
format: ClassVar[str] = "message-single-selection"
24+
25+
26+
class MessageMultiSelection(BaseMMCAnnotation):
27+
format: ClassVar[str] = "message-multi-selection"
28+
selected_messages: List[MessageInfo]
29+
30+
31+
class MessageRanking(BaseMMCAnnotation):
32+
format: ClassVar[str] = "message-ranking"
33+
ranked_messages: List[OrderedMessageInfo]
34+
35+
@pydantic_compat.validator("ranked_messages")
36+
def _validate_ranked_messages(cls, v: List[OrderedMessageInfo]):
37+
if not {msg.order for msg in v} == set(range(1, len(v) + 1)):
38+
raise ValueError("Messages must be ordered by unique and consecutive natural numbers starting from 1")
39+
return v

libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def serialize(
7272
ConfusionMatrixMetric,
7373
RelationshipAnnotation,
7474
]] = []
75-
# First pass to get all RelatiohnshipAnnotaitons
75+
# First pass to get all RelationshipAnnotaitons
7676
# and update the UUIDs of the source and target annotations
7777
for annotation in label.annotations:
7878
if isinstance(annotation, RelationshipAnnotation):

libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from labelbox.data.annotation_types.ner.conversation_entity import ConversationEntity
66
from labelbox.data.annotation_types.video import VideoObjectAnnotation, DICOMObjectAnnotation
77
from labelbox.data.mixins import ConfidenceMixin, CustomMetricsMixin, CustomMetric, CustomMetricsNotSupportedMixin
8+
from labelbox.utils import _CamelCaseMixin
89
import numpy as np
910

1011
from labelbox import pydantic_compat
@@ -19,6 +20,7 @@
1920
from ...annotation_types.geometry import DocumentRectangle, Rectangle, Polygon, Line, Point, Mask
2021
from ...annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation
2122
from ...annotation_types.video import VideoMaskAnnotation, DICOMMaskAnnotation, MaskFrame, MaskInstance
23+
from ...annotation_types.mmc import BaseMMCAnnotation
2224
from .classification import NDClassification, NDSubclassification, NDSubclassificationType
2325
from .base import DataRow, NDAnnotation, NDJsonBase
2426

@@ -666,8 +668,53 @@ def from_common(
666668
custom_metrics=custom_metrics)
667669

668670

671+
class NDMessageTask(NDAnnotation):
672+
673+
class MessageTaskData(_CamelCaseMixin):
674+
format: str
675+
data: BaseMMCAnnotation
676+
677+
message_evaluation_task: MessageTaskData
678+
679+
@classmethod
680+
def from_common(
681+
cls,
682+
uuid: str,
683+
annotation_value: BaseMMCAnnotation,
684+
classifications: List[ClassificationAnnotation],
685+
name: str,
686+
feature_schema_id: Cuid,
687+
extra: Dict[str, Any],
688+
data: Union[ImageData, TextData],
689+
confidence: Optional[float] = None,
690+
custom_metrics: Optional[List[CustomMetric]] = None
691+
) -> "NDMessageTask":
692+
return cls(
693+
uuid=uuid,
694+
name=name,
695+
dataRow=DataRow(id=data.uid, global_key=data.global_key),
696+
message_evaluation_task=cls.MessageTaskData(
697+
format=annotation_value.format,
698+
data=annotation_value
699+
)
700+
)
701+
702+
669703
class NDObject:
670704

705+
_ANNOTATION_TO_SERIALIZER = {
706+
Line: NDLine,
707+
Point: NDPoint,
708+
Polygon: NDPolygon,
709+
Rectangle: NDRectangle,
710+
DocumentRectangle: NDDocumentRectangle,
711+
Mask: NDMask,
712+
TextEntity: NDTextEntity,
713+
DocumentEntity: NDDocumentEntity,
714+
ConversationEntity: NDConversationEntity,
715+
BaseMMCAnnotation: NDMessageTask
716+
}
717+
671718
@staticmethod
672719
def to_common(annotation: "NDObjectType") -> ObjectAnnotation:
673720
common_annotation = annotation.to_common()
@@ -755,17 +802,11 @@ def lookup_object(
755802
else:
756803
result = NDSegments
757804
else:
758-
result = {
759-
Line: NDLine,
760-
Point: NDPoint,
761-
Polygon: NDPolygon,
762-
Rectangle: NDRectangle,
763-
DocumentRectangle: NDDocumentRectangle,
764-
Mask: NDMask,
765-
TextEntity: NDTextEntity,
766-
DocumentEntity: NDDocumentEntity,
767-
ConversationEntity: NDConversationEntity,
768-
}.get(type(annotation.value))
805+
result = next((
806+
serializer_class
807+
for annotation_class, serializer_class in NDObject._ANNOTATION_TO_SERIALIZER.items()
808+
if isinstance(annotation.value, annotation_class)
809+
), None)
769810
if result is None:
770811
raise TypeError(
771812
f"Unable to convert object to MAL format. `{type(annotation.value)}`"

0 commit comments

Comments
 (0)