Skip to content

Commit 824c1fd

Browse files
authored
Merge pull request #36 from scaleapi/sasha/segmentation_prediction
Sasha/segmentation prediction
2 parents 74e619f + 117fc25 commit 824c1fd

File tree

8 files changed

+215
-54
lines changed

8 files changed

+215
-54
lines changed

nucleus/__init__.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,13 @@
7373
BoxAnnotation,
7474
PolygonAnnotation,
7575
SegmentationAnnotation,
76+
Segment,
77+
)
78+
from .prediction import (
79+
BoxPrediction,
80+
PolygonPrediction,
81+
SegmentationPrediction,
7682
)
77-
from .prediction import BoxPrediction, PolygonPrediction
7883
from .model_run import ModelRun
7984
from .slice import Slice
8085
from .upload_response import UploadResponse
@@ -623,7 +628,9 @@ def create_model_run(self, dataset_id: str, payload: dict) -> ModelRun:
623628
def predict(
624629
self,
625630
model_run_id: str,
626-
annotations: List[Union[BoxPrediction, PolygonPrediction]],
631+
annotations: List[
632+
Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
633+
],
627634
update: bool,
628635
batch_size: int = 100,
629636
):
@@ -639,9 +646,26 @@ def predict(
639646
"predictions_ignored": int,
640647
}
641648
"""
649+
segmentations = [
650+
ann
651+
for ann in annotations
652+
if isinstance(ann, SegmentationPrediction)
653+
]
654+
655+
other_predictions = [
656+
ann
657+
for ann in annotations
658+
if not isinstance(ann, SegmentationPrediction)
659+
]
660+
661+
s_batches = [
662+
segmentations[i : i + batch_size]
663+
for i in range(0, len(segmentations), batch_size)
664+
]
665+
642666
batches = [
643-
annotations[i : i + batch_size]
644-
for i in range(0, len(annotations), batch_size)
667+
other_predictions[i : i + batch_size]
668+
for i in range(0, len(other_predictions), batch_size)
645669
]
646670

647671
agg_response = {
@@ -670,8 +694,23 @@ def predict(
670694
PREDICTIONS_IGNORED_KEY
671695
]
672696

697+
for s_batch in s_batches:
698+
payload = construct_segmentation_payload(s_batch, update)
699+
response = self._make_request(
700+
payload, f"modelRun/{model_run_id}/predict_segmentation"
701+
)
702+
# pbar.update(1)
703+
if STATUS_CODE_KEY in response:
704+
agg_response[ERRORS_KEY] = response
705+
else:
706+
agg_response[PREDICTIONS_PROCESSED_KEY] += response[
707+
PREDICTIONS_PROCESSED_KEY
708+
]
709+
agg_response[PREDICTIONS_IGNORED_KEY] += response[
710+
PREDICTIONS_IGNORED_KEY
711+
]
712+
673713
return agg_response
674-
# return self._make_request(payload, f"modelRun/{model_run_id}/predict")
675714

676715
def commit_model_run(
677716
self, model_run_id: str, payload: Optional[dict] = None

nucleus/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
1+
# NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
2+
NUCLEUS_ENDPOINT = "http://localhost:3000/v1/nucleus"
23
ITEMS_KEY = "items"
34
ITEM_KEY = "item"
45
REFERENCE_ID_KEY = "reference_id"

nucleus/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def create_run(
3737
self,
3838
name: str,
3939
dataset: Dataset,
40-
predictions: List[Union[BoxPrediction, PolygonPrediction]],
40+
predictions: List[
41+
Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
42+
],
4143
metadata: Optional[Dict] = None,
4244
) -> ModelRun:
4345
payload: dict = {

nucleus/model_run.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from typing import Optional, List, Dict, Any, Union
22
from .constants import ANNOTATIONS_KEY, DEFAULT_ANNOTATION_UPDATE_MODE
3-
from .prediction import BoxPrediction, PolygonPrediction
3+
from .prediction import (
4+
BoxPrediction,
5+
PolygonPrediction,
6+
SegmentationPrediction,
7+
)
48
from .payload_constructor import construct_box_predictions_payload
59

610

@@ -62,7 +66,9 @@ def commit(self, payload: Optional[dict] = None) -> dict:
6266

6367
def predict(
6468
self,
65-
annotations: List[Union[BoxPrediction, PolygonPrediction]],
69+
annotations: List[
70+
Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
71+
],
6672
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
6773
) -> dict:
6874
"""

