Skip to content

Commit 196d2df

Browse files
author
Claire Pajot
committed
Merge branch 'master' into add_multicategory_type_to_groundtruth
2 parents c84bb4c + 6a0e6e0 commit 196d2df

File tree

10 files changed

+301
-116
lines changed

10 files changed

+301
-116
lines changed

conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,13 @@ def dataset(CLIENT):
3030
yield ds
3131

3232
CLIENT.delete_dataset(ds.id)
33+
34+
35+
if __name__ == "__main__":
36+
client = nucleus.NucleusClient(API_KEY)
37+
# ds = client.create_dataset("Test Dataset With Autotags")
38+
# ds.append(TEST_DATASET_ITEMS)
39+
ds = client.get_dataset("ds_c5jwptkgfsqg0cs503z0")
40+
job = ds.create_image_index()
41+
job.sleep_until_complete()
42+
print(ds.id)

nucleus/__init__.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,6 @@ def create_model_run(self, dataset_id: str, payload: dict) -> ModelRun:
736736

737737
def predict(
738738
self,
739-
model_run_id: str,
740739
annotations: List[
741740
Union[
742741
BoxPrediction,
@@ -745,7 +744,10 @@ def predict(
745744
SegmentationPrediction,
746745
]
747746
],
748-
update: bool,
747+
model_run_id: Optional[str] = None,
748+
model_id: Optional[str] = None,
749+
dataset_id: Optional[str] = None,
750+
update: bool = False,
749751
batch_size: int = 5000,
750752
):
751753
"""
@@ -760,6 +762,16 @@ def predict(
760762
"predictions_ignored": int,
761763
}
762764
"""
765+
if model_run_id is not None:
766+
assert model_id is None and dataset_id is None
767+
endpoint = f"modelRun/{model_run_id}/predict"
768+
else:
769+
assert (
770+
model_id is not None and dataset_id is not None
771+
), "Model ID and dataset ID are required if not using model run id."
772+
endpoint = (
773+
f"dataset/{dataset_id}/model/{model_id}/uploadPredictions"
774+
)
763775
segmentations = [
764776
ann
765777
for ann in annotations
@@ -782,11 +794,9 @@ def predict(
782794
for i in range(0, len(other_predictions), batch_size)
783795
]
784796

785-
agg_response = {
786-
MODEL_RUN_ID_KEY: model_run_id,
787-
PREDICTIONS_PROCESSED_KEY: 0,
788-
PREDICTIONS_IGNORED_KEY: 0,
789-
}
797+
errors = []
798+
predictions_processed = 0
799+
predictions_ignored = 0
790800

791801
tqdm_batches = self.tqdm_bar(batches)
792802

@@ -795,36 +805,29 @@ def predict(
795805
batch,
796806
update,
797807
)
798-
response = self.make_request(
799-
batch_payload, f"modelRun/{model_run_id}/predict"
800-
)
808+
response = self.make_request(batch_payload, endpoint)
801809
if STATUS_CODE_KEY in response:
802-
agg_response[ERRORS_KEY] = response
810+
errors.append(response)
803811
else:
804-
agg_response[PREDICTIONS_PROCESSED_KEY] += response[
805-
PREDICTIONS_PROCESSED_KEY
806-
]
807-
agg_response[PREDICTIONS_IGNORED_KEY] += response[
808-
PREDICTIONS_IGNORED_KEY
809-
]
812+
predictions_processed += response[PREDICTIONS_PROCESSED_KEY]
813+
predictions_ignored += response[PREDICTIONS_IGNORED_KEY]
810814

811815
for s_batch in s_batches:
812816
payload = construct_segmentation_payload(s_batch, update)
813-
response = self.make_request(
814-
payload, f"modelRun/{model_run_id}/predict_segmentation"
815-
)
817+
response = self.make_request(payload, endpoint)
816818
# pbar.update(1)
817819
if STATUS_CODE_KEY in response:
818-
agg_response[ERRORS_KEY] = response
820+
errors.append(response)
819821
else:
820-
agg_response[PREDICTIONS_PROCESSED_KEY] += response[
821-
PREDICTIONS_PROCESSED_KEY
822-
]
823-
agg_response[PREDICTIONS_IGNORED_KEY] += response[
824-
PREDICTIONS_IGNORED_KEY
825-
]
822+
predictions_processed += response[PREDICTIONS_PROCESSED_KEY]
823+
predictions_ignored += response[PREDICTIONS_IGNORED_KEY]
826824

827-
return agg_response
825+
return {
826+
MODEL_RUN_ID_KEY: model_run_id,
827+
PREDICTIONS_PROCESSED_KEY: predictions_processed,
828+
PREDICTIONS_IGNORED_KEY: predictions_ignored,
829+
ERRORS_KEY: errors,
830+
}
828831

829832
def commit_model_run(
830833
self, model_run_id: str, payload: Optional[dict] = None

nucleus/dataset.py

Lines changed: 146 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,25 @@
33
import requests
44

55
from nucleus.job import AsyncJob
6+
from nucleus.prediction import (
7+
BoxPrediction,
8+
CuboidPrediction,
9+
PolygonPrediction,
10+
SegmentationPrediction,
11+
from_json,
12+
)
613
from nucleus.url_utils import sanitize_string_args
714
from nucleus.utils import (
815
convert_export_payload,
916
format_dataset_item_response,
17+
format_prediction_response,
1018
serialize_and_write_to_presigned_url,
1119
)
1220

13-
from .annotation import (
14-
Annotation,
15-
check_all_mask_paths_remote,
16-
)
21+
from .annotation import Annotation, check_all_mask_paths_remote
1722
from .constants import (
23+
ANNOTATIONS_KEY,
24+
AUTOTAG_SCORE_THRESHOLD,
1825
DATASET_LENGTH_KEY,
1926
DATASET_MODEL_RUNS_KEY,
2027
DATASET_NAME_KEY,
@@ -24,20 +31,19 @@
2431
NAME_KEY,
2532
REFERENCE_IDS_KEY,
2633
REQUEST_ID_KEY,
27-
AUTOTAG_SCORE_THRESHOLD,
2834
UPDATE_KEY,
2935
)
3036
from .dataset_item import (
3137
DatasetItem,
3238
check_all_paths_remote,
3339
check_for_duplicate_reference_ids,
3440
)
35-
from .scene import LidarScene, Scene, check_all_scene_paths_remote
3641
from .payload_constructor import (
3742
construct_append_scenes_payload,
3843
construct_model_run_creation_payload,
3944
construct_taxonomy_payload,
4045
)
46+
from .scene import LidarScene, Scene, check_all_scene_paths_remote
4147

4248
WARN_FOR_LARGE_UPLOAD = 50000
4349
WARN_FOR_LARGE_SCENES_UPLOAD = 5
@@ -525,3 +531,137 @@ def get_scene(self, reference_id) -> Scene:
525531
requests_command=requests.get,
526532
)
527533
)
534+
535+
def export_predictions(self, model):
536+
"""Exports all predictions from a model on this dataset"""
537+
json_response = self._client.make_request(
538+
payload=None,
539+
route=f"dataset/{self.id}/model/{model.id}/export",
540+
requests_command=requests.get,
541+
)
542+
return format_prediction_response({ANNOTATIONS_KEY: json_response})
543+
544+
def calculate_evaluation_metrics(self, model, options=None):
545+
"""
546+
547+
:param model: the model to calculate eval metrics for
548+
:param options: Dict with keys:
549+
class_agnostic -- A flag to specify if matching algorithm should be class-agnostic or not.
550+
Default value: True
551+
552+
allowed_label_matches -- An optional list of AllowedMatch objects to specify allowed matches
553+
for ground truth and model predictions.
554+
If specified, 'class_agnostic' flag is assumed to be False
555+
556+
Type 'AllowedMatch':
557+
{
558+
ground_truth_label: string, # A label for ground truth annotation.
559+
model_prediction_label: string, # A label for model prediction that can be matched with
560+
# corresponding ground truth label.
561+
}
562+
563+
payload:
564+
{
565+
"class_agnostic": boolean,
566+
"allowed_label_matches": List[AllowedMatch],
567+
}"""
568+
if options is None:
569+
options = {}
570+
return self._client.make_request(
571+
payload=options,
572+
route=f"dataset/{self.id}/model/{model.id}/calculateEvaluationMetrics",
573+
)
574+
575+
def upload_predictions(
576+
self,
577+
model,
578+
predictions: List[
579+
Union[
580+
BoxPrediction,
581+
PolygonPrediction,
582+
CuboidPrediction,
583+
SegmentationPrediction,
584+
]
585+
],
586+
update=False,
587+
asynchronous=False,
588+
):
589+
"""
590+
Uploads model outputs as predictions for a model_run. Returns info about the upload.
591+
:param predictions: List of prediction objects to ingest
592+
:param update: Whether to update (if true) or ignore (if false) on conflicting reference_id/annotation_id
593+
:param asynchronous: If true, return launch and then return a reference to an asynchronous job object. This is recommended for large ingests.
594+
:return:
595+
If synchronoius
596+
{
597+
"model_run_id": str,
598+
"predictions_processed": int,
599+
"predictions_ignored": int,
600+
}
601+
"""
602+
if asynchronous:
603+
check_all_mask_paths_remote(predictions)
604+
605+
request_id = serialize_and_write_to_presigned_url(
606+
predictions, self.id, self._client
607+
)
608+
response = self._client.make_request(
609+
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
610+
route=f"dataset/{self.id}/model/{model.id}/uploadPredictions?async=1",
611+
)
612+
return AsyncJob.from_json(response, self._client)
613+
else:
614+
return self._client.predict(
615+
model_run_id=None,
616+
dataset_id=self.id,
617+
model_id=model.id,
618+
annotations=predictions,
619+
update=update,
620+
)
621+
622+
def predictions_iloc(self, model, index):
623+
"""
624+
Returns predictions For Dataset Item by index.
625+
:param model: model object to get predictions from.
626+
:param index: absolute number of Dataset Item for a dataset corresponding to the model run.
627+
:return: List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]],
628+
}
629+
"""
630+
return format_prediction_response(
631+
self._client.make_request(
632+
payload=None,
633+
route=f"dataset/{self.id}/model/{model.id}/iloc/{index}",
634+
requests_command=requests.get,
635+
)
636+
)
637+
638+
def predictions_refloc(self, model, reference_id):
639+
"""
640+
Returns predictions for dataset Item by its reference_id.
641+
:param model: model object to get predictions from.
642+
:param reference_id: reference_id of a dataset item.
643+
:return: List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]],
644+
"""
645+
return format_prediction_response(
646+
self._client.make_request(
647+
payload=None,
648+
route=f"dataset/{self.id}/model/{model.id}/referenceId/{reference_id}",
649+
requests_command=requests.get,
650+
)
651+
)
652+
653+
def prediction_loc(self, model, reference_id, annotation_id):
654+
"""
655+
Returns info for single Prediction by its reference id and annotation id. Not supported for segmentation predictions yet.
656+
:param reference_id: the user specified id for the image
657+
:param annotation_id: the user specified id for the prediction, or if one was not provided, the Scale internally generated id for the prediction
658+
:return:
659+
BoxPrediction | PolygonPrediction | CuboidPrediction
660+
"""
661+
return from_json(
662+
self._client.make_request(
663+
payload=None,
664+
route=f"dataset/{self.id}/model/{model.id}/loc/{reference_id}/{annotation_id}",
665+
requests_command=requests.get,
666+
)
667+
)

0 commit comments

Comments
 (0)