Skip to content

Commit e6c480a

Browse files
author
Val Brodsky
committed
Refactor Dataset create_data_rows_sync to use upsert
1 parent a97300e commit e6c480a

File tree

2 files changed

+37
-23
lines changed

2 files changed

+37
-23
lines changed

libs/labelbox/src/labelbox/schema/dataset.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from io import StringIO
1616
import requests
1717

18-
from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, InvalidAttributeError
18+
from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, ResourceCreationError
1919
from labelbox.orm.comparison import Comparison
2020
from labelbox.orm.db_object import DbObject, Updateable, Deletable, experimental
2121
from labelbox.orm.model import Entity, Field, Relationship
@@ -214,7 +214,10 @@ def convert_field_keys(items):
214214
res = self.client.execute(query_str, {**args, 'dataset': self.uid})
215215
return DataRow(self.client, res['createDataRow'])
216216

217-
def create_data_rows_sync(self, items) -> None:
217+
def create_data_rows_sync(
218+
self,
219+
items,
220+
file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT) -> None:
218221
""" Synchronously bulk upload data rows.
219222
220223
Use this instead of `Dataset.create_data_rows` for smaller batches of data rows that need to be uploaded quickly.
@@ -228,6 +231,7 @@ def create_data_rows_sync(self, items) -> None:
228231
None. If the function doesn't raise an exception then the import was successful.
229232
230233
Raises:
234+
ResourceCreationError: Errors in data row upload
231235
InvalidQueryError: If the `items` parameter does not conform to
232236
the specification in Dataset._create_descriptor_file or if the server did not accept the
233237
DataRow creation request (unknown reason).
@@ -242,18 +246,25 @@ def create_data_rows_sync(self, items) -> None:
242246
f"Dataset.create_data_rows_sync() supports a max of {max_data_rows_supported} data rows."
243247
" For larger imports use the async function Dataset.create_data_rows()"
244248
)
245-
descriptor_url = DescriptorFileCreator(self.client).create_one(
246-
items, max_attachments_per_data_row=max_attachments_per_data_row)
247-
dataset_param = "datasetId"
248-
url_param = "jsonUrl"
249-
query_str = """mutation AppendRowsToDatasetSyncPyApi($%s: ID!, $%s: String!){
250-
appendRowsToDatasetSync(data:{datasetId: $%s, jsonFileUrl: $%s}
251-
){dataset{id}}} """ % (dataset_param, url_param, dataset_param,
252-
url_param)
253-
self.client.execute(query_str, {
254-
dataset_param: self.uid,
255-
url_param: descriptor_url
256-
})
249+
if file_upload_thread_count < 1:
250+
raise ValueError(
251+
"file_upload_thread_count must be a positive integer")
252+
253+
upload_items = self._separate_and_process_items(items)
254+
specs = DataRowCreateItem.build(self.uid, upload_items)
255+
task: DataUpsertTask = self._exec_upsert_data_rows(
256+
specs, file_upload_thread_count)
257+
task.wait_till_done()
258+
259+
if task.has_errors():
260+
raise ResourceCreationError(
261+
f"Data row upload errors: {task.errors}", cause=task.uid)
262+
if task.status != "COMPLETE":
263+
raise ResourceCreationError(
264+
f"Data row upload did not complete, task status {task.status} task id {task.uid}"
265+
)
266+
267+
return None
257268

258269
def create_data_rows(self,
259270
items,
@@ -287,14 +298,18 @@ def create_data_rows(self,
287298
raise ValueError(
288299
"file_upload_thread_count must be a positive integer")
289300

301+
# Usage example
302+
upload_items = self._separate_and_process_items(items)
303+
specs = DataRowCreateItem.build(self.uid, upload_items)
304+
return self._exec_upsert_data_rows(specs, file_upload_thread_count)
305+
306+
def _separate_and_process_items(self, items):
290307
string_items = [item for item in items if isinstance(item, str)]
291308
dict_items = [item for item in items if isinstance(item, dict)]
292309
dict_string_items = []
293310
if len(string_items) > 0:
294311
dict_string_items = self._build_from_local_paths(string_items)
295-
specs = DataRowCreateItem.build(self.uid,
296-
dict_items + dict_string_items)
297-
return self._exec_upsert_data_rows(specs, file_upload_thread_count)
312+
return dict_items + dict_string_items
298313

299314
def _build_from_local_paths(
300315
self,

libs/labelbox/tests/integration/test_data_rows.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010

1111
from labelbox.schema.media_type import MediaType
1212
from labelbox import DataRow, AssetAttachment
13-
from labelbox.exceptions import MalformedQueryException
14-
from labelbox.schema.task import Task
13+
from labelbox.exceptions import MalformedQueryException, ResourceCreationError
14+
from labelbox.schema.task import Task, DataUpsertTask
1515
from labelbox.schema.data_row_metadata import DataRowMetadataField, DataRowMetadataKind
16-
import labelbox.exceptions
1716

1817
SPLIT_SCHEMA_ID = "cko8sbczn0002h2dkdaxb5kal"
1918
TEST_SPLIT_ID = "cko8scbz70005h2dkastwhgqt"
@@ -1050,7 +1049,7 @@ def test_data_row_bulk_creation_sync_with_same_global_keys(
10501049
dataset, sample_image):
10511050
global_key_1 = str(uuid.uuid4())
10521051

1053-
with pytest.raises(labelbox.exceptions.MalformedQueryException) as exc_info:
1052+
with pytest.raises(ResourceCreationError) as exc_info:
10541053
dataset.create_data_rows_sync([{
10551054
DataRow.row_data: sample_image,
10561055
DataRow.global_key: global_key_1
@@ -1061,8 +1060,8 @@ def test_data_row_bulk_creation_sync_with_same_global_keys(
10611060

10621061
assert len(list(dataset.data_rows())) == 1
10631062
assert list(dataset.data_rows())[0].global_key == global_key_1
1064-
assert "Some data rows were not imported. Check error output here" in str(
1065-
exc_info.value)
1063+
assert "Duplicate global key" in str(exc_info.value)
1064+
assert exc_info.value.args[1] # task id
10661065

10671066

10681067
@pytest.fixture

0 commit comments

Comments
 (0)