Skip to content

Commit b664115

Browse files
authored
Merge pull request #69 from scaleapi/da/predictions_async
Async prediction upload
2 parents 37afc9f + 0ce4ca2 commit b664115

File tree

12 files changed

+211
-83
lines changed

12 files changed

+211
-83
lines changed

nucleus/__init__.py

Lines changed: 52 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -50,89 +50,83 @@
5050
geometry | dict | Representation of the bounding box in the Box2DGeometry format.\n
5151
metadata | dict | An arbitrary metadata blob for the annotation.\n
5252
"""
53-
__version__ = "0.1.0"
54-
5553
import json
5654
import logging
57-
import warnings
5855
import os
59-
from typing import List, Union, Dict, Callable, Any, Optional
60-
61-
import tqdm
62-
import tqdm.notebook as tqdm_notebook
56+
import warnings
57+
from typing import Any, Callable, Dict, List, Optional, Union
6358

6459
import grequests
60+
import pkg_resources
6561
import requests
62+
import tqdm
63+
import tqdm.notebook as tqdm_notebook
6664
from requests.adapters import HTTPAdapter
6765

6866
# pylint: disable=E1101
6967
# TODO: refactor to reduce this file to under 1000 lines.
7068
# pylint: disable=C0302
7169
from requests.packages.urllib3.util.retry import Retry
7270

73-
from .constants import REFERENCE_IDS_KEY, DATASET_ITEM_IDS_KEY, UPDATE_KEY
74-
from .dataset import Dataset
75-
from .dataset_item import DatasetItem
7671
from .annotation import (
7772
BoxAnnotation,
7873
PolygonAnnotation,
79-
SegmentationAnnotation,
8074
Segment,
81-
)
82-
from .prediction import (
83-
BoxPrediction,
84-
PolygonPrediction,
85-
SegmentationPrediction,
86-
)
87-
from .model_run import ModelRun
88-
from .slice import Slice
89-
from .upload_response import UploadResponse
90-
from .payload_constructor import (
91-
construct_append_payload,
92-
construct_annotation_payload,
93-
construct_model_creation_payload,
94-
construct_box_predictions_payload,
95-
construct_segmentation_payload,
75+
SegmentationAnnotation,
9676
)
9777
from .constants import (
98-
NUCLEUS_ENDPOINT,
78+
ANNOTATION_METADATA_SCHEMA_KEY,
79+
ANNOTATIONS_IGNORED_KEY,
80+
ANNOTATIONS_PROCESSED_KEY,
81+
AUTOTAGS_KEY,
82+
DATASET_ID_KEY,
83+
DATASET_ITEM_IDS_KEY,
9984
DEFAULT_NETWORK_TIMEOUT_SEC,
100-
ERRORS_KEY,
85+
EMBEDDINGS_URL_KEY,
10186
ERROR_ITEMS,
10287
ERROR_PAYLOAD,
103-
ITEMS_KEY,
104-
ITEM_KEY,
88+
ERRORS_KEY,
10589
IMAGE_KEY,
10690
IMAGE_URL_KEY,
107-
DATASET_ID_KEY,
91+
ITEM_METADATA_SCHEMA_KEY,
92+
ITEMS_KEY,
10893
MODEL_RUN_ID_KEY,
109-
DATASET_ITEM_ID_KEY,
110-
SLICE_ID_KEY,
111-
ANNOTATIONS_PROCESSED_KEY,
112-
ANNOTATIONS_IGNORED_KEY,
113-
PREDICTIONS_PROCESSED_KEY,
94+
NAME_KEY,
95+
NUCLEUS_ENDPOINT,
11496
PREDICTIONS_IGNORED_KEY,
97+
PREDICTIONS_PROCESSED_KEY,
98+
REFERENCE_IDS_KEY,
99+
SLICE_ID_KEY,
115100
STATUS_CODE_KEY,
116-
SUCCESS_STATUS_CODES,
117-
DATASET_NAME_KEY,
118-
DATASET_MODEL_RUNS_KEY,
119-
DATASET_SLICES_KEY,
120-
DATASET_LENGTH_KEY,
121-
NAME_KEY,
122-
ANNOTATIONS_KEY,
123-
AUTOTAGS_KEY,
124-
ANNOTATION_METADATA_SCHEMA_KEY,
125-
ITEM_METADATA_SCHEMA_KEY,
126-
EMBEDDINGS_URL_KEY,
101+
UPDATE_KEY,
127102
)
128-
from .model import Model
103+
from .dataset import Dataset
104+
from .dataset_item import DatasetItem
129105
from .errors import (
106+
DatasetItemRetrievalError,
130107
ModelCreationError,
131108
ModelRunCreationError,
132-
DatasetItemRetrievalError,
133109
NotFoundError,
134110
NucleusAPIError,
135111
)
112+
from .model import Model
113+
from .model_run import ModelRun
114+
from .payload_constructor import (
115+
construct_annotation_payload,
116+
construct_append_payload,
117+
construct_box_predictions_payload,
118+
construct_model_creation_payload,
119+
construct_segmentation_payload,
120+
)
121+
from .prediction import (
122+
BoxPrediction,
123+
PolygonPrediction,
124+
SegmentationPrediction,
125+
)
126+
from .slice import Slice
127+
from .upload_response import UploadResponse
128+
129+
__version__ = pkg_resources.get_distribution("scale-nucleus").version
136130

137131
logger = logging.getLogger(__name__)
138132
logging.basicConfig()
@@ -158,6 +152,8 @@ def __init__(
158152
self.endpoint = os.environ.get(
159153
"NUCLEUS_ENDPOINT", NUCLEUS_ENDPOINT
160154
)
155+
else:
156+
self.endpoint = endpoint
161157
self._use_notebook = use_notebook
162158
if use_notebook:
163159
self.tqdm_bar = tqdm_notebook.tqdm
@@ -230,13 +226,15 @@ def get_dataset(self, dataset_id: str) -> Dataset:
230226
"""
231227
return Dataset(dataset_id, self)
232228

