Skip to content

Commit 905f718

Browse files
author
Matt Sokoloff
committed
Merge branch 'develop' of https://github.com/Labelbox/labelbox-python into ms/custom-scalar-metrics
2 parents e27918b + 1254354 commit 905f718

16 files changed

+321
-103
lines changed

CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
# Changelog
22

3+
# Version 3.3.0 (2021-09-02)
4+
## Added
5+
* `Dataset.create_data_rows_sync()` for synchronous bulk uploads of data rows
6+
* `Model.delete()`, `ModelRun.delete()`, and `ModelRun.delete_annotation_groups()` to
7+
Clean up models, model runs, and annotation groups.
8+
9+
## Fix
10+
* Increased timeout for label exports since projects with many segmentation masks weren't finishing quickly enough.
11+
12+
# Version 3.2.1 (2021-08-31)
13+
## Fix
14+
* Resolved issue with `create_data_rows()` was not working on amazon linux
15+
316
# Version 3.2.0 (2021-08-26)
417
## Added
518
* List `BulkImportRequest`s for a project with `Project.bulk_import_requests()`

labelbox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name = "labelbox"
2-
__version__ = "3.2.0"
2+
__version__ = "3.3.0"
33

44
from labelbox.schema.project import Project
55
from labelbox.client import Client

labelbox/schema/dataset.py

Lines changed: 114 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,111 @@ def create_data_row(self, **kwargs):
6969
row_data = kwargs[DataRow.row_data.name]
7070
if os.path.exists(row_data):
7171
kwargs[DataRow.row_data.name] = self.client.upload_file(row_data)
72-
7372
kwargs[DataRow.dataset.name] = self
74-
7573
return self.client._create(DataRow, kwargs)
7674

