Skip to content

Commit 24158e8

Browse files
authored
Merge pull request #40 from scaleapi/sasha/annotation_id_fix
Add annotation_id field to segmentation annotations
2 parents 9305353 + 49b9982 commit 24158e8

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed

nucleus/annotation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,8 @@ class Annotation:
2626
@classmethod
2727
def from_json(cls, payload: dict):
2828
if payload.get(TYPE_KEY, None) == BOX_TYPE:
29-
geometry = payload.get(GEOMETRY_KEY, {})
3029
return BoxAnnotation.from_json(payload)
3130
elif payload.get(TYPE_KEY, None) == POLYGON_TYPE:
32-
geometry = payload.get(GEOMETRY_KEY, {})
3331
return PolygonAnnotation.from_json(payload)
3432
else:
3533
return SegmentationAnnotation.from_json(payload)
@@ -71,6 +69,7 @@ def __init__(
7169
annotations: List[Segment],
7270
reference_id: Optional[str] = None,
7371
item_id: Optional[str] = None,
72+
annotation_id: Optional[str] = None,
7473
):
7574
super().__init__()
7675
if not mask_url:
@@ -83,6 +82,7 @@ def __init__(
8382
self.annotations = annotations
8483
self.reference_id = reference_id
8584
self.item_id = item_id
85+
self.annotation_id = annotation_id
8686

8787
def __str__(self):
8888
return str(self.to_payload())
@@ -97,12 +97,14 @@ def from_json(cls, payload: dict):
9797
],
9898
reference_id=payload.get(REFERENCE_ID_KEY, None),
9999
item_id=payload.get(ITEM_ID_KEY, None),
100+
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
100101
)
101102

102103
def to_payload(self) -> dict:
103104
payload = {
104105
MASK_URL_KEY: self.mask_url,
105106
ANNOTATIONS_KEY: [ann.to_payload() for ann in self.annotations],
107+
ANNOTATION_ID_KEY: self.annotation_id,
106108
}
107109
if self.reference_id:
108110
payload[REFERENCE_ID_KEY] = self.reference_id

nucleus/payload_constructor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def construct_annotation_payload(
4949

5050

5151
def construct_segmentation_payload(
52-
annotation_items: List[
53-
Union[SegmentationAnnotation, SegmentationPrediction]
52+
annotation_items: Union[
53+
List[SegmentationAnnotation], List[SegmentationPrediction]
5454
],
5555
update: bool,
5656
) -> dict:

nucleus/prediction.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626

2727
class SegmentationPrediction(SegmentationAnnotation):
28+
# No need to define init or to_payload methods because
29+
# we default to functions defined in the parent class
2830
@classmethod
2931
def from_json(cls, payload: dict):
3032
return cls(
@@ -37,10 +39,6 @@ def from_json(cls, payload: dict):
3739
item_id=payload.get(ITEM_ID_KEY, None),
3840
)
3941

40-
def to_payload(self) -> dict:
41-
payload = super().to_payload()
42-
return payload
43-
4442

4543
class BoxPrediction(BoxAnnotation):
4644
def __init__(

tests/helpers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,11 @@ def reference_id_from_url(url):
9494
for i in range(len(TEST_IMG_URLS))
9595
]
9696

97-
TEST_MASK_URL = "https://scale-temp.s3.amazonaws.com/scale-select/nucleus/mscoco_semseg_masks_uint8/000000000285.png"
97+
TEST_MASK_URL = "https://scale-ml.s3.amazonaws.com/tmp/nucleus/mscoco_semseg_masks/000000000285.png"
9898
TEST_SEGMENTATION_ANNOTATIONS = [
9999
{
100100
"reference_id": reference_id_from_url(TEST_IMG_URLS[i]),
101+
"annotation_id": f"[Pytest] Segmentation Annotation Id{i}",
101102
"mask_url": get_signed_url(TEST_MASK_URL),
102103
"annotations": [
103104
{"label": "bear", "index": 2},
@@ -151,6 +152,9 @@ def assert_segmentation_annotation_matches_dict(
151152
annotation_instance, annotation_dict
152153
):
153154
assert annotation_instance.mask_url == annotation_dict["mask_url"]
155+
assert (
156+
annotation_instance.annotation_id == annotation_dict["annotation_id"]
157+
)
154158
# Cannot guarantee segments are in same order
155159
assert len(annotation_instance.annotations) == len(
156160
annotation_dict["annotations"]

0 commit comments

Comments
 (0)