233-
def get_model_run(self, model_run_id: str) -> ModelRun:
229+
def get_model_run(self, model_run_id: str, dataset_id: str) -> ModelRun:
234230
"""
235231
Fetches a model_run for given id
236232
:param model_run_id: internally controlled model_run_id
233+
:param dataset_id: the dataset id which may determine the prediction schema
234+
for this model run if present on the dataset.
237235
:return: model_run
238236
"""
239-
return ModelRun(model_run_id, self)
237+
return ModelRun(model_run_id, dataset_id, self)
240238

241239
def delete_model_run(self, model_run_id: str):
242240
"""
@@ -674,7 +672,9 @@ def create_model_run(self, dataset_id: str, payload: dict) -> ModelRun:
674672
if response.get(STATUS_CODE_KEY, None):
675673
raise ModelRunCreationError(response.get("error"))
676674

677-
return ModelRun(response[MODEL_RUN_ID_KEY], self)
675+
return ModelRun(
676+
response[MODEL_RUN_ID_KEY], dataset_id=dataset_id, client=self
677+
)
678678

679679
def predict(
680680
self,

nucleus/annotation.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import json
22
from dataclasses import dataclass
33
from enum import Enum
4-
from typing import Any, Dict, List, Optional, Union
4+
from typing import Any, Dict, List, Optional, Sequence, Union
5+
from nucleus.dataset_item import is_local_path
56

67
from .constants import (
78
ANNOTATION_ID_KEY,
@@ -13,6 +14,7 @@
1314
INDEX_KEY,
1415
ITEM_ID_KEY,
1516
LABEL_KEY,
17+
MASK_TYPE,
1618
MASK_URL_KEY,
1719
METADATA_KEY,
1820
POLYGON_TYPE,
@@ -108,6 +110,7 @@ def from_json(cls, payload: dict):
108110

109111
def to_payload(self) -> dict:
110112
payload = {
113+
TYPE_KEY: MASK_TYPE,
111114
MASK_URL_KEY: self.mask_url,
112115
ANNOTATIONS_KEY: [ann.to_payload() for ann in self.annotations],
113116
ANNOTATION_ID_KEY: self.annotation_id,
@@ -206,3 +209,14 @@ def to_payload(self) -> dict:
206209
ANNOTATION_ID_KEY: self.annotation_id,
207210
METADATA_KEY: self.metadata,
208211
}
212+
213+
214+
def check_all_annotation_paths_remote(
215+
annotations: Sequence[Union[Annotation]],
216+
):
217+
for annotation in annotations:
218+
if hasattr(annotation, MASK_URL_KEY):
219+
if is_local_path(getattr(annotation, MASK_URL_KEY)):
220+
raise ValueError(
221+
f"Found an annotation with a local path, which cannot be uploaded asynchronously. Use a remote path instead. {annotation}"
222+
)

nucleus/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
ANNOTATION_METADATA_SCHEMA_KEY = "annotation_metadata_schema"
66
BOX_TYPE = "box"
77
POLYGON_TYPE = "polygon"
8+
MASK_TYPE = "mask"
89
SEGMENTATION_TYPE = "segmentation"
910
ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE)
1011
ANNOTATION_UPDATE_KEY = "update"

nucleus/dataset.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import uuid
21
from typing import Any, Dict, List, Optional, Union
32

43
import requests
@@ -199,14 +198,8 @@ def append(
199198

200199
if asynchronous:
201200
check_all_paths_remote(dataset_items)
202-
request_id = uuid.uuid4().hex
203-
response = self._client.make_request(
204-
payload={},
205-
route=f"dataset/{self.id}/signedUrl/{request_id}",
206-
requests_command=requests.get,
207-
)
208-
serialize_and_write_to_presigned_url(
209-
dataset_items, response["signed_url"]
201+
request_id = serialize_and_write_to_presigned_url(
202+
dataset_items, self.id, self._client
210203
)
211204
response = self._client.make_request(
212205
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},

nucleus/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def create_run(
4545
Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
4646
],
4747
metadata: Optional[Dict] = None,
48+
asynchronous: bool = False,
4849
) -> ModelRun:
4950
payload: dict = {
5051
NAME_KEY: name,
@@ -56,6 +57,6 @@ def create_run(
5657
dataset.id, payload
5758
)
5859

59-
model_run.predict(predictions)
60+
model_run.predict(predictions, asynchronous=asynchronous)
6061

6162
return model_run

nucleus/model_run.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
from typing import Dict, Optional, List, Union, Type
1+
from typing import Dict, List, Optional, Type, Union
2+
3+
from nucleus.annotation import check_all_annotation_paths_remote
4+
from nucleus.job import AsyncJob
5+
from nucleus.utils import serialize_and_write_to_presigned_url
6+
27
from .constants import (
38
ANNOTATIONS_KEY,
4-
DEFAULT_ANNOTATION_UPDATE_MODE,
59
BOX_TYPE,
10+
DEFAULT_ANNOTATION_UPDATE_MODE,
11+
JOB_ID_KEY,
612
POLYGON_TYPE,
13+
REQUEST_ID_KEY,
714
SEGMENTATION_TYPE,
15+
UPDATE_KEY,
816
)
917
from .prediction import (
1018
BoxPrediction,
@@ -19,12 +27,13 @@ class ModelRun:
1927
Having an open model run is a prerequisite for uploading predictions to your dataset.
2028
"""
2129

