Skip to content

Commit 9305353

Browse files
authored
Merge pull request #39 from scaleapi/sasha/refloc_response_change
adjust refloc payload format
2 parents 4954aac + aa576bf commit 9305353

File tree

7 files changed

+147
-50
lines changed

7 files changed

+147
-50
lines changed

nucleus/annotation.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,19 @@
2222
)
2323

2424

25+
class Annotation:
26+
@classmethod
27+
def from_json(cls, payload: dict):
28+
if payload.get(TYPE_KEY, None) == BOX_TYPE:
29+
geometry = payload.get(GEOMETRY_KEY, {})
30+
return BoxAnnotation.from_json(payload)
31+
elif payload.get(TYPE_KEY, None) == POLYGON_TYPE:
32+
geometry = payload.get(GEOMETRY_KEY, {})
33+
return PolygonAnnotation.from_json(payload)
34+
else:
35+
return SegmentationAnnotation.from_json(payload)
36+
37+
2538
class Segment:
2639
def __init__(
2740
self, label: str, index: int, metadata: Optional[dict] = None
@@ -51,14 +64,17 @@ def to_payload(self) -> dict:
5164
return payload
5265

5366

54-
class SegmentationAnnotation:
67+
class SegmentationAnnotation(Annotation):
5568
def __init__(
5669
self,
5770
mask_url: str,
5871
annotations: List[Segment],
5972
reference_id: Optional[str] = None,
6073
item_id: Optional[str] = None,
6174
):
75+
super().__init__()
76+
if not mask_url:
77+
raise Exception("You must specify a mask_url.")
6278
if bool(reference_id) == bool(item_id):
6379
raise Exception(
6480
"You must specify either a reference_id or an item_id for an annotation."
@@ -74,7 +90,7 @@ def __str__(self):
7490
@classmethod
7591
def from_json(cls, payload: dict):
7692
return cls(
77-
mask_url=payload[MASK_URL_KEY],
93+
mask_url=payload.get(MASK_URL_KEY),
7894
annotations=[
7995
Segment.from_json(ann)
8096
for ann in payload.get(ANNOTATIONS_KEY, [])
@@ -101,7 +117,7 @@ class AnnotationTypes(Enum):
101117

102118

103119
# TODO: Add base annotation class to reduce repeated code here
104-
class BoxAnnotation:
120+
class BoxAnnotation(Annotation):
105121
# pylint: disable=too-many-instance-attributes
106122
def __init__(
107123
self,
@@ -115,6 +131,7 @@ def __init__(
115131
annotation_id: Optional[str] = None,
116132
metadata: Optional[Dict] = None,
117133
):
134+
super().__init__()
118135
if bool(reference_id) == bool(item_id):
119136
raise Exception(
120137
"You must specify either a reference_id or an item_id for an annotation."
@@ -164,7 +181,7 @@ def __str__(self):
164181

165182

166183
# TODO: Add Generic type for 2D point
167-
class PolygonAnnotation:
184+
class PolygonAnnotation(Annotation):
168185
def __init__(
169186
self,
170187
label: str,
@@ -174,6 +191,7 @@ def __init__(
174191
annotation_id: Optional[str] = None,
175192
metadata: Optional[Dict] = None,
176193
):
194+
super().__init__()
177195
if bool(reference_id) == bool(item_id):
178196
raise Exception(
179197
"You must specify either a reference_id or an item_id for an annotation."

nucleus/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
VERTICES_KEY = "vertices"
4848
BOX_TYPE = "box"
4949
POLYGON_TYPE = "polygon"
50+
SEGMENTATION_TYPE = "segmentation"
51+
ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE)
5052
GEOMETRY_KEY = "geometry"
5153
AUTOTAGS_KEY = "autotags"
5254
MASK_URL_KEY = "mask_url"

nucleus/dataset.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from typing import List, Dict, Any, Optional, Union
22
from .dataset_item import DatasetItem
3-
from .annotation import BoxAnnotation, PolygonAnnotation
3+
from .annotation import (
4+
Annotation,
5+
BoxAnnotation,
6+
PolygonAnnotation,
7+
SegmentationAnnotation,
8+
)
49
from .constants import (
510
DATASET_NAME_KEY,
611
DATASET_MODEL_RUNS_KEY,
@@ -12,9 +17,14 @@
1217
ITEM_KEY,
1318
DEFAULT_ANNOTATION_UPDATE_MODE,
1419
ANNOTATIONS_KEY,
20+
BOX_TYPE,
21+
POLYGON_TYPE,
22+
SEGMENTATION_TYPE,
23+
ANNOTATION_TYPES,
1524
)
1625
from .payload_constructor import construct_model_run_creation_payload
1726

27+
1828
class Dataset:
1929
"""
2030
Nucleus Dataset. You can append images with metadata to your dataset,
@@ -227,17 +237,19 @@ def list_autotags(self):
227237

228238
def _format_dataset_item_response(self, response: dict) -> dict:
229239
item = response.get(ITEM_KEY, None)
230-
annotation_payload = response.get(ANNOTATIONS_KEY, [])
240+
annotation_payload = response.get(ANNOTATIONS_KEY, {})
231241
if not item or not annotation_payload:
232242
# An error occured
233243
return response
234-
annotations = [
235-
BoxAnnotation.from_json(ann)
236-
if ann["type"] == "box"
237-
else PolygonAnnotation.from_json(ann)
238-
for ann in annotation_payload
239-
]
244+
245+
annotation_response = {}
246+
for annotation_type in ANNOTATION_TYPES:
247+
if annotation_type in annotation_payload:
248+
annotation_response[annotation_type] = [
249+
Annotation.from_json(ann)
250+
for ann in annotation_payload[annotation_type]
251+
]
240252
return {
241253
ITEM_KEY: DatasetItem.from_json(item),
242-
ANNOTATIONS_KEY: annotations,
254+
ANNOTATIONS_KEY: annotation_response,
243255
}

nucleus/model_run.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1-
from typing import Optional, List, Dict, Any, Union
2-
from .constants import ANNOTATIONS_KEY, DEFAULT_ANNOTATION_UPDATE_MODE
1+
from typing import Optional, List, Union
2+
from .constants import (
3+
ANNOTATIONS_KEY,
4+
DEFAULT_ANNOTATION_UPDATE_MODE,
5+
BOX_TYPE,
6+
POLYGON_TYPE,
7+
SEGMENTATION_TYPE,
8+
)
39
from .prediction import (
410
BoxPrediction,
511
PolygonPrediction,
612
SegmentationPrediction,
713
)
8-
from .payload_constructor import construct_box_predictions_payload
914

1015

1116
class ModelRun:
@@ -121,13 +126,18 @@ def loc(self, dataset_item_id: str):
121126
def _format_prediction_response(
122127
self, response: dict
123128
) -> Union[dict, List[Union[BoxPrediction, PolygonPrediction]]]:
124-
annotations = response.get(ANNOTATIONS_KEY, None)
125-
if annotations:
126-
return [
127-
BoxPrediction.from_json(ann)
128-
if ann["type"] == "box"
129-
else PolygonPrediction.from_json(ann)
130-
for ann in annotations
131-
]
129+
annotation_payload = response.get(ANNOTATIONS_KEY, None)
130+
if annotation_payload:
131+
annotation_response = {}
132+
for (type_key, type_cls) in zip(
133+
[BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE],
134+
[BoxPrediction, PolygonPrediction, SegmentationPrediction],
135+
):
136+
if type_key in annotation_payload:
137+
annotation_response[type_key] = [
138+
type_cls.from_json(ann)
139+
for ann in annotation_payload[type_key]
140+
]
141+
return annotation_response
132142
else: # An error occurred
133143
return response

tests/helpers.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
TEST_DATASET_NAME = "[PyTest] Test Dataset"
1212
TEST_SLICE_NAME = "[PyTest] Test Slice"
1313

14-
TEST_MODEL_NAME = '[PyTest] Test Model Name'
15-
TEST_MODEL_REFERENCE = '[PyTest] Test Model Reference'
16-
TEST_MODEL_RUN = '[PyTest] Test Model Run Reference'
17-
TEST_DATASET_NAME = '[PyTest] Test Dataset'
18-
TEST_SLICE_NAME = '[PyTest] Test Slice'
14+
TEST_MODEL_NAME = "[PyTest] Test Model Name"
15+
TEST_MODEL_REFERENCE = "[PyTest] Test Model Reference"
16+
TEST_MODEL_RUN = "[PyTest] Test Model Run Reference"
17+
TEST_DATASET_NAME = "[PyTest] Test Dataset"
18+
TEST_SLICE_NAME = "[PyTest] Test Slice"
1919
TEST_IMG_URLS = [
2020
"s3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/6dd63871-831611a6.jpg",
2121
"s3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/82c1005c-e2d1d94f.jpg",
@@ -24,16 +24,16 @@
2424
"s3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/89b42832-10d662f4.jpg",
2525
]
2626
TEST_DATASET_ITEMS = [
27-
DatasetItem(TEST_IMG_URLS[0], '1'),
28-
DatasetItem(TEST_IMG_URLS[1], '2'),
29-
DatasetItem(TEST_IMG_URLS[2], '3'),
30-
DatasetItem(TEST_IMG_URLS[3], '4')
27+
DatasetItem(TEST_IMG_URLS[0], "1"),
28+
DatasetItem(TEST_IMG_URLS[1], "2"),
29+
DatasetItem(TEST_IMG_URLS[2], "3"),
30+
DatasetItem(TEST_IMG_URLS[3], "4"),
3131
]
3232
TEST_PREDS = [
33-
BoxPrediction('[Pytest Box Prediction 1]', 0, 0, 100, 100, '1'),
34-
BoxPrediction('[Pytest Box Prediction 2]', 0, 0, 100, 100, '2'),
35-
BoxPrediction('[Pytest Box Prediction 3]', 0, 0, 100, 100, '3'),
36-
BoxPrediction('[Pytest Box Prediction 4]', 0, 0, 100, 100, '4')
33+
BoxPrediction("[Pytest Box Prediction 1]", 0, 0, 100, 100, "1"),
34+
BoxPrediction("[Pytest Box Prediction 2]", 0, 0, 100, 100, "2"),
35+
BoxPrediction("[Pytest Box Prediction 3]", 0, 0, 100, 100, "3"),
36+
BoxPrediction("[Pytest Box Prediction 4]", 0, 0, 100, 100, "4"),
3737
]
3838

3939

@@ -147,6 +147,23 @@ def assert_polygon_annotation_matches_dict(
147147
assert instance_pt["y"] == dict_pt["y"]
148148

149149

150+
def assert_segmentation_annotation_matches_dict(
151+
annotation_instance, annotation_dict
152+
):
153+
assert annotation_instance.mask_url == annotation_dict["mask_url"]
154+
# Cannot guarantee segments are in same order
155+
assert len(annotation_instance.annotations) == len(
156+
annotation_dict["annotations"]
157+
)
158+
159+
for instance_segment, dict_segment in zip(
160+
sorted(annotation_instance.annotations, key=lambda i: i.index),
161+
sorted(annotation_dict["annotations"], key=lambda i: i["index"]),
162+
):
163+
assert instance_segment.index == dict_segment["index"]
164+
assert instance_segment.label == dict_segment["label"]
165+
166+
150167
# Asserts that a box prediction instance matches a dict representing its properties.
151168
# Useful to check prediction uploads/updates match.
152169
def assert_box_prediction_matches_dict(prediction_instance, prediction_dict):

tests/test_annotation.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
reference_id_from_url,
1010
assert_box_annotation_matches_dict,
1111
assert_polygon_annotation_matches_dict,
12+
assert_segmentation_annotation_matches_dict,
1213
)
1314

1415
from nucleus import (
@@ -48,7 +49,7 @@ def test_box_gt_upload(dataset):
4849
assert response["annotations_processed"] == 1
4950
assert response["annotations_ignored"] == 0
5051

51-
response = dataset.refloc(annotation.reference_id)["annotations"]
52+
response = dataset.refloc(annotation.reference_id)["annotations"]["box"]
5253
assert len(response) == 1
5354
response_annotation = response[0]
5455
assert_box_annotation_matches_dict(
@@ -64,7 +65,9 @@ def test_polygon_gt_upload(dataset):
6465
assert response["annotations_processed"] == 1
6566
assert response["annotations_ignored"] == 0
6667

67-
response = dataset.refloc(annotation.reference_id)["annotations"]
68+
response = dataset.refloc(annotation.reference_id)["annotations"][
69+
"polygon"
70+
]
6871
assert len(response) == 1
6972
response_annotation = response[0]
7073
assert_polygon_annotation_matches_dict(
@@ -80,7 +83,13 @@ def test_single_semseg_gt_upload(dataset):
8083
assert response["dataset_id"] == dataset.id
8184
assert response["annotations_processed"] == 1
8285
assert response["annotations_ignored"] == 0
83-
# assert_box_annotation_matches_dict(response_annotation, TEST_BOX_ANNOTATIONS[0])
86+
87+
response_annotation = dataset.refloc(annotation.reference_id)[
88+
"annotations"
89+
]["segmentation"][0]
90+
assert_segmentation_annotation_matches_dict(
91+
response_annotation, TEST_SEGMENTATION_ANNOTATIONS[0]
92+
)
8493

8594

8695
def test_batch_semseg_gt_upload(dataset):
@@ -142,6 +151,12 @@ def test_mixed_annotation_upload(dataset):
142151
assert response["dataset_id"] == dataset.id
143152
assert response["annotations_processed"] == 10
144153
assert response["annotations_ignored"] == 0
154+
response_annotations = dataset.refloc(bbox_annotations[0].reference_id)[
155+
"annotations"
156+
]
157+
assert len(response_annotations) == 2
158+
assert len(response_annotations["box"]) == 1
159+
assert "segmentation" in response_annotations
145160

146161

147162
def test_box_gt_upload_update(dataset):
@@ -165,7 +180,7 @@ def test_box_gt_upload_update(dataset):
165180
assert response["annotations_processed"] == 1
166181
assert response["annotations_ignored"] == 0
167182

168-
response = dataset.refloc(annotation.reference_id)["annotations"]
183+
response = dataset.refloc(annotation.reference_id)["annotations"]["box"]
169184
assert len(response) == 1
170185
response_annotation = response[0]
171186
assert_box_annotation_matches_dict(
@@ -194,7 +209,7 @@ def test_box_gt_upload_ignore(dataset):
194209
assert response["annotations_processed"] == 1
195210
assert response["annotations_ignored"] == 1
196211

197-
response = dataset.refloc(annotation.reference_id)["annotations"]
212+
response = dataset.refloc(annotation.reference_id)["annotations"]["box"]
198213
assert len(response) == 1
199214
response_annotation = response[0]
200215
assert_box_annotation_matches_dict(
@@ -223,7 +238,9 @@ def test_polygon_gt_upload_update(dataset):
223238
assert response["annotations_processed"] == 1
224239
assert response["annotations_ignored"] == 0
225240

226-
response = dataset.refloc(annotation.reference_id)["annotations"]
241+
response = dataset.refloc(annotation.reference_id)["annotations"][
242+
"polygon"
243+
]
227244
assert len(response) == 1
228245
response_annotation = response[0]
229246
assert_polygon_annotation_matches_dict(
@@ -253,7 +270,9 @@ def test_polygon_gt_upload_ignore(dataset):
253270
assert response["annotations_processed"] == 1
254271
assert response["annotations_ignored"] == 1
255272

256-
response = dataset.refloc(annotation.reference_id)["annotations"]
273+
response = dataset.refloc(annotation.reference_id)["annotations"][
274+
"polygon"
275+
]
257276
assert len(response) == 1
258277
response_annotation = response[0]
259278
assert_polygon_annotation_matches_dict(

0 commit comments

Comments
 (0)