Skip to content

Commit fd633b5

Browse files
authored
Merge pull request #132 from scaleapi/add_multicategory_type_to_groundtruth
MultiCategory Annotation API Support
2 parents 6a0e6e0 + 68ec7dd commit fd633b5

File tree

8 files changed

+249
-32
lines changed

8 files changed

+249
-32
lines changed

nucleus/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Point3D,
2727
PolygonAnnotation,
2828
CategoryAnnotation,
29+
MultiCategoryAnnotation,
2930
Segment,
3031
SegmentationAnnotation,
3132
)
@@ -583,6 +584,7 @@ def annotate_dataset(
583584
PolygonAnnotation,
584585
CuboidAnnotation,
585586
CategoryAnnotation,
587+
MultiCategoryAnnotation,
586588
SegmentationAnnotation,
587589
]
588590
],

nucleus/annotation.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
ANNOTATIONS_KEY,
1010
BOX_TYPE,
1111
CATEGORY_TYPE,
12+
MULTICATEGORY_TYPE,
1213
CUBOID_TYPE,
1314
DIMENSIONS_KEY,
1415
GEOMETRY_KEY,
1516
HEIGHT_KEY,
1617
INDEX_KEY,
1718
LABEL_KEY,
19+
LABELS_KEY,
1820
MASK_TYPE,
1921
MASK_URL_KEY,
2022
METADATA_KEY,
@@ -45,6 +47,8 @@ def from_json(cls, payload: dict):
4547
return CuboidAnnotation.from_json(payload)
4648
elif payload.get(TYPE_KEY, None) == CATEGORY_TYPE:
4749
return CategoryAnnotation.from_json(payload)
50+
elif payload.get(TYPE_KEY, None) == MULTICATEGORY_TYPE:
51+
return MultiCategoryAnnotation.from_json(payload)
4852
else:
4953
return SegmentationAnnotation.from_json(payload)
5054

@@ -125,6 +129,7 @@ class AnnotationTypes(Enum):
125129
POLYGON = POLYGON_TYPE
126130
CUBOID = CUBOID_TYPE
127131
CATEGORY = CATEGORY_TYPE
132+
MULTICATEGORY = MULTICATEGORY_TYPE
128133

129134

130135
@dataclass # pylint: disable=R0902
@@ -326,6 +331,36 @@ def to_payload(self) -> dict:
326331
}
327332

328333

334+
@dataclass
335+
class MultiCategoryAnnotation(Annotation):
336+
labels: List[str]
337+
taxonomy_name: str
338+
reference_id: str
339+
metadata: Optional[Dict] = None
340+
341+
def __post_init__(self):
342+
self.metadata = self.metadata if self.metadata else {}
343+
344+
@classmethod
345+
def from_json(cls, payload: dict):
346+
return cls(
347+
labels=payload[LABELS_KEY],
348+
taxonomy_name=payload[TAXONOMY_NAME_KEY],
349+
reference_id=payload[REFERENCE_ID_KEY],
350+
metadata=payload.get(METADATA_KEY, {}),
351+
)
352+
353+
def to_payload(self) -> dict:
354+
return {
355+
LABELS_KEY: self.labels,
356+
TAXONOMY_NAME_KEY: self.taxonomy_name,
357+
TYPE_KEY: MULTICATEGORY_TYPE,
358+
GEOMETRY_KEY: {},
359+
REFERENCE_ID_KEY: self.reference_id,
360+
METADATA_KEY: self.metadata,
361+
}
362+
363+
329364
def is_local_path(path: str) -> bool:
330365
return urlparse(path).scheme not in {"https", "http", "s3", "gs"}
331366

nucleus/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
SEGMENTATION_TYPE,
1717
CUBOID_TYPE,
1818
CATEGORY_TYPE,
19+
MULTICATEGORY_TYPE,
1920
)
2021
ANNOTATION_UPDATE_KEY = "update"
2122
AUTOTAGS_KEY = "autotags"

