Skip to content

Commit ffb2c08

Browse files
author
Ubuntu
committed
Model-only endpoints
1 parent 58b83ab commit ffb2c08

File tree

9 files changed

+225
-94
lines changed

9 files changed

+225
-94
lines changed

nucleus/__init__.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,9 @@ def create_model_run(self, dataset_id: str, payload: dict) -> ModelRun:
734734

735735
def predict(
736736
self,
737-
model_run_id: str,
737+
model_run_id: Optional[str],
738+
model_id: Optional[str],
739+
dataset_id: Optional[str],
738740
annotations: List[
739741
Union[
740742
BoxPrediction,
@@ -758,6 +760,16 @@ def predict(
758760
"predictions_ignored": int,
759761
}
760762
"""
763+
if model_run_id is not None:
764+
assert model_id is None and dataset_id is None
765+
endpoint = f"modelRun/{model_run_id}/predict"
766+
else:
767+
assert (
768+
model_id is not None and dataset_id is not None
769+
), "Model ID and dataset ID are required if not using model run id."
770+
endpoint = (
771+
f"dataset/{dataset_id}/model/{model_id}/uploadPredictions"
772+
)
761773
segmentations = [
762774
ann
763775
for ann in annotations
@@ -793,9 +805,7 @@ def predict(
793805
batch,
794806
update,
795807
)
796-
response = self.make_request(
797-
batch_payload, f"modelRun/{model_run_id}/predict"
798-
)
808+
response = self.make_request(batch_payload, endpoint)
799809
if STATUS_CODE_KEY in response:
800810
agg_response[ERRORS_KEY] = response
801811
else:
@@ -808,9 +818,7 @@ def predict(
808818

809819
for s_batch in s_batches:
810820
payload = construct_segmentation_payload(s_batch, update)
811-
response = self.make_request(
812-
payload, f"modelRun/{model_run_id}/predict_segmentation"
813-
)
821+
response = self.make_request(payload, endpoint)
814822
# pbar.update(1)
815823
if STATUS_CODE_KEY in response:
816824
agg_response[ERRORS_KEY] = response

nucleus/dataset.py

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,19 @@
33
import requests
44

55
from nucleus.job import AsyncJob
6+
from nucleus.prediction import from_json
67
from nucleus.url_utils import sanitize_string_args
78
from nucleus.utils import (
89
convert_export_payload,
910
format_dataset_item_response,
11+
format_prediction_response,
1012
serialize_and_write_to_presigned_url,
1113
)
1214

13-
from .annotation import (
14-
Annotation,
15-
check_all_mask_paths_remote,
16-
)
15+
from .annotation import Annotation, check_all_mask_paths_remote
1716
from .constants import (
17+
ANNOTATIONS_KEY,
18+
AUTOTAG_SCORE_THRESHOLD,
1819
DATASET_LENGTH_KEY,
1920
DATASET_MODEL_RUNS_KEY,
2021
DATASET_NAME_KEY,
@@ -24,20 +25,19 @@
2425
NAME_KEY,
2526
REFERENCE_IDS_KEY,
2627
REQUEST_ID_KEY,
27-
AUTOTAG_SCORE_THRESHOLD,
2828
UPDATE_KEY,
2929
)
3030
from .dataset_item import (
3131
DatasetItem,
3232
check_all_paths_remote,
3333
check_for_duplicate_reference_ids,
3434
)
35-
from .scene import LidarScene, Scene, check_all_scene_paths_remote
3635
from .payload_constructor import (
3736
construct_append_scenes_payload,
3837
construct_model_run_creation_payload,
3938
construct_taxonomy_payload,
4039
)
40+
from .scene import LidarScene, Scene, check_all_scene_paths_remote
4141

4242
WARN_FOR_LARGE_UPLOAD = 50000
4343
WARN_FOR_LARGE_SCENES_UPLOAD = 5
@@ -525,3 +525,89 @@ def get_scene(self, reference_id) -> Scene:
525525
requests_command=requests.get,
526526
)
527527
)
528+
529+
def export_predictions(self, model):
530+
json_response = self._client.make_request(
531+
payload=None,
532+
route=f"dataset/{self.id}/model/{model.id}/export",
533+
requests_command=requests.get,
534+
)
535+
return format_prediction_response({ANNOTATIONS_KEY: json_response})
536+
537+
def calculate_evaluation_metrics(self, model, options=None):
538+
"""
539+
class_agnostic -- A flag to specify if matching algorithm should be class-agnostic or not.
540+
Default value: True
541+
542+
allowed_label_matches -- An optional list of AllowedMatch objects to specify allowed matches
543+
for ground truth and model predictions.
544+
If specified, 'class_agnostic' flag is assumed to be False
545+
546+
Type 'AllowedMatch':
547+
{
548+
ground_truth_label: string, # A label for ground truth annotation.
549+
model_prediction_label: string, # A label for model prediction that can be matched with
550+
# corresponding ground truth label.
551+
}
552+
553+
payload:
554+
{
555+
"class_agnostic": boolean,
556+
"allowed_label_matches": List[AllowedMatch],
557+
}"""
558+
if options == None:
559+
options = {}
560+
return self._client.make_request(
561+
payload=options,
562+
route=f"dataset/{self.id}/model/{model.id}/calculateEvaluationMetrics",
563+
)
564+
565+
def upload_predictions(
566+
self, model, predictions, update=False, asynchronous=False
567+
):
568+
if asynchronous:
569+
check_all_mask_paths_remote(predictions)
570+
571+
request_id = serialize_and_write_to_presigned_url(
572+
predictions, self.id, self._client
573+
)
574+
response = self._client.make_request(
575+
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
576+
route=f"dataset/{self.id}/model/{model.id}/uploadPredictions?async=1",
577+
)
578+
return AsyncJob.from_json(response, self._client)
579+
else:
580+
return self._client.predict(
581+
model_run_id=None,
582+
dataset_id=self.id,
583+
model_id=model.id,
584+
annotations=predictions,
585+
update=update,
586+
)
587+
588+
def predictions_iloc(self, model, index):
589+
return format_prediction_response(
590+
self._client.make_request(
591+
payload=None,
592+
route=f"dataset/{self.id}/model/{model.id}/iloc/{index}",
593+
requests_command=requests.get,
594+
)
595+
)
596+
597+
def predictions_refloc(self, model, reference_id):
598+
return format_prediction_response(
599+
self._client.make_request(
600+
payload=None,
601+
route=f"dataset/{self.id}/model/{model.id}/referenceId/{reference_id}",
602+
requests_command=requests.get,
603+
)
604+
)
605+
606+
def prediction_loc(self, model, reference_id, annotation_id):
607+
return from_json(
608+
self._client.make_request(
609+
payload=None,
610+
route=f"dataset/{self.id}/model/{model.id}/loc/{reference_id}/{annotation_id}",
611+
requests_command=requests.get,
612+
)
613+
)

nucleus/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Dict, Union
1+
from typing import List, Optional, Dict, Type, Union
22
from .dataset import Dataset
33
from .prediction import (
44
BoxPrediction,
@@ -8,9 +8,14 @@
88
)
99
from .model_run import ModelRun
1010
from .constants import (
11+
ANNOTATIONS_KEY,
12+
BOX_TYPE,
13+
CUBOID_TYPE,
1114
NAME_KEY,
15+
POLYGON_TYPE,
1216
REFERENCE_ID_KEY,
1317
METADATA_KEY,
18+
SEGMENTATION_TYPE,
1419
)
1520

1621

nucleus/model_run.py

Lines changed: 16 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1-
from typing import Dict, List, Optional, Type, Union
1+
from typing import List, Optional, Union
2+
23
import requests
4+
35
from nucleus.annotation import check_all_mask_paths_remote
46
from nucleus.job import AsyncJob
5-
from nucleus.utils import serialize_and_write_to_presigned_url
7+
from nucleus.utils import (
8+
format_prediction_response,
9+
serialize_and_write_to_presigned_url,
10+
)
611

712
from .constants import (
813
ANNOTATIONS_KEY,
9-
BOX_TYPE,
10-
CUBOID_TYPE,
1114
DEFAULT_ANNOTATION_UPDATE_MODE,
12-
POLYGON_TYPE,
1315
REQUEST_ID_KEY,
14-
SEGMENTATION_TYPE,
1516
UPDATE_KEY,
1617
)
1718
from .prediction import (
@@ -124,7 +125,11 @@ def predict(
124125
)
125126
return AsyncJob.from_json(response, self._client)
126127
else:
127-
return self._client.predict(self.model_run_id, annotations, update)
128+
return self._client.predict(
129+
model_run_id=self.model_run_id,
130+
annotation=annotations,
131+
update=update,
132+
)
128133

129134
def iloc(self, i: int):
130135
"""
@@ -134,7 +139,7 @@ def iloc(self, i: int):
134139
}
135140
"""
136141
response = self._client.predictions_iloc(self.model_run_id, i)
137-
return self._format_prediction_response(response)
142+
return format_prediction_response(response)
138143

139144
def refloc(self, reference_id: str):
140145
"""
@@ -145,7 +150,7 @@ def refloc(self, reference_id: str):
145150
response = self._client.predictions_ref_id(
146151
self.model_run_id, reference_id
147152
)
148-
return self._format_prediction_response(response)
153+
return format_prediction_response(response)
149154

150155
def loc(self, dataset_item_id: str):
151156
"""
@@ -159,7 +164,7 @@ def loc(self, dataset_item_id: str):
159164
response = self._client.predictions_loc(
160165
self.model_run_id, dataset_item_id
161166
)
162-
return self._format_prediction_response(response)
167+
return format_prediction_response(response)
163168

164169
def prediction_loc(self, reference_id: str, annotation_id: str):
165170
"""
@@ -184,46 +189,4 @@ def ungrouped_export(self):
184189
route=f"modelRun/{self.model_run_id}/ungrouped",
185190
requests_command=requests.get,
186191
)
187-
return self._format_prediction_response(
188-
{ANNOTATIONS_KEY: json_response}
189-
)
190-
191-
def _format_prediction_response(
192-
self, response: dict
193-
) -> Union[
194-
dict,
195-
List[
196-
Union[
197-
BoxPrediction,
198-
PolygonPrediction,
199-
CuboidPrediction,
200-
SegmentationPrediction,
201-
]
202-
],
203-
]:
204-
annotation_payload = response.get(ANNOTATIONS_KEY, None)
205-
if not annotation_payload:
206-
# An error occurred
207-
return response
208-
annotation_response = {}
209-
type_key_to_class: Dict[
210-
str,
211-
Union[
212-
Type[BoxPrediction],
213-
Type[PolygonPrediction],
214-
Type[CuboidPrediction],
215-
Type[SegmentationPrediction],
216-
],
217-
] = {
218-
BOX_TYPE: BoxPrediction,
219-
POLYGON_TYPE: PolygonPrediction,
220-
CUBOID_TYPE: CuboidPrediction,
221-
SEGMENTATION_TYPE: SegmentationPrediction,
222-
}
223-
for type_key in annotation_payload:
224-
type_class = type_key_to_class[type_key]
225-
annotation_response[type_key] = [
226-
type_class.from_json(annotation)
227-
for annotation in annotation_payload[type_key]
228-
]
229-
return annotation_response
192+
return format_prediction_response({ANNOTATIONS_KEY: json_response})

nucleus/utils.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import io
55
import uuid
66
import json
7-
from typing import IO, Dict, List, Sequence, Union
7+
from typing import IO, Dict, List, Sequence, Type, Union
88

99
import requests
1010
from requests.models import HTTPError
@@ -30,10 +30,56 @@
3030
SEGMENTATION_TYPE,
3131
)
3232
from .dataset_item import DatasetItem
33-
from .prediction import BoxPrediction, CuboidPrediction, PolygonPrediction
33+
from .prediction import (
34+
BoxPrediction,
35+
CuboidPrediction,
36+
PolygonPrediction,
37+
SegmentationPrediction,
38+
)
3439
from .scene import LidarScene
3540

