Skip to content

Commit ee83127

Browse files
author
Matt Sokoloff
committed
bug fix
1 parent 1254354 commit ee83127

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

labelbox/data/serialization/labelbox_v1/classification.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Union
22

33
from pydantic.main import BaseModel
4+
from pydantic.schema import schema
45

56
from ...annotation_types.annotation import ClassificationAnnotation
67
from ...annotation_types.classification import Checklist, ClassificationAnswer, Radio, Text, Dropdown
@@ -15,10 +16,8 @@ class LBV1ClassificationAnswer(LBV1Feature):
1516
class LBV1Radio(LBV1Feature):
1617
answer: LBV1ClassificationAnswer
1718

18-
def to_common(self):
19-
return Radio(answer=ClassificationAnswer(
20-
feature_schema_id=self.answer.schema_id,
21-
name=self.answer.title,
19+
def to_common(self) -> Radio:
20+
return Radio(answer=ClassificationAnswer(feature_schema_id=self.answer.schema_id, name=self.answer.title,
2221
extra={
2322
'feature_id': self.answer.feature_id,
2423
'value': self.answer.value
@@ -27,7 +26,8 @@ def to_common(self):
2726
@classmethod
2827
def from_common(cls, radio: Radio, feature_schema_id: Cuid,
2928
**extra) -> "LBV1Radio":
30-
return cls(schema_id=feature_schema_id,
29+
return cls(
30+
schema_id=feature_schema_id,
3131
answer=LBV1ClassificationAnswer(
3232
schema_id=radio.answer.feature_schema_id,
3333
title=radio.answer.name,
@@ -39,8 +39,9 @@ def from_common(cls, radio: Radio, feature_schema_id: Cuid,
3939
class LBV1Checklist(LBV1Feature):
4040
answers: List[LBV1ClassificationAnswer]
4141

42-
def to_common(self):
43-
return Checklist(answer=[
42+
def to_common(self) -> Checklist:
43+
return Checklist(
44+
answer=[
4445
ClassificationAnswer(feature_schema_id=answer.schema_id,
4546
name=answer.title,
4647
extra={
@@ -64,6 +65,34 @@ def from_common(cls, checklist: Checklist, feature_schema_id: Cuid,
6465
**extra)
6566

6667

68+
class LBV1Dropdown(LBV1Feature):
69+
answer: List[LBV1ClassificationAnswer]
70+
def to_common(self) -> Dropdown:
71+
return Dropdown(
72+
answer=[
73+
ClassificationAnswer(feature_schema_id=answer.schema_id,
74+
name=answer.title,
75+
extra={
76+
'feature_id': answer.feature_id,
77+
'value': answer.value
78+
}) for answer in self.answer
79+
])
80+
81+
@classmethod
82+
def from_common(cls, dropdown: Dropdown, feature_schema_id: Cuid,
83+
**extra) -> "LBV1Dropdown":
84+
return cls(schema_id = feature_schema_id,
85+
answers=[
86+
LBV1ClassificationAnswer(
87+
schema_id=answer.feature_schema_id,
88+
title=answer.name,
89+
value=answer.extra.get('value'),
90+
feature_id=answer.extra.get('feature_id'))
91+
for answer in dropdown.answer
92+
],
93+
**extra)
94+
95+
6796
class LBV1Text(LBV1Feature):
6897
answer: str
6998

@@ -77,7 +106,7 @@ def from_common(cls, text: Text, feature_schema_id: Cuid,
77106

78107

79108
class LBV1Classifications(BaseModel):
80-
classifications: List[Union[LBV1Radio, LBV1Checklist, LBV1Text]] = []
109+
classifications: List[Union[LBV1Text, LBV1Radio, LBV1Dropdown, LBV1Checklist]] = []
81110

82111
def to_common(self) -> List[ClassificationAnnotation]:
83112
classifications = [
@@ -112,10 +141,10 @@ def from_common(
112141
@staticmethod
113142
def lookup_classification(
114143
annotation: ClassificationAnnotation
115-
) -> Union[LBV1Text, LBV1Checklist, LBV1Radio]:
144+
) -> Union[LBV1Text, LBV1Checklist, LBV1Radio, LBV1Checklist]:
116145
return {
117146
Text: LBV1Text,
118-
Dropdown: LBV1Checklist,
147+
Dropdown: LBV1Dropdown,
119148
Checklist: LBV1Checklist,
120149
Radio: LBV1Radio
121150
}.get(type(annotation.value))

labelbox/data/serialization/ndjson/label.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from itertools import groupby
2+
from labelbox.data.annotation_types.classification.classification import Dropdown
23
from labelbox.data.annotation_types.metrics import ScalarMetric
34

45
from operator import itemgetter
@@ -102,6 +103,9 @@ def _create_non_video_annotations(cls, label: Label):
102103
]
103104
for annotation in non_video_annotations:
104105
if isinstance(annotation, ClassificationAnnotation):
106+
if isinstance(annotation.value, Dropdown):
107+
raise ValueError("Dropdowns are not supported by the NDJson format."
108+
" Please filter out Dropdown annotations before converting.")
105109
yield NDClassification.from_common(annotation, label.data)
106110
elif isinstance(annotation, ObjectAnnotation):
107111
yield NDObject.from_common(annotation, label.data)

0 commit comments

Comments
 (0)