75+
def create_data_rows_sync(self, items):
76+
""" Synchronously bulk upload data rows.
77+
78+
Use this instead of `Dataset.create_data_rows` for smaller batches of data rows that need to be uploaded quickly.
79+
Cannot use this for uploads containing more than 1000 data rows.
80+
Each data row is also limited to 5 attachments.
81+
82+
Args:
83+
items (iterable of (dict or str)):
84+
See the docstring for `Dataset._create_descriptor_file` for more information.
85+
Returns:
86+
None. If the function doesn't raise an exception then the import was successful.
87+
88+
Raises:
89+
InvalidQueryError: If the `items` parameter does not conform to
90+
the specification in Dataset._create_descriptor_file or if the server did not accept the
91+
DataRow creation request (unknown reason).
92+
InvalidAttributeError: If there are fields in `items` not valid for
93+
a DataRow.
94+
ValueError: When the upload parameters are invalid
95+
"""
96+
max_data_rows_supported = 1000
97+
max_attachments_per_data_row = 5
98+
if len(items) > max_data_rows_supported:
99+
raise ValueError(
100+
f"Dataset.create_data_rows_sync() supports a max of {max_data_rows_supported} data rows."
101+
" For larger imports use the async function Dataset.create_data_rows()"
102+
)
103+
descriptor_url = self._create_descriptor_file(
104+
items, max_attachments_per_data_row=max_attachments_per_data_row)
105+
dataset_param = "datasetId"
106+
url_param = "jsonUrl"
107+
query_str = """mutation AppendRowsToDatasetSyncPyApi($%s: ID!, $%s: String!){
108+
appendRowsToDatasetSync(data:{datasetId: $%s, jsonFileUrl: $%s}
109+
){dataset{id}}} """ % (dataset_param, url_param, dataset_param,
110+
url_param)
111+
self.client.execute(query_str, {
112+
dataset_param: self.uid,
113+
url_param: descriptor_url
114+
})
115+
77116
def create_data_rows(self, items):
78-
""" Creates multiple DataRow objects based on the given `items`.
117+
""" Asynchronously bulk upload data rows
118+
119+
Use this instead of `Dataset.create_data_rows_sync` uploads for batches that contain more than 100 data rows.
120+
121+
Args:
122+
items (iterable of (dict or str)): See the docstring for `Dataset._create_descriptor_file` for more information
123+
124+
Returns:
125+
Task representing the data import on the server side. The Task
126+
can be used for inspecting task progress and waiting until it's done.
127+
128+
Raises:
129+
InvalidQueryError: If the `items` parameter does not conform to
130+
the specification above or if the server did not accept the
131+
DataRow creation request (unknown reason).
132+
ResourceNotFoundError: If unable to retrieve the Task for the
133+
import process. This could imply that the import failed.
134+
InvalidAttributeError: If there are fields in `items` not valid for
135+
a DataRow.
136+
ValueError: When the upload parameters are invalid
137+
"""
138+
descriptor_url = self._create_descriptor_file(items)
139+
# Create data source
140+
dataset_param = "datasetId"
141+
url_param = "jsonUrl"
142+
query_str = """mutation AppendRowsToDatasetPyApi($%s: ID!, $%s: String!){
143+
appendRowsToDataset(data:{datasetId: $%s, jsonFileUrl: $%s}
144+
){ taskId accepted errorMessage } } """ % (dataset_param, url_param,
145+
dataset_param, url_param)
146+
147+
res = self.client.execute(query_str, {
148+
dataset_param: self.uid,
149+
url_param: descriptor_url
150+
})
151+
res = res["appendRowsToDataset"]
152+
if not res["accepted"]:
153+
msg = res['errorMessage']
154+
raise InvalidQueryError(
155+
f"Server did not accept DataRow creation request. {msg}")
156+
157+
# Fetch and return the task.
158+
task_id = res["taskId"]
159+
user = self.client.get_user()
160+
task = list(user.created_tasks(where=Entity.Task.uid == task_id))
161+
# Cache user in a private variable as the relationship can't be
162+
# resolved due to server-side limitations (see Task.created_by)
163+
# for more info.
164+
if len(task) != 1:
165+
raise ResourceNotFoundError(Entity.Task, task_id)
166+
task = task[0]
167+
task._user = user
168+
return task
169+
170+
def _create_descriptor_file(self, items, max_attachments_per_data_row=None):
171+
"""
172+
This function is shared by both `Dataset.create_data_rows` and `Dataset.create_data_rows_sync`
173+
to prepare the input file. The user defined input is validated, processed, and json stringified.
174+
Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed to
175+
176+
79177
80178
Each element in `items` can be either a `str` or a `dict`. If
81179
it is a `str`, then it is interpreted as a local file path. The file
@@ -102,22 +200,23 @@ def create_data_rows(self, items):
102200
103201
Args:
104202
items (iterable of (dict or str)): See above for details.
203+
max_attachments_per_data_row (Optional[int]): Param used during attachment validation to determine
204+
if the user has provided too many attachments.
105205
106206
Returns:
107-
Task representing the data import on the server side. The Task
108-
can be used for inspecting task progress and waiting until it's done.
207+
uri (string): A reference to the uploaded json data.
109208
110209
Raises:
111210
InvalidQueryError: If the `items` parameter does not conform to
112211
the specification above or if the server did not accept the
113212
DataRow creation request (unknown reason).
114-
ResourceNotFoundError: If unable to retrieve the Task for the
115-
import process. This could imply that the import failed.
116213
InvalidAttributeError: If there are fields in `items` not valid for
117214
a DataRow.
215+
ValueError: When the upload parameters are invalid
118216
"""
119217
file_upload_thread_count = 20
120218
DataRow = Entity.DataRow
219+
AssetAttachment = Entity.AssetAttachment
121220

122221
def upload_if_necessary(item):
123222
row_data = item['row_data']
@@ -134,9 +233,14 @@ def validate_attachments(item):
134233
attachments = item.get('attachments')
135234
if attachments:
136235
if isinstance(attachments, list):
236+
if max_attachments_per_data_row and len(
237+
attachments) > max_attachments_per_data_row:
238+
raise ValueError(
239+
f"Max attachments number of supported attachments per data row is {max_attachments_per_data_row}."
240+
f" Found {len(attachments)}. Condense multiple attachments into one with the HTML attachment type if necessary."
241+
)
137242
for attachment in attachments:
138-
Entity.AssetAttachment.validate_attachment_json(
139-
attachment)
243+
AssetAttachment.validate_attachment_json(attachment)
140244
else:
141245
raise ValueError(
142246
f"Attachments must be a list. Found {type(attachments)}"
@@ -198,40 +302,9 @@ def convert_item(item):
198302
with ThreadPoolExecutor(file_upload_thread_count) as executor:
199303
futures = [executor.submit(convert_item, item) for item in items]
200304
items = [future.result() for future in as_completed(futures)]
201-
202305
# Prepare and upload the desciptor file
203306
data = json.dumps(items)
204-
descriptor_url = self.client.upload_data(data)
205-
# Create data source
206-
dataset_param = "datasetId"
207-
url_param = "jsonUrl"
208-
query_str = """mutation AppendRowsToDatasetPyApi($%s: ID!, $%s: String!){
209-
appendRowsToDataset(data:{datasetId: $%s, jsonFileUrl: $%s}
210-
){ taskId accepted errorMessage } } """ % (dataset_param, url_param,
211-
dataset_param, url_param)
212-
213-
res = self.client.execute(query_str, {
214-
dataset_param: self.uid,
215-
url_param: descriptor_url
216-
})
217-
res = res["appendRowsToDataset"]
218-
if not res["accepted"]:
219-
msg = res['errorMessage']
220-
raise InvalidQueryError(
221-
f"Server did not accept DataRow creation request. {msg}")
222-
223-
# Fetch and return the task.
224-
task_id = res["taskId"]
225-
user = self.client.get_user()
226-
task = list(user.created_tasks(where=Entity.Task.uid == task_id))
227-
# Cache user in a private variable as the relationship can't be
228-
# resolved due to server-side limitations (see Task.created_by)
229-
# for more info.
230-
if len(task) != 1:
231-
raise ResourceNotFoundError(Entity.Task, task_id)
232-
task = task[0]
233-
task._user = user
234-
return task
307+
return self.client.upload_data(data)
235308

236309
def data_rows_for_external_id(self, external_id, limit=10):
237310
""" Convenience method for getting a single `DataRow` belonging to this

