Skip to content

Commit 51955b5

Browse files
committed
All predictions and annotations have better repr
1 parent 9013806 commit 51955b5

File tree

7 files changed

+95
-109
lines changed

7 files changed

+95
-109
lines changed

nucleus/annotation.py

Lines changed: 45 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,16 @@
2121
ANNOTATIONS_KEY,
2222
)
2323

24+
from dataclasses import dataclass
2425

2526
class Annotation:
27+
28+
def _check_ids(self):
29+
if not bool(self.reference_id) and not bool(self.item_id):
30+
raise Exception(
31+
"You must specify either a reference_id or an item_id for an annotation."
32+
)
33+
2634
@classmethod
2735
def from_json(cls, payload: dict):
2836
if payload.get(TYPE_KEY, None) == BOX_TYPE:
@@ -32,17 +40,11 @@ def from_json(cls, payload: dict):
3240
else:
3341
return SegmentationAnnotation.from_json(payload)
3442

35-
43+
@dataclass
3644
class Segment:
37-
def __init__(
38-
self, label: str, index: int, metadata: Optional[dict] = None
39-
):
40-
self.label = label
41-
self.index = index
42-
self.metadata = metadata
43-
44-
def __str__(self):
45-
return str(self.to_payload())
45+
label: str
46+
index: int
47+
metadata: Optional[dict] = None
4648

4749
@classmethod
4850
def from_json(cls, payload: dict):
@@ -62,30 +64,18 @@ def to_payload(self) -> dict:
6264
return payload
6365

6466

67+
@dataclass
6568
class SegmentationAnnotation(Annotation):
66-
def __init__(
67-
self,
68-
mask_url: str,
69-
annotations: List[Segment],
70-
reference_id: Optional[str] = None,
71-
item_id: Optional[str] = None,
72-
annotation_id: Optional[str] = None,
73-
):
74-
super().__init__()
75-
if not mask_url:
69+
mask_url: str
70+
annotations: List[Segment]
71+
reference_id: Optional[str] = None
72+
item_id: Optional[str] = None
73+
annotation_id: Optional[str] = None
74+
def __post_init__(self):
75+
if not self.mask_url:
7676
raise Exception("You must specify a mask_url.")
77-
if bool(reference_id) == bool(item_id):
78-
raise Exception(
79-
"You must specify either a reference_id or an item_id for an annotation."
80-
)
81-
self.mask_url = mask_url
82-
self.annotations = annotations
83-
self.reference_id = reference_id
84-
self.item_id = item_id
85-
self.annotation_id = annotation_id
86-
87-
def __str__(self):
88-
return str(self.to_payload())
77+
self._check_ids()
78+
8979

9080
@classmethod
9181
def from_json(cls, payload: dict):
@@ -120,35 +110,20 @@ class AnnotationTypes(Enum):
120110
POLYGON = POLYGON_TYPE
121111

122112

123-
# TODO: Add base annotation class to reduce repeated code here
113+
@dataclass
124114
class BoxAnnotation(Annotation):
125-
# pylint: disable=too-many-instance-attributes
126-
def __init__(
127-
self,
128-
label: str,
129-
x: Union[float, int],
130-
y: Union[float, int],
131-
width: Union[float, int],
132-
height: Union[float, int],
133-
reference_id: Optional[str] = None,
134-
item_id: Optional[str] = None,
135-
annotation_id: Optional[str] = None,
136-
metadata: Optional[Dict] = None,
137-
):
138-
super().__init__()
139-
if bool(reference_id) == bool(item_id):
140-
raise Exception(
141-
"You must specify either a reference_id or an item_id for an annotation."
142-
)
143-
self.label = label
144-
self.x = x
145-
self.y = y
146-
self.width = width
147-
self.height = height
148-
self.reference_id = reference_id
149-
self.item_id = item_id
150-
self.annotation_id = annotation_id
151-
self.metadata = metadata if metadata else {}
115+
label: str
116+
x: Union[float, int]
117+
y: Union[float, int]
118+
width: Union[float, int]
119+
height: Union[float, int]
120+
reference_id: Optional[str] = None
121+
item_id: Optional[str] = None
122+
annotation_id: Optional[str] = None
123+
metadata: Optional[Dict] = None
124+
def __post_init__(self):
125+
self._check_ids()
126+
self.metadata = self.metadata if self.metadata else {}
152127

153128
@classmethod
154129
def from_json(cls, payload: dict):
@@ -180,32 +155,19 @@ def to_payload(self) -> dict:
180155
METADATA_KEY: self.metadata,
181156
}
182157

183-
def __str__(self):
184-
return str(self.to_payload())
185-
186158

187159
# TODO: Add Generic type for 2D point
160+
@dataclass
188161
class PolygonAnnotation(Annotation):
189-
def __init__(
190-
self,
191-
label: str,
192-
vertices: List[Any],
193-
reference_id: Optional[str] = None,
194-
item_id: Optional[str] = None,
195-
annotation_id: Optional[str] = None,
196-
metadata: Optional[Dict] = None,
197-
):
198-
super().__init__()
199-
if bool(reference_id) == bool(item_id):
200-
raise Exception(
201-
"You must specify either a reference_id or an item_id for an annotation."
202-
)
203-
self.label = label
204-
self.vertices = vertices
205-
self.reference_id = reference_id
206-
self.item_id = item_id
207-
self.annotation_id = annotation_id
208-
self.metadata = metadata if metadata else {}
162+
label: str
163+
vertices: List[Any]
164+
reference_id: Optional[str] = None
165+
item_id: Optional[str] = None
166+
annotation_id: Optional[str] = None
167+
metadata: Optional[Dict] = None
168+
def __post_init__(self):
169+
self._check_ids()
170+
self.metadata = self.metadata if self.metadata else {}
209171

