Skip to content

Commit 57bdb0b

Browse files
authored
Vb/remove label data plt 37 (#1527)
2 parents c0c45df + 556207b commit 57bdb0b

File tree

8 files changed

+519
-74
lines changed

8 files changed

+519
-74
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Callable, Literal, Optional
2+
3+
from labelbox import pydantic_compat
4+
from labelbox.data.annotation_types.data.base_data import BaseData
5+
from labelbox.utils import _NoCoercionMixin
6+
7+
8+
class GenericDataRowData(BaseData, _NoCoercionMixin):
9+
"""Generic data row data. This is replacing all other DataType passed into Label
10+
"""
11+
url: Optional[str] = None
12+
class_name: Literal["GenericDataRowData"] = "GenericDataRowData"
13+
14+
def create_url(self, signer: Callable[[bytes], str]) -> Optional[str]:
15+
return self.url
16+
17+
@pydantic_compat.root_validator(pre=True)
18+
def validate_one_datarow_key_present(cls, data):
19+
keys = ['external_id', 'global_key', 'uid']
20+
count = sum([key in data for key in keys])
21+
22+
if count < 1:
23+
raise ValueError(f"Exactly one of {keys} must be present.")
24+
if count > 1:
25+
raise ValueError(f"Only one of {keys} can be present.")
26+
return data

labelbox/data/annotation_types/label.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
from labelbox import pydantic_compat
66

77
import labelbox
8+
from labelbox.data.annotation_types.data.generic_data_row_data import GenericDataRowData
89
from labelbox.data.annotation_types.data.tiled_image import TiledImageData
910
from labelbox.schema import ontology
1011
from .annotation import ClassificationAnnotation, ObjectAnnotation
1112
from .relationship import RelationshipAnnotation
1213
from .classification import ClassificationAnswer
13-
from .data import AudioData, ConversationData, DicomData, DocumentData, HTMLData, ImageData, MaskData, TextData, VideoData, LlmPromptCreationData, LlmPromptResponseCreationData, LlmResponseCreationData
14+
from .data import AudioData, ConversationData, DicomData, DocumentData, HTMLData, ImageData, TextData, VideoData, LlmPromptCreationData, LlmPromptResponseCreationData, LlmResponseCreationData
1415
from .geometry import Mask
1516
from .metrics import ScalarMetric, ConfusionMatrixMetric
1617
from .types import Cuid
@@ -21,14 +22,14 @@
2122
DataType = Union[VideoData, ImageData, TextData, TiledImageData, AudioData,
2223
ConversationData, DicomData, DocumentData, HTMLData,
2324
LlmPromptCreationData, LlmPromptResponseCreationData,
24-
LlmResponseCreationData]
25+
LlmResponseCreationData, GenericDataRowData]
2526

2627

2728
class Label(pydantic_compat.BaseModel):
2829
"""Container for holding data and annotations
2930
3031
>>> Label(
31-
>>> data = ImageData(url = "http://my-img.jpg"),
32+
>>> data = {'global_key': 'my-data-row-key'} # also accepts uid, external_id as keys
3233
>>> annotations = [
3334
>>> ObjectAnnotation(
3435
>>> value = Point(x = 10, y = 10),
@@ -39,7 +40,8 @@ class Label(pydantic_compat.BaseModel):
3940
4041
Args:
4142
uid: Optional Label Id in Labelbox
42-
data: Data of Label, Image, Video, Text
43+
data: Data of Label, Image, Video, Text or dict with a single key uid | global_key | external_id.
44+
Note use of classes as data is deprecated. Use GenericDataRowData or dict with a single key instead.
4345
annotations: List of Annotations in the label
4446
extra: additional context
4547
"""
@@ -51,6 +53,16 @@ class Label(pydantic_compat.BaseModel):
5153
RelationshipAnnotation]] = []
5254
extra: Dict[str, Any] = {}
5355