3641

42+
def format_prediction_response(
43+
response: dict,
44+
) -> Union[
45+
dict,
46+
List[
47+
Union[
48+
BoxPrediction,
49+
PolygonPrediction,
50+
CuboidPrediction,
51+
SegmentationPrediction,
52+
]
53+
],
54+
]:
55+
annotation_payload = response.get(ANNOTATIONS_KEY, None)
56+
if not annotation_payload:
57+
# An error occurred
58+
return response
59+
annotation_response = {}
60+
type_key_to_class: Dict[
61+
str,
62+
Union[
63+
Type[BoxPrediction],
64+
Type[PolygonPrediction],
65+
Type[CuboidPrediction],
66+
Type[SegmentationPrediction],
67+
],
68+
] = {
69+
BOX_TYPE: BoxPrediction,
70+
POLYGON_TYPE: PolygonPrediction,
71+
CUBOID_TYPE: CuboidPrediction,
72+
SEGMENTATION_TYPE: SegmentationPrediction,
73+
}
74+
for type_key in annotation_payload:
75+
type_class = type_key_to_class[type_key]
76+
annotation_response[type_key] = [
77+
type_class.from_json(annotation)
78+
for annotation in annotation_payload[type_key]
79+
]
80+
return annotation_response
81+
82+
3783
def _get_all_field_values(metadata_list: List[dict], key: str):
3884
return {metadata[key] for metadata in metadata_list if key in metadata}
3985

0 commit comments

Comments
 (0)