22-
def __init__(self, model_run_id: str, client):
30+
def __init__(self, model_run_id: str, dataset_id: str, client):
2331
self.model_run_id = model_run_id
2432
self._client = client
33+
self._dataset_id = dataset_id
2534

2635
def __repr__(self):
27-
return f"ModelRun(model_run_id='{self.model_run_id}', client={self._client})"
36+
return f"ModelRun(model_run_id='{self.model_run_id}', dataset_id='{self._dataset_id}', client={self._client})"
2837

2938
def __eq__(self, other):
3039
if self.model_run_id == other.model_run_id:
@@ -84,7 +93,8 @@ def predict(
8493
Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
8594
],
8695
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
87-
) -> dict:
96+
asynchronous: bool = False,
97+
) -> Union[dict, AsyncJob]:
8898
"""
8999
Uploads model outputs as predictions for a model_run. Returns info about the upload.
90100
:param annotations: List[Union[BoxPrediction, PolygonPrediction]],
@@ -95,7 +105,20 @@ def predict(
95105
"predictions_ignored": int,
96106
}
97107
"""
98-
return self._client.predict(self.model_run_id, annotations, update)
108+
if asynchronous:
109+
check_all_annotation_paths_remote(annotations)
110+
111+
request_id = serialize_and_write_to_presigned_url(
112+
annotations, self._dataset_id, self._client
113+
)
114+
response = self._client.make_request(
115+
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
116+
route=f"modelRun/{self.model_run_id}/predict?async=1",
117+
)
118+
119+
return AsyncJob(response[JOB_ID_KEY], self._client)
120+
else:
121+
return self._client.predict(self.model_run_id, annotations, update)
99122

