Skip to content

Commit 7746479

Browse files
authored
[PLT-1207] Added prompt classification for python object support (#1700)
1 parent 259272c commit 7746479

File tree

10 files changed

+218
-10
lines changed

10 files changed

+218
-10
lines changed

libs/labelbox/src/labelbox/data/annotation_types/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,6 @@
6161
from .data.tiled_image import TiledBounds
6262
from .data.tiled_image import TiledImageData
6363
from .data.tiled_image import TileLayer
64+
65+
from .llm_prompt_response.prompt import PromptText
66+
from .llm_prompt_response.prompt import PromptClassificationAnnotation

libs/labelbox/src/labelbox/data/annotation_types/label.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from labelbox.schema import ontology
1111
from .annotation import ClassificationAnnotation, ObjectAnnotation
1212
from .relationship import RelationshipAnnotation
13+
from .llm_prompt_response.prompt import PromptClassificationAnnotation
1314
from .classification import ClassificationAnswer
1415
from .data import AudioData, ConversationData, DicomData, DocumentData, HTMLData, ImageData, TextData, VideoData, LlmPromptCreationData, LlmPromptResponseCreationData, LlmResponseCreationData
1516
from .geometry import Mask
@@ -50,7 +51,8 @@ class Label(pydantic_compat.BaseModel):
5051
annotations: List[Union[ClassificationAnnotation, ObjectAnnotation,
5152
VideoMaskAnnotation, ScalarMetric,
5253
ConfusionMatrixMetric,
53-
RelationshipAnnotation]] = []
54+
RelationshipAnnotation,
55+
PromptClassificationAnnotation]] = []
5456
extra: Dict[str, Any] = {}
5557

5658
@pydantic_compat.root_validator(pre=True)
@@ -209,10 +211,17 @@ def validate_union(cls, value):
209211
])
210212
if not isinstance(value, list):
211213
raise TypeError(f"Annotations must be a list. Found {type(value)}")
212-
214+
prompt_count = 0
213215
for v in value:
214216
if not isinstance(v, supported):
215217
raise TypeError(
216218
f"Annotations should be a list containing the following classes : {supported}. Found {type(v)}"
217219
)
220+
# Validates only one prompt annotation is included
221+
if isinstance(v, PromptClassificationAnnotation):
222+
prompt_count+=1
223+
if prompt_count > 1:
224+
raise TypeError(
225+
f"Only one prompt annotation is allowed per label"
226+
)
218227
return value
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .prompt import PromptText
2+
from .prompt import PromptClassificationAnnotation
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Union
2+
3+
from labelbox.data.annotation_types.base_annotation import BaseAnnotation
4+
5+
from labelbox.data.mixins import ConfidenceMixin, CustomMetricsMixin
6+
7+
from labelbox import pydantic_compat
8+
9+
10+
class PromptText(ConfidenceMixin, CustomMetricsMixin, pydantic_compat.BaseModel):
11+
""" Prompt text for LLM data generation
12+
13+
>>> PromptText(answer = "some text answer",
14+
>>> confidence = 0.5,
15+
>>> custom_metrics = [
16+
>>> {
17+
>>> "name": "iou",
18+
>>> "value": 0.1
19+
>>> }])
20+
"""
21+
answer: str
22+
23+
24+
class PromptClassificationAnnotation(BaseAnnotation, ConfidenceMixin,
25+
CustomMetricsMixin):
26+
"""Prompt annotation (non localized)
27+
28+
>>> PromptClassificationAnnotation(
29+
>>> value=PromptText(answer="my caption message"),
30+
>>> feature_schema_id="my-feature-schema-id"
31+
>>> )
32+
33+
Args:
34+
name (Optional[str])
35+
feature_schema_id (Optional[Cuid])
36+
value (Union[Text])
37+
"""
38+
39+
value: PromptText

libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from labelbox.utils import camel_case
88
from ...annotation_types.annotation import ClassificationAnnotation
99
from ...annotation_types.video import VideoClassificationAnnotation
10+
from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation, PromptText
1011
from ...annotation_types.classification.classification import ClassificationAnswer, Dropdown, Text, Checklist, Radio
1112
from ...annotation_types.types import Cuid
1213
from ...annotation_types.data import TextData, VideoData, ImageData
@@ -150,6 +151,26 @@ def from_common(cls, radio: Radio, name: str,
150151
schema_id=feature_schema_id)
151152

152153

154+
class NDPromptTextSubclass(NDAnswer):
155+
answer: str
156+
157+
def to_common(self) -> PromptText:
158+
return PromptText(answer=self.answer,
159+
confidence=self.confidence,
160+
custom_metrics=self.custom_metrics)
161+
162+
@classmethod
163+
def from_common(cls, prompt_text: PromptText, name: str,
164+
feature_schema_id: Cuid) -> "NDPromptTextSubclass":
165+
return cls(
166+
answer=prompt_text.answer,
167+
name=name,
168+
schema_id=feature_schema_id,
169+
confidence=prompt_text.confidence,
170+
custom_metrics=prompt_text.custom_metrics,
171+
)
172+
173+
153174
# ====== End of subclasses
154175

155176

@@ -242,6 +263,28 @@ def from_common(
242263
frames=extra.get('frames'),
243264
message_id=message_id,
244265
confidence=confidence)
266+
267+
268+
class NDPromptText(NDAnnotation, NDPromptTextSubclass):
269+
270+
@classmethod
271+
def from_common(
272+
cls,
273+
uuid: str,
274+
text: PromptText,
275+
name,
276+
data: Dict,
277+
feature_schema_id: Cuid,
278+
confidence: Optional[float] = None
279+
) -> "NDPromptText":
280+
return cls(
281+
answer=text.answer,
282+
data_row=DataRow(id=data.uid, global_key=data.global_key),
283+
name=name,
284+
schema_id=feature_schema_id,
285+
uuid=uuid,
286+
confidence=text.confidence,
287+
custom_metrics=text.custom_metrics)
245288

246289

247290
class NDSubclassification:
@@ -333,6 +376,33 @@ def lookup_classification(
333376
Radio: NDRadio
334377
}.get(type(annotation.value))
335378

379+
class NDPromptClassification:
380+
381+
@staticmethod
382+
def to_common(
383+
annotation: "NDPromptClassificationType"
384+
) -> Union[PromptClassificationAnnotation]:
385+
common = PromptClassificationAnnotation(
386+
value=annotation,
387+
name=annotation.name,
388+
feature_schema_id=annotation.schema_id,
389+
extra={'uuid': annotation.uuid},
390+
confidence=annotation.confidence,
391+
)
392+
393+
return common
394+
395+
@classmethod
396+
def from_common(
397+
cls, annotation: Union[PromptClassificationAnnotation],
398+
data: Union[VideoData, TextData, ImageData]
399+
) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]:
400+
return NDPromptText.from_common(str(annotation._uuid), annotation.value,
401+
annotation.name,
402+
data,
403+
annotation.feature_schema_id,
404+
annotation.confidence)
405+
336406

