Skip to content

[PLT-1207] Added prompt classification for python object support #1700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@
from .data.tiled_image import TiledBounds
from .data.tiled_image import TiledImageData
from .data.tiled_image import TileLayer

from .llm_prompt_response.prompt import PromptText
from .llm_prompt_response.prompt import PromptClassificationAnnotation
13 changes: 11 additions & 2 deletions libs/labelbox/src/labelbox/data/annotation_types/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from labelbox.schema import ontology
from .annotation import ClassificationAnnotation, ObjectAnnotation
from .relationship import RelationshipAnnotation
from .llm_prompt_response.prompt import PromptClassificationAnnotation
from .classification import ClassificationAnswer
from .data import AudioData, ConversationData, DicomData, DocumentData, HTMLData, ImageData, TextData, VideoData, LlmPromptCreationData, LlmPromptResponseCreationData, LlmResponseCreationData
from .geometry import Mask
Expand Down Expand Up @@ -50,7 +51,8 @@ class Label(pydantic_compat.BaseModel):
annotations: List[Union[ClassificationAnnotation, ObjectAnnotation,
VideoMaskAnnotation, ScalarMetric,
ConfusionMatrixMetric,
RelationshipAnnotation]] = []
RelationshipAnnotation,
PromptClassificationAnnotation]] = []
extra: Dict[str, Any] = {}

@pydantic_compat.root_validator(pre=True)
Expand Down Expand Up @@ -209,10 +211,17 @@ def validate_union(cls, value):
])
if not isinstance(value, list):
raise TypeError(f"Annotations must be a list. Found {type(value)}")

prompt_count = 0
for v in value:
if not isinstance(v, supported):
raise TypeError(
f"Annotations should be a list containing the following classes : {supported}. Found {type(v)}"
)
# Validates only one prompt annotation is included
if isinstance(v, PromptClassificationAnnotation):
prompt_count+=1
if prompt_count > 1:
raise TypeError(
f"Only one prompt annotation is allowed per label"
)
return value
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .prompt import PromptText
from .prompt import PromptClassificationAnnotation
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Union

from labelbox.data.annotation_types.base_annotation import BaseAnnotation

from labelbox.data.mixins import ConfidenceMixin, CustomMetricsMixin

from labelbox import pydantic_compat


class PromptText(ConfidenceMixin, CustomMetricsMixin, pydantic_compat.BaseModel):
""" Prompt text for LLM data generation

>>> PromptText(answer = "some text answer",
>>> confidence = 0.5,
>>> custom_metrics = [
>>> {
>>> "name": "iou",
>>> "value": 0.1
>>> }])
"""
answer: str


class PromptClassificationAnnotation(BaseAnnotation, ConfidenceMixin,
CustomMetricsMixin):
"""Prompt annotation (non localized)

>>> PromptClassificationAnnotation(
>>> value=PromptText(answer="my caption message"),
>>> feature_schema_id="my-feature-schema-id"
>>> )

Args:
name (Optional[str])
feature_schema_id (Optional[Cuid])
value (Union[Text])
"""

value: PromptText
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from labelbox.utils import camel_case
from ...annotation_types.annotation import ClassificationAnnotation
from ...annotation_types.video import VideoClassificationAnnotation
from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation, PromptText
from ...annotation_types.classification.classification import ClassificationAnswer, Dropdown, Text, Checklist, Radio
from ...annotation_types.types import Cuid
from ...annotation_types.data import TextData, VideoData, ImageData
Expand Down Expand Up @@ -150,6 +151,26 @@ def from_common(cls, radio: Radio, name: str,
schema_id=feature_schema_id)


class NDPromptTextSubclass(NDAnswer):
answer: str

def to_common(self) -> PromptText:
return PromptText(answer=self.answer,
confidence=self.confidence,
custom_metrics=self.custom_metrics)

@classmethod
def from_common(cls, prompt_text: PromptText, name: str,
feature_schema_id: Cuid) -> "NDPromptTextSubclass":
return cls(
answer=prompt_text.answer,
name=name,
schema_id=feature_schema_id,
confidence=prompt_text.confidence,
custom_metrics=prompt_text.custom_metrics,
)


# ====== End of subclasses


Expand Down Expand Up @@ -242,6 +263,28 @@ def from_common(
frames=extra.get('frames'),
message_id=message_id,
confidence=confidence)


class NDPromptText(NDAnnotation, NDPromptTextSubclass):