100123
def iloc(self, i: int):
101124
"""

nucleus/utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
import io
5+
import uuid
56
from typing import IO, Dict, List, Sequence, Union
67

78
import requests
@@ -104,9 +105,19 @@ def upload_to_presigned_url(presigned_url: str, file_pointer: IO):
104105

105106

106107
def serialize_and_write_to_presigned_url(
107-
upload_units: Sequence[Union[DatasetItem, Annotation]], presigned_url
108+
upload_units: Sequence[Union["DatasetItem", Annotation]],
109+
dataset_id: str,
110+
client,
108111
):
112+
request_id = uuid.uuid4().hex
113+
response = client.make_request(
114+
payload={},
115+
route=f"dataset/{dataset_id}/signedUrl/{request_id}",
116+
requests_command=requests.get,
117+
)
118+
109119
strio = io.StringIO()
110120
serialize_and_write(upload_units, strio)
111121
strio.seek(0)
112-
upload_to_presigned_url(presigned_url, strio)
122+
upload_to_presigned_url(response["signed_url"], strio)
123+
return request_id

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.1.4"
24+
version = "0.1.5"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

tests/helpers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313

1414

1515
TEST_IMG_URLS = [
16-
"http://farm1.staticflickr.com/107/309278012_7a1f67deaa_z.jpg",
17-
"http://farm9.staticflickr.com/8001/7679588594_4e51b76472_z.jpg",
18-
"http://farm6.staticflickr.com/5295/5465771966_76f9773af1_z.jpg",
19-
"http://farm4.staticflickr.com/3449/4002348519_8ddfa4f2fb_z.jpg",
20-
"http://farm1.staticflickr.com/6/7617223_d84fcbce0e_z.jpg",
16+
"https://homepages.cae.wisc.edu/~ece533/images/airplane.png",
17+
"https://homepages.cae.wisc.edu/~ece533/images/arctichare.png",
18+
"https://homepages.cae.wisc.edu/~ece533/images/baboon.png",
19+
"https://homepages.cae.wisc.edu/~ece533/images/barbara.png",
20+
"https://homepages.cae.wisc.edu/~ece533/images/cat.png",
2121
]
2222

2323
TEST_DATASET_ITEMS = [

tests/test_dataset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,11 @@ def test_dataset_append_async_with_1_bad_url(dataset: Dataset):
190190
"started_image_processing": f"Dataset: {dataset.id}, Job: {job.id}",
191191
},
192192
}
193-
assert job.errors() == [
194-
"One or more of the images you attempted to upload did not process correctly. Please see the status for an overview and the errors for more detailed messages.",
195-
# Todo: figure out why this error isn't propagating from image upload.
196-
'Failure when processing the image "https://looks.ok.but.is.not.accessible": {}',
197-
]
193+
# The error is fairly detailed and subject to change. What's important is we surface which URLs failed.
194+
assert (
195+
'Failure when processing the image "https://looks.ok.but.is.not.accessible"'
196+
in str(job.errors())
197+
)
198198

199199

200200
def test_dataset_list_autotags(CLIENT, dataset):

0 commit comments

Comments
 (0)