210172
@classmethod
211173
def from_json(cls, payload: dict):
@@ -228,6 +190,3 @@ def to_payload(self) -> dict:
228190
ANNOTATION_ID_KEY: self.annotation_id,
229191
METADATA_KEY: self.metadata,
230192
}
231-
232-
def __str__(self):
233-
return str(self.to_payload())

nucleus/dataset_item.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,19 @@
77
DATASET_ITEM_ID_KEY,
88
)
99

10+
from dataclasses import dataclass
1011

12+
13+
@dataclass
1114
class DatasetItem:
12-
def __init__(
13-
self,
14-
image_location: str,
15-
reference_id: str = None,
16-
item_id: str = None,
17-
metadata: dict = None,
18-
):
19-
20-
self.image_url = image_location
21-
self.local = self._is_local_path(image_location)
22-
self.item_id = item_id
23-
self.reference_id = reference_id
24-
self.metadata = metadata
15+
16+
image_location: str
17+
reference_id: str = None
18+
item_id: str = None
19+
metadata: dict = None
20+
21+
def __post_init__(self):
22+
self.local = self._is_local_path(self.image_location)
2523

2624
@classmethod
2725
def from_json(cls, payload: dict):
@@ -35,9 +33,6 @@ def from_json(cls, payload: dict):
3533
metadata=payload.get(METADATA_KEY, {}),
3634
)
3735

38-
def __str__(self):
39-
return str(self.to_payload())
40-
4136
def _is_local_path(self, path: str) -> bool:
4237
path_components = [comp.lower() for comp in path.split("/")]
4338
return not (

nucleus/model_run.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ def _format_prediction_response(
130130
List[Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]],
131131
]:
132132
annotation_payload = response.get(ANNOTATIONS_KEY, None)
133-
SegmentationPrediction.from_json({"asdf": "asdf"})
134133
if not annotation_payload:
135134
# An error occurred
136135
return response

nucleus/prediction.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,6 @@ def from_json(cls, payload: dict) -> BoxPrediction:
9393
metadata=payload.get(METADATA_KEY, {}),
9494
)
9595

96-
def __str__(self):
97-
return str(self.to_payload())
98-
9996

10097
class PolygonPrediction(PolygonAnnotation):
10198
def __init__(
@@ -132,6 +129,3 @@ def from_json(cls, payload: dict) -> PolygonPrediction:
132129
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
133130
metadata=payload.get(METADATA_KEY, {}),
134131
)
135-
136-
def __str__(self):
137-
return str(self.to_payload())

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ tqdm = "^4.60.0"
4141
poetry = "^1.1.5"
4242
pytest = "^6.2.3"
4343
pylint = "^2.7.4"
44+
boto3 = "^1.17.51"
4445

4546
[build-system]
4647
requires = ["poetry-core>=1.0.0"]

tests/test_annotation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,29 @@
1717
PolygonAnnotation,
1818
SegmentationAnnotation,
1919
DatasetItem,
20+
Segment,
2021
)
2122
from nucleus.constants import ERROR_PAYLOAD
2223

2324

25+
def test_repr(test_object: any):
26+
assert eval(str(test_object)) == test_object
27+
28+
29+
def test_reprs():
30+
[
31+
test_repr(SegmentationAnnotation.from_json(_))
32+
for _ in TEST_SEGMENTATION_ANNOTATIONS
33+
]
34+
35+
[test_repr(BoxAnnotation.from_json(_)) for _ in TEST_BOX_ANNOTATIONS]
36+
37+
[
38+
test_repr(PolygonAnnotation.from_json(_))
39+
for _ in TEST_POLYGON_ANNOTATIONS
40+
]
41+
42+
2443
@pytest.fixture()
2544
def dataset(CLIENT):
2645
ds = CLIENT.create_dataset(TEST_DATASET_NAME)

tests/test_prediction.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,29 @@
2020
PolygonPrediction,
2121
SegmentationPrediction,
2222
DatasetItem,
23+
Segment
2324
)
2425
from nucleus.constants import ERROR_PAYLOAD
2526

2627

28+
def test_repr(test_object: any):
29+
assert eval(str(test_object)) == test_object
30+
31+
32+
def test_reprs():
33+
[
34+
test_repr(SegmentationPrediction.from_json(_))
35+
for _ in TEST_SEGMENTATION_PREDICTIONS
36+
]
37+
38+
[test_repr(BoxPrediction.from_json(_)) for _ in TEST_BOX_PREDICTIONS]
39+
40+
[
41+
test_repr(PolygonPrediction.from_json(_))
42+
for _ in TEST_POLYGON_PREDICTIONS
43+
]
44+
45+
2746
@pytest.fixture()
2847
def model_run(CLIENT):
2948
ds = CLIENT.create_dataset(TEST_DATASET_NAME)

0 commit comments

Comments
 (0)