nucleus/payload_constructor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
CuboidAnnotation,
77
PolygonAnnotation,
88
CategoryAnnotation,
9+
MultiCategoryAnnotation,
910
SegmentationAnnotation,
1011
)
1112
from .prediction import (
@@ -62,6 +63,7 @@ def construct_annotation_payload(
6263
PolygonAnnotation,
6364
CuboidAnnotation,
6465
CategoryAnnotation,
66+
MultiCategoryAnnotation,
6567
SegmentationAnnotation,
6668
]
6769
],

nucleus/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
CuboidAnnotation,
1616
PolygonAnnotation,
1717
CategoryAnnotation,
18+
MultiCategoryAnnotation,
1819
SegmentationAnnotation,
1920
)
2021

@@ -24,6 +25,7 @@
2425
BOX_TYPE,
2526
CUBOID_TYPE,
2627
CATEGORY_TYPE,
28+
MULTICATEGORY_TYPE,
2729
ITEM_KEY,
2830
POLYGON_TYPE,
2931
REFERENCE_ID_KEY,
@@ -169,6 +171,11 @@ def convert_export_payload(api_payload):
169171
annotations[CATEGORY_TYPE].append(
170172
CategoryAnnotation.from_json(category)
171173
)
174+
for multicategory in row[MULTICATEGORY_TYPE]:
175+
multicategory[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
176+
annotations[MULTICATEGORY_TYPE].append(
177+
MultiCategoryAnnotation.from_json(multicategory)
178+
)
172179
return_payload_row[ANNOTATIONS_KEY] = annotations
173180
return_payload.append(return_payload_row)
174181
return return_payload

tests/helpers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,18 @@ def reference_id_from_url(url):
168168
for i in range(len(TEST_IMG_URLS))
169169
]
170170

171+
TEST_MULTICATEGORY_ANNOTATIONS = [
172+
{
173+
"labels": [
174+
f"[Pytest] MultiCategory Label ${i}",
175+
f"[Pytest] MultiCategory Label ${i+1}",
176+
],
177+
"reference_id": reference_id_from_url(TEST_IMG_URLS[i]),
178+
"taxonomy_name": "[Pytest] MultiCategory Taxonomy 1",
179+
}
180+
for i in range(len(TEST_IMG_URLS))
181+
]
182+
171183
TEST_MASK_URL = "https://raw.githubusercontent.com/scaleapi/nucleus-python-client/master/tests/testdata/000000000285.png"
172184

173185
TEST_SEGMENTATION_ANNOTATIONS = [
@@ -284,6 +296,15 @@ def assert_category_annotation_matches_dict(
284296
)
285297

286298

299+
def assert_multicategory_annotation_matches_dict(
300+
annotation_instance, annotation_dict
301+
):
302+
assert set(annotation_instance.labels) == set(annotation_dict["labels"])
303+
assert (
304+
annotation_instance.taxonomy_name == annotation_dict["taxonomy_name"]
305+
)
306+
307+
287308
def assert_segmentation_annotation_matches_dict(
288309
annotation_instance, annotation_dict
289310
):

tests/test_annotation.py

Lines changed: 141 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,21 @@
66
TEST_BOX_ANNOTATIONS,
77
TEST_POLYGON_ANNOTATIONS,
88
TEST_CATEGORY_ANNOTATIONS,
9+
TEST_MULTICATEGORY_ANNOTATIONS,
910
TEST_SEGMENTATION_ANNOTATIONS,
1011
reference_id_from_url,
1112
assert_box_annotation_matches_dict,
1213
assert_polygon_annotation_matches_dict,
1314
assert_category_annotation_matches_dict,
15+
assert_multicategory_annotation_matches_dict,
1416
assert_segmentation_annotation_matches_dict,
1517
)
1618

