Skip to content

Commit 6a0e6e0

Browse files
authored
Merge pull request #131 from scaleapi/da-model-only
Model-only endpoints
2 parents dd5f0b3 + d61286b commit 6a0e6e0

File tree

8 files changed

+289
-114
lines changed

8 files changed

+289
-114
lines changed

nucleus/__init__.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,6 @@ def create_model_run(self, dataset_id: str, payload: dict) -> ModelRun:
734734

735735
def predict(
736736
self,
737-
model_run_id: str,
738737
annotations: List[
739738
Union[
740739
BoxPrediction,
@@ -743,7 +742,10 @@ def predict(
743742
SegmentationPrediction,
744743
]
745744
],
746-
update: bool,
745+
model_run_id: Optional[str] = None,
746+
model_id: Optional[str] = None,
747+
dataset_id: Optional[str] = None,
748+
update: bool = False,
747749
batch_size: int = 5000,
748750
):
749751
"""
@@ -758,6 +760,16 @@ def predict(
758760
"predictions_ignored": int,
759761
}
760762
"""
763+
if model_run_id is not None:
764+
assert model_id is None and dataset_id is None
765+
endpoint = f"modelRun/{model_run_id}/predict"
766+
else:
767+
assert (
768+
model_id is not None and dataset_id is not None
769+
), "Model ID and dataset ID are required if not using model run id."
770+
endpoint = (
771+
f"dataset/{dataset_id}/model/{model_id}/uploadPredictions"
772+
)
761773
segmentations = [
762774
ann
763775
for ann in annotations
@@ -780,11 +792,9 @@ def predict(
780792
for i in range(0, len(other_predictions), batch_size)
781793
]
782794

783-
agg_response = {
784-
MODEL_RUN_ID_KEY: model_run_id,
785-
PREDICTIONS_PROCESSED_KEY: 0,
786-
PREDICTIONS_IGNORED_KEY: 0,
787-
}
795+
errors = []
796+
predictions_processed = 0
797+
predictions_ignored = 0
788798

789799
tqdm_batches = self.tqdm_bar(batches)
790800

@@ -793,36 +803,29 @@ def predict(
793803
batch,
794804
update,
795805
)
796-
response = self.make_request(
797-
batch_payload, f"modelRun/{model_run_id}/predict"
798-
)
806+
response = self.make_request(batch_payload, endpoint)
799807
if STATUS_CODE_KEY in response:
800-
agg_response[ERRORS_KEY] = response
808+
errors.append(response)
801809
else:
802-
agg_response[PREDICTIONS_PROCESSED_KEY] += response[
803-
PREDICTIONS_PROCESSED_KEY
804-
]
805-
agg_response[PREDICTIONS_IGNORED_KEY] += response[
806-
PREDICTIONS_IGNORED_KEY
807-
]
810+
predictions_processed += response[PREDICTIONS_PROCESSED_KEY]
811+
predictions_ignored += response[PREDICTIONS_IGNORED_KEY]
808812

809813
for s_batch in s_batches:
810814
payload = construct_segmentation_payload(s_batch, update)
811-
response = self.make_request(
812-
payload, f"modelRun/{model_run_id}/predict_segmentation"
813-
)
815+
response = self.make_request(payload, endpoint)
814816
# pbar.update(1)
815817
if STATUS_CODE_KEY in response:
816-
agg_response[ERRORS_KEY] = response
818+
errors.append(response)
817819
else:
818-
agg_response[PREDICTIONS_PROCESSED_KEY] += response[
819-
PREDICTIONS_PROCESSED_KEY
820-
]
821-
agg_response[PREDICTIONS_IGNORED_KEY] += response[
822-
PREDICTIONS_IGNORED_KEY
823-
]
820+
predictions_processed += response[PREDICTIONS_PROCESSED_KEY]
821+
predictions_ignored += response[PREDICTIONS_IGNORED_KEY]
824822

825-
return agg_response
823+
return {
824+
MODEL_RUN_ID_KEY: model_run_id,
825+
PREDICTIONS_PROCESSED_KEY: predictions_processed,
826+
PREDICTIONS_IGNORED_KEY: predictions_ignored,
827+
ERRORS_KEY: errors,
828+
}
826829

827830
def commit_model_run(
828831
self, model_run_id: str, payload: Optional[dict] = None

nucleus/dataset.py

Lines changed: 146 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,25 @@
33
import requests
44

55
from nucleus.job import AsyncJob
6+
from nucleus.prediction import (
7+
BoxPrediction,
8+
CuboidPrediction,
9+
PolygonPrediction,
10+
SegmentationPrediction,
11+
from_json,
12+
)
613
from nucleus.url_utils import sanitize_string_args
714
from nucleus.utils import (
815
convert_export_payload,
916
format_dataset_item_response,
17+
format_prediction_response,
1018
serialize_and_write_to_presigned_url,
1119
)
1220

13-
from .annotation import (
14-
Annotation,
15-
check_all_mask_paths_remote,
16-
)
21+
from .annotation import Annotation, check_all_mask_paths_remote
1722
from .constants import (
23+
ANNOTATIONS_KEY,
24+
AUTOTAG_SCORE_THRESHOLD,
1825
DATASET_LENGTH_KEY,
1926
DATASET_MODEL_RUNS_KEY,
2027
DATASET_NAME_KEY,
@@ -24,20 +31,19 @@
2431
NAME_KEY,
2532
REFERENCE_IDS_KEY,
2633
REQUEST_ID_KEY,
27-
AUTOTAG_SCORE_THRESHOLD,
2834
UPDATE_KEY,
2935
)
3036
from .dataset_item import (
3137
DatasetItem,
3238
check_all_paths_remote,
3339
check_for_duplicate_reference_ids,
3440
)
35-
from .scene import LidarScene, Scene, check_all_scene_paths_remote
3641
from .payload_constructor import (
3742
construct_append_scenes_payload,
3843
construct_model_run_creation_payload,
3944
construct_taxonomy_payload,
4045
)
46+
from .scene import LidarScene, Scene, check_all_scene_paths_remote
4147

4248
WARN_FOR_LARGE_UPLOAD = 50000
4349
WARN_FOR_LARGE_SCENES_UPLOAD = 5
@@ -525,3 +531,137 @@ def get_scene(self, reference_id) -> Scene:
525531
requests_command=requests.get,
526532
)
527533
)
534+
535+
def export_predictions(self, model):
536+
"""Exports all predictions from a model on this dataset"""
537+
json_response = self._client.make_request(
538+
payload=None,
539+
route=f"dataset/{self.id}/model/{model.id}/export",
540+
requests_command=requests.get,
541+
)
542+
return format_prediction_response({ANNOTATIONS_KEY: json_response})
543+
544+
def calculate_evaluation_metrics(self, model, options=None):
545+
"""
546+
547+
:param model: the model to calculate eval metrics for
548+
:param options: Dict with keys:
549+
class_agnostic -- A flag to specify if matching algorithm should be class-agnostic or not.
550+
Default value: True
551+
552+
allowed_label_matches -- An optional list of AllowedMatch objects to specify allowed matches
553+
for ground truth and model predictions.
554+
If specified, 'class_agnostic' flag is assumed to be False
555+
556+
Type 'AllowedMatch':
557+
{
558+
ground_truth_label: string, # A label for ground truth annotation.
559+
model_prediction_label: string, # A label for model prediction that can be matched with
560+
# corresponding ground truth label.
561+
}
562+
563+
payload:
564+
{
565+
"class_agnostic": boolean,
566+
"allowed_label_matches": List[AllowedMatch],
567+
}"""
568+
if options is None:
569+
options = {}
570+
return self._client.make_request(
571+
payload=options,
572+
route=f"dataset/{self.id}/model/{model.id}/calculateEvaluationMetrics",
573+
)
574+
575+
def upload_predictions(
576+
self,
577+
model,
578+
predictions: List[
579+
Union[
580+
BoxPrediction,
581+
PolygonPrediction,
582+
CuboidPrediction,
583+
SegmentationPrediction,
584+
]
585+
],
586+
update=False,
587+
asynchronous=False,
588+
):
589+
"""
590+
Uploads model outputs as predictions for a model_run. Returns info about the upload.
591+
:param predictions: List of prediction objects to ingest
592+
:param update: Whether to update (if true) or ignore (if false) on conflicting reference_id/annotation_id
593+
:param asynchronous: If true, return launch and then return a reference to an asynchronous job object. This is recommended for large ingests.
594+
:return:
595+
If synchronoius
596+
{
597+
"model_run_id": str,
598+
"predictions_processed": int,
599+
"predictions_ignored": int,
600+
}
601+
"""
602+
if asynchronous:
603+
check_all_mask_paths_remote(predictions)
604+
605+
request_id = serialize_and_write_to_presigned_url(
606+
predictions, self.id, self._client
607+
)
608+
response = self._client.make_request(
609+
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
610+
route=f"dataset/{self.id}/model/{model.id}/uploadPredictions?async=1",
611+
)
612+
return AsyncJob.from_json(response, self._client)
613+
else:
614+
return self._client.predict(
615+
model_run_id=None,
616+
dataset_id=self.id,
617+
model_id=model.id,
618+
annotations=predictions,
619+
update=update,
620+
)
621+
622+
def predictions_iloc(self, model, index):
623+
"""
624+
Returns predictions For Dataset Item by index.
625+
:param model: model object to get predictions from.
626+
:param index: absolute number of Dataset Item for a dataset corresponding to the model run.
627+
:return: List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]],
628+
}
629+
"""
630+
return format_prediction_response(
631+
self._client.make_request(
632+
payload=None,
633+
route=f"dataset/{self.id}/model/{model.id}/iloc/{index}",
634+
requests_command=requests.get,
635+
)
636+
)
637+
638+
def predictions_refloc(self, model, reference_id):
639+
"""
640+
Returns predictions for dataset Item by its reference_id.
641+
:param model: model object to get predictions from.
642+
:param reference_id: reference_id of a dataset item.
643+
:return: List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]],
644+
"""
645+
return format_prediction_response(
646+
self._client.make_request(
647+
payload=None,
648+
route=f"dataset/{self.id}/model/{model.id}/referenceId/{reference_id}",
649+
requests_command=requests.get,
650+
)
651+
)
652+
653+
def prediction_loc(self, model, reference_id, annotation_id):
654+
"""
655+
Returns info for single Prediction by its reference id and annotation id. Not supported for segmentation predictions yet.
656+
:param reference_id: the user specified id for the image
657+
:param annotation_id: the user specified id for the prediction, or if one was not provided, the Scale internally generated id for the prediction
658+
:return:
659+
BoxPrediction | PolygonPrediction | CuboidPrediction
660+
"""
661+
return from_json(
662+
self._client.make_request(
663+
payload=None,
664+
route=f"dataset/{self.id}/model/{model.id}/loc/{reference_id}/{annotation_id}",
665+
requests_command=requests.get,
666+
)
667+
)

0 commit comments

Comments
 (0)