@classmethod
def from_common(
cls,
uuid: str,
text: PromptText,
name,
data: Dict,
feature_schema_id: Cuid,
confidence: Optional[float] = None
) -> "NDPromptText":
return cls(
answer=text.answer,
data_row=DataRow(id=data.uid, global_key=data.global_key),
name=name,
schema_id=feature_schema_id,
uuid=uuid,
confidence=text.confidence,
custom_metrics=text.custom_metrics)


class NDSubclassification:
Expand Down Expand Up @@ -333,6 +376,33 @@ def lookup_classification(
Radio: NDRadio
}.get(type(annotation.value))

class NDPromptClassification:

@staticmethod
def to_common(
annotation: "NDPromptClassificationType"
) -> Union[PromptClassificationAnnotation]:
common = PromptClassificationAnnotation(
value=annotation,
name=annotation.name,
feature_schema_id=annotation.schema_id,
extra={'uuid': annotation.uuid},
confidence=annotation.confidence,
)

return common

@classmethod
def from_common(
cls, annotation: Union[PromptClassificationAnnotation],
data: Union[VideoData, TextData, ImageData]
) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]:
return NDPromptText.from_common(str(annotation._uuid), annotation.value,
annotation.name,
data,
annotation.feature_schema_id,
annotation.confidence)


