Skip to content

Commit 3293108

Browse files
authored
Merge pull request #108 from scaleapi/add_classification_type_to_groundtruth
Category Annotation API Support
2 parents b587628 + b8df106 commit 3293108

File tree

9 files changed

+199
-12
lines changed

9 files changed

+199
-12
lines changed

nucleus/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Point,
2626
Point3D,
2727
PolygonAnnotation,
28+
CategoryAnnotation,
2829
Segment,
2930
SegmentationAnnotation,
3031
)
@@ -581,6 +582,7 @@ def annotate_dataset(
581582
BoxAnnotation,
582583
PolygonAnnotation,
583584
CuboidAnnotation,
585+
CategoryAnnotation,
584586
SegmentationAnnotation,
585587
]
586588
],

nucleus/annotation.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ANNOTATION_ID_KEY,
99
ANNOTATIONS_KEY,
1010
BOX_TYPE,
11+
CATEGORY_TYPE,
1112
CUBOID_TYPE,
1213
DIMENSIONS_KEY,
1314
GEOMETRY_KEY,
@@ -20,6 +21,7 @@
2021
POLYGON_TYPE,
2122
POSITION_KEY,
2223
REFERENCE_ID_KEY,
24+
TAXONOMY_NAME_KEY,
2325
TYPE_KEY,
2426
VERTICES_KEY,
2527
WIDTH_KEY,
@@ -41,6 +43,8 @@ def from_json(cls, payload: dict):
4143
return PolygonAnnotation.from_json(payload)
4244
elif payload.get(TYPE_KEY, None) == CUBOID_TYPE:
4345
return CuboidAnnotation.from_json(payload)
46+
elif payload.get(TYPE_KEY, None) == CATEGORY_TYPE:
47+
return CategoryAnnotation.from_json(payload)
4448
else:
4549
return SegmentationAnnotation.from_json(payload)
4650

@@ -120,6 +124,7 @@ class AnnotationTypes(Enum):
120124
BOX = BOX_TYPE
121125
POLYGON = POLYGON_TYPE
122126
CUBOID = CUBOID_TYPE
127+
CATEGORY = CATEGORY_TYPE
123128

124129

125130
@dataclass # pylint: disable=R0902
@@ -291,6 +296,36 @@ def to_payload(self) -> dict:
291296
return payload
292297

293298

299+
@dataclass
300+
class CategoryAnnotation(Annotation):
301+
label: str
302+
taxonomy_name: str
303+
reference_id: str
304+
metadata: Optional[Dict] = None
305+
306+
def __post_init__(self):
307+
self.metadata = self.metadata if self.metadata else {}
308+
309+
@classmethod
310+
def from_json(cls, payload: dict):
311+
return cls(
312+
label=payload[LABEL_KEY],
313+
taxonomy_name=payload[TAXONOMY_NAME_KEY],
314+
reference_id=payload[REFERENCE_ID_KEY],
315+
metadata=payload.get(METADATA_KEY, {}),
316+
)
317+
318+
def to_payload(self) -> dict:
319+
return {
320+
LABEL_KEY: self.label,
321+
TAXONOMY_NAME_KEY: self.taxonomy_name,
322+
TYPE_KEY: CATEGORY_TYPE,
323+
GEOMETRY_KEY: {},
324+
REFERENCE_ID_KEY: self.reference_id,
325+
METADATA_KEY: self.metadata,
326+
}
327+
328+
294329
def is_local_path(path: str) -> bool:
295330
return urlparse(path).scheme not in {"https", "http", "s3", "gs"}
296331

nucleus/constants.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010
CUBOID_TYPE = "cuboid"
1111
CATEGORY_TYPE = "category"
1212
MULTICATEGORY_TYPE = "multicategory"
13-
ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE, CUBOID_TYPE)
13+
ANNOTATION_TYPES = (
14+
BOX_TYPE,
15+
POLYGON_TYPE,
16+
SEGMENTATION_TYPE,
17+
CUBOID_TYPE,
18+
CATEGORY_TYPE,
19+
)
1420
ANNOTATION_UPDATE_KEY = "update"
1521
AUTOTAGS_KEY = "autotags"
1622
AUTOTAG_SCORE_THRESHOLD = "score_threshold"

nucleus/dataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ def annotate(
217217
route=f"dataset/{self.id}/annotate?async=1",
218218
)
219219
return AsyncJob.from_json(response, self._client)
220-
221220
return self._client.annotate_dataset(
222221
self.id, annotations, update=update, batch_size=batch_size
223222
)

