Skip to content

Commit 556207b

Browse files
author
Val Brodsky
committed
PR feedback
1 parent 11490ee commit 556207b

File tree

4 files changed

+53
-61
lines changed

4 files changed

+53
-61
lines changed

labelbox/data/annotation_types/data/generic_data_row_data.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@ def create_url(self, signer: Callable[[bytes], str]) -> Optional[str]:
1717
@pydantic_compat.root_validator(pre=True)
1818
def validate_one_datarow_key_present(cls, data):
1919
keys = ['external_id', 'global_key', 'uid']
20-
count = 0
21-
for key in keys:
22-
if data.get(key):
23-
count += 1
20+
count = sum([key in data for key in keys])
21+
2422
if count < 1:
2523
raise ValueError(f"Exactly one of {keys} must be present.")
2624
if count > 1:

tests/data/annotation_import/conftest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1918,6 +1918,30 @@ def rename_cuid_key_recursive(d):
19181918
if isinstance(i, dict):
19191919
Helpers.rename_cuid_key_recursive(i)
19201920

1921+
@staticmethod
1922+
def set_project_media_type_from_data_type(project, data_type_class):
1923+
1924+
def to_pascal_case(name: str) -> str:
1925+
return "".join([word.capitalize() for word in name.split("_")])
1926+
1927+
data_type_string = data_type_class.__name__[:-4].lower()
1928+
media_type = to_pascal_case(data_type_string)
1929+
if media_type == "Conversation":
1930+
media_type = "Conversational"
1931+
elif media_type == "Llmpromptcreation":
1932+
media_type = "LLMPromptCreation"
1933+
elif media_type == "Llmpromptresponsecreation":
1934+
media_type = "LLMPromptResponseCreation"
1935+
elif media_type == "Llmresponsecreation":
1936+
media_type = "Text"
1937+
elif media_type == "Genericdatarow":
1938+
media_type = "Image"
1939+
project.update(media_type=MediaType[media_type])
1940+
1941+
@staticmethod
1942+
def find_data_row_filter(data_row):
1943+
return lambda dr: dr['data_row']['id'] == data_row.uid
1944+
19211945

19221946
@pytest.fixture
19231947
def helpers():

tests/data/annotation_import/test_data_types.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,13 @@ def test_import_data_types(
168168
data_row_json_by_data_type,
169169
annotations_by_data_type,
170170
data_type_class,
171+
helpers,
171172
):
172173
project = configured_project
173174
project_id = project.uid
174175
dataset = initial_dataset
175176

176-
set_project_media_type_from_data_type(project, data_type_class)
177+
helpers.set_project_media_type_from_data_type(project, data_type_class)
177178

178179
data_type_string = data_type_class.__name__[:-4].lower()
179180
data_row_ndjson = data_row_json_by_data_type[data_type_string]
@@ -211,12 +212,13 @@ def test_import_data_types_by_global_key(
211212
rand_gen,
212213
data_row_json_by_data_type,
213214
annotations_by_data_type,
215+
helpers,
214216
):
215217
project = configured_project
216218
project_id = project.uid
217219
dataset = initial_dataset
218220
data_type_class = ImageData
219-
set_project_media_type_from_data_type(project, data_type_class)
221+
helpers.set_project_media_type_from_data_type(project, data_type_class)
220222

221223
data_row_ndjson = data_row_json_by_data_type["image"]
222224
data_row_ndjson["global_key"] = str(uuid.uuid4())
@@ -257,24 +259,6 @@ def validate_iso_format(date_string: str):
257259
assert parsed_t.second is not None
258260

259261

