Skip to content

Commit 3daa25b

Browse files
authored
[PTDT-2854] MAL and GT support for pdf relationships (#1932)
2 parents 76c35ba + 8602a34 commit 3daa25b

File tree

3 files changed

+161
-7
lines changed

3 files changed

+161
-7
lines changed

libs/labelbox/src/labelbox/data/annotation_types/relationship.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from typing import Union
12
from pydantic import BaseModel
23
from enum import Enum
34
from labelbox.data.annotation_types.annotation import (
45
BaseAnnotation,
56
ObjectAnnotation,
7+
ClassificationAnnotation,
68
)
79

810

@@ -11,7 +13,7 @@ class Type(Enum):
1113
UNIDIRECTIONAL = "unidirectional"
1214
BIDIRECTIONAL = "bidirectional"
1315

14-
source: ObjectAnnotation
16+
source: Union[ObjectAnnotation, ClassificationAnnotation]
1517
target: ObjectAnnotation
1618
type: Type = Type.UNIDIRECTIONAL
1719

libs/labelbox/src/labelbox/data/serialization/ndjson/label.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
VideoMaskAnnotation,
2525
VideoObjectAnnotation,
2626
)
27+
from labelbox.types import DocumentRectangle, DocumentEntity
2728
from .classification import (
2829
NDChecklistSubclass,
2930
NDClassification,
@@ -169,6 +170,7 @@ def _create_non_video_annotations(cls, label: Label):
169170
VideoClassificationAnnotation,
170171
VideoObjectAnnotation,
171172
VideoMaskAnnotation,
173+
RelationshipAnnotation,
172174
),
173175
)
174176
]
@@ -179,8 +181,6 @@ def _create_non_video_annotations(cls, label: Label):
179181
yield NDObject.from_common(annotation, label.data)
180182
elif isinstance(annotation, (ScalarMetric, ConfusionMatrixMetric)):
181183
yield NDMetricAnnotation.from_common(annotation, label.data)
182-
elif isinstance(annotation, RelationshipAnnotation):
183-
yield NDRelationship.from_common(annotation, label.data)
184184
elif isinstance(annotation, PromptClassificationAnnotation):
185185
yield NDPromptClassification.from_common(annotation, label.data)
186186
elif isinstance(annotation, MessageEvaluationTaskAnnotation):
@@ -191,19 +191,54 @@ def _create_non_video_annotations(cls, label: Label):
191191
)
192192

193193
@classmethod
194-
def _create_relationship_annotations(cls, label: Label):
194+
def _create_relationship_annotations(
195+
cls, label: Label
196+
) -> Generator[NDRelationship, None, None]:
197+
"""Processes relationship annotations from a label, converting them to NDJSON format.
198+
199+
Args:
200+
label: Label containing relationship annotations to be processed
201+
202+
Yields:
203+
NDRelationship: Validated relationship annotations in NDJSON format
204+
205+
Raises:
206+
TypeError: If source/target types are invalid:
207+
- Source:
208+
- For PDF target annotations (DocumentRectangle, DocumentEntity): source must be ObjectAnnotation or ClassificationAnnotation
209+
- For other target annotations: source must be ObjectAnnotation
210+
- Target:
211+
- Target must always be ObjectAnnotation
212+
"""
195213
for annotation in label.annotations:
196214
if isinstance(annotation, RelationshipAnnotation):
197215
uuid1 = uuid4()
198216
uuid2 = uuid4()
199217
source = copy.copy(annotation.value.source)
200218
target = copy.copy(annotation.value.target)
201-
if not isinstance(source, ObjectAnnotation) or not isinstance(
202-
target, ObjectAnnotation
219+
220+
# Check if source type is valid based on target type
221+
if isinstance(
222+
target.value, (DocumentRectangle, DocumentEntity)
203223
):
224+
if not isinstance(
225+
source, (ObjectAnnotation, ClassificationAnnotation)
226+
):
227+
raise TypeError(
228+
f"Unable to create relationship with invalid source. For PDF targets, "
229+
f"source must be ObjectAnnotation or ClassificationAnnotation. Got: {type(source)}"
230+
)
231+
elif not isinstance(source, ObjectAnnotation):
204232
raise TypeError(
205-
f"Unable to create relationship with non ObjectAnnotations. `Source: {type(source)} Target: {type(target)}`"
233+
f"Unable to create relationship with non ObjectAnnotation source: {type(source)}"
206234
)
235+
236+
# Check if target type is valid
237+
if not isinstance(target, ObjectAnnotation):
238+
raise TypeError(
239+
f"Unable to create relationship with non ObjectAnnotation target: {type(target)}"
240+
)
241+
207242
if not source._uuid:
208243
source._uuid = uuid1
209244
if not target._uuid:

libs/labelbox/tests/data/annotation_import/test_relationships.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,17 @@
1010
RelationshipAnnotation,
1111
Relationship,
1212
TextEntity,
13+
DocumentRectangle,
14+
DocumentEntity,
15+
Point,
16+
Text,
17+
ClassificationAnnotation,
18+
DocumentTextSelection,
19+
Radio,
20+
ClassificationAnswer,
21+
Checklist,
1322
)
23+
from labelbox.data.serialization.ndjson import NDJsonConverter
1424
import pytest
1525