labelbox/schema/model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,14 @@ def create_model_run(self, name):
3434
model_id_param: self.uid
3535
})
3636
return ModelRun(self.client, res["createModelRun"])
37+
38+
def delete(self):
39+
""" Deletes specified model.
40+
41+
Returns:
42+
Query execution success.
43+
"""
44+
ids_param = "ids"
45+
query_str = """mutation DeleteModelPyApi($%s: ID!) {
46+
deleteModels(where: {ids: [$%s]})}""" % (ids_param, ids_param)
47+
self.client.execute(query_str, {ids_param: str(self.uid)})

labelbox/schema/model_run.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,36 @@ def annotation_groups(self):
7474
lambda client, res: AnnotationGroup(client, self.model_id, res),
7575
['annotationGroups', 'pageInfo', 'endCursor'])
7676

77+
def delete(self):
78+
""" Deletes specified model run.
79+
80+
Returns:
81+
Query execution success.
82+
"""
83+
ids_param = "ids"
84+
query_str = """mutation DeleteModelRunPyApi($%s: ID!) {
85+
deleteModelRuns(where: {ids: [$%s]})}""" % (ids_param, ids_param)
86+
self.client.execute(query_str, {ids_param: str(self.uid)})
87+
88+
def delete_annotation_groups(self, data_row_ids):
89+
""" Deletes annotation groups by data row ids for a model run.
90+
91+
Args:
92+
data_row_ids (list): List of data row ids to delete annotation groups.
93+
Returns:
94+
Query execution success.
95+
"""
96+
model_run_id_param = "modelRunId"
97+
data_row_ids_param = "dataRowIds"
98+
query_str = """mutation DeleteModelRunDataRowsPyApi($%s: ID!, $%s: [ID!]!) {
99+
deleteModelRunDataRows(where: {modelRunId: $%s, dataRowIds: $%s})}""" % (
100+
model_run_id_param, data_row_ids_param, model_run_id_param,
101+
data_row_ids_param)
102+
self.client.execute(query_str, {
103+
model_run_id_param: self.uid,
104+
data_row_ids_param: data_row_ids
105+
})
106+
77107