337407
# Make sure to keep NDChecklistSubclass prior to NDRadioSubclass in the list,
338408
# otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used
@@ -345,8 +415,10 @@ def lookup_classification(
345415
NDRadioSubclass.update_forward_refs()
346416
NDRadio.update_forward_refs()
347417
NDText.update_forward_refs()
418+
NDPromptText.update_forward_refs()
348419
NDTextSubclass.update_forward_refs()
349420

350421
# Make sure to keep NDChecklist prior to NDRadio in the list,
351422
# otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used
352423
NDClassificationType = Union[NDChecklist, NDRadio, NDText]
424+
NDPromptClassificationType = Union[NDPromptText]

libs/labelbox/src/labelbox/data/serialization/ndjson/label.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,23 @@
1212
from ...annotation_types.video import VideoObjectAnnotation, VideoMaskAnnotation
1313
from ...annotation_types.collection import LabelCollection, LabelGenerator
1414
from ...annotation_types.data import DicomData, ImageData, TextData, VideoData
15+
from ...annotation_types.data.generic_data_row_data import GenericDataRowData
1516
from ...annotation_types.label import Label
1617
from ...annotation_types.ner import TextEntity, ConversationEntity
1718
from ...annotation_types.classification import Dropdown
1819
from ...annotation_types.metrics import ScalarMetric, ConfusionMatrixMetric
20+
from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation
1921

2022
from .metric import NDScalarMetric, NDMetricAnnotation, NDConfusionMatrixMetric
21-
from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass
23+
from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass, NDPromptClassification, NDPromptClassificationType, NDPromptText
2224
from .objects import NDObject, NDObjectType, NDSegments, NDDicomSegments, NDVideoMasks, NDDicomMasks
2325
from .relationship import NDRelationship
2426
from .base import DataRow
2527

26-
AnnotationType = Union[NDObjectType, NDClassificationType,
28+
AnnotationType = Union[NDObjectType, NDClassificationType, NDPromptClassificationType,
2729
NDConfusionMatrixMetric, NDScalarMetric, NDDicomSegments,
28-
NDSegments, NDDicomMasks, NDVideoMasks, NDRelationship]
30+
NDSegments, NDDicomMasks, NDVideoMasks, NDRelationship,
31+
NDPromptText]
2932

3033

3134
class NDLabel(pydantic_compat.BaseModel):
@@ -120,6 +123,9 @@ def _generate_annotations(
120123
(NDScalarMetric, NDConfusionMatrixMetric)):
121124
annotations.append(
122125
NDMetricAnnotation.to_common(ndjson_annotation))
126+
elif isinstance(ndjson_annotation, NDPromptClassificationType):
127+
annotation = NDPromptClassification.to_common(ndjson_annotation)
128+
annotations.append(annotation)
123129
else:
124130
raise TypeError(
125131
f"Unsupported annotation. {type(ndjson_annotation)}")
@@ -156,7 +162,7 @@ def _infer_media_type(
156162
raise ValueError("Missing annotations while inferring media type")
157163

158164
types = {type(annotation) for annotation in annotations}
159-
data = ImageData
165+
data = GenericDataRowData
160166
if (TextEntity in types) or (ConversationEntity in types):
161167
data = TextData
162168
elif VideoClassificationAnnotation in types or VideoObjectAnnotation in types:
@@ -269,6 +275,8 @@ def _create_non_video_annotations(cls, label: Label):
269275
yield NDMetricAnnotation.from_common(annotation, label.data)
270276
elif isinstance(annotation, RelationshipAnnotation):
271277
yield NDRelationship.from_common(annotation, label.data)
278+
elif isinstance(annotation, PromptClassificationAnnotation):
279+
yield NDPromptClassification.from_common(annotation, label.data)
272280
else:
273281
raise TypeError(
274282
f"Unable to convert object to MAL format. `{type(getattr(annotation, 'value',annotation))}`"

libs/labelbox/tests/data/annotation_import/test_data_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,3 +423,4 @@ def test_import_mal_annotations_global_key(client,
423423

424424
assert import_annotations.errors == []
425425
# MAL Labels cannot be exported and compared to input labels
426+

libs/labelbox/tests/data/annotation_types/test_label.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
from labelbox.pydantic_compat import ValidationError
12
import numpy as np
23

34
import labelbox.types as lb_types
45
from labelbox import OntologyBuilder, Tool, Classification as OClassification, Option
56
from labelbox.data.annotation_types import (ClassificationAnswer, Radio, Text,
67
ClassificationAnnotation,
8+
PromptText,
79
ObjectAnnotation, Point, Line,
810
ImageData, Label)
11+
import pytest
912

1013

1114
def test_schema_assignment_geometry():
@@ -193,3 +196,17 @@ def test_initialize_label_no_coercion():
193196
annotations=[ner_annotation])
194197
assert isinstance(label.data, lb_types.ConversationData)
195198
assert label.data.global_key == global_key
199+
200+
def test_prompt_classification_validation():
201+
global_key = 'global-key'
202+
prompt_text = lb_types.PromptClassificationAnnotation(
203+
name="prompt text",
204+
value=PromptText(answer="test")
205+
)
206+
prompt_text_2 = lb_types.PromptClassificationAnnotation(
207+
name="prompt text",
208+
value=PromptText(answer="test")
209+
)
210+
with pytest.raises(ValidationError) as e_info:
211+
label = Label(data={"global_key": global_key},
212+
annotations=[prompt_text, prompt_text_2])
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from copy import copy
2+
import pytest
3+
import labelbox.types as lb_types
4+
from labelbox.data.serialization import NDJsonConverter
5+
from labelbox.data.serialization.ndjson.objects import NDDicomSegments, NDDicomSegment, NDDicomLine
6+
"""
7+
Data gen prompt test data
8+
"""
9+
10+
prompt_text_annotation = lb_types.PromptClassificationAnnotation(
11+
feature_schema_id="ckrb1sfkn099c0y910wbo0p1a",
12+
name="test",
13+
value=lb_types.PromptText(answer="the answer to the text questions right here"),
14+
)
15+
16+
prompt_text_ndjson = {
17+
"answer": "the answer to the text questions right here",
18+
"name": "test",
19+
"schemaId": "ckrb1sfkn099c0y910wbo0p1a",
20+
"dataRow": {
21+
"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"
22+
},
23+
}
24+
25+
data_gen_label = lb_types.Label(
26+
data={"uid": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
27+
annotations=[prompt_text_annotation]
28+
)
29+
30+
"""
31+
Prompt annotation test
32+
"""
33+
34+
def test_serialize_label():
35+
serialized_label = next(NDJsonConverter().serialize([data_gen_label]))
36+
# Remove uuid field since this is a random value that can not be specified also meant for relationships
37+
del serialized_label["uuid"]
38+
assert serialized_label == prompt_text_ndjson
39+
40+
41+
def test_deserialize_label():
42+
deserialized_label = next(NDJsonConverter().deserialize([prompt_text_ndjson]))
43+
if hasattr(deserialized_label.annotations[0], 'extra'):
44+
# Extra fields are added to deserialized label by default need removed to match
45+
deserialized_label.annotations[0].extra = {}
46+
assert deserialized_label.annotations == data_gen_label.annotations
47+
48+
49+
def test_serialize_deserialize_label():
50+
serialized = list(NDJsonConverter.serialize([data_gen_label]))
51+
deserialized = next(NDJsonConverter.deserialize(serialized))
52+
if hasattr(deserialized.annotations[0], 'extra'):
53+
# Extra fields are added to deserialized label by default need removed to match
54+
deserialized.annotations[0].extra = {}
55+
print(data_gen_label.annotations)
56+
print(deserialized.annotations)
57+
assert deserialized.annotations == data_gen_label.annotations

libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_rectangle_inverted_start_end_points():
2525
),
2626
extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"})
2727

28-
label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID),
28+
label = lb_types.Label(data={"uid":DATAROW_ID},
2929
annotations=[bbox])
3030

3131
res = list(NDJsonConverter.serialize([label]))
@@ -43,7 +43,7 @@ def test_rectangle_inverted_start_end_points():
4343
"unit": None
4444
})
4545

46-
label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID),
46+
label = lb_types.Label(data={"uid":DATAROW_ID},
4747
annotations=[expected_bbox])
4848

4949
res = list(NDJsonConverter.deserialize(res))
@@ -62,7 +62,7 @@ def test_rectangle_mixed_start_end_points():
6262
),
6363
extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"})
6464

65-
label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID),
65+
label = lb_types.Label(data={"uid":DATAROW_ID},
6666
annotations=[bbox])
6767

6868
res = list(NDJsonConverter.serialize([label]))
@@ -80,7 +80,7 @@ def test_rectangle_mixed_start_end_points():
8080
"unit": None
8181
})
8282

83-
label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID),
83+
label = lb_types.Label(data={"uid":DATAROW_ID},
8484
annotations=[bbox])
8585

8686
res = list(NDJsonConverter.deserialize(res))

0 commit comments

Comments
 (0)