Skip to content

Commit f2cb018

Browse files
authored
Merge pull request #35 from scaleapi/sasha/segmentation
Sasha/segmentation
2 parents b928eaf + 824c1fd commit f2cb018

File tree

10 files changed

+628
-159
lines changed

10 files changed

+628
-159
lines changed

nucleus/__init__.py

Lines changed: 107 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,17 @@
6969

7070
from .dataset import Dataset
7171
from .dataset_item import DatasetItem
72-
from .annotation import BoxAnnotation, PolygonAnnotation
73-
from .prediction import BoxPrediction, PolygonPrediction
72+
from .annotation import (
73+
BoxAnnotation,
74+
PolygonAnnotation,
75+
SegmentationAnnotation,
76+
Segment,
77+
)
78+
from .prediction import (
79+
BoxPrediction,
80+
PolygonPrediction,
81+
SegmentationPrediction,
82+
)
7483
from .model_run import ModelRun
7584
from .slice import Slice
7685
from .upload_response import UploadResponse
@@ -79,6 +88,7 @@
7988
construct_annotation_payload,
8089
construct_model_creation_payload,
8190
construct_box_predictions_payload,
91+
construct_segmentation_payload,
8292
)
8393
from .constants import (
8494
NUCLEUS_ENDPOINT,
@@ -465,7 +475,9 @@ def exception_handler(request, exception):
465475
def annotate_dataset(
466476
self,
467477
dataset_id: str,
468-
annotations: List[Union[BoxAnnotation, PolygonAnnotation]],
478+
annotations: List[
479+
Union[BoxAnnotation, PolygonAnnotation, SegmentationAnnotation]
480+
],
469481
update: bool,
470482
batch_size: int = 100,
471483
):
@@ -477,9 +489,26 @@ def annotate_dataset(
477489
:return: {"dataset_id: str, "annotations_processed": int}
478490
"""
479491

492+
# Split payload into segmentations and Box/Polygon
493+
segmentations = [
494+
ann
495+
for ann in annotations
496+
if isinstance(ann, SegmentationAnnotation)
497+
]
498+
other_annotations = [
499+
ann
500+
for ann in annotations
501+
if not isinstance(ann, SegmentationAnnotation)
502+
]
503+
480504
batches = [
481-
annotations[i : i + batch_size]
482-
for i in range(0, len(annotations), batch_size)
505+
other_annotations[i : i + batch_size]
506+
for i in range(0, len(other_annotations), batch_size)
507+
]
508+
509+
semseg_batches = [
510+
segmentations[i : i + batch_size]
511+
for i in range(0, len(segmentations), batch_size)
483512
]
484513

485514
agg_response = {
@@ -488,22 +517,42 @@ def annotate_dataset(
488517
ANNOTATIONS_IGNORED_KEY: 0,
489518
}
490519

520+
total_batches = len(batches) + len(semseg_batches)
521+
491522
tqdm_batches = self.tqdm_bar(batches)
492523

493-
for batch in tqdm_batches:
494-
payload = construct_annotation_payload(batch, update)
495-
response = self._make_request(
496-
payload, f"dataset/{dataset_id}/annotate"
497-
)
498-
if STATUS_CODE_KEY in response:
499-
agg_response[ERRORS_KEY] = response
500-
else:
501-
agg_response[ANNOTATIONS_PROCESSED_KEY] += response[
502-
ANNOTATIONS_PROCESSED_KEY
503-
]
504-
agg_response[ANNOTATIONS_IGNORED_KEY] += response[
505-
ANNOTATIONS_IGNORED_KEY
506-
]
524+
with self.tqdm_bar(total=total_batches) as pbar:
525+
for batch in tqdm_batches:
526+
payload = construct_annotation_payload(batch, update)
527+
response = self._make_request(
528+
payload, f"dataset/{dataset_id}/annotate"
529+
)
530+
pbar.update(1)
531+
if STATUS_CODE_KEY in response:
532+
agg_response[ERRORS_KEY] = response
533+
else:
534+
agg_response[ANNOTATIONS_PROCESSED_KEY] += response[
535+
ANNOTATIONS_PROCESSED_KEY
536+
]
537+
agg_response[ANNOTATIONS_IGNORED_KEY] += response[
538+
ANNOTATIONS_IGNORED_KEY
539+
]
540+
541+
for s_batch in semseg_batches:
542+
payload = construct_segmentation_payload(s_batch, update)
543+
response = self._make_request(
544+
payload, f"dataset/{dataset_id}/annotate_segmentation"
545+
)
546+
pbar.update(1)
547+
if STATUS_CODE_KEY in response:
548+
agg_response[ERRORS_KEY] = response
549+
else:
550+
agg_response[ANNOTATIONS_PROCESSED_KEY] += response[
551+
ANNOTATIONS_PROCESSED_KEY
552+
]
553+
agg_response[ANNOTATIONS_IGNORED_KEY] += response[
554+
ANNOTATIONS_IGNORED_KEY
555+
]
507556

508557
return agg_response
509558

@@ -579,7 +628,9 @@ def create_model_run(self, dataset_id: str, payload: dict) -> ModelRun:
579628
def predict(
580629
self,
581630
model_run_id: str,
582-
annotations: List[Union[BoxPrediction, PolygonPrediction]],
631+
annotations: List[
632+
Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
633+
],
583634
update: bool,
584635
batch_size: int = 100,
585636
):
@@ -595,9 +646,26 @@ def predict(
595646
"predictions_ignored": int,
596647
}
597648
"""
649+
segmentations = [
650+
ann
651+
for ann in annotations
652+
if isinstance(ann, SegmentationPrediction)
653+
]
654+
655+
other_predictions = [
656+
ann
657+
for ann in annotations
658+
if not isinstance(ann, SegmentationPrediction)
659+
]
660+
661+
s_batches = [
662+
segmentations[i : i + batch_size]
663+
for i in range(0, len(segmentations), batch_size)
664+
]
665+
598666
batches = [
599-
annotations[i : i + batch_size]
600-
for i in range(0, len(annotations), batch_size)
667+
other_predictions[i : i + batch_size]
668+
for i in range(0, len(other_predictions), batch_size)
601669
]
602670

603671
agg_response = {
@@ -610,7 +678,7 @@ def predict(
610678

611679
for batch in tqdm_batches:
612680
batch_payload = construct_box_predictions_payload(
613-
annotations,
681+
batch,
614682
update,
615683
)
616684
response = self._make_request(
@@ -626,8 +694,23 @@ def predict(
626694
PREDICTIONS_IGNORED_KEY
627695
]
628696

697+
for s_batch in s_batches:
698+
payload = construct_segmentation_payload(s_batch, update)
699+
response = self._make_request(
700+
payload, f"modelRun/{model_run_id}/predict_segmentation"
701+
)
702+
# pbar.update(1)
703+
if STATUS_CODE_KEY in response:
704+
agg_response[ERRORS_KEY] = response
705+
else:
706+
agg_response[PREDICTIONS_PROCESSED_KEY] += response[
707+
PREDICTIONS_PROCESSED_KEY
708+
]
709+
agg_response[PREDICTIONS_IGNORED_KEY] += response[
710+
PREDICTIONS_IGNORED_KEY
711+
]
712+
629713
return agg_response
630-
# return self._make_request(payload, f"modelRun/{model_run_id}/predict")
631714

632715
def commit_model_run(
633716
self, model_run_id: str, payload: Optional[dict] = None

nucleus/annotation.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,86 @@
1515
LABEL_KEY,
1616
TYPE_KEY,
1717
VERTICES_KEY,
18+
ITEM_ID_KEY,
19+
MASK_URL_KEY,
20+
INDEX_KEY,
21+
ANNOTATIONS_KEY,
1822
)
1923

2024

25+
class Segment:
26+
def __init__(
27+
self, label: str, index: int, metadata: Optional[dict] = None
28+
):
29+
self.label = label
30+
self.index = index
31+
self.metadata = metadata
32+
33+
def __str__(self):
34+
return str(self.to_payload())
35+
36+
@classmethod
37+
def from_json(cls, payload: dict):
38+
return cls(
39+
label=payload.get(LABEL_KEY, ""),
40+
index=payload.get(INDEX_KEY, None),
41+
metadata=payload.get(METADATA_KEY, None),
42+
)
43+
44+
def to_payload(self) -> dict:
45+
payload = {
46+
LABEL_KEY: self.label,
47+
INDEX_KEY: self.index,
48+
}
49+
if self.metadata is not None:
50+
payload[METADATA_KEY] = self.metadata
51+
return payload
52+
53+
54+
class SegmentationAnnotation:
55+
def __init__(
56+
self,
57+
mask_url: str,
58+
annotations: List[Segment],
59+
reference_id: Optional[str] = None,
60+
item_id: Optional[str] = None,
61+
):
62+
if bool(reference_id) == bool(item_id):
63+
raise Exception(
64+
"You must specify either a reference_id or an item_id for an annotation."
65+
)
66+
self.mask_url = mask_url
67+
self.annotations = annotations
68+
self.reference_id = reference_id
69+
self.item_id = item_id
70+
71+
def __str__(self):
72+
return str(self.to_payload())
73+
74+
@classmethod
75+
def from_json(cls, payload: dict):
76+
return cls(
77+
mask_url=payload[MASK_URL_KEY],
78+
annotations=[
79+
Segment.from_json(ann)
80+
for ann in payload.get(ANNOTATIONS_KEY, [])
81+
],
82+
reference_id=payload.get(REFERENCE_ID_KEY, None),
83+
item_id=payload.get(ITEM_ID_KEY, None),
84+
)
85+
86+
def to_payload(self) -> dict:
87+
payload = {
88+
MASK_URL_KEY: self.mask_url,
89+
ANNOTATIONS_KEY: [ann.to_payload() for ann in self.annotations],
90+
}
91+
if self.reference_id:
92+
payload[REFERENCE_ID_KEY] = self.reference_id
93+
else:
94+
payload[ITEM_ID_KEY] = self.item_id
95+
return payload
96+
97+
2198
class AnnotationTypes(Enum):
2299
BOX = BOX_TYPE
23100
POLYGON = POLYGON_TYPE
@@ -131,4 +208,4 @@ def to_payload(self) -> dict:
131208
}
132209

133210
def __str__(self):
134-
return str(self.to_payload())
211+
return str(self.to_payload())

nucleus/constants.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
1+
# NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
2+
NUCLEUS_ENDPOINT = "http://localhost:3000/v1/nucleus"
23
ITEMS_KEY = "items"
34
ITEM_KEY = "item"
45
REFERENCE_ID_KEY = "reference_id"
@@ -26,6 +27,7 @@
2627
MODEL_RUN_ID_KEY = "model_run_id"
2728
MODEL_ID_KEY = "model_id"
2829
DATASET_ITEM_ID_KEY = "dataset_item_id"
30+
ITEM_ID_KEY = "item_id"
2931
DATASET_ITEM_IDS_KEY = "dataset_item_ids"
3032
SLICE_ID_KEY = "slice_id"
3133
DATASET_NAME_KEY = "name"
@@ -48,3 +50,6 @@
4850
POLYGON_TYPE = "polygon"
4951
GEOMETRY_KEY = "geometry"
5052
AUTOTAGS_KEY = "autotags"
53+
MASK_URL_KEY = "mask_url"
54+
INDEX_KEY = "index"
55+
SEGMENTATIONS_KEY = "segmentations"

nucleus/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def create_run(
3737
self,
3838
name: str,
3939
dataset: Dataset,
40-
predictions: List[Union[BoxPrediction, PolygonPrediction]],
40+
predictions: List[
41+
Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
42+
],
4143
metadata: Optional[Dict] = None,
4244
) -> ModelRun:
4345
payload: dict = {

nucleus/model_run.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from typing import Optional, List, Dict, Any, Union
22
from .constants import ANNOTATIONS_KEY, DEFAULT_ANNOTATION_UPDATE_MODE
3-
from .prediction import BoxPrediction, PolygonPrediction
3+
from .prediction import (
4+
BoxPrediction,
5+
PolygonPrediction,
6+
SegmentationPrediction,
7+
)
48
from .payload_constructor import construct_box_predictions_payload
59

610

@@ -62,7 +66,9 @@ def commit(self, payload: Optional[dict] = None) -> dict:
6266

6367
def predict(
6468
self,
65-
annotations: List[Union[BoxPrediction, PolygonPrediction]],
69+
annotations: List[
70+
Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
71+
],
6672
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
6773
) -> dict:
6874
"""

nucleus/payload_constructor.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
from typing import List, Optional, Dict, Union
22
from .dataset_item import DatasetItem
3-
from .annotation import BoxAnnotation, PolygonAnnotation
4-
from .prediction import BoxPrediction, PolygonPrediction
3+
from .annotation import (
4+
BoxAnnotation,
5+
PolygonAnnotation,
6+
SegmentationAnnotation,
7+
)
8+
from .prediction import (
9+
BoxPrediction,
10+
PolygonPrediction,
11+
SegmentationPrediction,
12+
)
513
from .constants import (
614
ANNOTATION_UPDATE_KEY,
715
NAME_KEY,
@@ -11,6 +19,7 @@
1119
ITEMS_KEY,
1220
FORCE_KEY,
1321
MODEL_ID_KEY,
22+
SEGMENTATIONS_KEY,
1423
)
1524

1625

@@ -39,6 +48,19 @@ def construct_annotation_payload(
3948
return {ANNOTATIONS_KEY: annotations, ANNOTATION_UPDATE_KEY: update}
4049

4150

51+
def construct_segmentation_payload(
52+
annotation_items: List[
53+
Union[SegmentationAnnotation, SegmentationPrediction]
54+
],
55+
update: bool,
56+
) -> dict:
57+
annotations = []
58+
for annotation_item in annotation_items:
59+
annotations.append(annotation_item.to_payload())
60+
61+
return {SEGMENTATIONS_KEY: annotations, ANNOTATION_UPDATE_KEY: update}
62+
63+
4264
def construct_box_predictions_payload(
4365
box_predictions: List[Union[BoxPrediction, PolygonPrediction]],
4466
update: bool,

0 commit comments

Comments
 (0)