diff --git a/libs/labelbox/src/labelbox/data/annotation_types/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/__init__.py index 6eaa205bc..85b7ae1af 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/__init__.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/__init__.py @@ -61,3 +61,6 @@ from .data.tiled_image import TiledBounds from .data.tiled_image import TiledImageData from .data.tiled_image import TileLayer + +from .llm_prompt_response.prompt import PromptText +from .llm_prompt_response.prompt import PromptClassificationAnnotation \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/data/annotation_types/label.py b/libs/labelbox/src/labelbox/data/annotation_types/label.py index f31dbdcda..cd209a493 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/label.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/label.py @@ -10,6 +10,7 @@ from labelbox.schema import ontology from .annotation import ClassificationAnnotation, ObjectAnnotation from .relationship import RelationshipAnnotation +from .llm_prompt_response.prompt import PromptClassificationAnnotation from .classification import ClassificationAnswer from .data import AudioData, ConversationData, DicomData, DocumentData, HTMLData, ImageData, TextData, VideoData, LlmPromptCreationData, LlmPromptResponseCreationData, LlmResponseCreationData from .geometry import Mask @@ -50,7 +51,8 @@ class Label(pydantic_compat.BaseModel): annotations: List[Union[ClassificationAnnotation, ObjectAnnotation, VideoMaskAnnotation, ScalarMetric, ConfusionMatrixMetric, - RelationshipAnnotation]] = [] + RelationshipAnnotation, + PromptClassificationAnnotation]] = [] extra: Dict[str, Any] = {} @pydantic_compat.root_validator(pre=True) @@ -209,10 +211,17 @@ def validate_union(cls, value): ]) if not isinstance(value, list): raise TypeError(f"Annotations must be a list. Found {type(value)}") - + prompt_count = 0 for v in value: if not isinstance(v, supported): raise TypeError( f"Annotations should be a list containing the following classes : {supported}. Found {type(v)}" ) + # Validates only one prompt annotation is included + if isinstance(v, PromptClassificationAnnotation): + prompt_count+=1 + if prompt_count > 1: + raise TypeError( + f"Only one prompt annotation is allowed per label" + ) return value diff --git a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py new file mode 100644 index 000000000..7c0b63abc --- /dev/null +++ b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py @@ -0,0 +1,2 @@ +from .prompt import PromptText +from .prompt import PromptClassificationAnnotation \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py new file mode 100644 index 000000000..c235526b0 --- /dev/null +++ b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py @@ -0,0 +1,39 @@ +from typing import Union + +from labelbox.data.annotation_types.base_annotation import BaseAnnotation + +from labelbox.data.mixins import ConfidenceMixin, CustomMetricsMixin + +from labelbox import pydantic_compat + + +class PromptText(ConfidenceMixin, CustomMetricsMixin, pydantic_compat.BaseModel): + """ Prompt text for LLM data generation + + >>> PromptText(answer = "some text answer", + >>> confidence = 0.5, + >>> custom_metrics = [ + >>> { + >>> "name": "iou", + >>> "value": 0.1 + >>> }]) + """ + answer: str + + +class PromptClassificationAnnotation(BaseAnnotation, ConfidenceMixin, + CustomMetricsMixin): + """Prompt annotation (non localized) + + >>> PromptClassificationAnnotation( + >>> value=PromptText(answer="my caption message"), + >>> feature_schema_id="my-feature-schema-id" + >>> ) + + Args: + name (Optional[str]) + feature_schema_id (Optional[Cuid]) + value (Union[Text]) + """ + + value: PromptText \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py index 028eeded8..46b8fc91f 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py @@ -7,6 +7,7 @@ from labelbox.utils import camel_case from ...annotation_types.annotation import ClassificationAnnotation from ...annotation_types.video import VideoClassificationAnnotation +from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation, PromptText from ...annotation_types.classification.classification import ClassificationAnswer, Dropdown, Text, Checklist, Radio from ...annotation_types.types import Cuid from ...annotation_types.data import TextData, VideoData, ImageData @@ -150,6 +151,26 @@ def from_common(cls, radio: Radio, name: str, schema_id=feature_schema_id) +class NDPromptTextSubclass(NDAnswer): + answer: str + + def to_common(self) -> PromptText: + return PromptText(answer=self.answer, + confidence=self.confidence, + custom_metrics=self.custom_metrics) + + @classmethod + def from_common(cls, prompt_text: PromptText, name: str, + feature_schema_id: Cuid) -> "NDPromptTextSubclass": + return cls( + answer=prompt_text.answer, + name=name, + schema_id=feature_schema_id, + confidence=prompt_text.confidence, + custom_metrics=prompt_text.custom_metrics, + ) + + # ====== End of subclasses @@ -242,6 +263,28 @@ def from_common( frames=extra.get('frames'), message_id=message_id, confidence=confidence) + + +class NDPromptText(NDAnnotation, NDPromptTextSubclass): + + @classmethod + def from_common( + cls, + uuid: str, + text: PromptText, + name, + data: Dict, + feature_schema_id: Cuid, + confidence: Optional[float] = None + ) -> "NDPromptText": + return cls( + answer=text.answer, + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + confidence=text.confidence, + custom_metrics=text.custom_metrics) class NDSubclassification: @@ -333,6 +376,33 @@ def lookup_classification( Radio: NDRadio }.get(type(annotation.value)) +class NDPromptClassification: + + @staticmethod + def to_common( + annotation: "NDPromptClassificationType" + ) -> Union[PromptClassificationAnnotation]: + common = PromptClassificationAnnotation( + value=annotation, + name=annotation.name, + feature_schema_id=annotation.schema_id, + extra={'uuid': annotation.uuid}, + confidence=annotation.confidence, + ) + + return common + + @classmethod + def from_common( + cls, annotation: Union[PromptClassificationAnnotation], + data: Union[VideoData, TextData, ImageData] + ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: + return NDPromptText.from_common(str(annotation._uuid), annotation.value, + annotation.name, + data, + annotation.feature_schema_id, + annotation.confidence) + # Make sure to keep NDChecklistSubclass prior to NDRadioSubclass in the list, # otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used @@ -345,8 +415,10 @@ def lookup_classification( NDRadioSubclass.update_forward_refs() NDRadio.update_forward_refs() NDText.update_forward_refs() +NDPromptText.update_forward_refs() NDTextSubclass.update_forward_refs() # Make sure to keep NDChecklist prior to NDRadio in the list, # otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used NDClassificationType = Union[NDChecklist, NDRadio, NDText] +NDPromptClassificationType = Union[NDPromptText] \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index 1b649a80e..9d34c451b 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -12,20 +12,23 @@ from ...annotation_types.video import VideoObjectAnnotation, VideoMaskAnnotation from ...annotation_types.collection import LabelCollection, LabelGenerator from ...annotation_types.data import DicomData, ImageData, TextData, VideoData +from ...annotation_types.data.generic_data_row_data import GenericDataRowData from ...annotation_types.label import Label from ...annotation_types.ner import TextEntity, ConversationEntity from ...annotation_types.classification import Dropdown from ...annotation_types.metrics import ScalarMetric, ConfusionMatrixMetric +from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation from .metric import NDScalarMetric, NDMetricAnnotation, NDConfusionMatrixMetric -from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass +from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass, NDPromptClassification, NDPromptClassificationType, NDPromptText from .objects import NDObject, NDObjectType, NDSegments, NDDicomSegments, NDVideoMasks, NDDicomMasks from .relationship import NDRelationship from .base import DataRow -AnnotationType = Union[NDObjectType, NDClassificationType, +AnnotationType = Union[NDObjectType, NDClassificationType, NDPromptClassificationType, NDConfusionMatrixMetric, NDScalarMetric, NDDicomSegments, - NDSegments, NDDicomMasks, NDVideoMasks, NDRelationship] + NDSegments, NDDicomMasks, NDVideoMasks, NDRelationship, + NDPromptText] class NDLabel(pydantic_compat.BaseModel): @@ -120,6 +123,9 @@ def _generate_annotations( (NDScalarMetric, NDConfusionMatrixMetric)): annotations.append( NDMetricAnnotation.to_common(ndjson_annotation)) + elif isinstance(ndjson_annotation, NDPromptClassificationType): + annotation = NDPromptClassification.to_common(ndjson_annotation) + annotations.append(annotation) else: raise TypeError( f"Unsupported annotation. {type(ndjson_annotation)}") @@ -156,7 +162,7 @@ def _infer_media_type( raise ValueError("Missing annotations while inferring media type") types = {type(annotation) for annotation in annotations} - data = ImageData + data = GenericDataRowData if (TextEntity in types) or (ConversationEntity in types): data = TextData elif VideoClassificationAnnotation in types or VideoObjectAnnotation in types: @@ -269,6 +275,8 @@ def _create_non_video_annotations(cls, label: Label): yield NDMetricAnnotation.from_common(annotation, label.data) elif isinstance(annotation, RelationshipAnnotation): yield NDRelationship.from_common(annotation, label.data) + elif isinstance(annotation, PromptClassificationAnnotation): + yield NDPromptClassification.from_common(annotation, label.data) else: raise TypeError( f"Unable to convert object to MAL format. `{type(getattr(annotation, 'value',annotation))}`" diff --git a/libs/labelbox/tests/data/annotation_import/test_data_types.py b/libs/labelbox/tests/data/annotation_import/test_data_types.py index 880c4896d..d607c4a3c 100644 --- a/libs/labelbox/tests/data/annotation_import/test_data_types.py +++ b/libs/labelbox/tests/data/annotation_import/test_data_types.py @@ -423,3 +423,4 @@ def test_import_mal_annotations_global_key(client, assert import_annotations.errors == [] # MAL Labels cannot be exported and compared to input labels + \ No newline at end of file diff --git a/libs/labelbox/tests/data/annotation_types/test_label.py b/libs/labelbox/tests/data/annotation_types/test_label.py index ee83b1b50..a6947cd4b 100644 --- a/libs/labelbox/tests/data/annotation_types/test_label.py +++ b/libs/labelbox/tests/data/annotation_types/test_label.py @@ -1,11 +1,14 @@ +from labelbox.pydantic_compat import ValidationError import numpy as np import labelbox.types as lb_types from labelbox import OntologyBuilder, Tool, Classification as OClassification, Option from labelbox.data.annotation_types import (ClassificationAnswer, Radio, Text, ClassificationAnnotation, + PromptText, ObjectAnnotation, Point, Line, ImageData, Label) +import pytest def test_schema_assignment_geometry(): @@ -193,3 +196,17 @@ def test_initialize_label_no_coercion(): annotations=[ner_annotation]) assert isinstance(label.data, lb_types.ConversationData) assert label.data.global_key == global_key + +def test_prompt_classification_validation(): + global_key = 'global-key' + prompt_text = lb_types.PromptClassificationAnnotation( + name="prompt text", + value=PromptText(answer="test") + ) + prompt_text_2 = lb_types.PromptClassificationAnnotation( + name="prompt text", + value=PromptText(answer="test") + ) + with pytest.raises(ValidationError) as e_info: + label = Label(data={"global_key": global_key}, + annotations=[prompt_text, prompt_text_2]) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py b/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py new file mode 100644 index 000000000..7b7b33994 --- /dev/null +++ b/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py @@ -0,0 +1,57 @@ +from copy import copy +import pytest +import labelbox.types as lb_types +from labelbox.data.serialization import NDJsonConverter +from labelbox.data.serialization.ndjson.objects import NDDicomSegments, NDDicomSegment, NDDicomLine +""" +Data gen prompt test data +""" + +prompt_text_annotation = lb_types.PromptClassificationAnnotation( + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + name="test", + value=lb_types.PromptText(answer="the answer to the text questions right here"), + ) + +prompt_text_ndjson = { + "answer": "the answer to the text questions right here", + "name": "test", + "schemaId": "ckrb1sfkn099c0y910wbo0p1a", + "dataRow": { + "id": "ckrb1sf1i1g7i0ybcdc6oc8ct" + }, + } + +data_gen_label = lb_types.Label( + data={"uid": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + annotations=[prompt_text_annotation] +) + +""" +Prompt annotation test +""" + +def test_serialize_label(): + serialized_label = next(NDJsonConverter().serialize([data_gen_label])) + # Remove uuid field since this is a random value that can not be specified also meant for relationships + del serialized_label["uuid"] + assert serialized_label == prompt_text_ndjson + + +def test_deserialize_label(): + deserialized_label = next(NDJsonConverter().deserialize([prompt_text_ndjson])) + if hasattr(deserialized_label.annotations[0], 'extra'): + # Extra fields are added to deserialized label by default need removed to match + deserialized_label.annotations[0].extra = {} + assert deserialized_label.annotations == data_gen_label.annotations + + +def test_serialize_deserialize_label(): + serialized = list(NDJsonConverter.serialize([data_gen_label])) + deserialized = next(NDJsonConverter.deserialize(serialized)) + if hasattr(deserialized.annotations[0], 'extra'): + # Extra fields are added to deserialized label by default need removed to match + deserialized.annotations[0].extra = {} + print(data_gen_label.annotations) + print(deserialized.annotations) + assert deserialized.annotations == data_gen_label.annotations diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py b/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py index 0764e2988..73099c12f 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py @@ -25,7 +25,7 @@ def test_rectangle_inverted_start_end_points(): ), extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}) - label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID), + label = lb_types.Label(data={"uid":DATAROW_ID}, annotations=[bbox]) res = list(NDJsonConverter.serialize([label])) @@ -43,7 +43,7 @@ def test_rectangle_inverted_start_end_points(): "unit": None }) - label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID), + label = lb_types.Label(data={"uid":DATAROW_ID}, annotations=[expected_bbox]) res = list(NDJsonConverter.deserialize(res)) @@ -62,7 +62,7 @@ def test_rectangle_mixed_start_end_points(): ), extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}) - label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID), + label = lb_types.Label(data={"uid":DATAROW_ID}, annotations=[bbox]) res = list(NDJsonConverter.serialize([label])) @@ -80,7 +80,7 @@ def test_rectangle_mixed_start_end_points(): "unit": None }) - label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID), + label = lb_types.Label(data={"uid":DATAROW_ID}, annotations=[bbox]) res = list(NDJsonConverter.deserialize(res))