Skip to content

Commit e688fdb

Browse files
committed
nits
1 parent 2f147a6 commit e688fdb

File tree

2 files changed

+51
-7
lines changed

2 files changed

+51
-7
lines changed

nucleus/__init__.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,11 @@
7474
PolygonAnnotation,
7575
SegmentationAnnotation,
7676
)
77-
from .prediction import BoxPrediction, PolygonPrediction
77+
from .prediction import (
78+
BoxPrediction,
79+
PolygonPrediction,
80+
SegmentationPrediction,
81+
)
7882
from .model_run import ModelRun
7983
from .slice import Slice
8084
from .upload_response import UploadResponse
@@ -622,7 +626,9 @@ def create_model_run(self, dataset_id: str, payload: dict) -> ModelRun:
622626
def predict(
623627
self,
624628
model_run_id: str,
625-
annotations: List[Union[BoxPrediction, PolygonPrediction]],
629+
annotations: List[
630+
Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
631+
],
626632
update: bool,
627633
batch_size: int = 100,
628634
):
@@ -638,9 +644,26 @@ def predict(
638644
"predictions_ignored": int,
639645
}
640646
"""
647+
segmentations = [
648+
ann
649+
for ann in annotations
650+
if isinstance(ann, SegmentationPrediction)
651+
]
652+
653+
other_predictions = [
654+
ann
655+
for ann in annotations
656+
if not isinstance(ann, SegmentationPrediction)
657+
]
658+
659+
s_batches = [
660+
segmentations[i : i + batch_size]
661+
for i in range(0, len(segmentations), batch_size)
662+
]
663+
641664
batches = [
642-
annotations[i : i + batch_size]
643-
for i in range(0, len(annotations), batch_size)
665+
other_predictions[i : i + batch_size]
666+
for i in range(0, len(other_predictions), batch_size)
644667
]
645668

646669
agg_response = {
@@ -669,8 +692,23 @@ def predict(
669692
PREDICTIONS_IGNORED_KEY
670693
]
671694

695+
for s_batch in s_batches:
696+
payload = construct_segmentation_payload(s_batch, update)
697+
response = self._make_request(
698+
payload, f"modelRun/{model_run_id}/predict_segmentation"
699+
)
700+
# pbar.update(1)
701+
if STATUS_CODE_KEY in response:
702+
agg_response[ERRORS_KEY] = response
703+
else:
704+
agg_response[PREDICTIONS_PROCESSED_KEY] += response[
705+
PREDICTIONS_PROCESSED_KEY
706+
]
707+
agg_response[PREDICTIONS_IGNORED_KEY] += response[
708+
PREDICTIONS_IGNORED_KEY
709+
]
710+
672711
return agg_response
673-
# return self._make_request(payload, f"modelRun/{model_run_id}/predict")
674712

675713
def commit_model_run(
676714
self, model_run_id: str, payload: Optional[dict] = None

nucleus/payload_constructor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
PolygonAnnotation,
66
SegmentationAnnotation,
77
)
8-
from .prediction import BoxPrediction, PolygonPrediction
8+
from .prediction import (
9+
BoxPrediction,
10+
PolygonPrediction,
11+
SegmentationPrediction,
12+
)
913
from .constants import (
1014
ANNOTATION_UPDATE_KEY,
1115
NAME_KEY,
@@ -45,7 +49,9 @@ def construct_annotation_payload(
4549

4650

4751
def construct_segmentation_payload(
48-
annotation_items: List[SegmentationAnnotation],
52+
annotation_items: List[
53+
Union[SegmentationAnnotation, SegmentationPrediction]
54+
],
4955
update: bool,
5056
) -> dict:
5157
annotations = []

0 commit comments

Comments
 (0)