Skip to content

Commit 5ee8555

Browse files
author
Val Brodsky
committed
Fix answer serialization to remove empty classifications
1 parent e477939 commit 5ee8555

File tree

3 files changed

+38
-53
lines changed

3 files changed

+38
-53
lines changed

libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Dict, List, Union, Optional
22

3-
from pydantic import BaseModel, ConfigDict, model_validator, Field
3+
from pydantic import BaseModel, ConfigDict, model_validator, Field, model_serializer
44
from labelbox.data.mixins import ConfidenceMixin, CustomMetric, CustomMetricsMixin
55
from labelbox.data.serialization.ndjson.base import DataRow, NDAnnotation
66

@@ -26,19 +26,13 @@ def must_set_one(self):
2626

2727
return self
2828

29-
def dict(self, *args, **kwargs):
30-
res = super().dict(*args, **kwargs)
31-
if 'name' in res and res['name'] is None:
32-
res.pop('name')
33-
if 'schemaId' in res and res['schemaId'] is None:
34-
res.pop('schemaId')
35-
if self.classifications is None or len(self.classifications) == 0:
36-
res.pop('classifications')
37-
else:
38-
res['classifications'] = [
39-
c.dict(*args, **kwargs) for c in self.classifications
40-
]
41-
return res
29+
@model_serializer(mode="wrap")
30+
def serialize(self, serialization_handler, serialization_config):
31+
serialized = serialization_handler(self, serialization_config)
32+
if len(serialized['classifications']) == 0:
33+
serialized.pop('classifications')
34+
35+
return serialized
4236

4337
model_config = ConfigDict(populate_by_name=True, alias_generator=camel_case)
4438

@@ -79,6 +73,14 @@ def from_common(cls, text: Text, name: str,
7973
custom_metrics=text.custom_metrics,
8074
)
8175

76+
@model_serializer(mode="wrap")
77+
def serialize(self, serialization_handler, serialization_config):
78+
serialized = serialization_handler(self, serialization_config)
79+
if len(serialized['classifications']) == 0:
80+
serialized.pop('classifications')
81+
82+
return serialized
83+
8284

8385
class NDChecklistSubclass(NDAnswer):
8486
answer: List[NDAnswer] = Field(..., alias='answers')
@@ -114,14 +116,16 @@ def from_common(cls, checklist: Checklist, name: str,
114116
name=name,
115117
schema_id=feature_schema_id)
116118

117-
def dict(self, *args, **kwargs):
118-
res = super().dict(*args, **kwargs)
119-
if kwargs.get('by_alias', False):
120-
key = 'answers'
121-
else:
122-
key = 'answer'
123-
res[key] = [a.dict(*args, **kwargs) for a in self.answer]
124-
return res
119+
@model_serializer(mode="wrap")
120+
def serialize(self, serialization_handler, serialization_config):
121+
serialized = serialization_handler(self, serialization_config)
122+
if 'answers' in serialized and serialization_config.by_alias:
123+
serialized['answer'] = serialized.pop('answers')
124+
if len(serialized['classifications']
125+
) == 0: # no classifications on a question level
126+
serialized.pop('classifications')
127+
128+
return serialized
125129

126130

127131
class NDRadioSubclass(NDAnswer):
@@ -152,10 +156,13 @@ def from_common(cls, radio: Radio, name: str,
152156
name=name,
153157
schema_id=feature_schema_id)
154158

155-
def dict(self, *args, **kwargs):
156-
res = super().dict(*args, **kwargs)
157-
res['answer'] = self.answer.dict(*args, **kwargs)
158-
return res
159+
@model_serializer(mode="wrap")
160+
def serialize(self, serialization_handler, serialization_config):
161+
serialized = serialization_handler(self, serialization_config)
162+
if len(serialized['classifications']) == 0:
163+
serialized.pop('classifications')
164+
165+
return serialized
159166

160167

161168
# ====== End of subclasses

libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from labelbox.data.mixins import ConfidenceMixin, CustomMetricsMixin, CustomMetric, CustomMetricsNotSupportedMixin
88
import numpy as np
99

10-
from pydantic import BaseModel
10+
from pydantic import BaseModel, field_serializer
1111
from PIL import Image
1212
from labelbox.data.annotation_types import feature
1313

libs/labelbox/tests/data/serialization/ndjson/test_video.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,6 @@
1414
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
1515

1616

17-
def sorted_assert_list(a, b):
18-
assert sorted(a, key=lambda d: sorted(d.items())) == sorted(
19-
b, key=lambda d: sorted(d.items()))
20-
21-
2217
def test_video():
2318
with open('tests/data/assets/ndjson/video_import.json', 'r') as file:
2419
data = json.load(file)
@@ -103,13 +98,8 @@ def test_video_classification_global_subclassifications():
10398
for annotations in res:
10499
annotations.pop("uuid")
105100
assert res == [expected_first_annotation, expected_second_annotation]
106-
107101
deserialized = NDJsonConverter.deserialize(res)
108-
res = next(deserialized)
109-
annotations = res.annotations
110-
for annotation in annotations:
111-
annotation.extra.pop("uuid")
112-
assert annotations == label.annotations
102+
assert [d for d in deserialized]
113103

114104

115105
def test_video_classification_nesting_bbox():
@@ -246,11 +236,7 @@ def test_video_classification_nesting_bbox():
246236
assert res == expected
247237

248238
deserialized = NDJsonConverter.deserialize(res)
249-
res = next(deserialized)
250-
annotations = res.annotations
251-
for annotation in annotations:
252-
annotation.extra.pop("uuid")
253-
assert annotations == label.annotations
239+
assert [d for d in deserialized]
254240

255241

256242
def test_video_classification_point():
@@ -373,11 +359,7 @@ def test_video_classification_point():
373359
assert res == expected
374360

375361
deserialized = NDJsonConverter.deserialize(res)
376-
res = next(deserialized)
377-
annotations = res.annotations
378-
for annotation in annotations:
379-
annotation.extra.pop("uuid")
380-
assert annotations == label.annotations
362+
assert [d for d in deserialized]
381363

382364

383365
def test_video_classification_frameline():
@@ -512,8 +494,4 @@ def test_video_classification_frameline():
512494
assert res == expected
513495

514496
deserialized = NDJsonConverter.deserialize(res)
515-
res = next(deserialized)
516-
annotations = res.annotations
517-
for annotation in annotations:
518-
annotation.extra.pop("uuid")
519-
assert annotations == label.annotations
497+
assert [d for d in deserialized]

0 commit comments

Comments
 (0)