Skip to content

Commit 0d1ffd1

Browse files
author
Claire Pajot
committed
Initial classification support
1 parent 38913d6 commit 0d1ffd1

File tree

6 files changed

+80
-0
lines changed

6 files changed

+80
-0
lines changed

nucleus/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
Point,
7474
Point3D,
7575
PolygonAnnotation,
76+
CategoryAnnotation,
7677
Segment,
7778
SegmentationAnnotation,
7879
)
@@ -622,6 +623,7 @@ def annotate_dataset(
622623
BoxAnnotation,
623624
PolygonAnnotation,
624625
CuboidAnnotation,
626+
CategoryAnnotation,
625627
SegmentationAnnotation,
626628
]
627629
],

nucleus/annotation.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ANNOTATION_ID_KEY,
99
ANNOTATIONS_KEY,
1010
BOX_TYPE,
11+
CATEGORY_TYPE,
1112
CUBOID_TYPE,
1213
DATASET_ITEM_ID_KEY,
1314
DIMENSIONS_KEY,
@@ -135,6 +136,7 @@ class AnnotationTypes(Enum):
135136
BOX = BOX_TYPE
136137
POLYGON = POLYGON_TYPE
137138
CUBOID = CUBOID_TYPE
139+
CATEGORY = CATEGORY_TYPE
138140

139141

140142
@dataclass # pylint: disable=R0902
@@ -314,6 +316,40 @@ def to_payload(self) -> dict:
314316
return payload
315317

316318

319+
@dataclass
320+
class CategoryAnnotation(Annotation):
321+
label: str
322+
reference_id: Optional[str] = None
323+
item_id: Optional[str] = None
324+
annotation_id: Optional[str] = None
325+
metadata: Optional[Dict] = None
326+
327+
def __post_init__(self):
328+
self._check_ids()
329+
self.metadata = self.metadata if self.metadata else {}
330+
331+
@classmethod
332+
def from_json(cls, payload: dict):
333+
# TODO: Remove? geometry = payload.get(GEOMETRY_KEY, {})
334+
return cls(
335+
label=payload.get(LABEL_KEY, 0),
336+
reference_id=payload.get(REFERENCE_ID_KEY, None),
337+
item_id=payload.get(DATASET_ITEM_ID_KEY, None),
338+
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
339+
metadata=payload.get(METADATA_KEY, {}),
340+
)
341+
342+
def to_payload(self) -> dict:
343+
return {
344+
LABEL_KEY: self.label,
345+
TYPE_KEY: CATEGORY_TYPE,
346+
GEOMETRY_KEY: {},
347+
REFERENCE_ID_KEY: self.reference_id,
348+
ANNOTATION_ID_KEY: self.annotation_id,
349+
METADATA_KEY: self.metadata,
350+
}
351+
352+
317353
def is_local_path(path: str) -> bool:
318354
return urlparse(path).scheme not in {"https", "http", "s3", "gs"}
319355

nucleus/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
MASK_TYPE = "mask"
99
SEGMENTATION_TYPE = "segmentation"
1010
CUBOID_TYPE = "cuboid"
11+
CATEGORY_TYPE = "category"
1112
ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE, CUBOID_TYPE)
1213
ANNOTATION_UPDATE_KEY = "update"
1314
AUTOTAGS_KEY = "autotags"

nucleus/payload_constructor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
BoxAnnotation,
66
CuboidAnnotation,
77
PolygonAnnotation,
8+
CategoryAnnotation,
89
SegmentationAnnotation,
910
)
1011
from .prediction import (
@@ -57,6 +58,7 @@ def construct_annotation_payload(
5758
BoxAnnotation,
5859
PolygonAnnotation,
5960
CuboidAnnotation,
61+
CategoryAnnotation,
6062
SegmentationAnnotation,
6163
]
6264
],

tests/helpers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,14 @@ def reference_id_from_url(url):
155155
for i in range(len(TEST_POINTCLOUD_URLS))
156156
]
157157

158+
TEST_CATEGORY_ANNOTATIONS = [
159+
{
160+
"label": f"[Pytest] Category Annotation ${i}",
161+
"reference_id": reference_id_from_url(TEST_IMG_URLS[i]),
162+
"annotation_id": f"[Pytest] Category Annotation Annotation Id{i}",
163+
}
164+
for i in range(len(TEST_IMG_URLS))
165+
]
158166

159167
TEST_MASK_URL = "https://raw.githubusercontent.com/scaleapi/nucleus-python-client/master/tests/testdata/000000000285.png"
160168

@@ -263,6 +271,16 @@ def assert_cuboid_annotation_matches_dict(
263271
assert annotation_instance.yaw == annotation_dict["geometry"]["yaw"]
264272

265273

274+
def assert_category_annotation_matches_dict(
275+
annotation_instance, annotation_dict
276+
):
277+
assert annotation_instance.label == annotation_dict["label"]
278+
279+
assert (
280+
annotation_instance.annotation_id == annotation_dict["annotation_id"]
281+
)
282+
283+
266284
def assert_segmentation_annotation_matches_dict(
267285
annotation_instance, annotation_dict
268286
):

tests/test_annotation.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,19 @@
55
TEST_IMG_URLS,
66
TEST_BOX_ANNOTATIONS,
77
TEST_POLYGON_ANNOTATIONS,
8+
TEST_CATEGORY_ANNOTATIONS,
89
TEST_SEGMENTATION_ANNOTATIONS,
910
reference_id_from_url,
1011
assert_box_annotation_matches_dict,
1112
assert_polygon_annotation_matches_dict,
13+
assert_category_annotation_matches_dict,
1214
assert_segmentation_annotation_matches_dict,
1315
)
1416

1517
from nucleus import (
1618
BoxAnnotation,
1719
PolygonAnnotation,
20+
CategoryAnnotation,
1821
SegmentationAnnotation,
1922
DatasetItem,
2023
Segment,
@@ -95,6 +98,24 @@ def test_polygon_gt_upload(dataset):
9598
)
9699

97100

101+
def test_category_gt_upload(dataset):
102+
annotation = CategoryAnnotation.from_json(TEST_CATEGORY_ANNOTATIONS[0])
103+
response = dataset.annotate(annotations=[annotation])
104+
105+
assert response["dataset_id"] == dataset.id
106+
assert response["annotations_processed"] == 1
107+
assert response["annotations_ignored"] == 0
108+
109+
response = dataset.refloc(annotation.reference_id)["annotations"][
110+
"category"
111+
]
112+
assert len(response) == 1
113+
response_annotation = response[0]
114+
assert_category_annotation_matches_dict(
115+
response_annotation, TEST_CATEGORY_ANNOTATIONS[0]
116+
)
117+
118+
98119
def test_single_semseg_gt_upload(dataset):
99120
annotation = SegmentationAnnotation.from_json(
100121
TEST_SEGMENTATION_ANNOTATIONS[0]

0 commit comments

Comments
 (0)