From e4a96289999c56fdb2cffb029da8f42e3b7039a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20G=C5=82uszek?= Date: Wed, 18 Dec 2024 14:30:27 +0100 Subject: [PATCH 1/8] MAL and GT support for pdf relationships --- .../data/annotation_types/relationship.py | 4 +- .../data/serialization/ndjson/label.py | 44 ++++--- .../annotation_import/test_relationships.py | 111 +++++++++++++++++- 3 files changed, 138 insertions(+), 21 deletions(-) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/relationship.py b/libs/labelbox/src/labelbox/data/annotation_types/relationship.py index b65f21d16..0e9c4e934 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/relationship.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/relationship.py @@ -1,8 +1,10 @@ +from typing import Union from pydantic import BaseModel from enum import Enum from labelbox.data.annotation_types.annotation import ( BaseAnnotation, ObjectAnnotation, + ClassificationAnnotation, ) @@ -11,7 +13,7 @@ class Type(Enum): UNIDIRECTIONAL = "unidirectional" BIDIRECTIONAL = "bidirectional" - source: ObjectAnnotation + source: Union[ObjectAnnotation, ClassificationAnnotation] target: ObjectAnnotation type: Type = Type.UNIDIRECTIONAL diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index e822f3c42..52f2c353d 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -2,7 +2,7 @@ import copy from itertools import groupby from operator import itemgetter -from typing import Generator, List, Tuple, Union +from typing import Generator, List, Tuple, Union, Iterator, Dict from uuid import uuid4 from pydantic import BaseModel @@ -24,6 +24,7 @@ VideoMaskAnnotation, VideoObjectAnnotation, ) +from labelbox.types import DocumentRectangle, DocumentEntity from .classification import ( NDChecklistSubclass, NDClassification, @@ -61,9 +62,7 @@ class NDLabel(BaseModel): annotations: AnnotationType @classmethod - def from_common( - cls, data: LabelCollection - ) -> Generator["NDLabel", None, None]: + def from_common(cls, data: LabelCollection) -> Generator["NDLabel", None, None]: for label in data: yield from cls._create_relationship_annotations(label) yield from cls._create_non_video_annotations(label) @@ -127,16 +126,12 @@ def _create_video_annotations( if isinstance( annot, (VideoClassificationAnnotation, VideoObjectAnnotation) ): - video_annotations[annot.feature_schema_id or annot.name].append( - annot - ) + video_annotations[annot.feature_schema_id or annot.name].append(annot) elif isinstance(annot, VideoMaskAnnotation): yield NDObject.from_common(annotation=annot, data=label.data) for annotation_group in video_annotations.values(): - segment_frame_ranges = cls._get_segment_frame_ranges( - annotation_group - ) + segment_frame_ranges = cls._get_segment_frame_ranges(annotation_group) if isinstance(annotation_group[0], VideoClassificationAnnotation): annotation = annotation_group[0] frames_data = [] @@ -169,6 +164,7 @@ def _create_non_video_annotations(cls, label: Label): VideoClassificationAnnotation, VideoObjectAnnotation, VideoMaskAnnotation, + RelationshipAnnotation, ), ) ] @@ -179,8 +175,6 @@ def _create_non_video_annotations(cls, label: Label): yield NDObject.from_common(annotation, label.data) elif isinstance(annotation, (ScalarMetric, ConfusionMatrixMetric)): yield NDMetricAnnotation.from_common(annotation, label.data) - elif isinstance(annotation, RelationshipAnnotation): - yield NDRelationship.from_common(annotation, label.data) elif isinstance(annotation, PromptClassificationAnnotation): yield NDPromptClassification.from_common(annotation, label.data) elif isinstance(annotation, MessageEvaluationTaskAnnotation): @@ -191,19 +185,35 @@ def _create_non_video_annotations(cls, label: Label): ) @classmethod - def _create_relationship_annotations(cls, label: Label): + def _create_relationship_annotations( + cls, label: Label + ) -> Generator[NDRelationship, None, None]: for annotation in label.annotations: if isinstance(annotation, RelationshipAnnotation): uuid1 = uuid4() uuid2 = uuid4() source = copy.copy(annotation.value.source) target = copy.copy(annotation.value.target) - if not isinstance(source, ObjectAnnotation) or not isinstance( - target, ObjectAnnotation - ): + + # Check if source type is valid based on target type + if isinstance(target.value, (DocumentRectangle, DocumentEntity)): + if not isinstance( + source, (ObjectAnnotation, ClassificationAnnotation) + ): + raise TypeError( + f"Unable to create relationship with invalid source. For PDF targets, " + f"source must be ObjectAnnotation or ClassificationAnnotation. Got: {type(source)}" + ) + elif not isinstance(source, ObjectAnnotation): raise TypeError( - f"Unable to create relationship with non ObjectAnnotations. `Source: {type(source)} Target: {type(target)}`" + f"Unable to create relationship with non ObjectAnnotation source: {type(source)}" ) + + if not isinstance(target, ObjectAnnotation): + raise TypeError( + f"Unable to create relationship with non ObjectAnnotation target: {type(target)}" + ) + if not source._uuid: source._uuid = uuid1 if not target._uuid: diff --git a/libs/labelbox/tests/data/annotation_import/test_relationships.py b/libs/labelbox/tests/data/annotation_import/test_relationships.py index 1335261e5..61e11640b 100644 --- a/libs/labelbox/tests/data/annotation_import/test_relationships.py +++ b/libs/labelbox/tests/data/annotation_import/test_relationships.py @@ -10,7 +10,17 @@ RelationshipAnnotation, Relationship, TextEntity, + DocumentRectangle, + DocumentEntity, + Point, + Text, + ClassificationAnnotation, + DocumentTextSelection, + Radio, + ClassificationAnswer, + Checklist, ) +from labelbox.data.serialization.ndjson import NDJsonConverter import pytest @@ -169,9 +179,7 @@ def configured_project( data_row_data = [] for _ in range(3): - data_row_data.append( - data_row_json_by_media_type[media_type](rand_gen(str)) - ) + data_row_data.append(data_row_json_by_media_type[media_type](rand_gen(str))) task = dataset.create_data_rows(data_row_data) task.wait_till_done() @@ -220,3 +228,100 @@ def test_import_media_types( assert label_import.state == AnnotationImportState.FINISHED assert len(label_import.errors) == 0 + + +def test_valid_classification_relationships(): + def create_pdf_annotation(target_type: str) -> ObjectAnnotation: + if target_type == "bbox": + return ObjectAnnotation( + name="bbox", + value=DocumentRectangle( + start=Point(x=0, y=0), + end=Point(x=0.5, y=0.5), + page=1, + unit="PERCENT", + ), + ) + elif target_type == "entity": + return ObjectAnnotation( + name="entity", + value=DocumentEntity( + page=1, + textSelections=[ + DocumentTextSelection(token_ids=[], group_id="", page=1) + ], + ), + ) + raise ValueError(f"Unknown target type: {target_type}") + + def verify_relationship(source: ClassificationAnnotation, target: ObjectAnnotation): + relationship = RelationshipAnnotation( + name="relationship", + value=Relationship( + source=source, + target=target, + type=Relationship.Type.UNIDIRECTIONAL, + ), + ) + label = Label(data={"global_key": "global_key"}, annotations=[relationship]) + result = list(NDJsonConverter.serialize([label])) + assert len(result) == 1 + + # Test case 1: Text Classification -> DocumentRectangle + text_source = ClassificationAnnotation(name="text", value=Text(answer="test")) + verify_relationship(text_source, create_pdf_annotation("bbox")) + + # Test case 2: Text Classification -> DocumentEntity + verify_relationship(text_source, create_pdf_annotation("entity")) + + # Test case 3: Radio Classification -> DocumentRectangle + radio_source = ClassificationAnnotation( + name="sub_radio_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_radio_answer", + classifications=[ + ClassificationAnnotation( + name="second_sub_radio_question", + value=Radio( + answer=ClassificationAnswer(name="second_sub_radio_answer") + ), + ) + ], + ) + ), + ) + verify_relationship(radio_source, create_pdf_annotation("bbox")) + + # Test case 4: Checklist Classification -> DocumentEntity + checklist_source = ClassificationAnnotation( + name="sub_checklist_question", + value=Checklist( + answer=[ClassificationAnswer(name="first_sub_checklist_answer")] + ), + ) + verify_relationship(checklist_source, create_pdf_annotation("entity")) + + +def test_classification_relationship_restrictions(): + """Test all relationship validation error messages.""" + text = ClassificationAnnotation(name="text", value=Text(answer="test")) + point = ObjectAnnotation(name="point", value=Point(x=1, y=1)) + + # Test case: Classification -> Point (invalid) + # Should fail because classifications can only connect to PDF targets + relationship = RelationshipAnnotation( + name="relationship", + value=Relationship( + source=text, + target=point, + type=Relationship.Type.UNIDIRECTIONAL, + ), + ) + + with pytest.raises( + TypeError, + match="Unable to create relationship with non ObjectAnnotation source: .*", + ): + label = Label(data={"global_key": "test_key"}, annotations=[relationship]) + list(NDJsonConverter.serialize([label])) From ac91670402ec98ff27c9ef666d1c11281ea0f832 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20G=C5=82uszek?= Date: Wed, 18 Dec 2024 14:52:22 +0100 Subject: [PATCH 2/8] Add a comment to create_relationship_annotations function --- .../labelbox/data/serialization/ndjson/label.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index 52f2c353d..b4796b05d 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -188,6 +188,20 @@ def _create_non_video_annotations(cls, label: Label): def _create_relationship_annotations( cls, label: Label ) -> Generator[NDRelationship, None, None]: + """Creates relationship annotations following validation rules for source and target types. + + Args: + label: Label containing relationship annotations to be processed + + Yields: + NDRelationship: Validated relationship annotations in NDJSON format + + Raises: + TypeError: If source/target types violate the validation rules: + - Invalid source type for PDF target + - Non-ObjectAnnotation source for non-PDF target + - Non-ObjectAnnotation target + """ for annotation in label.annotations: if isinstance(annotation, RelationshipAnnotation): uuid1 = uuid4() From e879b7923060958d67a692246fde297dab9c21f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20G=C5=82uszek?= Date: Wed, 18 Dec 2024 15:31:09 +0100 Subject: [PATCH 3/8] Update comment --- .../data/serialization/ndjson/label.py | 26 +++++++++++++------ .../annotation_import/test_relationships.py | 24 ++++++++++++----- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index b4796b05d..7c6a993e6 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -62,7 +62,9 @@ class NDLabel(BaseModel): annotations: AnnotationType @classmethod - def from_common(cls, data: LabelCollection) -> Generator["NDLabel", None, None]: + def from_common( + cls, data: LabelCollection + ) -> Generator["NDLabel", None, None]: for label in data: yield from cls._create_relationship_annotations(label) yield from cls._create_non_video_annotations(label) @@ -126,12 +128,16 @@ def _create_video_annotations( if isinstance( annot, (VideoClassificationAnnotation, VideoObjectAnnotation) ): - video_annotations[annot.feature_schema_id or annot.name].append(annot) + video_annotations[annot.feature_schema_id or annot.name].append( + annot + ) elif isinstance(annot, VideoMaskAnnotation): yield NDObject.from_common(annotation=annot, data=label.data) for annotation_group in video_annotations.values(): - segment_frame_ranges = cls._get_segment_frame_ranges(annotation_group) + segment_frame_ranges = cls._get_segment_frame_ranges( + annotation_group + ) if isinstance(annotation_group[0], VideoClassificationAnnotation): annotation = annotation_group[0] frames_data = [] @@ -197,10 +203,12 @@ def _create_relationship_annotations( NDRelationship: Validated relationship annotations in NDJSON format Raises: - TypeError: If source/target types violate the validation rules: - - Invalid source type for PDF target - - Non-ObjectAnnotation source for non-PDF target - - Non-ObjectAnnotation target + TypeError: If source/target types are invalid: + - Source: + - For PDF target annotations (DocumentRectangle, DocumentEntity): source must be ObjectAnnotation or ClassificationAnnotation + - For other target annotations: source must be ObjectAnnotation + - Target: + - Target must always be ObjectAnnotation """ for annotation in label.annotations: if isinstance(annotation, RelationshipAnnotation): @@ -210,7 +218,9 @@ def _create_relationship_annotations( target = copy.copy(annotation.value.target) # Check if source type is valid based on target type - if isinstance(target.value, (DocumentRectangle, DocumentEntity)): + if isinstance( + target.value, (DocumentRectangle, DocumentEntity) + ): if not isinstance( source, (ObjectAnnotation, ClassificationAnnotation) ): diff --git a/libs/labelbox/tests/data/annotation_import/test_relationships.py b/libs/labelbox/tests/data/annotation_import/test_relationships.py index 61e11640b..f4a80dab9 100644 --- a/libs/labelbox/tests/data/annotation_import/test_relationships.py +++ b/libs/labelbox/tests/data/annotation_import/test_relationships.py @@ -179,7 +179,9 @@ def configured_project( data_row_data = [] for _ in range(3): - data_row_data.append(data_row_json_by_media_type[media_type](rand_gen(str))) + data_row_data.append( + data_row_json_by_media_type[media_type](rand_gen(str)) + ) task = dataset.create_data_rows(data_row_data) task.wait_till_done() @@ -254,7 +256,9 @@ def create_pdf_annotation(target_type: str) -> ObjectAnnotation: ) raise ValueError(f"Unknown target type: {target_type}") - def verify_relationship(source: ClassificationAnnotation, target: ObjectAnnotation): + def verify_relationship( + source: ClassificationAnnotation, target: ObjectAnnotation + ): relationship = RelationshipAnnotation( name="relationship", value=Relationship( @@ -263,12 +267,16 @@ def verify_relationship(source: ClassificationAnnotation, target: ObjectAnnotati type=Relationship.Type.UNIDIRECTIONAL, ), ) - label = Label(data={"global_key": "global_key"}, annotations=[relationship]) + label = Label( + data={"global_key": "global_key"}, annotations=[relationship] + ) result = list(NDJsonConverter.serialize([label])) assert len(result) == 1 # Test case 1: Text Classification -> DocumentRectangle - text_source = ClassificationAnnotation(name="text", value=Text(answer="test")) + text_source = ClassificationAnnotation( + name="text", value=Text(answer="test") + ) verify_relationship(text_source, create_pdf_annotation("bbox")) # Test case 2: Text Classification -> DocumentEntity @@ -284,7 +292,9 @@ def verify_relationship(source: ClassificationAnnotation, target: ObjectAnnotati ClassificationAnnotation( name="second_sub_radio_question", value=Radio( - answer=ClassificationAnswer(name="second_sub_radio_answer") + answer=ClassificationAnswer( + name="second_sub_radio_answer" + ) ), ) ], @@ -323,5 +333,7 @@ def test_classification_relationship_restrictions(): TypeError, match="Unable to create relationship with non ObjectAnnotation source: .*", ): - label = Label(data={"global_key": "test_key"}, annotations=[relationship]) + label = Label( + data={"global_key": "test_key"}, annotations=[relationship] + ) list(NDJsonConverter.serialize([label])) From dde802e609bdc7750cae4a77432c4f21a90d62ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20G=C5=82uszek?= Date: Wed, 18 Dec 2024 15:47:05 +0100 Subject: [PATCH 4/8] Fix formatting --- .../tests/data/annotation_import/conftest.py | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/libs/labelbox/tests/data/annotation_import/conftest.py b/libs/labelbox/tests/data/annotation_import/conftest.py index 3e851752e..2f773e9cb 100644 --- a/libs/labelbox/tests/data/annotation_import/conftest.py +++ b/libs/labelbox/tests/data/annotation_import/conftest.py @@ -2210,21 +2210,21 @@ def expected_export_v2_video(): "classifications": [ { "name": "checklist_index", - "value": "checklist_index", + "value": "checklist_index", "checklist_answers": [ { "name": "first_checklist_answer", "value": "first_checklist_answer", - "classifications": [] + "classifications": [], }, { - "name": "second_checklist_answer", + "name": "second_checklist_answer", "value": "second_checklist_answer", - "classifications": [] - } - ] + "classifications": [], + }, + ], } - ] + ], }, "13": { "objects": {}, @@ -2235,17 +2235,17 @@ def expected_export_v2_video(): "checklist_answers": [ { "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] + "value": "first_checklist_answer", + "classifications": [], }, { "name": "second_checklist_answer", "value": "second_checklist_answer", - "classifications": [] - } - ] + "classifications": [], + }, + ], } - ] + ], }, "18": { "objects": {}, @@ -2257,16 +2257,16 @@ def expected_export_v2_video(): { "name": "first_checklist_answer", "value": "first_checklist_answer", - "classifications": [] + "classifications": [], }, { "name": "second_checklist_answer", "value": "second_checklist_answer", - "classifications": [] - } - ] + "classifications": [], + }, + ], } - ] + ], }, "19": { "objects": {}, @@ -2278,17 +2278,17 @@ def expected_export_v2_video(): { "name": "first_checklist_answer", "value": "first_checklist_answer", - "classifications": [] + "classifications": [], }, { "name": "second_checklist_answer", "value": "second_checklist_answer", - "classifications": [] - } - ] + "classifications": [], + }, + ], } - ] - } + ], + }, }, "segments": { "": [[7, 13], [18, 19]], From b48940fed97c28b13fdd1bdeb29f3edea3a927d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20G=C5=82uszek?= Date: Wed, 18 Dec 2024 17:26:42 +0100 Subject: [PATCH 5/8] Add comment --- libs/labelbox/src/labelbox/data/serialization/ndjson/label.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index 7c6a993e6..ebf5e698a 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -233,6 +233,7 @@ def _create_relationship_annotations( f"Unable to create relationship with non ObjectAnnotation source: {type(source)}" ) + # Check if target type is valid if not isinstance(target, ObjectAnnotation): raise TypeError( f"Unable to create relationship with non ObjectAnnotation target: {type(target)}" From 07025827a9c8de42aa31e3ac8afaece6512fc77b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20G=C5=82uszek?= Date: Wed, 18 Dec 2024 17:30:22 +0100 Subject: [PATCH 6/8] Update comment --- libs/labelbox/src/labelbox/data/serialization/ndjson/label.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index ebf5e698a..4c8057657 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -194,7 +194,7 @@ def _create_non_video_annotations(cls, label: Label): def _create_relationship_annotations( cls, label: Label ) -> Generator[NDRelationship, None, None]: - """Creates relationship annotations following validation rules for source and target types. + """Creates relationship annotations. Args: label: Label containing relationship annotations to be processed From d095d44e06017efa1c43139c3722b58f8d6267b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20G=C5=82uszek?= Date: Thu, 19 Dec 2024 09:45:03 +0100 Subject: [PATCH 7/8] Remove unused imports --- libs/labelbox/src/labelbox/data/serialization/ndjson/label.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index 4c8057657..bae988586 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -2,7 +2,7 @@ import copy from itertools import groupby from operator import itemgetter -from typing import Generator, List, Tuple, Union, Iterator, Dict +from typing import Generator, List, Tuple, Union from uuid import uuid4 from pydantic import BaseModel From 7279476709af2ea295d65d898a2a88c212c87691 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20G=C5=82uszek?= Date: Thu, 19 Dec 2024 13:54:02 +0100 Subject: [PATCH 8/8] Update docstring for _create_relationship_annotations method to clarify its purpose in processing relationship annotations into NDJSON format. --- libs/labelbox/src/labelbox/data/serialization/ndjson/label.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index bae988586..5b146b660 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -194,7 +194,7 @@ def _create_non_video_annotations(cls, label: Label): def _create_relationship_annotations( cls, label: Label ) -> Generator[NDRelationship, None, None]: - """Creates relationship annotations. + """Processes relationship annotations from a label, converting them to NDJSON format. Args: label: Label containing relationship annotations to be processed