Skip to content

Commit 05d6e54

Browse files
authored
Merge pull request #314 from Labelbox/tpeharda/DIAG-779-update-model-run-upsert-labels
[DIAG-779] Update SDK model_run.upsert_labels() method
2 parents 5c7d22b + d753923 commit 05d6e54

File tree

1 file changed

+42
-7
lines changed

1 file changed

+42
-7
lines changed

labelbox/schema/model_run.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Dict, Iterable, Union
22
from pathlib import Path
33
import os
4+
import time
45

56
from labelbox.pagination import PaginatedCollection
67
from labelbox.schema.annotation_import import MEAPredictionImport
@@ -16,20 +17,54 @@ class ModelRun(DbObject):
1617
created_by_id = Field.String("created_by_id", "createdBy")
1718
model_id = Field.String("model_id")
1819

19-
def upsert_labels(self, label_ids):
20+
def upsert_labels(self, label_ids, timeout_seconds=60):
21+
""" Calls GraphQL API to start the MEA labels registration process
22+
Args:
23+
label_ids (list): label ids to insert
24+
timeout_seconds (float): Max waiting time, in seconds.
25+
Returns:
26+
ID of newly generated async task
27+
"""
2028

2129
if len(label_ids) < 1:
2230
raise ValueError("Must provide at least one label id")
2331

24-
query_str = """mutation upsertModelRunLabelsPyApi($modelRunId: ID!, $labelIds : [ID!]!) {
25-
upsertModelRunLabels(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
26-
"""
27-
res = self.client.execute(query_str, {
32+
sleep_time = 5
33+
34+
mutation_name = 'createMEAModelRunLabelRegistrationTask'
35+
create_task_query_str = """mutation createMEAModelRunLabelRegistrationTaskPyApi($modelRunId: ID!, $labelIds : [ID!]!) {
36+
%s(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
37+
""" % (mutation_name)
38+
39+
res = self.client.execute(create_task_query_str, {
2840
'modelRunId': self.uid,
2941
'labelIds': label_ids
3042
})
31-
# TODO: Return a task
32-
return True
43+
task_id = res[mutation_name]
44+
45+
status_query_str = """query MEALabelRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){
46+
MEALabelRegistrationTaskStatus(where: $where) {status errorMessage}
47+
}
48+
"""
49+
50+
while True:
51+
res = self.client.execute(status_query_str,
52+
{'where': {
53+
'id': task_id
54+
}})['MEALabelRegistrationTaskStatus']
55+
if res['status'] == 'COMPLETE':
56+
return res
57+
elif res['status'] == 'FAILED':
58+
raise Exception(
59+
f"MEA Label Import Failed. Details : {res['errorMessage']}")
60+
61+
timeout_seconds -= sleep_time
62+
if timeout_seconds <= 0:
63+
raise TimeoutError(
64+
f"Unable to complete import within {timeout_seconds} seconds."
65+
)
66+
67+
time.sleep(sleep_time)
3368

3469
def add_predictions(
3570
self,

0 commit comments

Comments
 (0)