Skip to content

Commit 39a8b0b

Browse files
authored
Refactor Dataset create_data_rows_sync to use upsert (#1694)
1 parent d142f7b commit 39a8b0b

File tree

7 files changed

+98
-155
lines changed

7 files changed

+98
-155
lines changed

libs/labelbox/src/labelbox/schema/dataset.py

Lines changed: 56 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from io import StringIO
1616
import requests
1717

18-
from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, InvalidAttributeError
18+
from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, ResourceCreationError
1919
from labelbox.orm.comparison import Comparison
2020
from labelbox.orm.db_object import DbObject, Updateable, Deletable, experimental
2121
from labelbox.orm.model import Entity, Field, Relationship
@@ -124,7 +124,6 @@ def data_rows(
124124

125125
def create_data_row(self, items=None, **kwargs) -> "DataRow":
126126
""" Creates a single DataRow belonging to this dataset.
127-
128127
>>> dataset.create_data_row(row_data="http://my_site.com/photos/img_01.jpg")
129128
130129
Args:
@@ -139,82 +138,31 @@ def create_data_row(self, items=None, **kwargs) -> "DataRow":
139138
in `kwargs`.
140139
InvalidAttributeError: in case the DB object type does not contain
141140
any of the field names given in `kwargs`.
142-
141+
ResourceCreationError: If data row creation failed on the server side.
143142
"""
144143
invalid_argument_error = "Argument to create_data_row() must be either a dictionary, or kwargs containing `row_data` at minimum"
145144

146-
def convert_field_keys(items):
147-
if not isinstance(items, dict):
148-
raise InvalidQueryError(invalid_argument_error)
149-
return {
150-
key.name if isinstance(key, Field) else key: value
151-
for key, value in items.items()
152-
}
153-
154145
if items is not None and len(kwargs) > 0:
155146
raise InvalidQueryError(invalid_argument_error)
156147

157-
DataRow = Entity.DataRow
158-
args = convert_field_keys(items) if items is not None else kwargs
159-
160-
if DataRow.row_data.name not in args:
161-
raise InvalidQueryError(
162-
"DataRow.row_data missing when creating DataRow.")
163-
164-
row_data = args[DataRow.row_data.name]
165-
166-
if isinstance(row_data, str) and row_data.startswith("s3:/"):
167-
raise InvalidQueryError(
168-
"row_data: s3 assets must start with 'https'.")
169-
170-
if not isinstance(row_data, str):
171-
# If the row data is an object, upload as a string
172-
args[DataRow.row_data.name] = json.dumps(row_data)
173-
elif os.path.exists(row_data):
174-
# If row data is a local file path, upload it to server.
175-
args[DataRow.row_data.name] = self.client.upload_file(row_data)
176-
177-
# Parse metadata fields, if they are provided
178-
if DataRow.metadata_fields.name in args:
179-
mdo = self.client.get_data_row_metadata_ontology()
180-
args[DataRow.metadata_fields.name] = mdo.parse_upsert_metadata(
181-
args[DataRow.metadata_fields.name])
182-
183-
if "embeddings" in args:
184-
args["embeddings"] = [
185-
EmbeddingVector(**e).to_gql() for e in args["embeddings"]
186-
]
148+
args = items if items is not None else kwargs
187149

188-
query_str = """mutation CreateDataRowPyApi(
189-
$row_data: String!,
190-
$metadata_fields: [DataRowCustomMetadataUpsertInput!],
191-
$attachments: [DataRowAttachmentInput!],
192-
$media_type : MediaType,
193-
$external_id : String,
194-
$global_key : String,
195-
$dataset: ID!,
196-
$embeddings: [DataRowEmbeddingVectorInput!]
197-
){
198-
createDataRow(
199-
data:
200-
{
201-
rowData: $row_data
202-
mediaType: $media_type
203-
metadataFields: $metadata_fields
204-
externalId: $external_id
205-
globalKey: $global_key
206-
attachments: $attachments
207-
dataset: {connect: {id: $dataset}}
208-
embeddings: $embeddings
209-
}
210-
)
211-
{%s}
212-
}
213-
""" % query.results_query_part(Entity.DataRow)
214-
res = self.client.execute(query_str, {**args, 'dataset': self.uid})
215-
return DataRow(self.client, res['createDataRow'])
150+
file_upload_thread_count = 1
151+
completed_task = self._create_data_rows_sync(
152+
[args], file_upload_thread_count=file_upload_thread_count)
216153

217-
def create_data_rows_sync(self, items) -> None:
154+
res = completed_task.result
155+
if res is None or len(res) == 0:
156+
raise ResourceCreationError(
157+
f"Data row upload did not complete, task status {completed_task.status} task id {completed_task.uid}"
158+
)
159+
160+
return self.client.get_data_row(res[0]['id'])
161+
162+
def create_data_rows_sync(
163+
self,
164+
items,
165+
file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT) -> None:
218166
""" Synchronously bulk upload data rows.
219167
220168
Use this instead of `Dataset.create_data_rows` for smaller batches of data rows that need to be uploaded quickly.
@@ -228,32 +176,49 @@ def create_data_rows_sync(self, items) -> None:
228176
None. If the function doesn't raise an exception then the import was successful.
229177
230178
Raises:
231-
InvalidQueryError: If the `items` parameter does not conform to
179+
ResourceCreationError: If the `items` parameter does not conform to
232180
the specification in Dataset._create_descriptor_file or if the server did not accept the
233181
DataRow creation request (unknown reason).
234182
InvalidAttributeError: If there are fields in `items` not valid for
235183
a DataRow.
236184
ValueError: When the upload parameters are invalid
237185
"""
186+
warnings.warn(
187+
"This method is deprecated and will be "
188+
"removed in a future release. Please use create_data_rows instead.")
189+
190+
self._create_data_rows_sync(
191+
items, file_upload_thread_count=file_upload_thread_count)
192+
193+
return None # Return None if no exception is raised
194+
195+
def _create_data_rows_sync(self,
196+
items,
197+
file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT
198+
) -> "DataUpsertTask":
238199
max_data_rows_supported = 1000
239-
max_attachments_per_data_row = 5
240200
if len(items) > max_data_rows_supported:
241201
raise ValueError(
242202
f"Dataset.create_data_rows_sync() supports a max of {max_data_rows_supported} data rows."
243203
" For larger imports use the async function Dataset.create_data_rows()"
244204
)
245-
descriptor_url = DescriptorFileCreator(self.client).create_one(
246-
items, max_attachments_per_data_row=max_attachments_per_data_row)
247-
dataset_param = "datasetId"
248-
url_param = "jsonUrl"
249-
query_str = """mutation AppendRowsToDatasetSyncPyApi($%s: ID!, $%s: String!){
250-
appendRowsToDatasetSync(data:{datasetId: $%s, jsonFileUrl: $%s}
251-
){dataset{id}}} """ % (dataset_param, url_param, dataset_param,
252-
url_param)
253-
self.client.execute(query_str, {
254-
dataset_param: self.uid,
255-
url_param: descriptor_url
256-
})
205+
if file_upload_thread_count < 1:
206+
raise ValueError(
207+
"file_upload_thread_count must be a positive integer")
208+
209+
task: DataUpsertTask = self.create_data_rows(items,
210+
file_upload_thread_count)
211+
task.wait_till_done()
212+
213+
if task.has_errors():
214+
raise ResourceCreationError(
215+
f"Data row upload errors: {task.errors}", cause=task.uid)
216+
if task.status != "COMPLETE":
217+
raise ResourceCreationError(
218+
f"Data row upload did not complete, task status {task.status} task id {task.uid}"
219+
)
220+
221+
return task
257222

258223
def create_data_rows(self,
259224
items,
@@ -287,14 +252,18 @@ def create_data_rows(self,
287252
raise ValueError(
288253
"file_upload_thread_count must be a positive integer")
289254

255+
# Usage example
256+
upload_items = self._separate_and_process_items(items)
257+
specs = DataRowCreateItem.build(self.uid, upload_items)
258+
return self._exec_upsert_data_rows(specs, file_upload_thread_count)
259+
260+
def _separate_and_process_items(self, items):
290261
string_items = [item for item in items if isinstance(item, str)]
291262
dict_items = [item for item in items if isinstance(item, dict)]
292263
dict_string_items = []
293264
if len(string_items) > 0:
294265
dict_string_items = self._build_from_local_paths(string_items)
295-
specs = DataRowCreateItem.build(self.uid,
296-
dict_items + dict_string_items)
297-
return self._exec_upsert_data_rows(specs, file_upload_thread_count)
266+
return dict_items + dict_string_items
298267

299268
def _build_from_local_paths(
300269
self,

libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,10 @@ class DescriptorFileCreator:
3333

3434
def __init__(self, client: "Client"):
3535
self.client = client
36-
""""
37-
This method is used to convert a list to json and upload it in a file to gcs.
38-
It will create multiple files if the size of upload is greater than max_chunk_size_bytes in bytes,
39-
It uploads the files to gcs in parallel, and return a list of urls
4036

41-
Args:
42-
items: The list to upload
43-
is_upsert (bool): Whether the upload is an upsert
44-
max_attachments_per_data_row (int): The maximum number of attachments per data row
45-
max_chunk_size_bytes (int): The maximum size of the file in bytes
46-
"""
47-
48-
def create(self,
49-
items,
50-
max_attachments_per_data_row=None,
51-
max_chunk_size_bytes=None) -> List[str]:
37+
def create(self, items, max_chunk_size_bytes=None) -> List[str]:
5238
is_upsert = True # This class will only support upsert use cases
53-
items = self._prepare_items_for_upload(items,
54-
max_attachments_per_data_row,
55-
is_upsert=is_upsert)
39+
items = self._prepare_items_for_upload(items, is_upsert=is_upsert)
5640
json_chunks = self._chunk_down_by_bytes(items, max_chunk_size_bytes)
5741
with ThreadPoolExecutor(FILE_UPLOAD_THREAD_COUNT) as executor:
5842
futures = [
@@ -62,19 +46,15 @@ def create(self,
6246
]
6347
return [future.result() for future in as_completed(futures)]
6448

65-
def create_one(self, items, max_attachments_per_data_row=None) -> List[str]:
66-
items = self._prepare_items_for_upload(items,
67-
max_attachments_per_data_row)
49+
def create_one(self, items) -> List[str]:
50+
items = self._prepare_items_for_upload(items,)
6851
# Prepare and upload the descriptor file
6952
data = json.dumps(items)
7053
return self.client.upload_data(data,
7154
content_type="application/json",
7255
filename="json_import.json")
7356

74-
def _prepare_items_for_upload(self,
75-
items,
76-
max_attachments_per_data_row=None,
77-
is_upsert=False):
57+
def _prepare_items_for_upload(self, items, is_upsert=False):
7858
"""
7959
This function is used to prepare the input file. The user defined input is validated, processed, and json stringified.
8060
Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed as a parameter to a mutation that uploads data rows
@@ -102,8 +82,6 @@ def _prepare_items_for_upload(self,
10282
10383
Args:
10484
items (iterable of (dict or str)): See above for details.
105-
max_attachments_per_data_row (Optional[int]): Param used during attachment validation to determine
106-
if the user has provided too many attachments.
10785
10886
Returns:
10987
uri (string): A reference to the uploaded json data.
@@ -137,12 +115,6 @@ def validate_attachments(item):
137115
attachments = item.get('attachments')
138116
if attachments:
139117
if isinstance(attachments, list):
140-
if max_attachments_per_data_row and len(
141-
attachments) > max_attachments_per_data_row:
142-
raise ValueError(
143-
f"Max attachments number of supported attachments per data row is {max_attachments_per_data_row}."
144-
f" Found {len(attachments)}. Condense multiple attachments into one with the HTML attachment type if necessary."
145-
)
146118
for attachment in attachments:
147119
AssetAttachment.validate_attachment_json(attachment)
148120
else:

libs/labelbox/tests/data/annotation_import/conftest.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -622,12 +622,15 @@ def configured_project(client, initial_dataset, ontology, rand_gen, image_url):
622622
data_row_ids = []
623623

624624
ontologies = ontology["tools"] + ontology["classifications"]
625+
data_row_data = []
625626
for ind in range(len(ontologies)):
626-
data_row_ids.append(
627-
dataset.create_data_row(
628-
row_data=image_url,
629-
global_key=f"gk_{ontologies[ind]['name']}_{rand_gen(str)}",
630-
).uid)
627+
data_row_data.append({
628+
"row_data": image_url,
629+
"global_key": f"gk_{ontologies[ind]['name']}_{rand_gen(str)}"
630+
})
631+
task = dataset.create_data_rows(data_row_data)
632+
task.wait_till_done()
633+
data_row_ids = [row['id'] for row in task.result]
631634
project._wait_until_data_rows_are_processed(data_row_ids=data_row_ids,
632635
sleep_interval=3)
633636

0 commit comments

Comments
 (0)