Skip to content

Commit f111c33

Browse files
author
Diego Ardila
committed
This just might work
1 parent 37afc9f commit f111c33

File tree

6 files changed

+65
-21
lines changed

6 files changed

+65
-21
lines changed

nucleus/annotation.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import json
22
from dataclasses import dataclass
33
from enum import Enum
4-
from typing import Any, Dict, List, Optional, Union
4+
from typing import Any, Dict, List, Optional, Sequence, Union
5+
from nucleus.dataset_item import is_local_path
56

67
from .constants import (
78
ANNOTATION_ID_KEY,
@@ -13,6 +14,7 @@
1314
INDEX_KEY,
1415
ITEM_ID_KEY,
1516
LABEL_KEY,
17+
MASK_TYPE,
1618
MASK_URL_KEY,
1719
METADATA_KEY,
1820
POLYGON_TYPE,
@@ -108,6 +110,7 @@ def from_json(cls, payload: dict):
108110

109111
def to_payload(self) -> dict:
110112
payload = {
113+
TYPE_KEY: MASK_TYPE,
111114
MASK_URL_KEY: self.mask_url,
112115
ANNOTATIONS_KEY: [ann.to_payload() for ann in self.annotations],
113116
ANNOTATION_ID_KEY: self.annotation_id,
@@ -206,3 +209,14 @@ def to_payload(self) -> dict:
206209
ANNOTATION_ID_KEY: self.annotation_id,
207210
METADATA_KEY: self.metadata,
208211
}
212+
213+
214+
def check_all_annotation_paths_remote(
215+
annotations: Sequence[Union[Annotation]],
216+
):
217+
for annotation in annotations:
218+
if hasattr(annotation, "mask_url"):
219+
if is_local_path(getattr(annotation, "mask_url")):
220+
raise ValueError(
221+
f"Found an annotation with a local path, which cannot be uploaded asynchronously. Use a remote path instead. {annotation}"
222+
)

nucleus/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
ANNOTATION_METADATA_SCHEMA_KEY = "annotation_metadata_schema"
66
BOX_TYPE = "box"
77
POLYGON_TYPE = "polygon"
8+
MASK_TYPE = "mask"
89
SEGMENTATION_TYPE = "segmentation"
910
ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE)
1011
ANNOTATION_UPDATE_KEY = "update"

nucleus/dataset.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import uuid
21
from typing import Any, Dict, List, Optional, Union
32

43
import requests
@@ -24,7 +23,6 @@
2423
)
2524
from .dataset_item import (
2625
DatasetItem,
27-
check_all_paths_remote,
2826
check_for_duplicate_reference_ids,
2927
)
3028
from .payload_constructor import construct_model_run_creation_payload
@@ -198,15 +196,8 @@ def append(
198196
check_for_duplicate_reference_ids(dataset_items)
199197

200198
if asynchronous:
201-
check_all_paths_remote(dataset_items)
202-
request_id = uuid.uuid4().hex
203-
response = self._client.make_request(
204-
payload={},
205-
route=f"dataset/{self.id}/signedUrl/{request_id}",
206-
requests_command=requests.get,
207-
)
208-
serialize_and_write_to_presigned_url(
209-
dataset_items, response["signed_url"]
199+
request_id = serialize_and_write_to_presigned_url(
200+
dataset_items, self.id, self._client
210201
)
211202
response = self._client.make_request(
212203
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},

nucleus/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def create_run(
4545
Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
4646
],
4747
metadata: Optional[Dict] = None,
48+
asynchronous=False,
4849
) -> ModelRun:
4950
payload: dict = {
5051
NAME_KEY: name,
@@ -56,6 +57,6 @@ def create_run(
5657
dataset.id, payload
5758
)
5859

59-
model_run.predict(predictions)
60+
model_run.predict(predictions, asynchronous=asynchronous)
6061

6162
return model_run

nucleus/model_run.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1-
from typing import Dict, Optional, List, Union, Type
1+
from typing import Dict, List, Optional, Type, Union
2+
3+
from nucleus.annotation import check_all_annotation_paths_remote
4+
from nucleus.job import AsyncJob
5+
from nucleus.utils import serialize_and_write_to_presigned_url
6+
27
from .constants import (
38
ANNOTATIONS_KEY,
4-
DEFAULT_ANNOTATION_UPDATE_MODE,
59
BOX_TYPE,
10+
DEFAULT_ANNOTATION_UPDATE_MODE,
11+
MASK_TYPE,
612
POLYGON_TYPE,
7-
SEGMENTATION_TYPE,
13+
REQUEST_ID_KEY,
14+
UPDATE_KEY,
815
)
916
from .prediction import (
1017
BoxPrediction,
@@ -84,7 +91,9 @@ def predict(
8491
Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
8592
],
8693
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
87-
) -> dict:
94+
asynchronous: bool = False,
95+
dataset_id: Optional[str] = None,
96+
) -> Union[dict, AsyncJob]:
8897
"""
8998
Uploads model outputs as predictions for a model_run. Returns info about the upload.
9099
:param annotations: List[Union[BoxPrediction, PolygonPrediction]],
@@ -95,7 +104,24 @@ def predict(
95104
"predictions_ignored": int,
96105
}
97106
"""
98-
return self._client.predict(self.model_run_id, annotations, update)
107+
if asynchronous:
108+
check_all_annotation_paths_remote(annotations)
109+
110+
assert (
111+
dataset_id is not None
112+
), "For now, you must pass a dataset id to predict for asynchronous uploads."
113+
114+
request_id = serialize_and_write_to_presigned_url(
115+
annotations, dataset_id, self._client
116+
)
117+
response = self._client.make_request(
118+
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
119+
route=f"modelRun/{self.model_run_id}/predict?async=1",
120+
)
121+
122+
return AsyncJob(response["job_id"], self._client)
123+
else:
124+
return self._client.predict(self.model_run_id, annotations, update)
99125

100126
def iloc(self, i: int):
101127
"""
@@ -153,7 +179,7 @@ def _format_prediction_response(
153179
] = {
154180
BOX_TYPE: BoxPrediction,
155181
POLYGON_TYPE: PolygonPrediction,
156-
SEGMENTATION_TYPE: SegmentationPrediction,
182+
MASK_TYPE: SegmentationPrediction,
157183
}
158184
for type_key in annotation_payload:
159185
type_class = type_key_to_class[type_key]

nucleus/utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
import io
5+
import uuid
56
from typing import IO, Dict, List, Sequence, Union
67

78
import requests
@@ -104,9 +105,19 @@ def upload_to_presigned_url(presigned_url: str, file_pointer: IO):
104105

105106

106107
def serialize_and_write_to_presigned_url(
107-
upload_units: Sequence[Union[DatasetItem, Annotation]], presigned_url
108+
upload_units: Sequence[Union["DatasetItem", Annotation]],
109+
dataset_id: str,
110+
client,
108111
):
112+
request_id = uuid.uuid4().hex
113+
response = client.make_request(
114+
payload={},
115+
route=f"dataset/{dataset_id}/signedUrl/{request_id}",
116+
requests_command=requests.get,
117+
)
118+
109119
strio = io.StringIO()
110120
serialize_and_write(upload_units, strio)
111121
strio.seek(0)
112-
upload_to_presigned_url(presigned_url, strio)
122+
upload_to_presigned_url(response["presigned_url"], strio)
123+
return request_id

0 commit comments

Comments
 (0)