Skip to content

Commit c84bb4c

Browse files
author
Claire Pajot
committed
Initial multicategory support
1 parent 58b83ab commit c84bb4c

File tree

8 files changed

+253
-33
lines changed

8 files changed

+253
-33
lines changed

nucleus/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Point3D,
2727
PolygonAnnotation,
2828
CategoryAnnotation,
29+
MultiCategoryAnnotation,
2930
Segment,
3031
SegmentationAnnotation,
3132
)
@@ -583,6 +584,7 @@ def annotate_dataset(
583584
PolygonAnnotation,
584585
CuboidAnnotation,
585586
CategoryAnnotation,
587+
MultiCategoryAnnotation,
586588
SegmentationAnnotation,
587589
]
588590
],

nucleus/annotation.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
ANNOTATIONS_KEY,
1010
BOX_TYPE,
1111
CATEGORY_TYPE,
12+
MULTICATEGORY_TYPE,
1213
CUBOID_TYPE,
1314
DIMENSIONS_KEY,
1415
GEOMETRY_KEY,
1516
HEIGHT_KEY,
1617
INDEX_KEY,
1718
LABEL_KEY,
19+
LABELS_KEY,
1820
MASK_TYPE,
1921
MASK_URL_KEY,
2022
METADATA_KEY,
@@ -45,6 +47,8 @@ def from_json(cls, payload: dict):
4547
return CuboidAnnotation.from_json(payload)
4648
elif payload.get(TYPE_KEY, None) == CATEGORY_TYPE:
4749
return CategoryAnnotation.from_json(payload)
50+
elif payload.get(TYPE_KEY, None) == MULTICATEGORY_TYPE:
51+
return MultiCategoryAnnotation.from_json(payload)
4852
else:
4953
return SegmentationAnnotation.from_json(payload)
5054

@@ -125,6 +129,7 @@ class AnnotationTypes(Enum):
125129
POLYGON = POLYGON_TYPE
126130
CUBOID = CUBOID_TYPE
127131
CATEGORY = CATEGORY_TYPE
132+
MULTICATEGORY = MULTICATEGORY_TYPE
128133

129134

130135
@dataclass # pylint: disable=R0902
@@ -326,6 +331,36 @@ def to_payload(self) -> dict:
326331
}
327332

328333

334+
@dataclass
335+
class MultiCategoryAnnotation(Annotation):
336+
labels: List[str]
337+
taxonomy_name: str
338+
reference_id: str
339+
metadata: Optional[Dict] = None
340+
341+
def __post_init__(self):
342+
self.metadata = self.metadata if self.metadata else {}
343+
344+
@classmethod
345+
def from_json(cls, payload: dict):
346+
return cls(
347+
labels=payload[LABELS_KEY],
348+
taxonomy_name=payload[TAXONOMY_NAME_KEY],
349+
reference_id=payload[REFERENCE_ID_KEY],
350+
metadata=payload.get(METADATA_KEY, {}),
351+
)
352+
353+
def to_payload(self) -> dict:
354+
return {
355+
LABELS_KEY: self.labels,
356+
TAXONOMY_NAME_KEY: self.taxonomy_name,
357+
TYPE_KEY: MULTICATEGORY_TYPE,
358+
GEOMETRY_KEY: {},
359+
REFERENCE_ID_KEY: self.reference_id,
360+
METADATA_KEY: self.metadata,
361+
}
362+
363+
329364
def is_local_path(path: str) -> bool:
330365
return urlparse(path).scheme not in {"https", "http", "s3", "gs"}
331366

nucleus/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
SEGMENTATION_TYPE,
1717
CUBOID_TYPE,
1818
CATEGORY_TYPE,
19+
MULTICATEGORY_TYPE,
1920
)
2021
ANNOTATION_UPDATE_KEY = "update"
2122
AUTOTAGS_KEY = "autotags"

