Skip to content

Commit 9418d42

Browse files
author
Diego Ardila
committed
Integration tests passing
1 parent f111c33 commit 9418d42

File tree

6 files changed

+103
-18
lines changed

6 files changed

+103
-18
lines changed

nucleus/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ def __init__(
158158
self.endpoint = os.environ.get(
159159
"NUCLEUS_ENDPOINT", NUCLEUS_ENDPOINT
160160
)
161+
else:
162+
self.endpoint = endpoint
161163
self._use_notebook = use_notebook
162164
if use_notebook:
163165
self.tqdm_bar = tqdm_notebook.tqdm
@@ -230,13 +232,13 @@ def get_dataset(self, dataset_id: str) -> Dataset:
230232
"""
231233
return Dataset(dataset_id, self)
232234

233-
def get_model_run(self, model_run_id: str) -> ModelRun:
235+
def get_model_run(self, model_run_id: str, dataset_id: str) -> ModelRun:
234236
"""
235237
Fetches a model_run for given id
236238
:param model_run_id: internally controlled model_run_id
237239
:return: model_run
238240
"""
239-
return ModelRun(model_run_id, self)
241+
return ModelRun(model_run_id, dataset_id, self)
240242

241243
def delete_model_run(self, model_run_id: str):
242244
"""
@@ -674,7 +676,9 @@ def create_model_run(self, dataset_id: str, payload: dict) -> ModelRun:
674676
if response.get(STATUS_CODE_KEY, None):
675677
raise ModelRunCreationError(response.get("error"))
676678

677-
return ModelRun(response[MODEL_RUN_ID_KEY], self)
679+
return ModelRun(
680+
response[MODEL_RUN_ID_KEY], dataset_id=dataset_id, client=self
681+
)
678682

679683
def predict(
680684
self,

nucleus/model_run.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@ class ModelRun:
2626
Having an open model run is a prerequisite for uploading predictions to your dataset.
2727
"""
2828

29-
def __init__(self, model_run_id: str, client):
29+
def __init__(self, model_run_id: str, dataset_id: str, client):
3030
self.model_run_id = model_run_id
3131
self._client = client
32+
self._dataset_id = dataset_id
3233

3334
def __repr__(self):
34-
return f"ModelRun(model_run_id='{self.model_run_id}', client={self._client})"
35+
return f"ModelRun(model_run_id='{self.model_run_id}', dataset_id='{self._dataset_id}', client={self._client})"
3536

3637
def __eq__(self, other):
3738
if self.model_run_id == other.model_run_id:
@@ -92,7 +93,6 @@ def predict(
9293
],
9394
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
9495
asynchronous: bool = False,
95-
dataset_id: Optional[str] = None,
9696
) -> Union[dict, AsyncJob]:
9797
"""
9898
Uploads model outputs as predictions for a model_run. Returns info about the upload.
@@ -107,12 +107,8 @@ def predict(
107107
if asynchronous:
108108
check_all_annotation_paths_remote(annotations)
109109

110-
assert (
111-
dataset_id is not None
112-
), "For now, you must pass a dataset id to predict for asynchronous uploads."
113-
114110
request_id = serialize_and_write_to_presigned_url(
115-
annotations, dataset_id, self._client
111+
annotations, self._dataset_id, self._client
116112
)
117113
response = self._client.make_request(
118114
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},

nucleus/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,5 +119,5 @@ def serialize_and_write_to_presigned_url(
119119
strio = io.StringIO()
120120
serialize_and_write(upload_units, strio)
121121
strio.seek(0)
122-
upload_to_presigned_url(response["presigned_url"], strio)
122+
upload_to_presigned_url(response["signed_url"], strio)
123123
return request_id

tests/helpers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313

1414

1515
TEST_IMG_URLS = [
16-
"http://farm1.staticflickr.com/107/309278012_7a1f67deaa_z.jpg",
17-
"http://farm9.staticflickr.com/8001/7679588594_4e51b76472_z.jpg",
18-
"http://farm6.staticflickr.com/5295/5465771966_76f9773af1_z.jpg",
19-
"http://farm4.staticflickr.com/3449/4002348519_8ddfa4f2fb_z.jpg",
20-
"http://farm1.staticflickr.com/6/7617223_d84fcbce0e_z.jpg",
16+
"https://homepages.cae.wisc.edu/~ece533/images/airplane.png",
17+
"https://homepages.cae.wisc.edu/~ece533/images/arctichare.png",
18+
"https://homepages.cae.wisc.edu/~ece533/images/baboon.png",
19+
"https://homepages.cae.wisc.edu/~ece533/images/barbara.png",
20+
"https://homepages.cae.wisc.edu/~ece533/images/cat.png",
2121
]
2222

2323
TEST_DATASET_ITEMS = [

tests/test_models.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@ def test_repr(test_object: any):
4040
metadata={"fake": "metadata"},
4141
)
4242
)
43-
test_repr(ModelRun(client=client, model_run_id="fake_model_run_id"))
43+
test_repr(
44+
ModelRun(
45+
client=client,
46+
dataset_id="fake_dataset_id",
47+
model_run_id="fake_model_run_id",
48+
)
49+
)
4450

4551

4652
def test_model_creation_and_listing(CLIENT, dataset):

tests/test_prediction.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from nucleus.job import AsyncJob
12
import pytest
23
import time
34
from .helpers import (
@@ -20,9 +21,12 @@
2021
SegmentationPrediction,
2122
DatasetItem,
2223
Segment,
24+
ModelRun,
2325
)
2426
from nucleus.constants import ERROR_PAYLOAD
2527

28+
from nucleus import utils
29+
2630

2731
def test_reprs():
2832
def test_repr(test_object: any):
@@ -54,6 +58,7 @@ def model_run(CLIENT):
5458
)
5559

5660
response = ds.append(ds_items)
61+
5762
assert ERROR_PAYLOAD not in response.json()
5863

5964
model = CLIENT.add_model(
@@ -264,3 +269,77 @@ def test_mixed_pred_upload(model_run):
264269
assert_segmentation_annotation_matches_dict(
265270
response_refloc["segmentation"][0], TEST_SEGMENTATION_PREDICTIONS[0]
266271
)
272+
273+
274+
def test_mixed_pred_upload_async(model_run: ModelRun):
275+
prediction_semseg = SegmentationPrediction.from_json(
276+
TEST_SEGMENTATION_PREDICTIONS[0]
277+
)
278+
prediction_polygon = PolygonPrediction(**TEST_POLYGON_PREDICTIONS[0])
279+
prediction_bbox = BoxPrediction(**TEST_BOX_PREDICTIONS[0])
280+
job: AsyncJob = model_run.predict(
281+
annotations=[prediction_semseg, prediction_polygon, prediction_bbox],
282+
asynchronous=True,
283+
)
284+
job.sleep_until_complete()
285+
print(job.status())
286+
print(job.errors())
287+
288+
assert job.status() == {
289+
"job_id": job.id,
290+
"status": "Completed",
291+
"message": {
292+
"annotation_upload": {
293+
"epoch": 1,
294+
"total": 2,
295+
"errored": 0,
296+
"ignored": 0,
297+
"datasetId": model_run._dataset_id,
298+
"processed": 2,
299+
},
300+
"segmentation_upload": {
301+
"errors": [],
302+
"ignored": 0,
303+
"n_errors": 0,
304+
"processed": 1,
305+
},
306+
},
307+
}
308+
309+
310+
def test_mixed_pred_upload_async_with_error(model_run: ModelRun):
311+
prediction_semseg = SegmentationPrediction.from_json(
312+
TEST_SEGMENTATION_PREDICTIONS[0]
313+
)
314+
prediction_polygon = PolygonPrediction(**TEST_POLYGON_PREDICTIONS[0])
315+
prediction_bbox = BoxPrediction(**TEST_BOX_PREDICTIONS[0])
316+
prediction_bbox.reference_id = "fake_garbage"
317+
318+
job: AsyncJob = model_run.predict(
319+
annotations=[prediction_semseg, prediction_polygon, prediction_bbox],
320+
asynchronous=True,
321+
)
322+
job.sleep_until_complete()
323+
324+
assert job.status() == {
325+
"job_id": job.id,
326+
"status": "Completed",
327+
"message": {
328+
"annotation_upload": {
329+
"epoch": 1,
330+
"total": 2,
331+
"errored": 0,
332+
"ignored": 0,
333+
"datasetId": model_run._dataset_id,
334+
"processed": 1,
335+
},
336+
"segmentation_upload": {
337+
"errors": [],
338+
"ignored": 0,
339+
"n_errors": 0,
340+
"processed": 1,
341+
},
342+
},
343+
}
344+
345+
assert "Item with id fake_garbage doesn" in str(job.errors())

0 commit comments

Comments
 (0)