Skip to content

Refactor Dataset create_data_rows_sync to use upsert #1694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 1, 2024
143 changes: 56 additions & 87 deletions libs/labelbox/src/labelbox/schema/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from io import StringIO
import requests

from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, InvalidAttributeError
from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, ResourceCreationError
from labelbox.orm.comparison import Comparison
from labelbox.orm.db_object import DbObject, Updateable, Deletable, experimental
from labelbox.orm.model import Entity, Field, Relationship
Expand Down Expand Up @@ -124,7 +124,6 @@ def data_rows(

def create_data_row(self, items=None, **kwargs) -> "DataRow":
""" Creates a single DataRow belonging to this dataset.

>>> dataset.create_data_row(row_data="http://my_site.com/photos/img_01.jpg")

Args:
Expand All @@ -139,82 +138,31 @@ def create_data_row(self, items=None, **kwargs) -> "DataRow":
in `kwargs`.
InvalidAttributeError: in case the DB object type does not contain
any of the field names given in `kwargs`.

ResourceCreationError: If data row creation failed on the server side.
"""
invalid_argument_error = "Argument to create_data_row() must be either a dictionary, or kwargs containing `row_data` at minimum"

def convert_field_keys(items):
if not isinstance(items, dict):
raise InvalidQueryError(invalid_argument_error)
return {
key.name if isinstance(key, Field) else key: value
for key, value in items.items()
}

if items is not None and len(kwargs) > 0:
raise InvalidQueryError(invalid_argument_error)

DataRow = Entity.DataRow
args = convert_field_keys(items) if items is not None else kwargs

if DataRow.row_data.name not in args:
raise InvalidQueryError(
"DataRow.row_data missing when creating DataRow.")

row_data = args[DataRow.row_data.name]

if isinstance(row_data, str) and row_data.startswith("s3:/"):
raise InvalidQueryError(
"row_data: s3 assets must start with 'https'.")

if not isinstance(row_data, str):
# If the row data is an object, upload as a string
args[DataRow.row_data.name] = json.dumps(row_data)
elif os.path.exists(row_data):
# If row data is a local file path, upload it to server.
args[DataRow.row_data.name] = self.client.upload_file(row_data)

# Parse metadata fields, if they are provided
if DataRow.metadata_fields.name in args:
mdo = self.client.get_data_row_metadata_ontology()
args[DataRow.metadata_fields.name] = mdo.parse_upsert_metadata(
args[DataRow.metadata_fields.name])

if "embeddings" in args:
args["embeddings"] = [
EmbeddingVector(**e).to_gql() for e in args["embeddings"]
]
args = items if items is not None else kwargs

query_str = """mutation CreateDataRowPyApi(
$row_data: String!,
$metadata_fields: [DataRowCustomMetadataUpsertInput!],
$attachments: [DataRowAttachmentInput!],
$media_type : MediaType,
$external_id : String,
$global_key : String,
$dataset: ID!,
$embeddings: [DataRowEmbeddingVectorInput!]
){
createDataRow(
data:
{
rowData: $row_data
mediaType: $media_type
metadataFields: $metadata_fields
externalId: $external_id
globalKey: $global_key
attachments: $attachments
dataset: {connect: {id: $dataset}}
embeddings: $embeddings
}
)
{%s}
}
""" % query.results_query_part(Entity.DataRow)
res = self.client.execute(query_str, {**args, 'dataset': self.uid})
return DataRow(self.client, res['createDataRow'])
file_upload_thread_count = 1
completed_task = self._create_data_rows_sync(
[args], file_upload_thread_count=file_upload_thread_count)

def create_data_rows_sync(self, items) -> None:
res = completed_task.result
if res is None or len(res) == 0:
raise ResourceCreationError(
f"Data row upload did not complete, task status {completed_task.status} task id {completed_task.uid}"
)

return self.client.get_data_row(res[0]['id'])

def create_data_rows_sync(
self,
items,
file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT) -> None:
""" Synchronously bulk upload data rows.

Use this instead of `Dataset.create_data_rows` for smaller batches of data rows that need to be uploaded quickly.
Expand All @@ -228,32 +176,49 @@ def create_data_rows_sync(self, items) -> None:
None. If the function doesn't raise an exception then the import was successful.

Raises:
InvalidQueryError: If the `items` parameter does not conform to
ResourceCreationError: If the `items` parameter does not conform to
the specification in Dataset._create_descriptor_file or if the server did not accept the
DataRow creation request (unknown reason).
InvalidAttributeError: If there are fields in `items` not valid for
a DataRow.
ValueError: When the upload parameters are invalid
"""
warnings.warn(
"This method is deprecated and will be "
"removed in a future release. Please use create_data_rows instead.")

self._create_data_rows_sync(
items, file_upload_thread_count=file_upload_thread_count)

return None # Return None if no exception is raised

def _create_data_rows_sync(self,
items,
file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT
) -> "DataUpsertTask":
max_data_rows_supported = 1000
max_attachments_per_data_row = 5
if len(items) > max_data_rows_supported:
raise ValueError(
f"Dataset.create_data_rows_sync() supports a max of {max_data_rows_supported} data rows."
" For larger imports use the async function Dataset.create_data_rows()"
)
descriptor_url = DescriptorFileCreator(self.client).create_one(
items, max_attachments_per_data_row=max_attachments_per_data_row)
dataset_param = "datasetId"
url_param = "jsonUrl"
query_str = """mutation AppendRowsToDatasetSyncPyApi($%s: ID!, $%s: String!){
appendRowsToDatasetSync(data:{datasetId: $%s, jsonFileUrl: $%s}
){dataset{id}}} """ % (dataset_param, url_param, dataset_param,
url_param)
self.client.execute(query_str, {
dataset_param: self.uid,
url_param: descriptor_url
})
if file_upload_thread_count < 1:
raise ValueError(
"file_upload_thread_count must be a positive integer")

task: DataUpsertTask = self.create_data_rows(items,
file_upload_thread_count)
task.wait_till_done()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can this call create_data_rows to obtain a DataUpsertTask to avoid code duplication.


if task.has_errors():
raise ResourceCreationError(
f"Data row upload errors: {task.errors}", cause=task.uid)
if task.status != "COMPLETE":
raise ResourceCreationError(
f"Data row upload did not complete, task status {task.status} task id {task.uid}"
)

return task

def create_data_rows(self,
items,
Expand Down Expand Up @@ -287,14 +252,18 @@ def create_data_rows(self,
raise ValueError(
"file_upload_thread_count must be a positive integer")

# Usage example
upload_items = self._separate_and_process_items(items)
specs = DataRowCreateItem.build(self.uid, upload_items)
return self._exec_upsert_data_rows(specs, file_upload_thread_count)

def _separate_and_process_items(self, items):
string_items = [item for item in items if isinstance(item, str)]
dict_items = [item for item in items if isinstance(item, dict)]
dict_string_items = []
if len(string_items) > 0:
dict_string_items = self._build_from_local_paths(string_items)
specs = DataRowCreateItem.build(self.uid,
dict_items + dict_string_items)
return self._exec_upsert_data_rows(specs, file_upload_thread_count)
return dict_items + dict_string_items

def _build_from_local_paths(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,10 @@ class DescriptorFileCreator:

def __init__(self, client: "Client"):
self.client = client
""""
This method is used to convert a list to json and upload it in a file to gcs.
It will create multiple files if the size of upload is greater than max_chunk_size_bytes in bytes,
It uploads the files to gcs in parallel, and return a list of urls

Args:
items: The list to upload
is_upsert (bool): Whether the upload is an upsert
max_attachments_per_data_row (int): The maximum number of attachments per data row
max_chunk_size_bytes (int): The maximum size of the file in bytes
"""

def create(self,
items,
max_attachments_per_data_row=None,
max_chunk_size_bytes=None) -> List[str]:
def create(self, items, max_chunk_size_bytes=None) -> List[str]:
is_upsert = True # This class will only support upsert use cases
items = self._prepare_items_for_upload(items,
max_attachments_per_data_row,
is_upsert=is_upsert)
items = self._prepare_items_for_upload(items, is_upsert=is_upsert)
json_chunks = self._chunk_down_by_bytes(items, max_chunk_size_bytes)
with ThreadPoolExecutor(FILE_UPLOAD_THREAD_COUNT) as executor:
futures = [
Expand All @@ -62,19 +46,15 @@ def create(self,
]
return [future.result() for future in as_completed(futures)]

def create_one(self, items, max_attachments_per_data_row=None) -> List[str]:
items = self._prepare_items_for_upload(items,
max_attachments_per_data_row)
def create_one(self, items) -> List[str]:
items = self._prepare_items_for_upload(items,)
# Prepare and upload the descriptor file
data = json.dumps(items)
return self.client.upload_data(data,
content_type="application/json",
filename="json_import.json")

def _prepare_items_for_upload(self,
items,
max_attachments_per_data_row=None,
is_upsert=False):
def _prepare_items_for_upload(self, items, is_upsert=False):
"""
This function is used to prepare the input file. The user defined input is validated, processed, and json stringified.
Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed as a parameter to a mutation that uploads data rows
Expand Down Expand Up @@ -102,8 +82,6 @@ def _prepare_items_for_upload(self,

Args:
items (iterable of (dict or str)): See above for details.
max_attachments_per_data_row (Optional[int]): Param used during attachment validation to determine
if the user has provided too many attachments.

Returns:
uri (string): A reference to the uploaded json data.
Expand Down Expand Up @@ -137,12 +115,6 @@ def validate_attachments(item):
attachments = item.get('attachments')
if attachments:
if isinstance(attachments, list):
if max_attachments_per_data_row and len(
attachments) > max_attachments_per_data_row:
raise ValueError(
f"Max attachments number of supported attachments per data row is {max_attachments_per_data_row}."
f" Found {len(attachments)}. Condense multiple attachments into one with the HTML attachment type if necessary."
)
for attachment in attachments:
AssetAttachment.validate_attachment_json(attachment)
else:
Expand Down
13 changes: 8 additions & 5 deletions libs/labelbox/tests/data/annotation_import/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,12 +622,15 @@ def configured_project(client, initial_dataset, ontology, rand_gen, image_url):
data_row_ids = []

ontologies = ontology["tools"] + ontology["classifications"]
data_row_data = []
for ind in range(len(ontologies)):
data_row_ids.append(
dataset.create_data_row(
row_data=image_url,
global_key=f"gk_{ontologies[ind]['name']}_{rand_gen(str)}",
).uid)
data_row_data.append({
"row_data": image_url,
"global_key": f"gk_{ontologies[ind]['name']}_{rand_gen(str)}"
})
task = dataset.create_data_rows(data_row_data)
task.wait_till_done()
data_row_ids = [row['id'] for row in task.result]
project._wait_until_data_rows_are_processed(data_row_ids=data_row_ids,
sleep_interval=3)

Expand Down
Loading
Loading