nucleus/payload_constructor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
CuboidAnnotation,
77
PolygonAnnotation,
88
CategoryAnnotation,
9+
MultiCategoryAnnotation,
910
SegmentationAnnotation,
1011
)
1112
from .prediction import (
@@ -62,6 +63,7 @@ def construct_annotation_payload(
6263
PolygonAnnotation,
6364
CuboidAnnotation,
6465
CategoryAnnotation,
66+
MultiCategoryAnnotation,
6567
SegmentationAnnotation,
6668
]
6769
],

nucleus/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
CuboidAnnotation,
1616
PolygonAnnotation,
1717
CategoryAnnotation,
18+
MultiCategoryAnnotation,
1819
SegmentationAnnotation,
1920
)
2021

@@ -24,6 +25,7 @@
2425
BOX_TYPE,
2526
CUBOID_TYPE,
2627
CATEGORY_TYPE,
28+
MULTICATEGORY_TYPE,
2729
ITEM_KEY,
2830
POLYGON_TYPE,
2931
REFERENCE_ID_KEY,
@@ -123,6 +125,11 @@ def convert_export_payload(api_payload):
123125
annotations[CATEGORY_TYPE].append(
124126
CategoryAnnotation.from_json(category)
125127
)
128+
for multicategory in row[MULTICATEGORY_TYPE]:
129+
multicategory[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
130+
annotations[MULTICATEGORY_TYPE].append(
131+
MultiCategoryAnnotation.from_json(multicategory)
132+
)
126133
return_payload_row[ANNOTATIONS_KEY] = annotations
127134
return_payload.append(return_payload_row)
128135
return return_payload

tests/helpers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,18 @@ def reference_id_from_url(url):
167167
for i in range(len(TEST_IMG_URLS))
168168
]
169169

170+
TEST_MULTICATEGORY_ANNOTATIONS = [
171+
{
172+
"labels": [
173+
f"[Pytest] MultiCategory Label ${i}",
174+
f"[Pytest] MultiCategory Label ${i+1}",
175+
],
176+
"reference_id": reference_id_from_url(TEST_IMG_URLS[i]),
177+
"taxonomy_name": "[Pytest] MultiCategory Taxonomy 1",
178+
}
179+
for i in range(len(TEST_IMG_URLS))
180+
]
181+
170182
TEST_MASK_URL = "https://raw.githubusercontent.com/scaleapi/nucleus-python-client/master/tests/testdata/000000000285.png"
171183

172184
TEST_SEGMENTATION_ANNOTATIONS = [
@@ -283,6 +295,15 @@ def assert_category_annotation_matches_dict(
283295
)
284296

285297

298+
def assert_multicategory_annotation_matches_dict(
299+
annotation_instance, annotation_dict
300+
):
301+
assert set(annotation_instance.labels) == set(annotation_dict["labels"])
302+
assert (
303+
annotation_instance.taxonomy_name == annotation_dict["taxonomy_name"]
304+
)
305+
306+
286307
def assert_segmentation_annotation_matches_dict(
287308
annotation_instance, annotation_dict
288309
):

tests/test_annotation.py

Lines changed: 143 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,21 @@
66
TEST_BOX_ANNOTATIONS,
77
TEST_POLYGON_ANNOTATIONS,
88
TEST_CATEGORY_ANNOTATIONS,
9+
TEST_MULTICATEGORY_ANNOTATIONS,
910
TEST_SEGMENTATION_ANNOTATIONS,
1011
reference_id_from_url,
1112
assert_box_annotation_matches_dict,
1213
assert_polygon_annotation_matches_dict,
1314
assert_category_annotation_matches_dict,
15+
assert_multicategory_annotation_matches_dict,
1416
assert_segmentation_annotation_matches_dict,
1517
)
1618

1719
from nucleus import (
1820
BoxAnnotation,
1921
PolygonAnnotation,
2022
CategoryAnnotation,
23+
MultiCategoryAnnotation,
2124
SegmentationAnnotation,
2225
DatasetItem,
2326
Segment,
@@ -64,6 +67,14 @@ def dataset(CLIENT):
6467
"category",
6568
[f"[Pytest] Category Label ${i}" for i in range((len(TEST_IMG_URLS)))],
6669
)
70+
response = ds.add_taxonomy(
71+
"[Pytest] MultiCategory Taxonomy 1",
72+
"multicategory",
73+
[
74+
f"[Pytest] MultiCategory Label ${i}"
75+
for i in range((len(TEST_IMG_URLS) + 1))
76+
],
77+
)
6778
yield ds
6879

