Skip to content

Commit 37f1a92

Browse files
committed
address PR comments
1 parent a383ab0 commit 37f1a92

File tree

4 files changed

+66
-22
lines changed

4 files changed

+66
-22
lines changed

nucleus/annotation.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,54 @@
2222
)
2323

2424

25+
class Annotation:
26+
@classmethod
27+
def from_json(cls, payload: dict):
28+
if payload.get(TYPE_KEY, None) == BOX_TYPE:
29+
geometry = payload.get(GEOMETRY_KEY, {})
30+
return BoxAnnotation.from_json(payload)
31+
elif payload.get(TYPE_KEY, None) == POLYGON_TYPE:
32+
geometry = payload.get(GEOMETRY_KEY, {})
33+
return PolygonAnnotation.from_json(payload)
34+
else:
35+
return SegmentationAnnotation.from_json(payload)
36+
37+
# def from_json(cls, payload: dict):
38+
# if payload.get(TYPE_KEY, None) == BOX_TYPE:
39+
# geometry = payload.get(GEOMETRY_KEY, {})
40+
# return BoxAnnotation(
41+
# label=payload.get(LABEL_KEY, 0),
42+
# x=geometry.get(X_KEY, 0),
43+
# y=geometry.get(Y_KEY, 0),
44+
# width=geometry.get(WIDTH_KEY, 0),
45+
# height=geometry.get(HEIGHT_KEY, 0),
46+
# reference_id=payload.get(REFERENCE_ID_KEY, None),
47+
# item_id=payload.get(DATASET_ITEM_ID_KEY, None),
48+
# annotation_id=payload.get(ANNOTATION_ID_KEY, None),
49+
# metadata=payload.get(METADATA_KEY, {}),
50+
# )
51+
# elif payload.get(TYPE_KEY, None) == POLYGON_TYPE:
52+
# geometry = payload.get(GEOMETRY_KEY, {})
53+
# return PolygonAnnotation(
54+
# label=payload.get(LABEL_KEY, 0),
55+
# vertices=geometry.get(VERTICES_KEY, []),
56+
# reference_id=payload.get(REFERENCE_ID_KEY, None),
57+
# item_id=payload.get(DATASET_ITEM_ID_KEY, None),
58+
# annotation_id=payload.get(ANNOTATION_ID_KEY, None),
59+
# metadata=payload.get(METADATA_KEY, {}),
60+
# )
61+
# else:
62+
# return SegmentationAnnotation(
63+
# mask_url=payload.get(MASK_URL_KEY),
64+
# annotations=[
65+
# Segment.from_json(ann)
66+
# for ann in payload.get(ANNOTATIONS_KEY, [])
67+
# ],
68+
# reference_id=payload.get(REFERENCE_ID_KEY, None),
69+
# item_id=payload.get(ITEM_ID_KEY, None),
70+
# )
71+
72+
2573
class Segment:
2674
def __init__(
2775
self, label: str, index: int, metadata: Optional[dict] = None
@@ -51,14 +99,17 @@ def to_payload(self) -> dict:
5199
return payload
52100

53101

54-
class SegmentationAnnotation:
102+
class SegmentationAnnotation(Annotation):
55103
def __init__(
56104
self,
57105
mask_url: str,
58106
annotations: List[Segment],
59107
reference_id: Optional[str] = None,
60108
item_id: Optional[str] = None,
61109
):
110+
super().__init__()
111+
if not mask_url:
112+
raise Exception("You must specify a mask_url.")
62113
if bool(reference_id) == bool(item_id):
63114
raise Exception(
64115
"You must specify either a reference_id or an item_id for an annotation."
@@ -74,7 +125,7 @@ def __str__(self):
74125
@classmethod
75126
def from_json(cls, payload: dict):
76127
return cls(
77-
mask_url=payload.get(MASK_URL_KEY, None),
128+
mask_url=payload.get(MASK_URL_KEY),
78129
annotations=[
79130
Segment.from_json(ann)
80131
for ann in payload.get(ANNOTATIONS_KEY, [])
@@ -101,7 +152,7 @@ class AnnotationTypes(Enum):
101152

102153

103154
# TODO: Add base annotation class to reduce repeated code here
104-
class BoxAnnotation:
155+
class BoxAnnotation(Annotation):
105156
# pylint: disable=too-many-instance-attributes
106157
def __init__(
107158
self,
@@ -115,6 +166,7 @@ def __init__(
115166
annotation_id: Optional[str] = None,
116167
metadata: Optional[Dict] = None,
117168
):
169+
super().__init__()
118170
if bool(reference_id) == bool(item_id):
119171
raise Exception(
120172
"You must specify either a reference_id or an item_id for an annotation."
@@ -164,7 +216,7 @@ def __str__(self):
164216

165217

166218
# TODO: Add Generic type for 2D point
167-
class PolygonAnnotation:
219+
class PolygonAnnotation(Annotation):
168220
def __init__(
169221
self,
170222
label: str,
@@ -174,6 +226,7 @@ def __init__(
174226
annotation_id: Optional[str] = None,
175227
metadata: Optional[Dict] = None,
176228
):
229+
super().__init__()
177230
if bool(reference_id) == bool(item_id):
178231
raise Exception(
179232
"You must specify either a reference_id or an item_id for an annotation."

nucleus/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
BOX_TYPE = "box"
4949
POLYGON_TYPE = "polygon"
5050
SEGMENTATION_TYPE = "segmentation"
51+
ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE)
5152
GEOMETRY_KEY = "geometry"
5253
AUTOTAGS_KEY = "autotags"
5354
MASK_URL_KEY = "mask_url"

nucleus/dataset.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Dict, Any, Optional, Union
22
from .dataset_item import DatasetItem
33
from .annotation import (
4+
Annotation,
45
BoxAnnotation,
56
PolygonAnnotation,
67
SegmentationAnnotation,
@@ -19,6 +20,7 @@
1920
BOX_TYPE,
2021
POLYGON_TYPE,
2122
SEGMENTATION_TYPE,
23+
ANNOTATION_TYPES,
2224
)
2325
from .payload_constructor import construct_model_run_creation_payload
2426

@@ -178,7 +180,6 @@ def refloc(self, reference_id: str) -> dict:
178180
}
179181
"""
180182
response = self._client.dataitem_ref_id(self.id, reference_id)
181-
print("RESPONSE" + str(response))
182183
return self._format_dataset_item_response(response)
183184

184185
def loc(self, dataset_item_id: str) -> dict:
@@ -242,22 +243,12 @@ def _format_dataset_item_response(self, response: dict) -> dict:
242243
return response
243244

244245
annotation_response = {}
245-
if BOX_TYPE in annotation_payload:
246-
annotation_response[BOX_TYPE] = [
247-
BoxAnnotation.from_json(ann)
248-
for ann in annotation_payload[BOX_TYPE]
249-
]
250-
if POLYGON_TYPE in annotation_payload:
251-
annotation_response[POLYGON_TYPE] = [
252-
PolygonAnnotation.from_json(ann)
253-
for ann in annotation_payload[POLYGON_TYPE]
254-
]
255-
if SEGMENTATION_TYPE in annotation_payload:
256-
annotation_response[
257-
SEGMENTATION_TYPE
258-
] = SegmentationAnnotation.from_json(
259-
annotation_payload[SEGMENTATION_TYPE]
260-
)
246+
for annotation_type in ANNOTATION_TYPES:
247+
if annotation_type in annotation_payload:
248+
annotation_response[annotation_type] = [
249+
Annotation.from_json(ann)
250+
for ann in annotation_payload[annotation_type]
251+
]
261252
return {
262253
ITEM_KEY: DatasetItem.from_json(item),
263254
ANNOTATIONS_KEY: annotation_response,

nucleus/model_run.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def _format_prediction_response(
127127
self, response: dict
128128
) -> Union[dict, List[Union[BoxPrediction, PolygonPrediction]]]:
129129
annotations = response.get(ANNOTATIONS_KEY, None)
130-
print(annotations)
131130
if annotations:
132131
annotation_response = {}
133132
if BOX_TYPE in annotations:

0 commit comments

Comments
 (0)