1719
from nucleus import (
1820
BoxAnnotation,
1921
PolygonAnnotation,
2022
CategoryAnnotation,
23+
MultiCategoryAnnotation,
2124
SegmentationAnnotation,
2225
DatasetItem,
2326
Segment,
@@ -64,6 +67,14 @@ def dataset(CLIENT):
6467
"category",
6568
[f"[Pytest] Category Label ${i}" for i in range((len(TEST_IMG_URLS)))],
6669
)
70+
response = ds.add_taxonomy(
71+
"[Pytest] MultiCategory Taxonomy 1",
72+
"multicategory",
73+
[
74+
f"[Pytest] MultiCategory Label ${i}"
75+
for i in range((len(TEST_IMG_URLS) + 1))
76+
],
77+
)
6778
yield ds
6879

6980
response = CLIENT.delete_dataset(ds.id)
@@ -127,6 +138,27 @@ def test_category_gt_upload(dataset):
127138
)
128139

129140

141+
def test_multicategory_gt_upload(dataset):
142+
annotation = MultiCategoryAnnotation.from_json(
143+
TEST_MULTICATEGORY_ANNOTATIONS[0]
144+
)
145+
response = dataset.annotate(annotations=[annotation])
146+
147+
assert response["dataset_id"] == dataset.id
148+
assert response["annotations_processed"] == 1
149+
assert response["annotations_ignored"] == 0
150+
151+
response = dataset.refloc(annotation.reference_id)["annotations"][
152+
"multicategory"
153+
]
154+
155+
assert len(response) == 1
156+
response_annotation = response[0]
157+
assert_multicategory_annotation_matches_dict(
158+
response_annotation, TEST_MULTICATEGORY_ANNOTATIONS[0]
159+
)
160+
161+
130162
def test_single_semseg_gt_upload(dataset):
131163
annotation = SegmentationAnnotation.from_json(
132164
TEST_SEGMENTATION_ANNOTATIONS[0]
@@ -206,6 +238,7 @@ def test_mixed_annotation_upload(dataset):
206238
response_annotations = dataset.refloc(bbox_annotations[0].reference_id)[
207239
"annotations"
208240
]
241+
209242
assert len(response_annotations) == 2
210243
assert len(response_annotations["box"]) == 1
211244
assert "segmentation" in response_annotations
@@ -392,34 +425,120 @@ def test_category_gt_upload_ignore(dataset):
392425
response_annotation, TEST_CATEGORY_ANNOTATIONS[0]
393426
)
394427

395-
@pytest.mark.integration
396-
def test_box_gt_deletion(dataset):
397-
annotation = BoxAnnotation(**TEST_BOX_ANNOTATIONS[0])
398428

