|
1 |
| -from typing import TYPE_CHECKING, Dict, Iterable, Union |
| 1 | +from typing import TYPE_CHECKING, Dict, Iterable, Union, List, Optional, Any |
2 | 2 | from pathlib import Path
|
3 | 3 | import os
|
4 | 4 | import time
|
| 5 | +import logging |
| 6 | +import requests |
| 7 | +import ndjson |
5 | 8 |
|
6 | 9 | from labelbox.pagination import PaginatedCollection
|
7 | 10 | from labelbox.orm.query import results_query_part
|
8 | 11 | from labelbox.orm.model import Field, Relationship, Entity
|
9 |
| -from labelbox.orm.db_object import DbObject |
| 12 | +from labelbox.orm.db_object import DbObject, experimental |
10 | 13 |
|
11 | 14 | if TYPE_CHECKING:
|
12 | 15 | from labelbox import MEAPredictionImport
|
13 | 16 |
|
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
14 | 19 |
|
15 | 20 | class ModelRun(DbObject):
|
16 | 21 | name = Field.String("name")
|
@@ -175,6 +180,54 @@ def delete_model_run_data_rows(self, data_row_ids):
|
175 | 180 | data_row_ids_param: data_row_ids
|
176 | 181 | })
|
177 | 182 |
|
| 183 | + @experimental |
| 184 | + def export_labels( |
| 185 | + self, |
| 186 | + download: bool = False, |
| 187 | + timeout_seconds: int = 600 |
| 188 | + ) -> Optional[Union[str, List[Dict[Any, Any]]]]: |
| 189 | + """ |
| 190 | + Experimental. To use, make sure client has enable_experimental=True. |
| 191 | +
|
| 192 | + Fetches Labels from the ModelRun |
| 193 | +
|
| 194 | + Args: |
| 195 | + download (bool): Returns the url if False |
| 196 | + Returns: |
| 197 | + URL of the data file with this ModelRun's labels. |
| 198 | + If download=True, this instead returns the contents as NDJSON format. |
| 199 | + If the server didn't generate during the `timeout_seconds` period, |
| 200 | + None is returned. |
| 201 | + """ |
| 202 | + sleep_time = 2 |
| 203 | + query_str = """mutation exportModelRunAnnotationsPyApi($modelRunId: ID!) { |
| 204 | + exportModelRunAnnotations(data: {modelRunId: $modelRunId}) { |
| 205 | + downloadUrl createdAt status |
| 206 | + } |
| 207 | + } |
| 208 | + """ |
| 209 | + |
| 210 | + while True: |
| 211 | + url = self.client.execute( |
| 212 | + query_str, {'modelRunId': self.uid}, |
| 213 | + experimental=True)['exportModelRunAnnotations']['downloadUrl'] |
| 214 | + |
| 215 | + if url: |
| 216 | + if not download: |
| 217 | + return url |
| 218 | + else: |
| 219 | + response = requests.get(url) |
| 220 | + response.raise_for_status() |
| 221 | + return ndjson.loads(response.content) |
| 222 | + |
| 223 | + timeout_seconds -= sleep_time |
| 224 | + if timeout_seconds <= 0: |
| 225 | + return None |
| 226 | + |
| 227 | + logger.debug("ModelRun '%s' label export, waiting for server...", |
| 228 | + self.uid) |
| 229 | + time.sleep(sleep_time) |
| 230 | + |
178 | 231 |
|
179 | 232 | class ModelRunDataRow(DbObject):
|
180 | 233 | label_id = Field.String("label_id")
|
|
0 commit comments