Skip to content

Commit 76dd198

Browse files
author
Val Brodsky
committed
Fix label serialization
1 parent 521f685 commit 76dd198

File tree

6 files changed

+36
-15
lines changed

6 files changed

+36
-15
lines changed

libs/labelbox/src/labelbox/data/annotation_types/feature.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ class FeatureSchema(BaseModel):
1414
name: Optional[str] = None
1515
feature_schema_id: Optional[Cuid] = None
1616

17-
@model_validator(mode='before')
18-
@classmethod
19-
def must_set_one(cls, values):
20-
if values['feature_schema_id'] is None and values['name'] is None:
17+
@model_validator(mode='after')
18+
def must_set_one(self):
19+
if self.feature_schema_id is None and self.name is None:
2120
raise ValueError(
2221
"Must set either feature_schema_id or name for all feature schemas"
2322
)
24-
return values
23+
24+
return self
2525

2626
def dict(self, *args, **kwargs):
2727
res = super().dict(*args, **kwargs)

libs/labelbox/src/labelbox/data/serialization/ndjson/base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77

88

99
class DataRow(_CamelCaseMixin):
10-
id: str = None
11-
global_key: str = None
10+
id: Optional[str] = None
11+
global_key: Optional[str] = None
1212

1313
@model_validator(mode='after')
1414
def must_set_one(self):
15-
if not is_exactly_one_set(self.id, self.global_key):
16-
raise ValueError("Must set either id or global_key")
15+
if self.id is None and self.global_key is None:
16+
raise ValueError(
17+
"Must set either id or global_key for all data rows")
1718

1819
return self
1920

libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,11 @@ def from_common(cls, checklist: Checklist, name: str,
116116

117117
def dict(self, *args, **kwargs):
118118
res = super().dict(*args, **kwargs)
119-
if 'answers' in res:
120-
res['answer'] = res.pop('answers')
119+
if kwargs.get('by_alias', False):
120+
key = 'answers'
121+
else:
122+
key = 'answer'
123+
res[key] = [a.dict(*args, **kwargs) for a in self.answer]
121124
return res
122125

123126

@@ -149,6 +152,11 @@ def from_common(cls, radio: Radio, name: str,
149152
name=name,
150153
schema_id=feature_schema_id)
151154

155+
def dict(self, *args, **kwargs):
156+
res = super().dict(*args, **kwargs)
157+
res['answer'] = self.answer.dict(*args, **kwargs)
158+
return res
159+
152160

153161
# ====== End of subclasses
154162

libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,10 @@ def serialize(
108108
label.annotations = uuid_safe_annotations
109109
for example in NDLabel.from_common([label]):
110110
annotation_uuid = getattr(example, "uuid", None)
111-
112-
res = example.dict(
111+
res = example.model_dump(
113112
by_alias=True,
114113
exclude={"uuid"} if annotation_uuid == "None" else None,
114+
exclude_none=True,
115115
)
116116
for k, v in list(res.items()):
117117
if k in IGNORE_IF_NONE and v is None:

libs/labelbox/src/labelbox/data/serialization/ndjson/label.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ def _infer_media_type(
163163
data = VideoData
164164
elif DICOMObjectAnnotation in types:
165165
data = DicomData
166-
167166
if data_row.id:
168167
return data(uid=data_row.id)
169168
else:

libs/labelbox/tests/data/serialization/ndjson/test_video.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,26 @@
1414
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
1515

1616

17+
def sorted_assert_list(a, b):
18+
assert sorted(a, key=lambda d: sorted(d.items())) == sorted(
19+
b, key=lambda d: sorted(d.items()))
20+
21+
1722
def test_video():
1823
with open('tests/data/assets/ndjson/video_import.json', 'r') as file:
1924
data = json.load(file)
2025

2126
res = list(NDJsonConverter.deserialize(data))
2227
res = list(NDJsonConverter.serialize(res))
23-
assert res == [data[2], data[0], data[1], data[3], data[4], data[5]]
28+
# assert res == [data[2], data[0], data[1], data[3], data[4], data[5]]
29+
assert (res[0]) == data[2]
30+
assert (res[1]) == data[0]
31+
answers = data[1].pop('answer')
32+
data[1]['answers'] = answers
33+
assert (res[2]) == data[1]
34+
assert (res[3]) == data[3]
35+
assert (res[4]) == data[4]
36+
assert (res[5]) == data[5]
2437

2538

2639
def test_video_name_only():

0 commit comments

Comments
 (0)