Skip to content

Commit e2b7e3b

Browse files
author
Val Brodsky
committed
Refactor Dataset create_data_row to use upsert
1 parent e6c480a commit e2b7e3b

File tree

1 file changed

+22
-69
lines changed

1 file changed

+22
-69
lines changed

libs/labelbox/src/labelbox/schema/dataset.py

Lines changed: 22 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -141,78 +141,31 @@ def create_data_row(self, items=None, **kwargs) -> "DataRow":
141141
any of the field names given in `kwargs`.
142142
143143
"""
144-
invalid_argument_error = "Argument to create_data_row() must be either a dictionary, or kwargs containing `row_data` at minimum"
145-
146-
def convert_field_keys(items):
147-
if not isinstance(items, dict):
148-
raise InvalidQueryError(invalid_argument_error)
149-
return {
150-
key.name if isinstance(key, Field) else key: value
151-
for key, value in items.items()
152-
}
144+
file_upload_thread_count = 1
145+
args = items if items is not None else kwargs
153146

154-
if items is not None and len(kwargs) > 0:
155-
raise InvalidQueryError(invalid_argument_error)
147+
upload_items = self._separate_and_process_items([args])
148+
specs = DataRowCreateItem.build(self.uid, upload_items)
149+
task: DataUpsertTask = self._exec_upsert_data_rows(
150+
specs, file_upload_thread_count)
156151

157-
DataRow = Entity.DataRow
158-
args = convert_field_keys(items) if items is not None else kwargs
159-
160-
if DataRow.row_data.name not in args:
161-
raise InvalidQueryError(
162-
"DataRow.row_data missing when creating DataRow.")
163-
164-
row_data = args[DataRow.row_data.name]
165-
166-
if isinstance(row_data, str) and row_data.startswith("s3:/"):
167-
raise InvalidQueryError(
168-
"row_data: s3 assets must start with 'https'.")
169-
170-
if not isinstance(row_data, str):
171-
# If the row data is an object, upload as a string
172-
args[DataRow.row_data.name] = json.dumps(row_data)
173-
elif os.path.exists(row_data):
174-
# If row data is a local file path, upload it to server.
175-
args[DataRow.row_data.name] = self.client.upload_file(row_data)
176-
177-
# Parse metadata fields, if they are provided
178-
if DataRow.metadata_fields.name in args:
179-
mdo = self.client.get_data_row_metadata_ontology()
180-
args[DataRow.metadata_fields.name] = mdo.parse_upsert_metadata(
181-
args[DataRow.metadata_fields.name])
182-
183-
if "embeddings" in args:
184-
args["embeddings"] = [
185-
EmbeddingVector(**e).to_gql() for e in args["embeddings"]
186-
]
152+
task.wait_till_done()
187153

188-
query_str = """mutation CreateDataRowPyApi(
189-
$row_data: String!,
190-
$metadata_fields: [DataRowCustomMetadataUpsertInput!],
191-
$attachments: [DataRowAttachmentInput!],
192-
$media_type : MediaType,
193-
$external_id : String,
194-
$global_key : String,
195-
$dataset: ID!,
196-
$embeddings: [DataRowEmbeddingVectorInput!]
197-
){
198-
createDataRow(
199-
data:
200-
{
201-
rowData: $row_data
202-
mediaType: $media_type
203-
metadataFields: $metadata_fields
204-
externalId: $external_id
205-
globalKey: $global_key
206-
attachments: $attachments
207-
dataset: {connect: {id: $dataset}}
208-
embeddings: $embeddings
209-
}
210-
)
211-
{%s}
212-
}
213-
""" % query.results_query_part(Entity.DataRow)
214-
res = self.client.execute(query_str, {**args, 'dataset': self.uid})
215-
return DataRow(self.client, res['createDataRow'])
154+
if task.has_errors():
155+
raise ResourceCreationError(
156+
f"Data row upload errors: {task.errors}", cause=task.uid)
157+
if task.status != "COMPLETE":
158+
raise ResourceCreationError(
159+
f"Data row upload did not complete, task status {task.status} task id {task.uid}"
160+
)
161+
162+
res = task.result
163+
if res is None or len(res) == 0:
164+
raise ResourceCreationError(
165+
f"Data row upload did not complete, task status {task.status} task id {task.uid}"
166+
)
167+
168+
return self.client.get_data_row(res[0]['id'])
216169

217170
def create_data_rows_sync(
218171
self,

0 commit comments

Comments
 (0)