Skip to content

Commit 65acdd2

Browse files
authored
Merge pull request #202 from Labelbox/ms/lbv1-annotation-types
lbv1-converter
2 parents 944865a + ff1cdd6 commit 65acdd2

File tree

12 files changed

+812
-1
lines changed

12 files changed

+812
-1
lines changed

labelbox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
name = "labelbox"
22
__version__ = "2.7.0"
33

4+
from labelbox.schema.project import Project
45
from labelbox.client import Client
56
from labelbox.schema.bulk_import_request import BulkImportRequest
6-
from labelbox.schema.project import Project
77
from labelbox.schema.dataset import Dataset
88
from labelbox.schema.data_row import DataRow
99
from labelbox.schema.label import Label
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .labelbox_v1 import LBV1Converter
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .converter import LBV1Converter
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from typing import List, Union
2+
3+
from pydantic.main import BaseModel
4+
5+
from ...annotation_types.annotation import ClassificationAnnotation
6+
from ...annotation_types.classification import Checklist, ClassificationAnswer, Radio, Text, Dropdown
7+
from ...annotation_types.types import Cuid
8+
from .feature import LBV1Feature
9+
10+
11+
class LBV1ClassificationAnswer(LBV1Feature):
12+
...
13+
14+
15+
class LBV1Radio(LBV1Feature):
16+
answer: LBV1ClassificationAnswer
17+
18+
def to_common(self):
19+
return Radio(answer=ClassificationAnswer(
20+
schema_id=self.answer.schema_id,
21+
name=self.answer.title,
22+
extra={
23+
'feature_id': self.answer.feature_id,
24+
'value': self.answer.value
25+
}))
26+
27+
@classmethod
28+
def from_common(cls, radio: Radio, schema_id: Cuid, **extra) -> "LBV1Radio":
29+
return cls(schema_id=schema_id,
30+
answer=LBV1ClassificationAnswer(
31+
schema_id=radio.answer.schema_id,
32+
title=radio.answer.name,
33+
value=radio.answer.extra.get('value'),
34+
feature_id=radio.answer.extra.get('feature_id')),
35+
**extra)
36+
37+
38+
class LBV1Checklist(LBV1Feature):
39+
answers: List[LBV1ClassificationAnswer]
40+
41+
def to_common(self):
42+
return Checklist(answer=[
43+
ClassificationAnswer(schema_id=answer.schema_id,
44+
name=answer.title,
45+
extra={
46+
'feature_id': answer.feature_id,
47+
'value': answer.value
48+
}) for answer in self.answers
49+
])
50+
51+
@classmethod
52+
def from_common(cls, checklist: Checklist, schema_id: Cuid,
53+
**extra) -> "LBV1Checklist":
54+
return cls(schema_id=schema_id,
55+
answers=[
56+
LBV1ClassificationAnswer(
57+
schema_id=answer.schema_id,
58+
title=answer.name,
59+
value=answer.extra.get('value'),
60+
feature_id=answer.extra.get('feature_id'))
61+
for answer in checklist.answer
62+
],
63+
**extra)
64+
65+
66+
class LBV1Text(LBV1Feature):
67+
answer: str
68+
69+
def to_common(self):
70+
return Text(answer=self.answer)
71+
72+
@classmethod
73+
def from_common(cls, text: Text, schema_id: Cuid, **extra) -> "LBV1Text":
74+
return cls(schema_id=schema_id, answer=text.answer, **extra)
75+
76+
77+
class LBV1Classifications(BaseModel):
78+
classifications: List[Union[LBV1Radio, LBV1Checklist, LBV1Text]] = []
79+
80+
def to_common(self) -> List[ClassificationAnnotation]:
81+
classifications = [
82+
ClassificationAnnotation(value=classification.to_common(),
83+
classifications=[],
84+
name=classification.title,
85+
extra={
86+
'value': classification.value,
87+
'feature_id': classification.feature_id
88+
})
89+
for classification in self.classifications
90+
]
91+
return classifications
92+
93+
@classmethod
94+
def from_common(
95+
cls, annotations: List[ClassificationAnnotation]
96+
) -> "LBV1Classifications":
97+
classifications = []
98+
for annotation in annotations:
99+
classification = cls.lookup_classification(annotation)
100+
if classification is not None:
101+
classifications.append(
102+
classification.from_common(annotation.value,
103+
annotation.schema_id,
104+
**annotation.extra))
105+
else:
106+
raise TypeError(f"Unexpected type {type(annotation.value)}")
107+
return cls(classifications=classifications)
108+
109+
@staticmethod
110+
def lookup_classification(
111+
annotation: ClassificationAnnotation
112+
) -> Union[LBV1Text, LBV1Checklist, LBV1Radio]:
113+
return {
114+
Text: LBV1Text,
115+
Dropdown: LBV1Checklist,
116+
Checklist: LBV1Checklist,
117+
Radio: LBV1Radio
118+
}.get(type(annotation.value))
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from typing import Any, Dict, Generator, Iterable
2+
import logging
3+
4+
import ndjson
5+
import requests
6+
from copy import deepcopy
7+
from requests.exceptions import HTTPError
8+
from google.api_core import retry
9+
10+
import labelbox
11+
from .label import LBV1Label
12+
from ...annotation_types.collection import (LabelCollection, LabelGenerator,
13+
PrefetchGenerator)
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class LBV1Converter:
19+
20+
@staticmethod
21+
def deserialize_video(json_data: Iterable[Dict[str, Any]],
22+
client: labelbox.Client):
23+
"""
24+
Converts a labelbox video export into the common labelbox format.
25+
26+
Args:
27+
json_data: An iterable representing the labelbox video export.
28+
client: The labelbox client for downloading video annotations
29+
Returns:
30+
LabelGenerator containing the video data.
31+
"""
32+
label_generator = (LBV1Label(**example).to_common()
33+
for example in LBV1VideoIterator(json_data, client))
34+
return LabelGenerator(data=label_generator)
35+
36+
@staticmethod
37+
def deserialize(json_data: Iterable[Dict[str, Any]]) -> LabelGenerator:
38+
"""
39+
Converts a labelbox export (non-video) into the common labelbox format.
40+
41+
Args:
42+
json_data: An iterable representing the labelbox export.
43+
Returns:
44+
LabelGenerator containing the export data.
45+
"""
46+
47+
def label_generator():
48+
for example in json_data:
49+
if 'frames' in example['Label']:
50+
raise ValueError(
51+
"Use `LBV1Converter.deserialize_video` to process video"
52+
)
53+
yield LBV1Label(**example).to_common()
54+
55+
return LabelGenerator(data=label_generator())
56+
57+
@staticmethod
58+
def serialize(
59+
labels: LabelCollection) -> Generator[Dict[str, Any], None, None]:
60+
"""
61+
Converts a labelbox common object to the labelbox json export format
62+
63+
Args:
64+
labels: Either a LabelList or a LabelGenerator (LabelCollection)
65+
Returns:
66+
A generator for accessing the labelbox json export representation of the data
67+
"""
68+
for label in labels:
69+
res = LBV1Label.from_common(label)
70+
yield res.dict(by_alias=True)
71+
72+
73+
class LBV1VideoIterator(PrefetchGenerator):
74+
"""
75+
Generator that fetches video annotations in the background to be faster.
76+
"""
77+
78+
def __init__(self, examples, client):
79+
self.client = client
80+
super().__init__(examples)
81+
82+
def _process(self, value):
83+
value = deepcopy(value)
84+
if 'frames' in value['Label']:
85+
req = self._request(value)
86+
value['Label'] = ndjson.loads(req)
87+
return value
88+
89+
@retry.Retry(predicate=retry.if_exception_type(HTTPError))
90+
def _request(self, value):
91+
req = requests.get(
92+
value['Label']['frames'],
93+
headers={"Authorization": f"Bearer {self.client.api_key}"})
94+
if req.status_code == 401:
95+
raise labelbox.exceptions.AuthenticationError("Invalid API key")
96+
req.raise_for_status()
97+
return req.text
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import Optional
2+
3+
from pydantic import BaseModel, root_validator
4+
5+
from labelbox.utils import camel_case
6+
from ...annotation_types.types import Cuid
7+
8+
9+
class LBV1Feature(BaseModel):
10+
keyframe: Optional[bool] = None
11+
title: str = None
12+
value: Optional[str] = None
13+
schema_id: Optional[Cuid] = None
14+
feature_id: Optional[Cuid] = None
15+
16+
@root_validator
17+
def check_ids(cls, values):
18+
if values.get('value') is None:
19+
values['value'] = values['title']
20+
return values
21+
22+
def dict(self, *args, **kwargs):
23+
res = super().dict(*args, **kwargs)
24+
# This means these are no video frames ..
25+
if self.keyframe is None:
26+
res.pop('keyframe')
27+
return res
28+
29+
class Config:
30+
allow_population_by_field_name = True
31+
alias_generator = camel_case

0 commit comments

Comments
 (0)