78108
class AnnotationGroup(DbObject):
79109
label_id = Field.String("label_id")

labelbox/schema/project.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def export_queued_data_rows(self, timeout_seconds=120):
166166
self.uid)
167167
time.sleep(sleep_time)
168168

169-
def video_label_generator(self, timeout_seconds=120):
169+
def video_label_generator(self, timeout_seconds=600):
170170
"""
171171
Download video annotations
172172
@@ -190,7 +190,7 @@ def video_label_generator(self, timeout_seconds=120):
190190
"Or use project.label_generator() for text and imagery data.")
191191
return LBV1Converter.deserialize_video(json_data, self.client)
192192

193-
def label_generator(self, timeout_seconds=60):
193+
def label_generator(self, timeout_seconds=600):
194194
"""
195195
Download text and image annotations
196196
@@ -214,7 +214,7 @@ def label_generator(self, timeout_seconds=60):
214214
"Or use project.video_label_generator() for video data.")
215215
return LBV1Converter.deserialize(json_data)
216216

217-
def export_labels(self, download=False, timeout_seconds=60):
217+
def export_labels(self, download=False, timeout_seconds=600):
218218
""" Calls the server-side Label exporting that generates a JSON
219219
payload, and returns the URL to that payload.
220220

labelbox/schema/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def refresh(self):
4040
for field in self.fields():
4141
setattr(self, field.name, getattr(tasks[0], field.name))
4242

43-
def wait_till_done(self, timeout_seconds=60):
43+
def wait_till_done(self, timeout_seconds=300):
4444
""" Waits until the task is completed. Periodically queries the server
4545
to update the task attributes.
4646

tests/integration/bulk_import/conftest.py renamed to tests/integration/mal_and_mea/conftest.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,25 @@ def predictions(object_predictions, classification_predictions):
297297

298298

299299
@pytest.fixture
300-
def model_run(client, rand_gen, configured_project, annotation_submit_fn,
301-
model_run_predictions):
302-
configured_project.enable_model_assisted_labeling()
300+
def model(client, rand_gen, configured_project):
303301
ontology = configured_project.ontology()
304302

303+
data = {"name": rand_gen(str), "ontology_id": ontology.uid}
304+
return client.create_model(data["name"], data["ontology_id"])
305+
306+
307+
@pytest.fixture
308+
def model_run(rand_gen, model):
309+
name = rand_gen(str)
310+
return model.create_model_run(name)
311+
312+
313+
@pytest.fixture
314+
def model_run_annotation_groups(client, configured_project,
315+
annotation_submit_fn, model_run_predictions,
316+
model_run):
317+
configured_project.enable_model_assisted_labeling()
318+
305319
upload_task = MALPredictionImport.create_from_objects(
306320
client, configured_project.uid, f'mal-import-{uuid.uuid4()}',
307321
model_run_predictions)
@@ -310,15 +324,10 @@ def model_run(client, rand_gen, configured_project, annotation_submit_fn,
310324
for data_row_id in {x['dataRow']['id'] for x in model_run_predictions}:
311325
annotation_submit_fn(configured_project.uid, data_row_id)
312326

313-
data = {"name": rand_gen(str), "ontology_id": ontology.uid}
314-
model = client.create_model(data["name"], data["ontology_id"])
315-
name = rand_gen(str)
316-
model_run_s = model.create_model_run(name)
317-
318327
time.sleep(3)
319328
labels = configured_project.export_labels(download=True)
320-
model_run_s.upsert_labels([label['ID'] for label in labels])
329+
model_run.upsert_labels([label['ID'] for label in labels])
321330
time.sleep(3)
322331

323-
yield model_run_s
332+
yield model_run
324333
# TODO: Delete resources when that is possible ..

0 commit comments

Comments
 (0)