Skip to content

Commit 31b90eb

Browse files
author
Matt Sokoloff
committed
attachments working
1 parent ebc9bf5 commit 31b90eb

File tree

2 files changed

+59
-15
lines changed

2 files changed

+59
-15
lines changed

labelbox/schema/dataset.py

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from labelbox import utils
12
import os
23
import json
34
import logging
@@ -74,6 +75,13 @@ def create_data_row(self, **kwargs):
7475
return self.client._create(DataRow, kwargs)
7576

7677
def create_data_rows(self, items):
78+
79+
## NOTE TODOS
80+
"""
81+
Add attachments (works with all types)
82+
Add external ids to bulk imports
83+
improved error handling (why job was accepted or not)
84+
"""
7785
""" Creates multiple DataRow objects based on the given `items`.
7886
7987
Each element in `items` can be either a `str` or a `dict`. If
@@ -117,47 +125,82 @@ def create_data_rows(self, items):
117125
def upload_if_necessary(item):
118126
if isinstance(item, str):
119127
item_url = self.client.upload_file(item)
120-
# Convert item from str into a dict so it gets processed
121-
# like all other dicts.
122128
item = {DataRow.row_data: item_url, DataRow.external_id: item}
129+
elif isinstance(item, dict):
130+
if os.path.exists(item['row_data']):
131+
item_url = self.client.upload_file(item['row_data'])
132+
parts = {
133+
DataRow.row_data:
134+
item_url,
135+
DataRow.external_id:
136+
item.get('external_id', item['row_data'])
137+
}
138+
attachments = item.get('attachments')
139+
if attachments:
140+
item = {**parts, **{'attachments': attachments}}
141+
else:
142+
item = parts
123143
return item
124144

125-
with ThreadPoolExecutor(file_upload_thread_count) as executor:
126-
futures = [
127-
executor.submit(upload_if_necessary, item) for item in items
128-
]
129-
items = [future.result() for future in as_completed(futures)]
145+
def validate_attachments(item):
146+
attachments = item.get('attachments')
147+
if attachments:
148+
if isinstance(attachments, list):
149+
for attachment in attachments:
150+
for required_key in ['type', 'value']:
151+
if required_key not in attachment:
152+
raise ValueError(
153+
f"Must provide a `{required_key}` key for each attachment. Found {attachment}."
154+
)
155+
attachment_type = attachment.get('type')
156+
if attachment_type not in DataRow.supported_attachment_types:
157+
raise ValueError(
158+
f"meta_type must be one of {DataRow.supported_attachment_types}. Found {attachment_type}"
159+
)
160+
else:
161+
raise ValueError(
162+
f"Attachments must be a list. Found {type(attachments)}"
163+
)
164+
return attachments
130165

131166
def convert_item(item):
132167
# Don't make any changes to tms data
168+
validate_attachments(item)
133169
if "tileLayerUrl" in item:
134170
return item
135-
# Convert string names to fields.
171+
172+
item = upload_if_necessary(item)
173+
# Convert fields to string names.
136174
item = {
137-
key if isinstance(key, Field) else DataRow.field(key): value
175+
key.name if isinstance(key, Field) else key: value
138176
for key, value in item.items()
139177
}
140178

141-
if DataRow.row_data not in item:
179+
if 'row_data' not in item:
142180
raise InvalidQueryError(
143-
"DataRow.row_data missing when creating DataRow.")
181+
"`row_data` missing when creating DataRow.")
144182

145-
invalid_keys = set(item) - set(DataRow.fields())
183+
# TODO: This is technically breaking. but also idt anyone is using the other fields.
184+
invalid_keys = set(item) - {
185+
'row_data', 'external_id', 'attachments'
186+
}
146187
if invalid_keys:
147188
raise InvalidAttributeError(DataRow, invalid_keys)
148189

149190
# Item is valid, convert it to a dict {graphql_field_name: value}
150191
# Need to change the name of DataRow.row_data to "data"
151192
return {
152-
"data" if key == DataRow.row_data else key.graphql_name: value
193+
"data" if key == "row_data" else utils.camel_case(key): value
153194
for key, value in item.items()
154195
}
155196

197+
with ThreadPoolExecutor(file_upload_thread_count) as executor:
198+
futures = [executor.submit(convert_item, item) for item in items]
199+
items = [future.result() for future in as_completed(futures)]
200+
156201
# Prepare and upload the desciptor file
157-
items = [convert_item(item) for item in items]
158202
data = json.dumps(items)
159203
descriptor_url = self.client.upload_data(data)
160-
161204
# Create data source
162205
dataset_param = "datasetId"
163206
url_param = "jsonUrl"

tests/integration/test_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,4 @@ def test_data_row_export(dataset):
106106
result = list(dataset.export_data_rows())
107107
assert len(result) == n_data_rows
108108
assert set(result) == ids
109+

0 commit comments

Comments
 (0)