nucleus/payload_constructor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
PolygonAnnotation,
66
SegmentationAnnotation,
77
)
8-
from .prediction import BoxPrediction, PolygonPrediction
8+
from .prediction import (
9+
BoxPrediction,
10+
PolygonPrediction,
11+
SegmentationPrediction,
12+
)
913
from .constants import (
1014
ANNOTATION_UPDATE_KEY,
1115
NAME_KEY,
@@ -45,7 +49,9 @@ def construct_annotation_payload(
4549

4650

4751
def construct_segmentation_payload(
48-
annotation_items: List[SegmentationAnnotation],
52+
annotation_items: List[
53+
Union[SegmentationAnnotation, SegmentationPrediction]
54+
],
4955
update: bool,
5056
) -> dict:
5157
annotations = []

nucleus/prediction.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
from typing import Dict, Optional, List, Any
2-
from .annotation import BoxAnnotation, PolygonAnnotation
2+
from .annotation import (
3+
BoxAnnotation,
4+
PolygonAnnotation,
5+
Segment,
6+
SegmentationAnnotation,
7+
)
38
from .constants import (
49
ANNOTATION_ID_KEY,
510
DATASET_ITEM_ID_KEY,
@@ -13,9 +18,30 @@
1318
HEIGHT_KEY,
1419
CONFIDENCE_KEY,
1520
VERTICES_KEY,
21+
ANNOTATIONS_KEY,
22+
ITEM_ID_KEY,
23+
MASK_URL_KEY,
1624
)
1725

1826

27+
class SegmentationPrediction(SegmentationAnnotation):
28+
@classmethod
29+
def from_json(cls, payload: dict):
30+
return cls(
31+
mask_url=payload[MASK_URL_KEY],
32+
annotations=[
33+
Segment.from_json(ann)
34+
for ann in payload.get(ANNOTATIONS_KEY, [])
35+
],
36+
reference_id=payload.get(REFERENCE_ID_KEY, None),
37+
item_id=payload.get(ITEM_ID_KEY, None),
38+
)
39+
40+
def to_payload(self) -> dict:
41+
payload = super().to_payload()
42+
return payload
43+
44+
1945
class BoxPrediction(BoxAnnotation):
2046
def __init__(
2147
self,
@@ -31,7 +57,15 @@ def __init__(
3157
metadata: Optional[Dict] = None,
3258
):
3359
super().__init__(
34-
label, x, y, width, height, reference_id, item_id, annotation_id, metadata
60+
label,
61+
x,
62+
y,
63+
width,
64+
height,
65+
reference_id,
66+
item_id,
67+
annotation_id,
68+
metadata,
3569
)
3670
self.confidence = confidence
3771

@@ -73,7 +107,9 @@ def __init__(
73107
annotation_id: Optional[str] = None,
74108
metadata: Optional[Dict] = None,
75109
):
76-
super().__init__(label, vertices, reference_id, item_id, annotation_id, metadata)
110+
super().__init__(
111+
label, vertices, reference_id, item_id, annotation_id, metadata
112+
)
77113
self.confidence = confidence
78114

79115
def to_payload(self) -> dict:

tests/helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ def reference_id_from_url(url):
107107
for i in range(len(TEST_IMG_URLS))
108108
]
109109

110+
TEST_SEGMENTATION_PREDICTIONS = TEST_SEGMENTATION_ANNOTATIONS
111+
110112
TEST_BOX_PREDICTIONS = [
111113
{**TEST_BOX_ANNOTATIONS[i], "confidence": 0.10 * i}
112114
for i in range(len(TEST_BOX_ANNOTATIONS))

0 commit comments

Comments
 (0)