1626

@@ -220,3 +230,110 @@ def test_import_media_types(
220230

221231
assert label_import.state == AnnotationImportState.FINISHED
222232
assert len(label_import.errors) == 0
233+
234+
235+
def test_valid_classification_relationships():
236+
def create_pdf_annotation(target_type: str) -> ObjectAnnotation:
237+
if target_type == "bbox":
238+
return ObjectAnnotation(
239+
name="bbox",
240+
value=DocumentRectangle(
241+
start=Point(x=0, y=0),
242+
end=Point(x=0.5, y=0.5),
243+
page=1,
244+
unit="PERCENT",
245+
),
246+
)
247+
elif target_type == "entity":
248+
return ObjectAnnotation(
249+
name="entity",
250+
value=DocumentEntity(
251+
page=1,
252+
textSelections=[
253+
DocumentTextSelection(token_ids=[], group_id="", page=1)
254+
],
255+
),
256+
)
257+
raise ValueError(f"Unknown target type: {target_type}")
258+
259+
def verify_relationship(
260+
source: ClassificationAnnotation, target: ObjectAnnotation
261+
):
262+
relationship = RelationshipAnnotation(
263+
name="relationship",
264+
value=Relationship(
265+
source=source,
266+
target=target,
267+
type=Relationship.Type.UNIDIRECTIONAL,
268+
),
269+
)
270+
label = Label(
271+
data={"global_key": "global_key"}, annotations=[relationship]
272+
)
273+
result = list(NDJsonConverter.serialize([label]))
274+
assert len(result) == 1
275+
276+
# Test case 1: Text Classification -> DocumentRectangle
277+
text_source = ClassificationAnnotation(
278+
name="text", value=Text(answer="test")
279+
)
280+
verify_relationship(text_source, create_pdf_annotation("bbox"))
281+
282+
# Test case 2: Text Classification -> DocumentEntity
283+
verify_relationship(text_source, create_pdf_annotation("entity"))
284+
285+
# Test case 3: Radio Classification -> DocumentRectangle
286+
radio_source = ClassificationAnnotation(
287+
name="sub_radio_question",
288+
value=Radio(
289+
answer=ClassificationAnswer(
290+
name="first_sub_radio_answer",
291+
classifications=[
292+
ClassificationAnnotation(
293+
name="second_sub_radio_question",
294+
value=Radio(
295+
answer=ClassificationAnswer(
296+
name="second_sub_radio_answer"
297+
)
298+
),
299+
)
300+
],
301+
)
302+
),
303+
)
304+
verify_relationship(radio_source, create_pdf_annotation("bbox"))
305+
306+
# Test case 4: Checklist Classification -> DocumentEntity
307+
checklist_source = ClassificationAnnotation(
308+
name="sub_checklist_question",
309+
value=Checklist(
310+
answer=[ClassificationAnswer(name="first_sub_checklist_answer")]
311+
),
312+
)
313+
verify_relationship(checklist_source, create_pdf_annotation("entity"))
314+
315+
316+
def test_classification_relationship_restrictions():
317+
"""Test all relationship validation error messages."""
318+
text = ClassificationAnnotation(name="text", value=Text(answer="test"))
319+
point = ObjectAnnotation(name="point", value=Point(x=1, y=1))
320+
321+
# Test case: Classification -> Point (invalid)
322+
# Should fail because classifications can only connect to PDF targets
323+
relationship = RelationshipAnnotation(
324+
name="relationship",
325+
value=Relationship(
326+
source=text,
327+
target=point,
328+
type=Relationship.Type.UNIDIRECTIONAL,
329+
),
330+
)
331+
332+
with pytest.raises(
333+
TypeError,
334+
match="Unable to create relationship with non ObjectAnnotation source: .*",
335+
):
336+
label = Label(
337+
data={"global_key": "test_key"}, annotations=[relationship]
338+
)
339+
list(NDJsonConverter.serialize([label]))

0 commit comments

Comments
 (0)