6980
response = CLIENT.delete_dataset(ds.id)
@@ -73,6 +84,7 @@ def dataset(CLIENT):
7384
def test_box_gt_upload(dataset):
7485
annotation = BoxAnnotation(**TEST_BOX_ANNOTATIONS[0])
7586
response = dataset.annotate(annotations=[annotation])
87+
print(response)
7688

7789
assert response["dataset_id"] == dataset.id
7890
assert response["annotations_processed"] == 1
@@ -127,6 +139,28 @@ def test_category_gt_upload(dataset):
127139
)
128140

129141

142+
def test_multicategory_gt_upload(dataset):
143+
annotation = MultiCategoryAnnotation.from_json(
144+
TEST_MULTICATEGORY_ANNOTATIONS[0]
145+
)
146+
response = dataset.annotate(annotations=[annotation])
147+
148+
assert response["dataset_id"] == dataset.id
149+
assert response["annotations_processed"] == 1
150+
assert response["annotations_ignored"] == 0
151+
print("HERE")
152+
response = dataset.refloc(annotation.reference_id)["annotations"][
153+
"multicategory"
154+
]
155+
print("RESPONSE: ", response)
156+
# TODO: Weirdness here also in the response from refloc?
157+
assert len(response) == 1
158+
response_annotation = response[0]
159+
assert_multicategory_annotation_matches_dict(
160+
response_annotation, TEST_MULTICATEGORY_ANNOTATIONS[0]
161+
)
162+
163+
130164
def test_single_semseg_gt_upload(dataset):
131165
annotation = SegmentationAnnotation.from_json(
132166
TEST_SEGMENTATION_ANNOTATIONS[0]
@@ -206,6 +240,7 @@ def test_mixed_annotation_upload(dataset):
206240
response_annotations = dataset.refloc(bbox_annotations[0].reference_id)[
207241
"annotations"
208242
]
243+
print(response_annotations)
209244
assert len(response_annotations) == 2
210245
assert len(response_annotations["box"]) == 1
211246
assert "segmentation" in response_annotations
@@ -392,34 +427,120 @@ def test_category_gt_upload_ignore(dataset):
392427
response_annotation, TEST_CATEGORY_ANNOTATIONS[0]
393428
)
394429

395-
@pytest.mark.integration
396-
def test_box_gt_deletion(dataset):
397-
annotation = BoxAnnotation(**TEST_BOX_ANNOTATIONS[0])
398430

399-
print(annotation)
431+
def test_multicategory_gt_upload_update(dataset):
432+
annotation = MultiCategoryAnnotation.from_json(
433+
TEST_MULTICATEGORY_ANNOTATIONS[0]
434+
)
435+
response = dataset.annotate(annotations=[annotation])
400436

401-
response = dataset.annotate(annotations=[annotation])
437+
assert response["annotations_processed"] == 1
402438

403-
assert response["annotations_processed"] == 1
439+
# Copy so we don't modify the original.
440+
annotation_update_params = dict(TEST_MULTICATEGORY_ANNOTATIONS[1])
441+
annotation_update_params["reference_id"] = TEST_MULTICATEGORY_ANNOTATIONS[
442+
0
443+
]["reference_id"]
404444

405-
job = dataset.delete_annotations()
406-
job.sleep_until_complete()
407-
job_status = job.status()
408-
assert job_status["status"] == "Completed"
409-
assert job_status["job_id"] == job.id
445+
annotation_update = MultiCategoryAnnotation.from_json(
446+
annotation_update_params
447+
)
448+
response = dataset.annotate(annotations=[annotation_update], update=True)
410449