nucleus/payload_constructor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
BoxAnnotation,
66
CuboidAnnotation,
77
PolygonAnnotation,
8+
CategoryAnnotation,
89
SegmentationAnnotation,
910
)
1011
from .prediction import (
@@ -60,6 +61,7 @@ def construct_annotation_payload(
6061
BoxAnnotation,
6162
PolygonAnnotation,
6263
CuboidAnnotation,
64+
CategoryAnnotation,
6365
SegmentationAnnotation,
6466
]
6567
],

nucleus/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
BoxAnnotation,
1515
CuboidAnnotation,
1616
PolygonAnnotation,
17+
CategoryAnnotation,
1718
SegmentationAnnotation,
1819
)
1920

@@ -22,6 +23,7 @@
2223
ANNOTATIONS_KEY,
2324
BOX_TYPE,
2425
CUBOID_TYPE,
26+
CATEGORY_TYPE,
2527
ITEM_KEY,
2628
POLYGON_TYPE,
2729
REFERENCE_ID_KEY,
@@ -116,6 +118,11 @@ def convert_export_payload(api_payload):
116118
for cuboid in row[CUBOID_TYPE]:
117119
cuboid[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
118120
annotations[CUBOID_TYPE].append(CuboidAnnotation.from_json(cuboid))
121+
for category in row[CATEGORY_TYPE]:
122+
category[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
123+
annotations[CATEGORY_TYPE].append(
124+
CategoryAnnotation.from_json(category)
125+
)
119126
return_payload_row[ANNOTATIONS_KEY] = annotations
120127
return_payload.append(return_payload_row)
121128
return return_payload

tests/helpers.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,14 @@ def reference_id_from_url(url):
158158
for i in range(len(TEST_POINTCLOUD_URLS))
159159
]
160160

161+
TEST_CATEGORY_ANNOTATIONS = [
162+
{
163+
"label": f"[Pytest] Category Label ${i}",
164+
"reference_id": reference_id_from_url(TEST_IMG_URLS[i]),
165+
"taxonomy_name": "[Pytest] Category Taxonomy 1",
166+
}
167+
for i in range(len(TEST_IMG_URLS))
168+
]
161169

162170
TEST_MASK_URL = "https://raw.githubusercontent.com/scaleapi/nucleus-python-client/master/tests/testdata/000000000285.png"
163171

@@ -266,6 +274,15 @@ def assert_cuboid_annotation_matches_dict(
266274
assert annotation_instance.yaw == annotation_dict["geometry"]["yaw"]
267275

268276

277+
def assert_category_annotation_matches_dict(
278+
annotation_instance, annotation_dict
279+
):
280+
assert annotation_instance.label == annotation_dict["label"]
281+
assert (
282+
annotation_instance.taxonomy_name == annotation_dict["taxonomy_name"]
283+
)
284+
285+
269286
def assert_segmentation_annotation_matches_dict(
270287
annotation_instance, annotation_dict
271288
):

tests/test_annotation.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,19 @@
55
TEST_IMG_URLS,
66
TEST_BOX_ANNOTATIONS,
77
TEST_POLYGON_ANNOTATIONS,
8+
TEST_CATEGORY_ANNOTATIONS,
89
TEST_SEGMENTATION_ANNOTATIONS,
910
reference_id_from_url,
1011
assert_box_annotation_matches_dict,
1112
assert_polygon_annotation_matches_dict,
13+
assert_category_annotation_matches_dict,
1214
assert_segmentation_annotation_matches_dict,
1315
)
1416

1517
from nucleus import (
1618
BoxAnnotation,
1719
PolygonAnnotation,
20+
CategoryAnnotation,
1821
SegmentationAnnotation,
1922
DatasetItem,
2023
Segment,
@@ -55,6 +58,12 @@ def dataset(CLIENT):
5558

5659
response = ds.append(ds_items)
5760
assert ERROR_PAYLOAD not in response.json()
61+
62+
response = ds.add_taxonomy(
63+
"[Pytest] Category Taxonomy 1",
64+
"category",
65+
[f"[Pytest] Category Label ${i}" for i in range((len(TEST_IMG_URLS)))],
66+
)
5867
yield ds
5968

6069
response = CLIENT.delete_dataset(ds.id)
@@ -100,6 +109,24 @@ def test_polygon_gt_upload(dataset):
100109
)
101110

102111

112+
def test_category_gt_upload(dataset):
113+
annotation = CategoryAnnotation.from_json(TEST_CATEGORY_ANNOTATIONS[0])
114+
response = dataset.annotate(annotations=[annotation])
115+
116+
assert response["dataset_id"] == dataset.id
117+
assert response["annotations_processed"] == 1
118+
assert response["annotations_ignored"] == 0
119+
120+
response = dataset.refloc(annotation.reference_id)["annotations"][
121+
"category"
122+
]
123+
assert len(response) == 1
124+
response_annotation = response[0]
125+
assert_category_annotation_matches_dict(
126+
response_annotation, TEST_CATEGORY_ANNOTATIONS[0]
127+
)
128+
129+
103130
def test_single_semseg_gt_upload(dataset):
104131
annotation = SegmentationAnnotation.from_json(
105132
TEST_SEGMENTATION_ANNOTATIONS[0]
@@ -308,6 +335,63 @@ def test_polygon_gt_upload_ignore(dataset):
308335
response_annotation, TEST_POLYGON_ANNOTATIONS[0]
309336
)
310337

338+
339+
def test_category_gt_upload_update(dataset):
340+
annotation = CategoryAnnotation.from_json(TEST_CATEGORY_ANNOTATIONS[0])
341+
response = dataset.annotate(annotations=[annotation])
342+
343+
assert response["annotations_processed"] == 1
344+
345+
# Copy so we don't modify the original.
346+
annotation_update_params = dict(TEST_CATEGORY_ANNOTATIONS[1])
347+
annotation_update_params["reference_id"] = TEST_CATEGORY_ANNOTATIONS[0][
348+
"reference_id"
349+
]
350+
351+
annotation_update = CategoryAnnotation.from_json(annotation_update_params)
352+
response = dataset.annotate(annotations=[annotation_update], update=True)
353+
354+
assert response["annotations_processed"] == 1
355+
assert response["annotations_ignored"] == 0
356+
357+
response = dataset.refloc(annotation.reference_id)["annotations"][
358+
"category"
359+
]
360+
assert len(response) == 1
361+
response_annotation = response[0]
362+
assert_category_annotation_matches_dict(
363+
response_annotation, annotation_update_params
364+
)
365+
366+
367+
def test_category_gt_upload_ignore(dataset):
368+
annotation = CategoryAnnotation.from_json(TEST_CATEGORY_ANNOTATIONS[0])
369+
response = dataset.annotate(annotations=[annotation])
370+
371+
assert response["annotations_processed"] == 1
372+
373+
# Copy so we don't modify the original.
374+
annotation_update_params = dict(TEST_CATEGORY_ANNOTATIONS[1])
375+
annotation_update_params["reference_id"] = TEST_CATEGORY_ANNOTATIONS[0][
376+
"reference_id"
377+
]
378+
379+
annotation_update = CategoryAnnotation.from_json(annotation_update_params)
380+
# Default behavior is ignore.
381+
response = dataset.annotate(annotations=[annotation_update])
382+
383+
assert response["annotations_processed"] == 0
384+
assert response["annotations_ignored"] == 1
385+
386+
response = dataset.refloc(annotation.reference_id)["annotations"][
387+
"category"
388+
]
389+
assert len(response) == 1
390+
response_annotation = response[0]
391+
assert_category_annotation_matches_dict(
392+
response_annotation, TEST_CATEGORY_ANNOTATIONS[0]
393+
)
394+
311395
@pytest.mark.integration
312396
def test_box_gt_deletion(dataset):
313397
annotation = BoxAnnotation(**TEST_BOX_ANNOTATIONS[0])
@@ -323,3 +407,19 @@ def test_box_gt_deletion(dataset):
323407
job_status = job.status()
324408
assert job_status["status"] == "Completed"
325409
assert job_status["job_id"] == job.id
410+
411+
@pytest.mark.integration
412+
def test_category_gt_deletion(dataset):
413+
annotation = CategoryAnnotation.from_json(TEST_CATEGORY_ANNOTATIONS[0])
414+
415+
print(annotation)
416+
417+
response = dataset.annotate(annotations=[annotation])
418+
419+
assert response["annotations_processed"] == 1
420+
421+
job = dataset.delete_annotations()
422+
job.sleep_until_complete()
423+
job_status = job.status()
424+
assert job_status["status"] == "Completed"
425+
assert job_status["job_id"] == job.id

0 commit comments

Comments
 (0)