From 3259998cf770a2abfe73e5f62a777cacb71bd32d Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Thu, 27 Jun 2024 14:07:31 -0500 Subject: [PATCH 01/16] added prompt classification for python object support --- .../data/annotation_types/__init__.py | 2 + .../labelbox/data/annotation_types/label.py | 13 ++++- .../llm_prompt_response/__init__.py | 1 + .../llm_prompt_response/prompt.py | 34 ++++++++++++ .../serialization/ndjson/classification.py | 52 +++++++++++++++++++ .../data/serialization/ndjson/label.py | 7 ++- .../labelbox/schema/bulk_import_request.py | 9 ++++ .../tests/data/annotation_import/conftest.py | 6 +++ .../data/annotation_import/test_data_types.py | 2 +- .../test_ndjson_validation.py | 5 +- .../data/annotation_types/test_annotation.py | 9 +++- .../tests/data/annotation_types/test_label.py | 14 +++++ 12 files changed, 147 insertions(+), 7 deletions(-) create mode 100644 libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py create mode 100644 libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py diff --git a/libs/labelbox/src/labelbox/data/annotation_types/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/__init__.py index 6eaa205bc..5e77dbfca 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/__init__.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/__init__.py @@ -61,3 +61,5 @@ from .data.tiled_image import TiledBounds from .data.tiled_image import TiledImageData from .data.tiled_image import TileLayer + +from .llm_prompt_response.prompt import PromptText, PromptClassificationAnnotation \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/data/annotation_types/label.py b/libs/labelbox/src/labelbox/data/annotation_types/label.py index f31dbdcda..cd209a493 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/label.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/label.py @@ -10,6 +10,7 @@ from labelbox.schema import ontology from .annotation import ClassificationAnnotation, ObjectAnnotation from .relationship import RelationshipAnnotation +from .llm_prompt_response.prompt import PromptClassificationAnnotation from .classification import ClassificationAnswer from .data import AudioData, ConversationData, DicomData, DocumentData, HTMLData, ImageData, TextData, VideoData, LlmPromptCreationData, LlmPromptResponseCreationData, LlmResponseCreationData from .geometry import Mask @@ -50,7 +51,8 @@ class Label(pydantic_compat.BaseModel): annotations: List[Union[ClassificationAnnotation, ObjectAnnotation, VideoMaskAnnotation, ScalarMetric, ConfusionMatrixMetric, - RelationshipAnnotation]] = [] + RelationshipAnnotation, + PromptClassificationAnnotation]] = [] extra: Dict[str, Any] = {} @pydantic_compat.root_validator(pre=True) @@ -209,10 +211,17 @@ def validate_union(cls, value): ]) if not isinstance(value, list): raise TypeError(f"Annotations must be a list. Found {type(value)}") - + prompt_count = 0 for v in value: if not isinstance(v, supported): raise TypeError( f"Annotations should be a list containing the following classes : {supported}. Found {type(v)}" ) + # Validates only one prompt annotation is included + if isinstance(v, PromptClassificationAnnotation): + prompt_count+=1 + if prompt_count > 1: + raise TypeError( + f"Only one prompt annotation is allowed per label" + ) return value diff --git a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py new file mode 100644 index 000000000..cadf7e2fd --- /dev/null +++ b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py @@ -0,0 +1 @@ +from .prompt import PromptText \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py new file mode 100644 index 000000000..8eb9963c5 --- /dev/null +++ b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py @@ -0,0 +1,34 @@ +from typing import Union + +from labelbox.data.annotation_types.base_annotation import BaseAnnotation + +from labelbox.data.mixins import ConfidenceMixin, CustomMetricsMixin + +from labelbox import pydantic_compat + + +class PromptText(ConfidenceMixin, CustomMetricsMixin, pydantic_compat.BaseModel): + """ Prompt text for LLM data generation + + >>> PromptText(answer = "some text answer") + + """ + answer: str + + +class PromptClassificationAnnotation(BaseAnnotation, ConfidenceMixin, + CustomMetricsMixin): + """Prompt annotation (non localized) + + >>> PromptAnnotation( + >>> value=PromptText(answer="my caption message"), + >>> feature_schema_id="my-feature-schema-id" + >>> ) + + Args: + name (Optional[str]) + feature_schema_id (Optional[Cuid]) + value (Union[Text]) + """ + + value: Union[PromptText] \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py index 028eeded8..015b281dc 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py @@ -7,6 +7,7 @@ from labelbox.utils import camel_case from ...annotation_types.annotation import ClassificationAnnotation from ...annotation_types.video import VideoClassificationAnnotation +from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation, PromptText from ...annotation_types.classification.classification import ClassificationAnswer, Dropdown, Text, Checklist, Radio from ...annotation_types.types import Cuid from ...annotation_types.data import TextData, VideoData, ImageData @@ -242,6 +243,28 @@ def from_common( frames=extra.get('frames'), message_id=message_id, confidence=confidence) + + +class NDPromptText(NDAnnotation): + + @classmethod + def from_common( + cls, + uuid: str, + text: PromptText, + name, + data: Union[VideoData, TextData, ImageData], + feature_schema_id: Cuid, + confidence: Optional[float] = None + ) -> "NDPromptText": + return cls( + answer=text.answer, + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + confidence=text.confidence, + custom_metrics=text.custom_metrics) class NDSubclassification: @@ -333,6 +356,33 @@ def lookup_classification( Radio: NDRadio }.get(type(annotation.value)) +class NDPromptClassification: + + @staticmethod + def to_common( + annotation: "NDPromptClassificationType" + ) -> Union[PromptClassificationAnnotation]: + common = PromptClassificationAnnotation( + value=annotation.to_common(), + name=annotation.name, + feature_schema_id=annotation.schema_id, + extra={'uuid': annotation.uuid}, + confidence=annotation.confidence, + ) + + return common + + @classmethod + def from_common( + cls, annotation: Union[PromptClassificationAnnotation], + data: Union[VideoData, TextData, ImageData] + ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: + return NDPromptText.from_common(str(annotation._uuid), annotation.value, + annotation.name, + data, + annotation.feature_schema_id, + annotation.confidence) + # Make sure to keep NDChecklistSubclass prior to NDRadioSubclass in the list, # otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used @@ -345,8 +395,10 @@ def lookup_classification( NDRadioSubclass.update_forward_refs() NDRadio.update_forward_refs() NDText.update_forward_refs() +NDPromptText.update_forward_refs() NDTextSubclass.update_forward_refs() # Make sure to keep NDChecklist prior to NDRadio in the list, # otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used NDClassificationType = Union[NDChecklist, NDRadio, NDText] +NDPromptClassificationType = Union[NDPromptText] \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index 1b649a80e..221c46aa9 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -16,14 +16,15 @@ from ...annotation_types.ner import TextEntity, ConversationEntity from ...annotation_types.classification import Dropdown from ...annotation_types.metrics import ScalarMetric, ConfusionMatrixMetric +from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation from .metric import NDScalarMetric, NDMetricAnnotation, NDConfusionMatrixMetric -from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass +from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass, NDPromptClassification, NDPromptClassificationType from .objects import NDObject, NDObjectType, NDSegments, NDDicomSegments, NDVideoMasks, NDDicomMasks from .relationship import NDRelationship from .base import DataRow -AnnotationType = Union[NDObjectType, NDClassificationType, +AnnotationType = Union[NDObjectType, NDClassificationType, NDPromptClassificationType, NDConfusionMatrixMetric, NDScalarMetric, NDDicomSegments, NDSegments, NDDicomMasks, NDVideoMasks, NDRelationship] @@ -269,6 +270,8 @@ def _create_non_video_annotations(cls, label: Label): yield NDMetricAnnotation.from_common(annotation, label.data) elif isinstance(annotation, RelationshipAnnotation): yield NDRelationship.from_common(annotation, label.data) + elif isinstance(annotation, PromptClassificationAnnotation): + yield NDPromptClassification.from_common(annotation, label.data) else: raise TypeError( f"Unable to convert object to MAL format. `{type(getattr(annotation, 'value',annotation))}`" diff --git a/libs/labelbox/src/labelbox/schema/bulk_import_request.py b/libs/labelbox/src/labelbox/schema/bulk_import_request.py index 65b71a310..367798190 100644 --- a/libs/labelbox/src/labelbox/schema/bulk_import_request.py +++ b/libs/labelbox/src/labelbox/schema/bulk_import_request.py @@ -937,3 +937,12 @@ def schema(cls): data['definitions'].update(schema_.pop('definitions')) data[type_.__name__] = schema_ return data + +###### Prompt Response ###### + +class NDPromptClassification( + SpecialUnion, + Type[Union[ # type: ignore + NDText + ]]): + ... \ No newline at end of file diff --git a/libs/labelbox/tests/data/annotation_import/conftest.py b/libs/labelbox/tests/data/annotation_import/conftest.py index 55453ade3..bf037a5ce 100644 --- a/libs/labelbox/tests/data/annotation_import/conftest.py +++ b/libs/labelbox/tests/data/annotation_import/conftest.py @@ -1083,6 +1083,12 @@ def checklist_inference(prediction_id_mapping): del checklist["tool"] return checklist +#TODO: Once data gen ontologies are able to be created will need to provide that here inside the prediction_id_mapping +@pytest.fixture +def prompt_text_inference(): + prompt_text = {"answer": "free form text..."} + + return prompt_text @pytest.fixture def checklist_inference_index(prediction_id_mapping): diff --git a/libs/labelbox/tests/data/annotation_import/test_data_types.py b/libs/labelbox/tests/data/annotation_import/test_data_types.py index a59149ea8..46869b46c 100644 --- a/libs/labelbox/tests/data/annotation_import/test_data_types.py +++ b/libs/labelbox/tests/data/annotation_import/test_data_types.py @@ -485,4 +485,4 @@ def test_import_mal_annotations_global_key(client, import_annotations.wait_until_done() assert import_annotations.errors == [] - # MAL Labels cannot be exported and compared to input labels + # MAL Labels cannot be exported and compared to input labels \ No newline at end of file diff --git a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py index fba161ef3..fccf25356 100644 --- a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py +++ b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py @@ -9,6 +9,7 @@ NDMask, NDPolygon, NDPolyline, NDRadio, NDRectangle, NDText, NDTextEntity, NDTool, + NDPromptClassification, _validate_ndjson) from labelbox.schema.labeling_frontend import LabelingFrontend from labelbox.schema.queue_mode import QueueMode @@ -39,7 +40,9 @@ def configured_project_with_ontology(client, ontology, rand_gen): yield project project.delete() - +def test_prompt_classification_construction(prompt_text_inference): + prompt_text = NDPromptClassification.build(prompt_text_inference) + assert isinstance(prompt_text, NDText) def test_classification_construction(checklist_inference, text_inference): checklist = NDClassification.build(checklist_inference) diff --git a/libs/labelbox/tests/data/annotation_types/test_annotation.py b/libs/labelbox/tests/data/annotation_types/test_annotation.py index b6dc00041..451b6a743 100644 --- a/libs/labelbox/tests/data/annotation_types/test_annotation.py +++ b/libs/labelbox/tests/data/annotation_types/test_annotation.py @@ -1,7 +1,8 @@ import pytest -from labelbox.data.annotation_types import (Text, Point, Line, +from labelbox.data.annotation_types import (Text, Point, Line, PromptText, ClassificationAnnotation, + PromptClassificationAnnotation, ObjectAnnotation, TextEntity) from labelbox.data.annotation_types.video import VideoObjectAnnotation from labelbox.data.annotation_types.geometry.rectangle import Rectangle @@ -33,6 +34,12 @@ def test_annotation(): value=classification, name=name, ) + + # Check prompt classification + PromptClassificationAnnotation( + value=PromptText(answer="some text answer"), + name=name + ) # Invalid subclass with pytest.raises(pydantic_compat.ValidationError): diff --git a/libs/labelbox/tests/data/annotation_types/test_label.py b/libs/labelbox/tests/data/annotation_types/test_label.py index ee83b1b50..ce23e3219 100644 --- a/libs/labelbox/tests/data/annotation_types/test_label.py +++ b/libs/labelbox/tests/data/annotation_types/test_label.py @@ -4,8 +4,10 @@ from labelbox import OntologyBuilder, Tool, Classification as OClassification, Option from labelbox.data.annotation_types import (ClassificationAnswer, Radio, Text, ClassificationAnnotation, + PromptText, ObjectAnnotation, Point, Line, ImageData, Label) +import pytest def test_schema_assignment_geometry(): @@ -193,3 +195,15 @@ def test_initialize_label_no_coercion(): annotations=[ner_annotation]) assert isinstance(label.data, lb_types.ConversationData) assert label.data.global_key == global_key + +def test_prompt_classification_validation(): + global_key = 'global-key' + prompt_text = lb_types.PromptClassificationAnnotation( + value=PromptText(answer="test") + ) + prompt_text_2 = lb_types.PromptClassificationAnnotation( + value=PromptText(answer="test") + ) + with pytest.raises(TypeError) as e_info: + label = Label(data={"global_key": global_key}, + annotations=[prompt_text, prompt_text_2]) From fa3a612e779bcd77d08c737e7fd98ec7b9204ac1 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Thu, 27 Jun 2024 19:47:22 -0500 Subject: [PATCH 02/16] improved test --- libs/labelbox/tests/data/annotation_import/conftest.py | 7 ------- .../tests/data/annotation_import/test_ndjson_validation.py | 7 +++++-- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/libs/labelbox/tests/data/annotation_import/conftest.py b/libs/labelbox/tests/data/annotation_import/conftest.py index bf037a5ce..e2c3c4423 100644 --- a/libs/labelbox/tests/data/annotation_import/conftest.py +++ b/libs/labelbox/tests/data/annotation_import/conftest.py @@ -1083,13 +1083,6 @@ def checklist_inference(prediction_id_mapping): del checklist["tool"] return checklist -#TODO: Once data gen ontologies are able to be created will need to provide that here inside the prediction_id_mapping -@pytest.fixture -def prompt_text_inference(): - prompt_text = {"answer": "free form text..."} - - return prompt_text - @pytest.fixture def checklist_inference_index(prediction_id_mapping): checklist = prediction_id_mapping["checklist_index"].copy() diff --git a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py index fccf25356..383ea7d68 100644 --- a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py +++ b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py @@ -40,10 +40,13 @@ def configured_project_with_ontology(client, ontology, rand_gen): yield project project.delete() -def test_prompt_classification_construction(prompt_text_inference): - prompt_text = NDPromptClassification.build(prompt_text_inference) + + +def test_prompt_classification_construction(text_inference): + prompt_text = NDPromptClassification.build(text_inference) assert isinstance(prompt_text, NDText) + def test_classification_construction(checklist_inference, text_inference): checklist = NDClassification.build(checklist_inference) assert isinstance(checklist, NDChecklist) From 98e8554196acf0914e6627281606990539c1417b Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Fri, 28 Jun 2024 08:27:28 -0500 Subject: [PATCH 03/16] added prompt text --- .../labelbox/src/labelbox/data/serialization/ndjson/label.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index 221c46aa9..63eb3cf22 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -19,14 +19,15 @@ from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation from .metric import NDScalarMetric, NDMetricAnnotation, NDConfusionMatrixMetric -from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass, NDPromptClassification, NDPromptClassificationType +from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass, NDPromptClassification, NDPromptClassificationType, NDPromptText from .objects import NDObject, NDObjectType, NDSegments, NDDicomSegments, NDVideoMasks, NDDicomMasks from .relationship import NDRelationship from .base import DataRow AnnotationType = Union[NDObjectType, NDClassificationType, NDPromptClassificationType, NDConfusionMatrixMetric, NDScalarMetric, NDDicomSegments, - NDSegments, NDDicomMasks, NDVideoMasks, NDRelationship] + NDSegments, NDDicomMasks, NDVideoMasks, NDRelationship, + NDPromptText] class NDLabel(pydantic_compat.BaseModel): From f0ffc0df857c3830f4cf41d083234c38a2da9cf0 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Fri, 28 Jun 2024 08:28:08 -0500 Subject: [PATCH 04/16] typo --- .../data/annotation_types/llm_prompt_response/prompt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py index 8eb9963c5..b5ea2812e 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py @@ -20,7 +20,7 @@ class PromptClassificationAnnotation(BaseAnnotation, ConfidenceMixin, CustomMetricsMixin): """Prompt annotation (non localized) - >>> PromptAnnotation( + >>> PromptClassificationAnnotation( >>> value=PromptText(answer="my caption message"), >>> feature_schema_id="my-feature-schema-id" >>> ) From 6cc96c39383af2c40c734abfc0cec44691544522 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Fri, 28 Jun 2024 12:23:58 -0500 Subject: [PATCH 05/16] removed bulk import test this is legacy feature and not needed --- libs/labelbox/pyproject.toml | 2 +- .../src/labelbox/schema/bulk_import_request.py | 10 +--------- .../data/annotation_import/test_ndjson_validation.py | 5 ----- 3 files changed, 2 insertions(+), 15 deletions(-) diff --git a/libs/labelbox/pyproject.toml b/libs/labelbox/pyproject.toml index 7e17a83e8..46875076c 100644 --- a/libs/labelbox/pyproject.toml +++ b/libs/labelbox/pyproject.toml @@ -85,7 +85,7 @@ unit = "pytest tests/unit" # SERVICE_API_KEY=${SERVICE_API_KEY} \ # LABELBOX_TEST_BASE_URL="http://host.docker.internal:8080" \ integration = { cmd = "pytest tests/integration" } -data = { cmd = "pytest tests/data" } +data = { cmd = "pytest tests/data/annotation_import/test_ndjson_validation.py::test_prompt_classification_construction" } yapf-lint = "yapf tests src -i --verbose --recursive --parallel --style \"google\"" mypy-lint = "mypy src --pretty --show-error-codes --non-interactive --install-types" lint = { chain = ["yapf-lint", "mypy-lint"] } diff --git a/libs/labelbox/src/labelbox/schema/bulk_import_request.py b/libs/labelbox/src/labelbox/schema/bulk_import_request.py index 367798190..415b6613f 100644 --- a/libs/labelbox/src/labelbox/schema/bulk_import_request.py +++ b/libs/labelbox/src/labelbox/schema/bulk_import_request.py @@ -937,12 +937,4 @@ def schema(cls): data['definitions'].update(schema_.pop('definitions')) data[type_.__name__] = schema_ return data - -###### Prompt Response ###### - -class NDPromptClassification( - SpecialUnion, - Type[Union[ # type: ignore - NDText - ]]): - ... \ No newline at end of file + \ No newline at end of file diff --git a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py index 383ea7d68..5aa6bc3b2 100644 --- a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py +++ b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py @@ -40,11 +40,6 @@ def configured_project_with_ontology(client, ontology, rand_gen): yield project project.delete() - - -def test_prompt_classification_construction(text_inference): - prompt_text = NDPromptClassification.build(text_inference) - assert isinstance(prompt_text, NDText) def test_classification_construction(checklist_inference, text_inference): From cbc7f3920892e72af0fa878331e3a41123149812 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Fri, 28 Jun 2024 12:25:30 -0500 Subject: [PATCH 06/16] typo --- .../tests/data/annotation_import/test_ndjson_validation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py index 5aa6bc3b2..fba161ef3 100644 --- a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py +++ b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py @@ -9,7 +9,6 @@ NDMask, NDPolygon, NDPolyline, NDRadio, NDRectangle, NDText, NDTextEntity, NDTool, - NDPromptClassification, _validate_ndjson) from labelbox.schema.labeling_frontend import LabelingFrontend from labelbox.schema.queue_mode import QueueMode From e3e1773106f2ebc5b22e6c54433956885534bfb8 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Fri, 28 Jun 2024 12:28:09 -0500 Subject: [PATCH 07/16] added newline back --- libs/labelbox/tests/data/annotation_import/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/labelbox/tests/data/annotation_import/conftest.py b/libs/labelbox/tests/data/annotation_import/conftest.py index e2c3c4423..55453ade3 100644 --- a/libs/labelbox/tests/data/annotation_import/conftest.py +++ b/libs/labelbox/tests/data/annotation_import/conftest.py @@ -1083,6 +1083,7 @@ def checklist_inference(prediction_id_mapping): del checklist["tool"] return checklist + @pytest.fixture def checklist_inference_index(prediction_id_mapping): checklist = prediction_id_mapping["checklist_index"].copy() From 9260262ac330aa7fa95507c45a009f79d3f29859 Mon Sep 17 00:00:00 2001 From: Gabe <33893811+Gabefire@users.noreply.github.com> Date: Fri, 28 Jun 2024 12:28:42 -0500 Subject: [PATCH 08/16] Update pyproject.toml --- libs/labelbox/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/labelbox/pyproject.toml b/libs/labelbox/pyproject.toml index 46875076c..7e17a83e8 100644 --- a/libs/labelbox/pyproject.toml +++ b/libs/labelbox/pyproject.toml @@ -85,7 +85,7 @@ unit = "pytest tests/unit" # SERVICE_API_KEY=${SERVICE_API_KEY} \ # LABELBOX_TEST_BASE_URL="http://host.docker.internal:8080" \ integration = { cmd = "pytest tests/integration" } -data = { cmd = "pytest tests/data/annotation_import/test_ndjson_validation.py::test_prompt_classification_construction" } +data = { cmd = "pytest tests/data" } yapf-lint = "yapf tests src -i --verbose --recursive --parallel --style \"google\"" mypy-lint = "mypy src --pretty --show-error-codes --non-interactive --install-types" lint = { chain = ["yapf-lint", "mypy-lint"] } From 17e4be5afcb01b7305fb06f370cf197605dd2d9b Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Sat, 29 Jun 2024 15:57:23 -0500 Subject: [PATCH 09/16] made small fix --- libs/labelbox/tests/data/annotation_types/test_label.py | 2 ++ libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py | 0 2 files changed, 2 insertions(+) create mode 100644 libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py diff --git a/libs/labelbox/tests/data/annotation_types/test_label.py b/libs/labelbox/tests/data/annotation_types/test_label.py index ce23e3219..38a9e467e 100644 --- a/libs/labelbox/tests/data/annotation_types/test_label.py +++ b/libs/labelbox/tests/data/annotation_types/test_label.py @@ -199,9 +199,11 @@ def test_initialize_label_no_coercion(): def test_prompt_classification_validation(): global_key = 'global-key' prompt_text = lb_types.PromptClassificationAnnotation( + name="prompt text", value=PromptText(answer="test") ) prompt_text_2 = lb_types.PromptClassificationAnnotation( + name="prompt text", value=PromptText(answer="test") ) with pytest.raises(TypeError) as e_info: diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py b/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py new file mode 100644 index 000000000..e69de29bb From 050ab8e26c348eb773ac6d93a2b48067cec2c972 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Sat, 29 Jun 2024 16:52:37 -0500 Subject: [PATCH 10/16] fix test --- .../serialization/ndjson/classification.py | 26 ++++++++++++++++--- .../data/serialization/ndjson/label.py | 3 +++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py index 015b281dc..4785193ff 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py @@ -151,6 +151,26 @@ def from_common(cls, radio: Radio, name: str, schema_id=feature_schema_id) +class NDPromptTextSubclass(NDAnswer): + answer: str + + def to_common(self) -> Text: + return PromptText(answer=self.answer, + confidence=self.confidence, + custom_metrics=self.custom_metrics) + + @classmethod + def from_common(cls, text: PromptText, name: str, + feature_schema_id: Cuid) -> "NDPromptTextSubclass": + return cls( + answer=text.answer, + name=name, + schema_id=feature_schema_id, + confidence=text.confidence, + custom_metrics=text.custom_metrics, + ) + + # ====== End of subclasses @@ -245,7 +265,7 @@ def from_common( confidence=confidence) -class NDPromptText(NDAnnotation): +class NDPromptText(NDAnnotation, NDPromptTextSubclass): @classmethod def from_common( @@ -253,7 +273,7 @@ def from_common( uuid: str, text: PromptText, name, - data: Union[VideoData, TextData, ImageData], + data: Dict, feature_schema_id: Cuid, confidence: Optional[float] = None ) -> "NDPromptText": @@ -363,7 +383,7 @@ def to_common( annotation: "NDPromptClassificationType" ) -> Union[PromptClassificationAnnotation]: common = PromptClassificationAnnotation( - value=annotation.to_common(), + value=annotation, name=annotation.name, feature_schema_id=annotation.schema_id, extra={'uuid': annotation.uuid}, diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index 63eb3cf22..61c569a4a 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -122,6 +122,9 @@ def _generate_annotations( (NDScalarMetric, NDConfusionMatrixMetric)): annotations.append( NDMetricAnnotation.to_common(ndjson_annotation)) + elif isinstance(ndjson_annotation, NDPromptClassificationType): + annotation = NDPromptClassification.to_common(ndjson_annotation) + annotations.append(annotation) else: raise TypeError( f"Unsupported annotation. {type(ndjson_annotation)}") From 10adf7ae0dfe383e0c73ba160103ef2c314d4e68 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Mon, 1 Jul 2024 08:42:45 -0500 Subject: [PATCH 11/16] added better tests and removed old tests --- libs/labelbox/pyproject.toml | 4 +- .../data/serialization/ndjson/label.py | 3 +- .../data/annotation_import/test_data_types.py | 3 +- .../data/annotation_types/test_annotation.py | 6 -- .../serialization/ndjson/test_data_gen.py | 57 +++++++++++++++++++ 5 files changed, 63 insertions(+), 10 deletions(-) diff --git a/libs/labelbox/pyproject.toml b/libs/labelbox/pyproject.toml index 7e17a83e8..46417ae3e 100644 --- a/libs/labelbox/pyproject.toml +++ b/libs/labelbox/pyproject.toml @@ -73,7 +73,7 @@ dev-dependencies = [ [tool.rye.scripts] unit = "pytest tests/unit" -# https://github.com/Labelbox/labelbox-python/blob/7c84fdffbc14fd1f69d2a6abdcc0087dc557fa4e/Makefile +# https://github.com/Labelbox/labelbox-python/blob/7c84fdffbc14fd1f69d2a6abdadfcc0087dc557fa4e/Makefile # see integration_client.py for full meaning / customization of this command # LABELBOX_TEST_ENVIRON="custom" \ # DA_GCP_LABELBOX_API_KEY=${DA_GCP_LABELBOX_API_KEY} \ @@ -85,7 +85,7 @@ unit = "pytest tests/unit" # SERVICE_API_KEY=${SERVICE_API_KEY} \ # LABELBOX_TEST_BASE_URL="http://host.docker.internal:8080" \ integration = { cmd = "pytest tests/integration" } -data = { cmd = "pytest tests/data" } +data = { cmd = "pytest tests/data/serialization/ndjson/test_data_gen.py" } yapf-lint = "yapf tests src -i --verbose --recursive --parallel --style \"google\"" mypy-lint = "mypy src --pretty --show-error-codes --non-interactive --install-types" lint = { chain = ["yapf-lint", "mypy-lint"] } diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index 61c569a4a..9d34c451b 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -12,6 +12,7 @@ from ...annotation_types.video import VideoObjectAnnotation, VideoMaskAnnotation from ...annotation_types.collection import LabelCollection, LabelGenerator from ...annotation_types.data import DicomData, ImageData, TextData, VideoData +from ...annotation_types.data.generic_data_row_data import GenericDataRowData from ...annotation_types.label import Label from ...annotation_types.ner import TextEntity, ConversationEntity from ...annotation_types.classification import Dropdown @@ -161,7 +162,7 @@ def _infer_media_type( raise ValueError("Missing annotations while inferring media type") types = {type(annotation) for annotation in annotations} - data = ImageData + data = GenericDataRowData if (TextEntity in types) or (ConversationEntity in types): data = TextData elif VideoClassificationAnnotation in types or VideoObjectAnnotation in types: diff --git a/libs/labelbox/tests/data/annotation_import/test_data_types.py b/libs/labelbox/tests/data/annotation_import/test_data_types.py index 3d7b100e2..1f17d755e 100644 --- a/libs/labelbox/tests/data/annotation_import/test_data_types.py +++ b/libs/labelbox/tests/data/annotation_import/test_data_types.py @@ -424,4 +424,5 @@ def test_import_mal_annotations_global_key(client, import_annotations.wait_until_done() assert import_annotations.errors == [] - # MAL Labels cannot be exported and compared to input labels \ No newline at end of file + # MAL Labels cannot be exported and compared to input labels + \ No newline at end of file diff --git a/libs/labelbox/tests/data/annotation_types/test_annotation.py b/libs/labelbox/tests/data/annotation_types/test_annotation.py index 451b6a743..e30d1f8de 100644 --- a/libs/labelbox/tests/data/annotation_types/test_annotation.py +++ b/libs/labelbox/tests/data/annotation_types/test_annotation.py @@ -34,12 +34,6 @@ def test_annotation(): value=classification, name=name, ) - - # Check prompt classification - PromptClassificationAnnotation( - value=PromptText(answer="some text answer"), - name=name - ) # Invalid subclass with pytest.raises(pydantic_compat.ValidationError): diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py b/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py index e69de29bb..7b7b33994 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py @@ -0,0 +1,57 @@ +from copy import copy +import pytest +import labelbox.types as lb_types +from labelbox.data.serialization import NDJsonConverter +from labelbox.data.serialization.ndjson.objects import NDDicomSegments, NDDicomSegment, NDDicomLine +""" +Data gen prompt test data +""" + +prompt_text_annotation = lb_types.PromptClassificationAnnotation( + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + name="test", + value=lb_types.PromptText(answer="the answer to the text questions right here"), + ) + +prompt_text_ndjson = { + "answer": "the answer to the text questions right here", + "name": "test", + "schemaId": "ckrb1sfkn099c0y910wbo0p1a", + "dataRow": { + "id": "ckrb1sf1i1g7i0ybcdc6oc8ct" + }, + } + +data_gen_label = lb_types.Label( + data={"uid": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + annotations=[prompt_text_annotation] +) + +""" +Prompt annotation test +""" + +def test_serialize_label(): + serialized_label = next(NDJsonConverter().serialize([data_gen_label])) + # Remove uuid field since this is a random value that can not be specified also meant for relationships + del serialized_label["uuid"] + assert serialized_label == prompt_text_ndjson + + +def test_deserialize_label(): + deserialized_label = next(NDJsonConverter().deserialize([prompt_text_ndjson])) + if hasattr(deserialized_label.annotations[0], 'extra'): + # Extra fields are added to deserialized label by default need removed to match + deserialized_label.annotations[0].extra = {} + assert deserialized_label.annotations == data_gen_label.annotations + + +def test_serialize_deserialize_label(): + serialized = list(NDJsonConverter.serialize([data_gen_label])) + deserialized = next(NDJsonConverter.deserialize(serialized)) + if hasattr(deserialized.annotations[0], 'extra'): + # Extra fields are added to deserialized label by default need removed to match + deserialized.annotations[0].extra = {} + print(data_gen_label.annotations) + print(deserialized.annotations) + assert deserialized.annotations == data_gen_label.annotations From e07d19ffe76ca9565391a2e96e4e79c80d85dbf7 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Mon, 1 Jul 2024 08:43:51 -0500 Subject: [PATCH 12/16] changed pyproject back --- libs/labelbox/pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/labelbox/pyproject.toml b/libs/labelbox/pyproject.toml index 46417ae3e..7e17a83e8 100644 --- a/libs/labelbox/pyproject.toml +++ b/libs/labelbox/pyproject.toml @@ -73,7 +73,7 @@ dev-dependencies = [ [tool.rye.scripts] unit = "pytest tests/unit" -# https://github.com/Labelbox/labelbox-python/blob/7c84fdffbc14fd1f69d2a6abdadfcc0087dc557fa4e/Makefile +# https://github.com/Labelbox/labelbox-python/blob/7c84fdffbc14fd1f69d2a6abdcc0087dc557fa4e/Makefile # see integration_client.py for full meaning / customization of this command # LABELBOX_TEST_ENVIRON="custom" \ # DA_GCP_LABELBOX_API_KEY=${DA_GCP_LABELBOX_API_KEY} \ @@ -85,7 +85,7 @@ unit = "pytest tests/unit" # SERVICE_API_KEY=${SERVICE_API_KEY} \ # LABELBOX_TEST_BASE_URL="http://host.docker.internal:8080" \ integration = { cmd = "pytest tests/integration" } -data = { cmd = "pytest tests/data/serialization/ndjson/test_data_gen.py" } +data = { cmd = "pytest tests/data" } yapf-lint = "yapf tests src -i --verbose --recursive --parallel --style \"google\"" mypy-lint = "mypy src --pretty --show-error-codes --non-interactive --install-types" lint = { chain = ["yapf-lint", "mypy-lint"] } From 70ed7b0506d684168c0cc8efc046d52b7daade68 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Mon, 1 Jul 2024 09:26:22 -0500 Subject: [PATCH 13/16] fixed/ improved a few tests --- libs/labelbox/pyproject.toml | 2 +- libs/labelbox/tests/data/annotation_types/test_label.py | 3 ++- .../tests/data/serialization/ndjson/test_rectangle.py | 8 ++++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/libs/labelbox/pyproject.toml b/libs/labelbox/pyproject.toml index 7e17a83e8..497af9239 100644 --- a/libs/labelbox/pyproject.toml +++ b/libs/labelbox/pyproject.toml @@ -85,7 +85,7 @@ unit = "pytest tests/unit" # SERVICE_API_KEY=${SERVICE_API_KEY} \ # LABELBOX_TEST_BASE_URL="http://host.docker.internal:8080" \ integration = { cmd = "pytest tests/integration" } -data = { cmd = "pytest tests/data" } +data = { cmd = "pytest tests/data/serialization/ndjson/test_rectangle.py" } yapf-lint = "yapf tests src -i --verbose --recursive --parallel --style \"google\"" mypy-lint = "mypy src --pretty --show-error-codes --non-interactive --install-types" lint = { chain = ["yapf-lint", "mypy-lint"] } diff --git a/libs/labelbox/tests/data/annotation_types/test_label.py b/libs/labelbox/tests/data/annotation_types/test_label.py index 38a9e467e..a6947cd4b 100644 --- a/libs/labelbox/tests/data/annotation_types/test_label.py +++ b/libs/labelbox/tests/data/annotation_types/test_label.py @@ -1,3 +1,4 @@ +from labelbox.pydantic_compat import ValidationError import numpy as np import labelbox.types as lb_types @@ -206,6 +207,6 @@ def test_prompt_classification_validation(): name="prompt text", value=PromptText(answer="test") ) - with pytest.raises(TypeError) as e_info: + with pytest.raises(ValidationError) as e_info: label = Label(data={"global_key": global_key}, annotations=[prompt_text, prompt_text_2]) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py b/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py index 0764e2988..73099c12f 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py @@ -25,7 +25,7 @@ def test_rectangle_inverted_start_end_points(): ), extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}) - label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID), + label = lb_types.Label(data={"uid":DATAROW_ID}, annotations=[bbox]) res = list(NDJsonConverter.serialize([label])) @@ -43,7 +43,7 @@ def test_rectangle_inverted_start_end_points(): "unit": None }) - label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID), + label = lb_types.Label(data={"uid":DATAROW_ID}, annotations=[expected_bbox]) res = list(NDJsonConverter.deserialize(res)) @@ -62,7 +62,7 @@ def test_rectangle_mixed_start_end_points(): ), extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}) - label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID), + label = lb_types.Label(data={"uid":DATAROW_ID}, annotations=[bbox]) res = list(NDJsonConverter.serialize([label])) @@ -80,7 +80,7 @@ def test_rectangle_mixed_start_end_points(): "unit": None }) - label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID), + label = lb_types.Label(data={"uid":DATAROW_ID}, annotations=[bbox]) res = list(NDJsonConverter.deserialize(res)) From 9c8a286e196a91a4ddb1feb2e17ab152c620e2f0 Mon Sep 17 00:00:00 2001 From: Gabe <33893811+Gabefire@users.noreply.github.com> Date: Mon, 1 Jul 2024 09:35:17 -0500 Subject: [PATCH 14/16] Update pyproject.toml --- libs/labelbox/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/labelbox/pyproject.toml b/libs/labelbox/pyproject.toml index 497af9239..7e17a83e8 100644 --- a/libs/labelbox/pyproject.toml +++ b/libs/labelbox/pyproject.toml @@ -85,7 +85,7 @@ unit = "pytest tests/unit" # SERVICE_API_KEY=${SERVICE_API_KEY} \ # LABELBOX_TEST_BASE_URL="http://host.docker.internal:8080" \ integration = { cmd = "pytest tests/integration" } -data = { cmd = "pytest tests/data/serialization/ndjson/test_rectangle.py" } +data = { cmd = "pytest tests/data" } yapf-lint = "yapf tests src -i --verbose --recursive --parallel --style \"google\"" mypy-lint = "mypy src --pretty --show-error-codes --non-interactive --install-types" lint = { chain = ["yapf-lint", "mypy-lint"] } From 17496c81452cf297ec1a551dad9875b42746f500 Mon Sep 17 00:00:00 2001 From: Gabe <33893811+Gabefire@users.noreply.github.com> Date: Mon, 1 Jul 2024 09:36:14 -0500 Subject: [PATCH 15/16] Update bulk_import_request.py --- libs/labelbox/src/labelbox/schema/bulk_import_request.py | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/labelbox/src/labelbox/schema/bulk_import_request.py b/libs/labelbox/src/labelbox/schema/bulk_import_request.py index 415b6613f..65b71a310 100644 --- a/libs/labelbox/src/labelbox/schema/bulk_import_request.py +++ b/libs/labelbox/src/labelbox/schema/bulk_import_request.py @@ -937,4 +937,3 @@ def schema(cls): data['definitions'].update(schema_.pop('definitions')) data[type_.__name__] = schema_ return data - \ No newline at end of file From 619fd3e67b8c53e59fa004bc6c6a5e17d3f1c2a9 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Wed, 3 Jul 2024 11:30:11 -0500 Subject: [PATCH 16/16] feedback --- .../src/labelbox/data/annotation_types/__init__.py | 3 ++- .../annotation_types/llm_prompt_response/__init__.py | 3 ++- .../annotation_types/llm_prompt_response/prompt.py | 11 ++++++++--- .../data/serialization/ndjson/classification.py | 10 +++++----- .../tests/data/annotation_types/test_annotation.py | 3 +-- 5 files changed, 18 insertions(+), 12 deletions(-) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/__init__.py index 5e77dbfca..85b7ae1af 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/__init__.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/__init__.py @@ -62,4 +62,5 @@ from .data.tiled_image import TiledImageData from .data.tiled_image import TileLayer -from .llm_prompt_response.prompt import PromptText, PromptClassificationAnnotation \ No newline at end of file +from .llm_prompt_response.prompt import PromptText +from .llm_prompt_response.prompt import PromptClassificationAnnotation \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py index cadf7e2fd..7c0b63abc 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py @@ -1 +1,2 @@ -from .prompt import PromptText \ No newline at end of file +from .prompt import PromptText +from .prompt import PromptClassificationAnnotation \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py index b5ea2812e..c235526b0 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py @@ -10,8 +10,13 @@ class PromptText(ConfidenceMixin, CustomMetricsMixin, pydantic_compat.BaseModel): """ Prompt text for LLM data generation - >>> PromptText(answer = "some text answer") - + >>> PromptText(answer = "some text answer", + >>> confidence = 0.5, + >>> custom_metrics = [ + >>> { + >>> "name": "iou", + >>> "value": 0.1 + >>> }]) """ answer: str @@ -31,4 +36,4 @@ class PromptClassificationAnnotation(BaseAnnotation, ConfidenceMixin, value (Union[Text]) """ - value: Union[PromptText] \ No newline at end of file + value: PromptText \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py index 4785193ff..46b8fc91f 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py @@ -154,20 +154,20 @@ def from_common(cls, radio: Radio, name: str, class NDPromptTextSubclass(NDAnswer): answer: str - def to_common(self) -> Text: + def to_common(self) -> PromptText: return PromptText(answer=self.answer, confidence=self.confidence, custom_metrics=self.custom_metrics) @classmethod - def from_common(cls, text: PromptText, name: str, + def from_common(cls, prompt_text: PromptText, name: str, feature_schema_id: Cuid) -> "NDPromptTextSubclass": return cls( - answer=text.answer, + answer=prompt_text.answer, name=name, schema_id=feature_schema_id, - confidence=text.confidence, - custom_metrics=text.custom_metrics, + confidence=prompt_text.confidence, + custom_metrics=prompt_text.custom_metrics, ) diff --git a/libs/labelbox/tests/data/annotation_types/test_annotation.py b/libs/labelbox/tests/data/annotation_types/test_annotation.py index e30d1f8de..b6dc00041 100644 --- a/libs/labelbox/tests/data/annotation_types/test_annotation.py +++ b/libs/labelbox/tests/data/annotation_types/test_annotation.py @@ -1,8 +1,7 @@ import pytest -from labelbox.data.annotation_types import (Text, Point, Line, PromptText, +from labelbox.data.annotation_types import (Text, Point, Line, ClassificationAnnotation, - PromptClassificationAnnotation, ObjectAnnotation, TextEntity) from labelbox.data.annotation_types.video import VideoObjectAnnotation from labelbox.data.annotation_types.geometry.rectangle import Rectangle