Skip to content

Commit 7d68200

Browse files
authored
Merge pull request #462 from Labelbox/modelrunlabels
[AL-1623] - ModelRun Export Annotations
2 parents cdc42c8 + 9c20ea7 commit 7d68200

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

labelbox/schema/model_run.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
1-
from typing import TYPE_CHECKING, Dict, Iterable, Union
1+
from typing import TYPE_CHECKING, Dict, Iterable, Union, List, Optional, Any
22
from pathlib import Path
33
import os
44
import time
5+
import logging
6+
import requests
7+
import ndjson
58

69
from labelbox.pagination import PaginatedCollection
710
from labelbox.orm.query import results_query_part
811
from labelbox.orm.model import Field, Relationship, Entity
9-
from labelbox.orm.db_object import DbObject
12+
from labelbox.orm.db_object import DbObject, experimental
1013

1114
if TYPE_CHECKING:
1215
from labelbox import MEAPredictionImport
1316

17+
logger = logging.getLogger(__name__)
18+
1419

1520
class ModelRun(DbObject):
1621
name = Field.String("name")
@@ -175,6 +180,54 @@ def delete_model_run_data_rows(self, data_row_ids):
175180
data_row_ids_param: data_row_ids
176181
})
177182

183+
@experimental
184+
def export_labels(
185+
self,
186+
download: bool = False,
187+
timeout_seconds: int = 600
188+
) -> Optional[Union[str, List[Dict[Any, Any]]]]:
189+
"""
190+
Experimental. To use, make sure client has enable_experimental=True.
191+
192+
Fetches Labels from the ModelRun
193+
194+
Args:
195+
download (bool): Returns the url if False
196+
Returns:
197+
URL of the data file with this ModelRun's labels.
198+
If download=True, this instead returns the contents as NDJSON format.
199+
If the server didn't generate during the `timeout_seconds` period,
200+
None is returned.
201+
"""
202+
sleep_time = 2
203+
query_str = """mutation exportModelRunAnnotationsPyApi($modelRunId: ID!) {
204+
exportModelRunAnnotations(data: {modelRunId: $modelRunId}) {
205+
downloadUrl createdAt status
206+
}
207+
}
208+
"""
209+
210+
while True:
211+
url = self.client.execute(
212+
query_str, {'modelRunId': self.uid},
213+
experimental=True)['exportModelRunAnnotations']['downloadUrl']
214+
215+
if url:
216+
if not download:
217+
return url
218+
else:
219+
response = requests.get(url)
220+
response.raise_for_status()
221+
return ndjson.loads(response.content)
222+
223+
timeout_seconds -= sleep_time
224+
if timeout_seconds <= 0:
225+
return None
226+
227+
logger.debug("ModelRun '%s' label export, waiting for server...",
228+
self.uid)
229+
time.sleep(sleep_time)
230+
178231

179232
class ModelRunDataRow(DbObject):
180233
label_id = Field.String("label_id")

tests/integration/annotation_import/test_model_run.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,8 @@ def test_model_run_upsert_data_rows_with_existing_labels(
7373
])
7474
assert n_data_rows == len(
7575
list(model_run_with_model_run_data_rows.model_run_data_rows()))
76+
77+
78+
def test_model_run_export_labels(model_run_with_model_run_data_rows):
79+
labels = model_run_with_model_run_data_rows.export_labels(download=True)
80+
assert len(labels) == 3

0 commit comments

Comments
 (0)