399-
print(annotation)
429+
def test_multicategory_gt_upload_update(dataset):
430+
annotation = MultiCategoryAnnotation.from_json(
431+
TEST_MULTICATEGORY_ANNOTATIONS[0]
432+
)
433+
response = dataset.annotate(annotations=[annotation])
434+
435+
assert response["annotations_processed"] == 1
436+
437+
# Copy so we don't modify the original.
438+
annotation_update_params = dict(TEST_MULTICATEGORY_ANNOTATIONS[1])
439+
annotation_update_params["reference_id"] = TEST_MULTICATEGORY_ANNOTATIONS[
440+
0
441+
]["reference_id"]
442+
443+
annotation_update = MultiCategoryAnnotation.from_json(
444+
annotation_update_params
445+
)
446+
response = dataset.annotate(annotations=[annotation_update], update=True)
447+
448+
assert response["annotations_processed"] == 1
449+
assert response["annotations_ignored"] == 0
450+
451+
response = dataset.refloc(annotation.reference_id)["annotations"][
452+
"multicategory"
453+
]
454+
assert len(response) == 1
455+
response_annotation = response[0]
456+
assert_multicategory_annotation_matches_dict(
457+
response_annotation, annotation_update_params
458+
)
459+
460+
461+
def test_multicategory_gt_upload_ignore(dataset):
462+
annotation = MultiCategoryAnnotation.from_json(
463+
TEST_MULTICATEGORY_ANNOTATIONS[0]
464+
)
465+
response = dataset.annotate(annotations=[annotation])
466+
467+
assert response["annotations_processed"] == 1
468+
469+
# Copy so we don't modify the original.
470+
annotation_update_params = dict(TEST_MULTICATEGORY_ANNOTATIONS[1])
471+
annotation_update_params["reference_id"] = TEST_MULTICATEGORY_ANNOTATIONS[
472+
0
473+
]["reference_id"]
474+
475+
annotation_update = MultiCategoryAnnotation.from_json(
476+
annotation_update_params
477+
)
478+
# Default behavior is ignore.
479+
response = dataset.annotate(annotations=[annotation_update])
480+
481+
assert response["annotations_processed"] == 0
482+
assert response["annotations_ignored"] == 1
483+
484+
response = dataset.refloc(annotation.reference_id)["annotations"][
485+
"multicategory"
486+
]
487+
assert len(response) == 1
488+
response_annotation = response[0]
489+
assert_multicategory_annotation_matches_dict(
490+
response_annotation, TEST_MULTICATEGORY_ANNOTATIONS[0]
491+
)
492+
493+
494+
@pytest.mark.integration
495+
def test_box_gt_deletion(dataset):
496+
annotation = BoxAnnotation(**TEST_BOX_ANNOTATIONS[0])
497+
498+
print(annotation)
499+
500+
response = dataset.annotate(annotations=[annotation])
501+
502+
assert response["annotations_processed"] == 1
503+
504+
job = dataset.delete_annotations()
505+
job.sleep_until_complete()
506+
job_status = job.status()
507+
assert job_status["status"] == "Completed"
508+
assert job_status["job_id"] == job.job_id
509+
510+
511+
@pytest.mark.integration
512+
def test_category_gt_deletion(dataset):
513+
annotation = CategoryAnnotation.from_json(TEST_CATEGORY_ANNOTATIONS[0])
514+
515+
print(annotation)
516+
517+
response = dataset.annotate(annotations=[annotation])
400518

401-
response = dataset.annotate(annotations=[annotation])
519+
assert response["annotations_processed"] == 1
402520

403-
assert response["annotations_processed"] == 1
521+
job = dataset.delete_annotations()
522+
job.sleep_until_complete()
523+
job_status = job.status()
524+
assert job_status["status"] == "Completed"
525+
assert job_status["job_id"] == job.job_id
404526

405-
job = dataset.delete_annotations()
406-
job.sleep_until_complete()
407-
job_status = job.status()
408-
assert job_status["status"] == "Completed"
409-
assert job_status["job_id"] == job.id
410527

411-
@pytest.mark.integration
412-
def test_category_gt_deletion(dataset):
413-
annotation = CategoryAnnotation.from_json(TEST_CATEGORY_ANNOTATIONS[0])
528+
@pytest.mark.integration
529+
def test_multicategory_gt_deletion(dataset):
530+
annotation = MultiCategoryAnnotation.from_json(
531+
TEST_MULTICATEGORY_ANNOTATIONS[0]
532+
)
414533

415-
print(annotation)
534+
print(annotation)
416535

417-
response = dataset.annotate(annotations=[annotation])
536+
response = dataset.annotate(annotations=[annotation])
418537

419-
assert response["annotations_processed"] == 1
538+
assert response["annotations_processed"] == 1
420539

421-
job = dataset.delete_annotations()
422-
job.sleep_until_complete()
423-
job_status = job.status()
424-
assert job_status["status"] == "Completed"
425-
assert job_status["job_id"] == job.id
540+
job = dataset.delete_annotations()
541+
job.sleep_until_complete()
542+
job_status = job.status()
543+
assert job_status["status"] == "Completed"
544+
assert job_status["job_id"] == job.job_id

0 commit comments

Comments
 (0)