1
1
from typing import Dict , Iterable , Union
2
2
from pathlib import Path
3
3
import os
4
+ import time
4
5
5
6
from labelbox .pagination import PaginatedCollection
6
7
from labelbox .schema .annotation_import import MEAPredictionImport
@@ -16,20 +17,54 @@ class ModelRun(DbObject):
16
17
created_by_id = Field .String ("created_by_id" , "createdBy" )
17
18
model_id = Field .String ("model_id" )
18
19
19
- def upsert_labels (self , label_ids ):
20
+ def upsert_labels (self , label_ids , timeout_seconds = 60 ):
21
+ """ Calls GraphQL API to start the MEA labels registration process
22
+ Args:
23
+ label_ids (list): label ids to insert
24
+ timeout_seconds (float): Max waiting time, in seconds.
25
+ Returns:
26
+ ID of newly generated async task
27
+ """
20
28
21
29
if len (label_ids ) < 1 :
22
30
raise ValueError ("Must provide at least one label id" )
23
31
24
- query_str = """mutation upsertModelRunLabelsPyApi($modelRunId: ID!, $labelIds : [ID!]!) {
25
- upsertModelRunLabels(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
26
- """
27
- res = self .client .execute (query_str , {
32
+ sleep_time = 5
33
+
34
+ mutation_name = 'createMEAModelRunLabelRegistrationTask'
35
+ create_task_query_str = """mutation createMEAModelRunLabelRegistrationTaskPyApi($modelRunId: ID!, $labelIds : [ID!]!) {
36
+ %s(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
37
+ """ % (mutation_name )
38
+
39
+ res = self .client .execute (create_task_query_str , {
28
40
'modelRunId' : self .uid ,
29
41
'labelIds' : label_ids
30
42
})
31
- # TODO: Return a task
32
- return True
43
+ task_id = res [mutation_name ]
44
+
45
+ status_query_str = """query MEALabelRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){
46
+ MEALabelRegistrationTaskStatus(where: $where) {status errorMessage}
47
+ }
48
+ """
49
+
50
+ while True :
51
+ res = self .client .execute (status_query_str ,
52
+ {'where' : {
53
+ 'id' : task_id
54
+ }})['MEALabelRegistrationTaskStatus' ]
55
+ if res ['status' ] == 'COMPLETE' :
56
+ return res
57
+ elif res ['status' ] == 'FAILED' :
58
+ raise Exception (
59
+ f"MEA Label Import Failed. Details : { res ['errorMessage' ]} " )
60
+
61
+ timeout_seconds -= sleep_time
62
+ if timeout_seconds <= 0 :
63
+ raise TimeoutError (
64
+ f"Unable to complete import within { timeout_seconds } seconds."
65
+ )
66
+
67
+ time .sleep (sleep_time )
33
68
34
69
def add_predictions (
35
70
self ,
0 commit comments