Skip to content

Commit 5f6dd68

Browse files
authored
Merge pull request #81 from scaleapi/da-batch-model-export
Batch export for model predictions
2 parents a1f87ef + 7fa31a9 commit 5f6dd68

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

nucleus/annotation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ class Annotation:
3737
item_id: Optional[str] = None
3838

3939
def _check_ids(self):
40-
if bool(self.reference_id) == bool(self.item_id):
40+
if self.reference_id and self.item_id:
41+
self.item_id = None # Prefer reference id to item id.
42+
if not (self.reference_id or self.item_id):
4143
raise Exception(
4244
"You must specify either a reference_id or an item_id for an annotation."
4345
)

nucleus/model_run.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Dict, List, Optional, Type, Union
22

3+
import requests
4+
35
from nucleus.annotation import check_all_annotation_paths_remote
46
from nucleus.job import AsyncJob
57
from nucleus.utils import serialize_and_write_to_presigned_url
@@ -155,6 +157,16 @@ def loc(self, dataset_item_id: str):
155157
)
156158
return self._format_prediction_response(response)
157159

160+
def ungrouped_export(self):
161+
json_response = self._client.make_request(
162+
payload={},
163+
route=f"modelRun/{self.model_run_id}/ungrouped",
164+
requests_command=requests.get,
165+
)
166+
return self._format_prediction_response(
167+
{ANNOTATIONS_KEY: json_response}
168+
)
169+
158170
def _format_prediction_response(
159171
self, response: dict
160172
) -> Union[

tests/test_prediction.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_repr(test_object: any):
5050
def model_run(CLIENT):
5151
ds = CLIENT.create_dataset(TEST_DATASET_NAME)
5252
ds_items = []
53-
for url in TEST_IMG_URLS:
53+
for url in TEST_IMG_URLS[:2]:
5454
ds_items.append(
5555
DatasetItem(
5656
image_location=url,
@@ -246,7 +246,7 @@ def test_polygon_pred_upload_ignore(model_run):
246246
)
247247

248248

249-
def test_mixed_pred_upload(model_run):
249+
def test_mixed_pred_upload(model_run: ModelRun):
250250
prediction_semseg = SegmentationPrediction.from_json(
251251
TEST_SEGMENTATION_PREDICTIONS[0]
252252
)
@@ -262,15 +262,15 @@ def test_mixed_pred_upload(model_run):
262262
assert response["predictions_processed"] == 3
263263
assert response["predictions_ignored"] == 0
264264

265-
response_refloc = model_run.refloc(prediction_polygon.reference_id)
265+
all_predictions = model_run.ungrouped_export()
266266
assert_box_prediction_matches_dict(
267-
response_refloc["box"][0], TEST_BOX_PREDICTIONS[0]
267+
all_predictions["box"][0], TEST_BOX_PREDICTIONS[0]
268268
)
269269
assert_polygon_prediction_matches_dict(
270-
response_refloc["polygon"][0], TEST_POLYGON_PREDICTIONS[0]
270+
all_predictions["polygon"][0], TEST_POLYGON_PREDICTIONS[0]
271271
)
272272
assert_segmentation_annotation_matches_dict(
273-
response_refloc["segmentation"][0], TEST_SEGMENTATION_PREDICTIONS[0]
273+
all_predictions["segmentation"][0], TEST_SEGMENTATION_PREDICTIONS[0]
274274
)
275275

276276

0 commit comments

Comments
 (0)