Skip to content

Commit b02b57a

Browse files
Adjust fn logic
1 parent 3a699da commit b02b57a

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

labelbox/schema/model_run.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Dict, Iterable, Union
22
from pathlib import Path
33
import os
4+
import time
45

56
from labelbox.pagination import PaginatedCollection
67
from labelbox.schema.annotation_import import MEAPredictionImport
@@ -16,6 +17,38 @@ class ModelRun(DbObject):
1617
created_by_id = Field.String("created_by_id", "createdBy")
1718
model_id = Field.String("model_id")
1819

20+
def upsert_labels(self, label_ids, timeout_seconds=600):
21+
""" Calls GraphQL API to start the MEA labels registration process
22+
Args:
23+
label_ids (list): label ids to insert
24+
timeout_seconds (float): Max waiting time, in seconds.
25+
Returns:
26+
ID of newly generated async task
27+
"""
28+
29+
sleep_time = 5
30+
31+
mutation_name = 'createMEAModelRunLabelRegistrationTask'
32+
query_str = """mutation createMEAModelRunLabelRegistrationTaskByApi($modelRunId: ID!, $labelIds : [ID!]!) {
33+
%s(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
34+
""" (mutation_name)
35+
36+
while True:
37+
res = self.client.execute(query_str, {
38+
'modelRunId': self.uid,
39+
'labelIds': label_ids
40+
})
41+
42+
res = res[mutation_name]
43+
if res:
44+
return res
45+
46+
timeout_seconds -= sleep_time
47+
if timeout_seconds <= 0:
48+
return None
49+
50+
time.sleep(sleep_time)
51+
1952
def upsert_labels(self, label_ids):
2053

2154
if len(label_ids) < 1:

0 commit comments

Comments
 (0)