56+
@pydantic_compat.root_validator(pre=True)
57+
def validate_data(cls, label):
58+
if isinstance(label.get("data"), Dict):
59+
label["data"]["class_name"] = "GenericDataRowData"
60+
else:
61+
warnings.warn(
62+
f"Using {type(label['data']).__name__} class for label.data is deprecated. "
63+
"Use a dict or an instance of GenericDataRowData instead.")
64+
return label
65+
5466
def object_annotations(self) -> List[ObjectAnnotation]:
5567
return self._get_annotations_by_type(ObjectAnnotation)
5668

labelbox/schema/id_type.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
1-
from strenum import StrEnum
1+
import sys
22

3+
if sys.version_info >= (3, 9):
4+
from strenum import StrEnum
35

4-
class IdType(StrEnum):
6+
class BaseStrEnum(StrEnum):
7+
pass
8+
else:
9+
from enum import Enum
10+
11+
class BaseStrEnum(str, Enum):
12+
pass
13+
14+
15+
class IdType(BaseStrEnum):
516
"""
617
The type of id used to identify a data row.
718

tests/conftest.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,13 @@ def rest_url(environ: str) -> str:
121121
return 'http://host.docker.internal:8080/api/v1'
122122

123123

124-
def testing_api_key(environ: str) -> str:
125-
for var in [
126-
"LABELBOX_TEST_API_KEY_PROD", "LABELBOX_TEST_API_KEY_STAGING",
127-
"LABELBOX_TEST_API_KEY_CUSTOM", "LABELBOX_TEST_API_KEY_LOCAL",
128-
"LABELBOX_TEST_API_KEY"
129-
]:
130-
value = os.environ.get(var)
124+
def testing_api_key(environ: Environ) -> str:
125+
keys = [
126+
f"LABELBOX_TEST_API_KEY_{environ.value.upper()}",
127+
"LABELBOX_TEST_API_KEY"
128+
]
129+
for key in keys:
130+
value = os.environ.get(key)
131131
if value is not None:
132132
return value
133133
raise Exception("Cannot find API to use for tests")
@@ -147,7 +147,6 @@ def __init__(self, environ: str) -> None:
147147
api_url = graphql_url(environ)
148148
api_key = testing_api_key(environ)
149149
rest_endpoint = rest_url(environ)
150-
151150
super().__init__(api_key,
152151
api_url,
153152
enable_experimental=True,

tests/data/annotation_import/conftest.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1885,3 +1885,64 @@ def bbox_video_annotation_objects():
18851885
]
18861886

18871887
return bbox_annotation
1888+
1889+
1890+
class Helpers:
1891+
1892+
@staticmethod
1893+
def remove_keys_recursive(d, keys):
1894+
for k in keys:
1895+
if k in d:
1896+
del d[k]
1897+
for k, v in d.items():
1898+
if isinstance(v, dict):
1899+
Helpers.remove_keys_recursive(v, keys)
1900+
elif isinstance(v, list):
1901+
for i in v:
1902+
if isinstance(i, dict):
1903+
Helpers.remove_keys_recursive(i, keys)
1904+
1905+
@staticmethod
1906+
# NOTE this uses quite a primitive check for cuids but I do not think it is worth coming up with a better one
1907+
# Also this function is NOT written with performance in mind, good for small to mid size dicts like we have in our test
1908+
def rename_cuid_key_recursive(d):
1909+
new_key = "<cuid>"
1910+
for k in list(d.keys()):
1911+
if len(k) == 25 and not k.isalpha(): # primitive check for cuid
1912+
d[new_key] = d.pop(k)
1913+
for k, v in d.items():
1914+
if isinstance(v, dict):
1915+
Helpers.rename_cuid_key_recursive(v)
1916+
elif isinstance(v, list):
1917+
for i in v:
1918+
if isinstance(i, dict):
1919+
Helpers.rename_cuid_key_recursive(i)
1920+
1921+
@staticmethod
1922+
def set_project_media_type_from_data_type(project, data_type_class):
1923+
1924+
def to_pascal_case(name: str) -> str:
1925+
return "".join([word.capitalize() for word in name.split("_")])
1926+
1927+
data_type_string = data_type_class.__name__[:-4].lower()
1928+
media_type = to_pascal_case(data_type_string)
1929+
if media_type == "Conversation":
1930+
media_type = "Conversational"
1931+
elif media_type == "Llmpromptcreation":
1932+
media_type = "LLMPromptCreation"
1933+
elif media_type == "Llmpromptresponsecreation":
1934+
media_type = "LLMPromptResponseCreation"
1935+
elif media_type == "Llmresponsecreation":
1936+
media_type = "Text"
1937+
elif media_type == "Genericdatarow":
1938+
media_type = "Image"
1939+
project.update(media_type=MediaType[media_type])
1940+
1941+
@staticmethod
1942+
def find_data_row_filter(data_row):
1943+
return lambda dr: dr['data_row']['id'] == data_row.uid
1944+
1945+
1946+
@pytest.fixture
1947+
def helpers():
1948+
return Helpers

tests/data/annotation_import/test_data_types.py

Lines changed: 18 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import labelbox as lb
77
from labelbox.data.annotation_types.data.video import VideoData
8-
from labelbox.schema.data_row import DataRow
98
from labelbox.schema.media_type import MediaType
109
import labelbox.types as lb_types
1110
from labelbox.data.annotation_types.data import (
@@ -70,35 +69,6 @@
7069
]
7170

7271

73-
def remove_keys_recursive(d, keys):
74-
for k in keys:
75-
if k in d:
76-
del d[k]
77-
for k, v in d.items():
78-
if isinstance(v, dict):
79-
remove_keys_recursive(v, keys)
80-
elif isinstance(v, list):
81-
for i in v:
82-
if isinstance(i, dict):
83-
remove_keys_recursive(i, keys)
84-
85-
86-
# NOTE this uses quite a primitive check for cuids but I do not think it is worth coming up with a better one
87-
# Also this function is NOT written with performance in mind, good for small to mid size dicts like we have in our test
88-
def rename_cuid_key_recursive(d):
89-
new_key = "<cuid>"
90-
for k in list(d.keys()):
91-
if len(k) == 25 and not k.isalpha(): # primitive check for cuid
92-
d[new_key] = d.pop(k)
93-
for k, v in d.items():
94-
if isinstance(v, dict):
95-
rename_cuid_key_recursive(v)
96-
elif isinstance(v, list):
97-
for i in v:
98-
if isinstance(i, dict):
99-
rename_cuid_key_recursive(i)
100-
101-
10272
def get_annotation_comparison_dicts_from_labels(labels):
10373
labels_ndjson = list(NDJsonConverter.serialize(labels))
10474
for annotation in labels_ndjson:
@@ -198,12 +168,13 @@ def test_import_data_types(
198168
data_row_json_by_data_type,
199169
annotations_by_data_type,
200170
data_type_class,
171+
helpers,
201172
):
202173
project = configured_project
203174
project_id = project.uid
204175
dataset = initial_dataset
205176

206-
set_project_media_type_from_data_type(project, data_type_class)
177+
helpers.set_project_media_type_from_data_type(project, data_type_class)
207178

208179
data_type_string = data_type_class.__name__[:-4].lower()
209180
data_row_ndjson = data_row_json_by_data_type[data_type_string]
@@ -241,12 +212,13 @@ def test_import_data_types_by_global_key(
241212
rand_gen,
242213
data_row_json_by_data_type,
243214
annotations_by_data_type,
215+
helpers,
244216
):
245217
project = configured_project
246218
project_id = project.uid
247219
dataset = initial_dataset
248220
data_type_class = ImageData
249-
set_project_media_type_from_data_type(project, data_type_class)
221+
helpers.set_project_media_type_from_data_type(project, data_type_class)
250222

251223
data_row_ndjson = data_row_json_by_data_type["image"]
252224
data_row_ndjson["global_key"] = str(uuid.uuid4())
@@ -287,24 +259,6 @@ def validate_iso_format(date_string: str):
287259
assert parsed_t.second is not None
288260

289261

290-
def to_pascal_case(name: str) -> str:
291-
return "".join([word.capitalize() for word in name.split("_")])
292-
293-
294-
def set_project_media_type_from_data_type(project, data_type_class):
295-
data_type_string = data_type_class.__name__[:-4].lower()
296-
media_type = to_pascal_case(data_type_string)
297-
if media_type == "Conversation":
298-
media_type = "Conversational"
299-
elif media_type == "Llmpromptcreation":
300-
media_type = "LLMPromptCreation"
301-
elif media_type == "Llmpromptresponsecreation":
302-
media_type = "LLMPromptResponseCreation"
303-
elif media_type == "Llmresponsecreation":
304-
media_type = "Text"
305-
project.update(media_type=MediaType[media_type])
306-
307-
308262
@pytest.mark.parametrize(
309263
"data_type_class",
310264
[
@@ -331,12 +285,13 @@ def test_import_data_types_v2(
331285
exports_v2_by_data_type,
332286
export_v2_test_helpers,
333287
rand_gen,
288+
helpers,
334289
):
335290
project = configured_project
336291
dataset = initial_dataset
337292
project_id = project.uid
338293

339-
set_project_media_type_from_data_type(project, data_type_class)
294+
helpers.set_project_media_type_from_data_type(project, data_type_class)
340295

341296
data_type_string = data_type_class.__name__[:-4].lower()
342297
data_row_ndjson = data_row_json_by_data_type[data_type_string]
@@ -381,9 +336,9 @@ def test_import_data_types_v2(
381336
exported_project_labels = exported_project["labels"][0]
382337
exported_annotations = exported_project_labels["annotations"]
383338

384-
remove_keys_recursive(exported_annotations,
385-
["feature_id", "feature_schema_id"])
386-
rename_cuid_key_recursive(exported_annotations)
339+
helpers.remove_keys_recursive(exported_annotations,
340+
["feature_id", "feature_schema_id"])
341+
helpers.rename_cuid_key_recursive(exported_annotations)
387342
assert exported_annotations == exports_v2_by_data_type[data_type_string]
388343

389344
data_row = client.get_data_row(data_row.uid)
@@ -400,10 +355,11 @@ def test_import_label_annotations(
400355
data_class,
401356
annotations,
402357
rand_gen,
358+
helpers,
403359
):
404360
project = configured_project_with_one_data_row
405361
dataset = initial_dataset
406-
set_project_media_type_from_data_type(project, data_class)
362+
helpers.set_project_media_type_from_data_type(project, data_class)
407363

408364
data_row_json = data_row_json_by_data_type[data_type]
409365
data_row = create_data_row_for_project(project, dataset, data_row_json,
@@ -471,10 +427,11 @@ def test_import_mal_annotations(
471427
annotations,
472428
rand_gen,
473429
one_datarow,
430+
helpers,
474431
):
475432
data_row = one_datarow
476-
set_project_media_type_from_data_type(configured_project_with_one_data_row,
477-
data_class)
433+
helpers.set_project_media_type_from_data_type(
434+
configured_project_with_one_data_row, data_class)
478435

479436
configured_project_with_one_data_row.create_batch(
480437
rand_gen(str),
@@ -500,12 +457,13 @@ def test_import_mal_annotations(
500457

501458
def test_import_mal_annotations_global_key(client,
502459
configured_project_with_one_data_row,
503-
rand_gen, one_datarow_global_key):
460+
rand_gen, one_datarow_global_key,
461+
helpers):
504462
data_class = lb_types.VideoData
505463
data_row = one_datarow_global_key
506464
annotations = [video_mask_annotation]
507-
set_project_media_type_from_data_type(configured_project_with_one_data_row,
508-
data_class)
465+
helpers.set_project_media_type_from_data_type(
466+
configured_project_with_one_data_row, data_class)
509467

510468
configured_project_with_one_data_row.create_batch(
511469
rand_gen(str),

0 commit comments

Comments
 (0)