Skip to content

Commit dfdab84

Browse files
author
Val Brodsky
committed
Fixing tests
1 parent f7d70fa commit dfdab84

File tree

8 files changed

+94
-52
lines changed

8 files changed

+94
-52
lines changed

libs/labelbox/src/labelbox/schema/dataset.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,14 @@ def create_data_rows(
287287
288288
NOTE dicts and strings items can not be mixed in the same call. It is a responsibility of the caller to ensure that all items are of the same type.
289289
"""
290-
if isinstance(items[0], str):
291-
items = self._build_from_local_paths(items) # Assume list of file paths
292-
specs = DataRowCreateItem.build(self.uid, items)
290+
string_items = [item for item in items if isinstance(item, str)]
291+
dict_items = [item for item in items if isinstance(item, dict)]
292+
dict_string_items = []
293+
294+
if len(string_items) > 0:
295+
dict_string_items = self._build_from_local_paths(string_items)
296+
specs = DataRowCreateItem.build(self.uid,
297+
dict_items + dict_string_items)
293298
return self._exec_upsert_data_rows(specs, file_upload_thread_count)
294299

295300
def _build_from_local_paths(

libs/labelbox/src/labelbox/schema/internal/data_row_create_upsert.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,12 @@ def build(
2525
key = item.pop('key', None)
2626
if not key:
2727
key = {'type': 'AUTO', 'value': ''}
28-
elif isinstance(key, key_types):
28+
elif isinstance(key, key_types): # type: ignore
2929
key = {'type': key.id_type.value, 'value': key.key}
3030
else:
3131
if not key_types:
3232
raise ValueError(
33-
f"Can not have a key for this item, got: {key}"
34-
)
33+
f"Can not have a key for this item, got: {key}")
3534
raise ValueError(
3635
f"Key must be an instance of {', '.join([t.__name__ for t in key_types])}, got: {type(item['key']).__name__}"
3736
)
@@ -53,14 +52,22 @@ def is_empty(self) -> bool:
5352
class DataRowUpsertItem(DataRowItemBase):
5453

5554
@classmethod
56-
def build(cls, dataset_id: str,
57-
items: List[dict]) -> List["DataRowUpsertItem"]:
55+
def build(
56+
cls,
57+
dataset_id: str,
58+
items: List[dict],
59+
key_types: Optional[Tuple[type, ...]] = ()
60+
) -> List["DataRowItemBase"]:
5861
return super().build(dataset_id, items, (UniqueId, GlobalKey))
5962

6063

6164
class DataRowCreateItem(DataRowItemBase):
6265

6366
@classmethod
64-
def build(cls, dataset_id: str,
65-
items: List[dict]) -> List["DataRowCreateItem"]:
67+
def build(
68+
cls,
69+
dataset_id: str,
70+
items: List[dict],
71+
key_types: Optional[Tuple[type, ...]] = ()
72+
) -> List["DataRowItemBase"]:
6673
return super().build(dataset_id, items, ())
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
MAX_DATAROW_PER_API_OPERATION = 150_000
22
FILE_UPLOAD_THREAD_COUNT = 20
33
UPSERT_CHUNK_SIZE = 10_000
4+
DOWNLOAD_RESULT_PAGE_SIZE = 5_000

libs/labelbox/src/labelbox/schema/task.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from labelbox.orm.model import Field, Relationship, Entity
1111

1212
from labelbox.pagination import PaginatedCollection
13-
from labelbox.schema.internal.datarow_upload_constants import MAX_DATAROW_PER_API_OPERATION
13+
from labelbox.schema.internal.datarow_upload_constants import (
14+
MAX_DATAROW_PER_API_OPERATION,
15+
DOWNLOAD_RESULT_PAGE_SIZE,
16+
)
1417

1518
if TYPE_CHECKING:
1619
from labelbox import User
@@ -52,6 +55,10 @@ class Task(DbObject):
5255
created_by = Relationship.ToOne("User", False, "created_by")
5356
organization = Relationship.ToOne("Organization")
5457

58+
def __eq__(self, task):
59+
return isinstance(
60+
task, Task) and task.uid == self.uid and task.type == self.type
61+
5562
# Import and upsert have several instances of special casing
5663
def is_creation_task(self) -> bool:
5764
return self.name == 'JSON Import' or self.type == 'adv-upsert-data-rows'
@@ -227,21 +234,23 @@ def __init__(self, *args, **kwargs):
227234
self._user = None
228235

229236
@property
230-
def result(self) -> Union[List[Dict[str, Any]]]:
237+
def result(self) -> Optional[List[Dict[str, Any]]]: # type: ignore
231238
if self.status == "FAILED":
232239
raise ValueError(f"Job failed. Errors : {self.errors}")
233240
return self._results_as_list()
234241

235242
@property
236-
def errors(self) -> Optional[Dict[str, Any]]:
243+
def errors(self) -> Optional[List[Dict[str, Any]]]: # type: ignore
237244
return self._errors_as_list()
238245

239246
@property
240-
def created_data_rows(self) -> Optional[Dict[str, Any]]:
247+
def created_data_rows( # type: ignore
248+
self) -> Optional[List[Dict[str, Any]]]:
241249
return self.result
242250

243251
@property
244-
def failed_data_rows(self) -> Optional[Dict[str, Any]]:
252+
def failed_data_rows( # type: ignore
253+
self) -> Optional[List[Dict[str, Any]]]:
245254
return self.errors
246255

247256
@property
@@ -253,7 +262,7 @@ def errors_all(self) -> PaginatedCollection:
253262
return self._download_errors_paginated()
254263

255264
def _download_results_paginated(self) -> PaginatedCollection:
256-
page_size = 900 # hardcode to avoid overloading the server
265+
page_size = DOWNLOAD_RESULT_PAGE_SIZE
257266
from_cursor = None
258267

259268
query_str = """query SuccessesfulDataRowImportsPyApi($taskId: ID!, $first: Int, $from: String) {
@@ -292,7 +301,7 @@ def _download_results_paginated(self) -> PaginatedCollection:
292301
)
293302

294303
def _download_errors_paginated(self) -> PaginatedCollection:
295-
page_size = 5000 # hardcode to avoid overloading the server
304+
page_size = DOWNLOAD_RESULT_PAGE_SIZE # hardcode to avoid overloading the server
296305
from_cursor = None
297306

298307
query_str = """query FailedDataRowImportsPyApi($taskId: ID!, $first: Int, $from: String) {
@@ -306,6 +315,16 @@ def _download_errors_paginated(self) -> PaginatedCollection:
306315
externalId
307316
globalKey
308317
rowData
318+
metadata {
319+
schemaId
320+
value
321+
name
322+
}
323+
attachments {
324+
type
325+
value
326+
name
327+
}
309328
}
310329
}
311330
}
@@ -318,28 +337,30 @@ def _download_errors_paginated(self) -> PaginatedCollection:
318337
'from': from_cursor,
319338
}
320339

340+
def convert_errors_to_legacy_format(client, data_row):
341+
spec = data_row.get('spec', {})
342+
return {
343+
'message':
344+
data_row.get('message'),
345+
'failedDataRows': [{
346+
'externalId': spec.get('externalId'),
347+
'rowData': spec.get('rowData'),
348+
'globalKey': spec.get('globalKey'),
349+
'metadata': spec.get('metadata', []),
350+
'attachments': spec.get('attachments', []),
351+
}]
352+
}
353+
321354
return PaginatedCollection(
322355
client=self.client,
323356
query=query_str,
324357
params=params,
325358
dereferencing=['failedDataRowImports', 'results'],
326-
obj_class=lambda _, data_row: {
327-
'error':
328-
data_row.get('message'),
329-
'external_id':
330-
data_row.get('spec').get('externalId')
331-
if data_row.get('spec') else None,
332-
'row_data':
333-
data_row.get('spec').get('rowData')
334-
if data_row.get('spec') else None,
335-
'global_key':
336-
data_row.get('spec').get('globalKey')
337-
if data_row.get('spec') else None,
338-
},
359+
obj_class=convert_errors_to_legacy_format,
339360
cursor_path=['failedDataRowImports', 'after'],
340361
)
341362

342-
def _results_as_list(self) -> List[Dict[str, Any]]:
363+
def _results_as_list(self) -> Optional[List[Dict[str, Any]]]:
343364
total_downloaded = 0
344365
results = []
345366
data = self._download_results_paginated()
@@ -350,9 +371,12 @@ def _results_as_list(self) -> List[Dict[str, Any]]:
350371
if total_downloaded >= self.__max_donwload_size:
351372
break
352373

374+
if len(results) == 0:
375+
return None
376+
353377
return results
354378

355-
def _errors_as_list(self) -> List[Dict[str, Any]]:
379+
def _errors_as_list(self) -> Optional[List[Dict[str, Any]]]:
356380
total_downloaded = 0
357381
errors = []
358382
data = self._download_errors_paginated()
@@ -363,4 +387,7 @@ def _errors_as_list(self) -> List[Dict[str, Any]]:
363387
if total_downloaded >= self.__max_donwload_size:
364388
break
365389

390+
if len(errors) == 0:
391+
return None
392+
366393
return errors

libs/labelbox/tests/integration/test_data_rows.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def test_data_row_bulk_creation_from_file(dataset, local_image_file, image_url):
238238
assert task.has_errors() is False
239239
results = [r for r in task.result_all]
240240
row_data = [result["row_data"] for result in results]
241-
assert row_data == [image_url, image_url]
241+
assert len(row_data) == 2
242242

243243

244244
def test_data_row_bulk_creation_from_row_data_file_external_id(
@@ -252,12 +252,14 @@ def test_data_row_bulk_creation_from_row_data_file_external_id(
252252
"row_data": image_url,
253253
'external_id': 'some_name2'
254254
}])
255+
task.wait_till_done()
255256
assert task.status == "COMPLETE"
256257
assert len(task.result) == 2
257258
assert task.has_errors() is False
258259
results = [r for r in task.result_all]
259260
row_data = [result["row_data"] for result in results]
260-
assert row_data == [image_url, image_url]
261+
assert len(row_data) == 2
262+
assert image_url in row_data
261263

262264

263265
def test_data_row_bulk_creation_from_row_data_file(dataset, rand_gen,
@@ -275,7 +277,7 @@ def test_data_row_bulk_creation_from_row_data_file(dataset, rand_gen,
275277
assert task.has_errors() is False
276278
results = [r for r in task.result_all]
277279
row_data = [result["row_data"] for result in results]
278-
assert row_data == [image_url, image_url]
280+
assert len(row_data) == 2
279281

280282

281283
@pytest.mark.slow
@@ -899,6 +901,7 @@ def test_create_data_rows_result(client, dataset, image_url):
899901
DataRow.external_id: "row1",
900902
},
901903
])
904+
task.wait_till_done()
902905
assert task.errors is None
903906
for result in task.result:
904907
client.get_data_row(result['id'])
@@ -973,8 +976,16 @@ def test_data_row_bulk_creation_with_same_global_keys(dataset, sample_image,
973976
'message'] == f"Duplicate global key: '{global_key_1}'"
974977
assert task.failed_data_rows[0]['failedDataRows'][0][
975978
'externalId'] == sample_image
976-
assert task.created_data_rows[0]['externalId'] == sample_image
977-
assert task.created_data_rows[0]['globalKey'] == global_key_1
979+
assert task.created_data_rows[0]['external_id'] == sample_image
980+
assert task.created_data_rows[0]['global_key'] == global_key_1
981+
982+
errors = task.errors_all
983+
all_errors = [er for er in errors]
984+
assert len(all_errors) == 1
985+
assert task.has_errors() is True
986+
987+
all_results = [result for result in task.result_all]
988+
assert len(all_results) == 1
978989

979990

980991
def test_data_row_delete_and_create_with_same_global_key(

libs/labelbox/tests/integration/test_data_rows_upsert.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,8 @@ def test_multiple_chunks(self, client, dataset, image_url):
208208
mocked_chunk_size = 3
209209
with patch('labelbox.client.Client.upload_data',
210210
wraps=client.upload_data) as spy_some_function:
211-
with patch(
212-
'labelbox.schema.dataset.Dataset._Dataset__upsert_chunk_size',
213-
new=mocked_chunk_size):
211+
with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE',
212+
new=mocked_chunk_size):
214213
task = dataset.upsert_data_rows([{
215214
'row_data': image_url
216215
} for i in range(10)])

libs/labelbox/tests/integration/test_dataset.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from labelbox import Dataset
44
from labelbox.exceptions import ResourceNotFoundError, MalformedQueryException, InvalidQueryError
55
from labelbox.schema.dataset import MAX_DATAROW_PER_API_OPERATION
6-
from labelbox.schema.internal.datarow_uploader import DataRowUploader
6+
from labelbox.schema.internal.data_row_uploader import DataRowUploader
77

88

99
def test_dataset(client, rand_gen):
@@ -166,12 +166,3 @@ def test_create_descriptor_file(dataset):
166166
'content_type': 'application/json',
167167
'filename': 'json_import.json'
168168
}
169-
170-
171-
def test_max_dataset_datarow_upload(dataset, image_url, rand_gen):
172-
external_id = str(rand_gen)
173-
items = [dict(row_data=image_url, external_id=external_id)
174-
] * (MAX_DATAROW_PER_API_OPERATION + 1)
175-
176-
with pytest.raises(MalformedQueryException):
177-
dataset.create_data_rows(items)

libs/labelbox/tests/integration/test_task.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,12 @@ def test_task_success_json(dataset, image_url, snapshot):
6161
@pytest.mark.export_v1("export_v1 test remove later")
6262
def test_task_success_label_export(client, configured_project_with_label):
6363
project, _, _, _ = configured_project_with_label
64-
project.export_labels()
64+
# TODO: Move to export_v2
65+
res = project.export_labels()
6566
user = client.get_user()
6667
task = None
6768
for task in user.created_tasks():
68-
if task.name != 'JSON Import':
69+
if task.name != 'JSON Import' and task.type != 'adv-upsert-data-rows':
6970
break
7071

7172
with pytest.raises(ValueError) as exc_info:

0 commit comments

Comments
 (0)