260-
def to_pascal_case(name: str) -> str:
261-
return "".join([word.capitalize() for word in name.split("_")])
262-
263-
264-
def set_project_media_type_from_data_type(project, data_type_class):
265-
data_type_string = data_type_class.__name__[:-4].lower()
266-
media_type = to_pascal_case(data_type_string)
267-
if media_type == "Conversation":
268-
media_type = "Conversational"
269-
elif media_type == "Llmpromptcreation":
270-
media_type = "LLMPromptCreation"
271-
elif media_type == "Llmpromptresponsecreation":
272-
media_type = "LLMPromptResponseCreation"
273-
elif media_type == "Llmresponsecreation":
274-
media_type = "Text"
275-
project.update(media_type=MediaType[media_type])
276-
277-
278262
@pytest.mark.parametrize(
279263
"data_type_class",
280264
[
@@ -307,7 +291,7 @@ def test_import_data_types_v2(
307291
dataset = initial_dataset
308292
project_id = project.uid
309293

310-
set_project_media_type_from_data_type(project, data_type_class)
294+
helpers.set_project_media_type_from_data_type(project, data_type_class)
311295

312296
data_type_string = data_type_class.__name__[:-4].lower()
313297
data_row_ndjson = data_row_json_by_data_type[data_type_string]
@@ -371,10 +355,11 @@ def test_import_label_annotations(
371355
data_class,
372356
annotations,
373357
rand_gen,
358+
helpers,
374359
):
375360
project = configured_project_with_one_data_row
376361
dataset = initial_dataset
377-
set_project_media_type_from_data_type(project, data_class)
362+
helpers.set_project_media_type_from_data_type(project, data_class)
378363

379364
data_row_json = data_row_json_by_data_type[data_type]
380365
data_row = create_data_row_for_project(project, dataset, data_row_json,
@@ -442,10 +427,11 @@ def test_import_mal_annotations(
442427
annotations,
443428
rand_gen,
444429
one_datarow,
430+
helpers,
445431
):
446432
data_row = one_datarow
447-
set_project_media_type_from_data_type(configured_project_with_one_data_row,
448-
data_class)
433+
helpers.set_project_media_type_from_data_type(
434+
configured_project_with_one_data_row, data_class)
449435

450436
configured_project_with_one_data_row.create_batch(
451437
rand_gen(str),
@@ -471,12 +457,13 @@ def test_import_mal_annotations(
471457

472458
def test_import_mal_annotations_global_key(client,
473459
configured_project_with_one_data_row,
474-
rand_gen, one_datarow_global_key):
460+
rand_gen, one_datarow_global_key,
461+
helpers):
475462
data_class = lb_types.VideoData
476463
data_row = one_datarow_global_key
477464
annotations = [video_mask_annotation]
478-
set_project_media_type_from_data_type(configured_project_with_one_data_row,
479-
data_class)
465+
helpers.set_project_media_type_from_data_type(
466+
configured_project_with_one_data_row, data_class)
480467

481468
configured_project_with_one_data_row.create_batch(
482469
rand_gen(str),

tests/data/annotation_import/test_generic_data_types.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,13 @@ def test_import_data_types_by_global_key(
8686
data_row_json_by_data_type,
8787
annotations_by_data_type,
8888
export_v2_test_helpers,
89+
helpers,
8990
):
9091
project = configured_project
9192
project_id = project.uid
9293
dataset = initial_dataset
9394
data_type_class = ImageData
94-
set_project_media_type_from_data_type(project, data_type_class)
95+
helpers.set_project_media_type_from_data_type(project, data_type_class)
9596

9697
data_row_ndjson = data_row_json_by_data_type["image"]
9798
data_row_ndjson["global_key"] = str(uuid.uuid4())
@@ -142,26 +143,6 @@ def validate_iso_format(date_string: str):
142143
assert parsed_t.second is not None
143144

144145

145-
def to_pascal_case(name: str) -> str:
146-
return "".join([word.capitalize() for word in name.split("_")])
147-
148-
149-
def set_project_media_type_from_data_type(project, data_type_class):
150-
data_type_string = data_type_class.__name__[:-4].lower()
151-
media_type = to_pascal_case(data_type_string)
152-
if media_type == "Conversation":
153-
media_type = "Conversational"
154-
elif media_type == "Llmpromptcreation":
155-
media_type = "LLMPromptCreation"
156-
elif media_type == "Llmpromptresponsecreation":
157-
media_type = "LLMPromptResponseCreation"
158-
elif media_type == "Llmresponsecreation":
159-
media_type = "Text"
160-
elif media_type == "Genericdatarow":
161-
media_type = "Image"
162-
project.update(media_type=MediaType[media_type])
163-
164-
165146
@pytest.mark.parametrize(
166147
"data_type_class",
167148
[
@@ -194,7 +175,7 @@ def test_import_data_types_v2(
194175
dataset = initial_dataset
195176
project_id = project.uid
196177

197-
set_project_media_type_from_data_type(project, data_type_class)
178+
helpers.set_project_media_type_from_data_type(project, data_type_class)
198179

199180
data_type_string = data_type_class.__name__[:-4].lower()
200181
data_row_ndjson = data_row_json_by_data_type[data_type_string]
@@ -221,8 +202,9 @@ def test_import_data_types_v2(
221202
# to be similar to tests/integration/test_task_queue.py
222203

223204
result = export_v2_test_helpers.run_project_export_v2_task(project)
224-
find_data_row = lambda dr: dr['data_row']['id'] == data_row.uid
225-
exported_data = list(filter(find_data_row, result))[0]
205+
206+
exported_data = next(
207+
dr for dr in result if dr['data_row']['id'] == data_row.uid)
226208
assert exported_data
227209

228210
# timestamp fields are in iso format
@@ -274,15 +256,15 @@ def one_datarow_global_key(client, rand_gen, data_row_json_by_data_type):
274256
def test_import_mal_annotations(
275257
client,
276258
configured_project_with_one_data_row,
277-
data_type,
278259
data_class,
279260
annotations,
280261
rand_gen,
281262
one_datarow,
263+
helpers,
282264
):
283265
data_row = one_datarow
284-
set_project_media_type_from_data_type(configured_project_with_one_data_row,
285-
data_class)
266+
helpers.set_project_media_type_from_data_type(
267+
configured_project_with_one_data_row, data_class)
286268

287269
configured_project_with_one_data_row.create_batch(
288270
rand_gen(str),
@@ -307,12 +289,13 @@ def test_import_mal_annotations(
307289

308290
def test_import_mal_annotations_global_key(client,
309291
configured_project_with_one_data_row,
310-
rand_gen, one_datarow_global_key):
292+
rand_gen, one_datarow_global_key,
293+
helpers):
311294
data_class = lb_types.VideoData
312295
data_row = one_datarow_global_key
313296
annotations = [video_mask_annotation]
314-
set_project_media_type_from_data_type(configured_project_with_one_data_row,
315-
data_class)
297+
helpers.set_project_media_type_from_data_type(
298+
configured_project_with_one_data_row, data_class)
316299

317300
configured_project_with_one_data_row.create_batch(
318301
rand_gen(str),

0 commit comments

Comments
 (0)