# Make sure to keep NDChecklistSubclass prior to NDRadioSubclass in the list,
# otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used
Expand All @@ -345,8 +415,10 @@ def lookup_classification(
NDRadioSubclass.update_forward_refs()
NDRadio.update_forward_refs()
NDText.update_forward_refs()
NDPromptText.update_forward_refs()
NDTextSubclass.update_forward_refs()

# Make sure to keep NDChecklist prior to NDRadio in the list,
# otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used
NDClassificationType = Union[NDChecklist, NDRadio, NDText]
NDPromptClassificationType = Union[NDPromptText]
16 changes: 12 additions & 4 deletions libs/labelbox/src/labelbox/data/serialization/ndjson/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,23 @@
from ...annotation_types.video import VideoObjectAnnotation, VideoMaskAnnotation
from ...annotation_types.collection import LabelCollection, LabelGenerator
from ...annotation_types.data import DicomData, ImageData, TextData, VideoData
from ...annotation_types.data.generic_data_row_data import GenericDataRowData
from ...annotation_types.label import Label
from ...annotation_types.ner import TextEntity, ConversationEntity
from ...annotation_types.classification import Dropdown
from ...annotation_types.metrics import ScalarMetric, ConfusionMatrixMetric
from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation

from .metric import NDScalarMetric, NDMetricAnnotation, NDConfusionMatrixMetric
from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass
from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass, NDPromptClassification, NDPromptClassificationType, NDPromptText
from .objects import NDObject, NDObjectType, NDSegments, NDDicomSegments, NDVideoMasks, NDDicomMasks
from .relationship import NDRelationship
from .base import DataRow

AnnotationType = Union[NDObjectType, NDClassificationType,
AnnotationType = Union[NDObjectType, NDClassificationType, NDPromptClassificationType,
NDConfusionMatrixMetric, NDScalarMetric, NDDicomSegments,
NDSegments, NDDicomMasks, NDVideoMasks, NDRelationship]
NDSegments, NDDicomMasks, NDVideoMasks, NDRelationship,
NDPromptText]


class NDLabel(pydantic_compat.BaseModel):
Expand Down Expand Up @@ -120,6 +123,9 @@ def _generate_annotations(
(NDScalarMetric, NDConfusionMatrixMetric)):
annotations.append(
NDMetricAnnotation.to_common(ndjson_annotation))
elif isinstance(ndjson_annotation, NDPromptClassificationType):
annotation = NDPromptClassification.to_common(ndjson_annotation)
annotations.append(annotation)
else:
raise TypeError(
f"Unsupported annotation. {type(ndjson_annotation)}")
Expand Down Expand Up @@ -156,7 +162,7 @@ def _infer_media_type(
raise ValueError("Missing annotations while inferring media type")

types = {type(annotation) for annotation in annotations}
data = ImageData
data = GenericDataRowData
if (TextEntity in types) or (ConversationEntity in types):
data = TextData
elif VideoClassificationAnnotation in types or VideoObjectAnnotation in types:
Expand Down Expand Up @@ -269,6 +275,8 @@ def _create_non_video_annotations(cls, label: Label):
yield NDMetricAnnotation.from_common(annotation, label.data)
elif isinstance(annotation, RelationshipAnnotation):
yield NDRelationship.from_common(annotation, label.data)
elif isinstance(annotation, PromptClassificationAnnotation):
yield NDPromptClassification.from_common(annotation, label.data)
else:
raise TypeError(
f"Unable to convert object to MAL format. `{type(getattr(annotation, 'value',annotation))}`"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,4 @@ def test_import_mal_annotations_global_key(client,

assert import_annotations.errors == []
# MAL Labels cannot be exported and compared to input labels

17 changes: 17 additions & 0 deletions libs/labelbox/tests/data/annotation_types/test_label.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from labelbox.pydantic_compat import ValidationError
import numpy as np

import labelbox.types as lb_types
from labelbox import OntologyBuilder, Tool, Classification as OClassification, Option
from labelbox.data.annotation_types import (ClassificationAnswer, Radio, Text,
ClassificationAnnotation,
PromptText,
ObjectAnnotation, Point, Line,
ImageData, Label)
import pytest


def test_schema_assignment_geometry():
Expand Down Expand Up @@ -193,3 +196,17 @@ def test_initialize_label_no_coercion():
annotations=[ner_annotation])
assert isinstance(label.data, lb_types.ConversationData)
assert label.data.global_key == global_key

def test_prompt_classification_validation():
global_key = 'global-key'
prompt_text = lb_types.PromptClassificationAnnotation(
name="prompt text",
value=PromptText(answer="test")
)
prompt_text_2 = lb_types.PromptClassificationAnnotation(
name="prompt text",
value=PromptText(answer="test")
)
with pytest.raises(ValidationError) as e_info:
label = Label(data={"global_key": global_key},
annotations=[prompt_text, prompt_text_2])
57 changes: 57 additions & 0 deletions libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from copy import copy
import pytest
import labelbox.types as lb_types
from labelbox.data.serialization import NDJsonConverter
from labelbox.data.serialization.ndjson.objects import NDDicomSegments, NDDicomSegment, NDDicomLine
"""
Data gen prompt test data
"""

prompt_text_annotation = lb_types.PromptClassificationAnnotation(
feature_schema_id="ckrb1sfkn099c0y910wbo0p1a",
name="test",
value=lb_types.PromptText(answer="the answer to the text questions right here"),
)

prompt_text_ndjson = {
"answer": "the answer to the text questions right here",
"name": "test",
"schemaId": "ckrb1sfkn099c0y910wbo0p1a",
"dataRow": {
"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"
},
}

data_gen_label = lb_types.Label(
data={"uid": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
annotations=[prompt_text_annotation]
)

"""
Prompt annotation test
"""

def test_serialize_label():
serialized_label = next(NDJsonConverter().serialize([data_gen_label]))
# Remove uuid field since this is a random value that can not be specified also meant for relationships
del serialized_label["uuid"]
assert serialized_label == prompt_text_ndjson


def test_deserialize_label():
deserialized_label = next(NDJsonConverter().deserialize([prompt_text_ndjson]))
if hasattr(deserialized_label.annotations[0], 'extra'):
# Extra fields are added to deserialized label by default need removed to match
deserialized_label.annotations[0].extra = {}
assert deserialized_label.annotations == data_gen_label.annotations


def test_serialize_deserialize_label():
serialized = list(NDJsonConverter.serialize([data_gen_label]))
deserialized = next(NDJsonConverter.deserialize(serialized))
if hasattr(deserialized.annotations[0], 'extra'):
# Extra fields are added to deserialized label by default need removed to match
deserialized.annotations[0].extra = {}
print(data_gen_label.annotations)
print(deserialized.annotations)
assert deserialized.annotations == data_gen_label.annotations
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_rectangle_inverted_start_end_points():
),
extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"})

label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID),
label = lb_types.Label(data={"uid":DATAROW_ID},
annotations=[bbox])

res = list(NDJsonConverter.serialize([label]))
Expand All @@ -43,7 +43,7 @@ def test_rectangle_inverted_start_end_points():
"unit": None
})

label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID),
label = lb_types.Label(data={"uid":DATAROW_ID},
annotations=[expected_bbox])

res = list(NDJsonConverter.deserialize(res))
Expand All @@ -62,7 +62,7 @@ def test_rectangle_mixed_start_end_points():
),
extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"})

label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID),
label = lb_types.Label(data={"uid":DATAROW_ID},
annotations=[bbox])

res = list(NDJsonConverter.serialize([label]))
Expand All @@ -80,7 +80,7 @@ def test_rectangle_mixed_start_end_points():
"unit": None
})

label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID),
label = lb_types.Label(data={"uid":DATAROW_ID},
annotations=[bbox])

res = list(NDJsonConverter.deserialize(res))
Expand Down