411-
@pytest.mark.integration
412-
def test_category_gt_deletion(dataset):
413-
annotation = CategoryAnnotation.from_json(TEST_CATEGORY_ANNOTATIONS[0])
450+
assert response["annotations_processed"] == 1
451+
assert response["annotations_ignored"] == 0
414452

415-
print(annotation)
453+
response = dataset.refloc(annotation.reference_id)["annotations"][
454+
"multicategory"
455+
]
456+
assert len(response) == 1
457+
response_annotation = response[0]
458+
assert_multicategory_annotation_matches_dict(
459+
response_annotation, annotation_update_params
460+
)
416461

417-
response = dataset.annotate(annotations=[annotation])
418462

419-
assert response["annotations_processed"] == 1
463+
def test_multicategory_gt_upload_ignore(dataset):
464+
annotation = MultiCategoryAnnotation.from_json(
465+
TEST_MULTICATEGORY_ANNOTATIONS[0]
466+
)
467+
response = dataset.annotate(annotations=[annotation])
468+
469+
assert response["annotations_processed"] == 1
470+
471+
# Copy so we don't modify the original.
472+
annotation_update_params = dict(TEST_MULTICATEGORY_ANNOTATIONS[1])
473+
annotation_update_params["reference_id"] = TEST_MULTICATEGORY_ANNOTATIONS[
474+
0
475+
]["reference_id"]
476+
477+
annotation_update = MultiCategoryAnnotation.from_json(
478+
annotation_update_params
479+
)
480+
# Default behavior is ignore.
481+
response = dataset.annotate(annotations=[annotation_update])
482+
483+
assert response["annotations_processed"] == 0
484+
assert response["annotations_ignored"] == 1
485+
486+
response = dataset.refloc(annotation.reference_id)["annotations"][
487+
"multicategory"
488+
]
489+
assert len(response) == 1
490+
response_annotation = response[0]
491+
assert_multicategory_annotation_matches_dict(
492+
response_annotation, TEST_MULTICATEGORY_ANNOTATIONS[0]
493+
)
494+
495+
496+
@pytest.mark.integration
497+
def test_box_gt_deletion(dataset):
498+
annotation = BoxAnnotation(**TEST_BOX_ANNOTATIONS[0])
499+
500+
print(annotation)
501+
502+
response = dataset.annotate(annotations=[annotation])
503+
504+
assert response["annotations_processed"] == 1
505+
506+
job = dataset.delete_annotations()
507+
job.sleep_until_complete()
508+
job_status = job.status()
509+
assert job_status["status"] == "Completed"
510+
assert job_status["job_id"] == job.job_id
511+
512+
513+
@pytest.mark.integration
514+
def test_category_gt_deletion(dataset):
515+
annotation = CategoryAnnotation.from_json(TEST_CATEGORY_ANNOTATIONS[0])
516+
517+
print(annotation)
518+
519+
response = dataset.annotate(annotations=[annotation])
520+
521+
assert response["annotations_processed"] == 1
522+
523+
job = dataset.delete_annotations()
524+
job.sleep_until_complete()
525+
job_status = job.status()
526+
assert job_status["status"] == "Completed"
527+
assert job_status["job_id"] == job.job_id
528+
529+
530+
@pytest.mark.integration
531+
def test_multicategory_gt_deletion(dataset):
532+
annotation = MultiCategoryAnnotation.from_json(
533+
TEST_MULTICATEGORY_ANNOTATIONS[0]
534+
)
535+
536+
print(annotation)
537+
538+
response = dataset.annotate(annotations=[annotation])
539+
540+
assert response["annotations_processed"] == 1
420541

421-
job = dataset.delete_annotations()
422-
job.sleep_until_complete()
423-
job_status = job.status()
424-
assert job_status["status"] == "Completed"
425-
assert job_status["job_id"] == job.id
542+
job = dataset.delete_annotations()
543+
job.sleep_until_complete()
544+
job_status = job.status()
545+
assert job_status["status"] == "Completed"
546+
assert job_status["job_id"] == job.job_id

0 commit comments

Comments
 (0)