Skip to content

Commit 6492f47

Browse files
author
Matt Sokoloff
committed
upsert data rows
1 parent e3eb576 commit 6492f47

File tree

2 files changed

+65
-10
lines changed

2 files changed

+65
-10
lines changed

labelbox/schema/model_run.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class ModelRun(DbObject):
1818
model_id = Field.String("model_id")
1919

2020
def upsert_labels(self, label_ids, timeout_seconds=60):
21-
""" Calls GraphQL API to start the MEA labels registration process
21+
""" Adds data rows and labels to a model run
2222
Args:
2323
label_ids (list): label ids to insert
2424
timeout_seconds (float): Max waiting time, in seconds.
@@ -29,8 +29,6 @@ def upsert_labels(self, label_ids, timeout_seconds=60):
2929
if len(label_ids) < 1:
3030
raise ValueError("Must provide at least one label id")
3131

32-
sleep_time = 5
33-
3432
mutation_name = 'createMEAModelRunLabelRegistrationTask'
3533
create_task_query_str = """mutation createMEAModelRunLabelRegistrationTaskPyApi($modelRunId: ID!, $labelIds : [ID!]!) {
3634
%s(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
@@ -46,18 +44,54 @@ def upsert_labels(self, label_ids, timeout_seconds=60):
4644
MEALabelRegistrationTaskStatus(where: $where) {status errorMessage}
4745
}
4846
"""
47+
return self._wait_until_done(lambda: self.client.execute(
48+
status_query_str, {'where': {
49+
'id': task_id
50+
}})['MEALabelRegistrationTaskStatus'],
51+
timeout_seconds=timeout_seconds)
52+
53+
def upsert_data_rows(self, data_row_ids, timeout_seconds=60):
54+
""" Adds data rows to a model run without any associated labels
55+
Args:
56+
data_row_ids (list): data row ids to add to mea
57+
timeout_seconds (float): Max waiting time, in seconds.
58+
Returns:
59+
ID of newly generated async task
60+
"""
61+
62+
if len(data_row_ids) < 1:
63+
raise ValueError("Must provide at least one data row id")
64+
65+
mutation_name = 'createMEAModelRunDataRowRegistrationTask'
66+
create_task_query_str = """mutation createMEAModelRunDataRowRegistrationTaskPyApi($modelRunId: ID!, $dataRowIds : [ID!]!) {
67+
%s(where : { id : $modelRunId}, data : {dataRowIds: $dataRowIds})}
68+
""" % (mutation_name)
4969

70+
res = self.client.execute(create_task_query_str, {
71+
'modelRunId': self.uid,
72+
'dataRowIds': data_row_ids
73+
})
74+
task_id = res[mutation_name]
75+
76+
status_query_str = """query MEADataRowRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){
77+
MEADataRowRegistrationTaskStatus(where: $where) {status errorMessage}
78+
}
79+
"""
80+
return self._wait_until_done(lambda: self.client.execute(
81+
status_query_str, {'where': {
82+
'id': task_id
83+
}})['MEADataRowRegistrationTaskStatus'],
84+
timeout_seconds=timeout_seconds)
85+
86+
def _wait_until_done(self, status_fn, timeout_seconds=60, sleep_time=5):
87+
# Do not use this function outside of the scope of upsert_data_rows or upsert_labels. It could change.
5088
while True:
51-
res = self.client.execute(status_query_str,
52-
{'where': {
53-
'id': task_id
54-
}})['MEALabelRegistrationTaskStatus']
89+
res = status_fn()
5590
if res['status'] == 'COMPLETE':
56-
return res
91+
return True
5792
elif res['status'] == 'FAILED':
5893
raise Exception(
59-
f"MEA Label Import Failed. Details : {res['errorMessage']}")
60-
94+
f"MEA Import Failed. Details : {res['errorMessage']}")
6195
timeout_seconds -= sleep_time
6296
if timeout_seconds <= 0:
6397
raise TimeoutError(

tests/integration/annotation_import/test_model_run.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,24 @@ def test_model_run_annotation_groups_delete(client,
5454
after = list(model_run.annotation_groups())
5555

5656
assert len(before) == len(after) + 1
57+
58+
59+
def test_model_run_upsert_data_rows(dataset, model_run):
60+
n_annotation_groups = len(list(model_run.annotation_groups()))
61+
assert n_annotation_groups == 0
62+
data_row = dataset.create_data_row(row_data="test row data")
63+
model_run.upsert_data_rows([data_row.uid])
64+
n_annotation_groups = len(list(model_run.annotation_groups()))
65+
assert n_annotation_groups == 1
66+
67+
68+
def test_model_run_upsert_data_rows_with_existing_labels(
69+
model_run_annotation_groups):
70+
annotation_groups = list(model_run_annotation_groups.annotation_groups())
71+
n_annotation_groups = len(annotation_groups)
72+
model_run_annotation_groups.upsert_data_rows([
73+
annotation_group.data_row().uid
74+
for annotation_group in annotation_groups
75+
])
76+
assert n_annotation_groups == len(
77+
list(model_run_annotation_groups.annotation_groups()))

0 commit comments

Comments
 (0)