Skip to content

Commit ef23e95

Browse files
author
Diego Ardila
committed
Point fix and most tests passing
1 parent 08c2f05 commit ef23e95

File tree

6 files changed

+88
-27
lines changed

6 files changed

+88
-27
lines changed

nucleus/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
PolygonAnnotation,
6868
Segment,
6969
SegmentationAnnotation,
70+
Point,
7071
)
7172
from .constants import (
7273
ANNOTATION_METADATA_SCHEMA_KEY,

nucleus/annotation.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from dataclasses import dataclass
33
from enum import Enum
4-
from typing import Any, Dict, List, Optional, Sequence, Union
4+
from typing import Dict, List, Optional, Sequence, Union
55
from nucleus.dataset_item import is_local_path
66

77
from .constants import (
@@ -175,10 +175,25 @@ def to_payload(self) -> dict:
175175

176176

177177
# TODO: Add Generic type for 2D point
178+
179+
180+
@dataclass
181+
class Point:
182+
x: float
183+
y: float
184+
185+
@classmethod
186+
def from_json(cls, payload: Dict[str, float]):
187+
return cls(payload[X_KEY], payload[Y_KEY])
188+
189+
def to_payload(self) -> dict:
190+
return {X_KEY: self.x, Y_KEY: self.y}
191+
192+
178193
@dataclass
179194
class PolygonAnnotation(Annotation):
180195
label: str
181-
vertices: List[Any]
196+
vertices: List[Point]
182197
reference_id: Optional[str] = None
183198
item_id: Optional[str] = None
184199
annotation_id: Optional[str] = None
@@ -187,28 +202,40 @@ class PolygonAnnotation(Annotation):
187202
def __post_init__(self):
188203
self._check_ids()
189204
self.metadata = self.metadata if self.metadata else {}
205+
if len(self.vertices) > 0:
206+
if not hasattr(self.vertices[0], X_KEY) or not hasattr(
207+
self.vertices[0], "to_payload"
208+
):
209+
raise ValueError(
210+
"Use the Point object, not a dictionary for vertices"
211+
)
190212

191213
@classmethod
192214
def from_json(cls, payload: dict):
193215
geometry = payload.get(GEOMETRY_KEY, {})
194216
return cls(
195217
label=payload.get(LABEL_KEY, 0),
196-
vertices=geometry.get(VERTICES_KEY, []),
218+
vertices=[
219+
Point.from_json(_) for _ in geometry.get(VERTICES_KEY, [])
220+
],
197221
reference_id=payload.get(REFERENCE_ID_KEY, None),
198222
item_id=payload.get(DATASET_ITEM_ID_KEY, None),
199223
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
200224
metadata=payload.get(METADATA_KEY, {}),
201225
)
202226

203227
def to_payload(self) -> dict:
204-
return {
228+
payload = {
205229
LABEL_KEY: self.label,
206230
TYPE_KEY: POLYGON_TYPE,
207-
GEOMETRY_KEY: {VERTICES_KEY: self.vertices},
231+
GEOMETRY_KEY: {
232+
VERTICES_KEY: [_.to_payload() for _ in self.vertices]
233+
},
208234
REFERENCE_ID_KEY: self.reference_id,
209235
ANNOTATION_ID_KEY: self.annotation_id,
210236
METADATA_KEY: self.metadata,
211237
}
238+
return payload
212239

213240

214241
def check_all_annotation_paths_remote(

nucleus/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def convert_export_payload(api_payload):
9292
return_payload_row = {}
9393
return_payload_row[ITEM_KEY] = DatasetItem.from_json(row[ITEM_KEY])
9494
annotations = defaultdict(list)
95-
if row[SEGMENTATION_TYPE] is not None:
95+
if row.get(SEGMENTATION_TYPE) is not None:
9696
segmentation = row[SEGMENTATION_TYPE]
9797
segmentation[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
9898
annotations[SEGMENTATION_TYPE] = SegmentationAnnotation.from_json(

tests/helpers.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,15 @@ def reference_id_from_url(url):
5757
TEST_POLYGON_ANNOTATIONS = [
5858
{
5959
"label": f"[Pytest] Polygon Annotation ${i}",
60-
"vertices": [
61-
{
62-
"x": 50 + i * 10 + j,
63-
"y": 60 + i * 10 + j,
64-
}
65-
for j in range(3)
66-
],
60+
"geometry": {
61+
"vertices": [
62+
{
63+
"x": 50 + i * 10 + j,
64+
"y": 60 + i * 10 + j,
65+
}
66+
for j in range(3)
67+
],
68+
},
6769
"reference_id": reference_id_from_url(TEST_IMG_URLS[i]),
6870
"annotation_id": f"[Pytest] Polygon Annotation Annotation Id{i}",
6971
}
@@ -139,10 +141,10 @@ def assert_polygon_annotation_matches_dict(
139141
annotation_instance.annotation_id == annotation_dict["annotation_id"]
140142
)
141143
for instance_pt, dict_pt in zip(
142-
annotation_instance.vertices, annotation_dict["vertices"]
144+
annotation_instance.vertices, annotation_dict["geometry"]["vertices"]
143145
):
144-
assert instance_pt["x"] == dict_pt["x"]
145-
assert instance_pt["y"] == dict_pt["y"]
146+
assert instance_pt.x == dict_pt["x"]
147+
assert instance_pt.y == dict_pt["y"]
146148

147149

148150
def assert_segmentation_annotation_matches_dict(

tests/test_annotation.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
SegmentationAnnotation,
1919
DatasetItem,
2020
Segment,
21+
Point,
2122
)
2223
from nucleus.constants import ERROR_PAYLOAD
2324

@@ -77,7 +78,7 @@ def test_box_gt_upload(dataset):
7778

7879

7980
def test_polygon_gt_upload(dataset):
80-
annotation = PolygonAnnotation(**TEST_POLYGON_ANNOTATIONS[0])
81+
annotation = PolygonAnnotation.from_json(TEST_POLYGON_ANNOTATIONS[0])
8182
response = dataset.annotate(annotations=[annotation])
8283

8384
assert response["dataset_id"] == dataset.id
@@ -241,7 +242,7 @@ def test_box_gt_upload_ignore(dataset):
241242

242243

243244
def test_polygon_gt_upload_update(dataset):
244-
annotation = PolygonAnnotation(**TEST_POLYGON_ANNOTATIONS[0])
245+
annotation = PolygonAnnotation.from_json(TEST_POLYGON_ANNOTATIONS[0])
245246
response = dataset.annotate(annotations=[annotation])
246247

247248
assert response["annotations_processed"] == 1
@@ -255,7 +256,7 @@ def test_polygon_gt_upload_update(dataset):
255256
"reference_id"
256257
]
257258

258-
annotation_update = PolygonAnnotation(**annotation_update_params)
259+
annotation_update = PolygonAnnotation.from_json(annotation_update_params)
259260
response = dataset.annotate(annotations=[annotation_update], update=True)
260261

261262
assert response["annotations_processed"] == 1
@@ -272,7 +273,7 @@ def test_polygon_gt_upload_update(dataset):
272273

273274

274275
def test_polygon_gt_upload_ignore(dataset):
275-
annotation = PolygonAnnotation(**TEST_POLYGON_ANNOTATIONS[0])
276+
annotation = PolygonAnnotation.from_json(TEST_POLYGON_ANNOTATIONS[0])
276277
response = dataset.annotate(annotations=[annotation])
277278

278279
assert response["annotations_processed"] == 1
@@ -286,7 +287,7 @@ def test_polygon_gt_upload_ignore(dataset):
286287
"reference_id"
287288
]
288289

289-
annotation_update = PolygonAnnotation(**annotation_update_params)
290+
annotation_update = PolygonAnnotation.from_json(annotation_update_params)
290291
# Default behavior is ignore.
291292
response = dataset.annotate(annotations=[annotation_update])
292293

tests/test_dataset.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
ERROR_PAYLOAD,
2222
IGNORED_ITEMS,
2323
NEW_ITEMS,
24+
POLYGON_TYPE,
25+
SEGMENTATION_TYPE,
2426
UPDATED_ITEMS,
2527
ITEM_KEY,
2628
ANNOTATIONS_KEY,
@@ -340,7 +342,13 @@ def test_annotate_async_with_error(dataset: Dataset):
340342
def test_append_and_export(dataset):
341343
# Dataset upload
342344
url = TEST_IMG_URLS[0]
343-
annotation = BoxAnnotation(**TEST_BOX_ANNOTATIONS[0])
345+
box_annotation = BoxAnnotation(**TEST_BOX_ANNOTATIONS[0])
346+
segmentation_annotation = SegmentationAnnotation.from_json(
347+
TEST_SEGMENTATION_ANNOTATIONS[0]
348+
)
349+
polygon_annotation = PolygonAnnotation.from_json(
350+
TEST_POLYGON_ANNOTATIONS[0]
351+
)
344352

345353
ds_items = [
346354
DatasetItem(
@@ -352,12 +360,34 @@ def test_append_and_export(dataset):
352360
response = dataset.append(ds_items)
353361
assert ERROR_PAYLOAD not in response.json()
354362

355-
dataset.annotate(annotations=[annotation])
363+
dataset.annotate(
364+
annotations=[
365+
box_annotation,
366+
polygon_annotation,
367+
segmentation_annotation,
368+
]
369+
)
356370

357-
expected_box_annotation = copy.deepcopy(annotation)
358-
expected_box_annotation.annotation_id = None
359-
expected_box_annotation.metadata = {}
371+
# We don't export everything on the annotations in order to speed up export.
372+
def clear_fields(annotation):
373+
cleared_annotation = copy.deepcopy(annotation)
374+
cleared_annotation.annotation_id = None
375+
cleared_annotation.metadata = {}
376+
return cleared_annotation
377+
378+
def sort_labelmap(segmentation_annotation):
379+
segmentation_annotation.annotations = sorted(
380+
segmentation_annotation.annotations, key=lambda x: x.index
381+
)
360382

361383
exported = dataset.items_and_annotations()
362384
assert exported[0][ITEM_KEY] == ds_items[0]
363-
assert exported[0][ANNOTATIONS_KEY][BOX_TYPE][0] == expected_box_annotation
385+
assert exported[0][ANNOTATIONS_KEY][BOX_TYPE][0] == clear_fields(
386+
box_annotation
387+
)
388+
assert sort_labelmap(
389+
exported[0][ANNOTATIONS_KEY][SEGMENTATION_TYPE]
390+
) == sort_labelmap(clear_fields(segmentation_annotation))
391+
assert exported[0][ANNOTATIONS_KEY][POLYGON_TYPE][0] == clear_fields(
392+
polygon_annotation
393+
)

0 commit comments

Comments
 (0)