Skip to content

Commit b9d87c3

Browse files
authored
Support for default category taxonomy (#164)
* Made taxonomy names optional * Added tests for default taxonomy * bug fixes * Fixed doc string in annotation.py * New version number
1 parent 1d9b237 commit b9d87c3

File tree

5 files changed

+371
-25
lines changed

5 files changed

+371
-25
lines changed

nucleus/annotation.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -515,25 +515,25 @@ class CategoryAnnotation(Annotation):
515515
516516
category = CategoryAnnotation(
517517
label="dress",
518-
taxonomy_name="clothing_type",
519518
reference_id="image_1",
519+
taxonomy_name="clothing_type",
520520
metadata={"dress_color": "navy"}
521521
)
522522
523523
Parameters:
524524
label (str): The label for this annotation.
525-
taxonomy_name (str): The name of the taxonomy this annotation conforms to.
526-
See :meth:`Dataset.add_taxonomy`.
527525
reference_id (str): User-defined ID of the image to which to apply this annotation.
526+
taxonomy_name (Optional[str]): The name of the taxonomy this annotation conforms to.
527+
See :meth:`Dataset.add_taxonomy`.
528528
metadata (Optional[Dict]): Arbitrary key/value dictionary of info to attach to this annotation.
529529
Strings, floats and ints are supported best by querying and insights
530530
features within Nucleus. For more details see our `metadata guide
531531
<https://nucleus.scale.com/docs/upload-metadata>`_.
532532
"""
533533

534534
label: str
535-
taxonomy_name: str
536535
reference_id: str
536+
taxonomy_name: Optional[str] = None
537537
metadata: Optional[Dict] = None
538538

539539
def __post_init__(self):
@@ -543,29 +543,31 @@ def __post_init__(self):
543543
def from_json(cls, payload: dict):
544544
return cls(
545545
label=payload[LABEL_KEY],
546-
taxonomy_name=payload[TAXONOMY_NAME_KEY],
547546
reference_id=payload[REFERENCE_ID_KEY],
547+
taxonomy_name=payload.get(TAXONOMY_NAME_KEY, None),
548548
metadata=payload.get(METADATA_KEY, {}),
549549
)
550550

551551
def to_payload(self) -> dict:
552-
return {
552+
payload = {
553553
LABEL_KEY: self.label,
554-
TAXONOMY_NAME_KEY: self.taxonomy_name,
555554
TYPE_KEY: CATEGORY_TYPE,
556555
GEOMETRY_KEY: {},
557556
REFERENCE_ID_KEY: self.reference_id,
558557
METADATA_KEY: self.metadata,
559558
}
559+
if self.taxonomy_name is not None:
560+
payload[TAXONOMY_NAME_KEY] = self.taxonomy_name
561+
return payload
560562

561563

562564
@dataclass
563565
class MultiCategoryAnnotation(Annotation):
564566
"""This class is not yet supported: MultiCategory annotation support coming soon!"""
565567

566568
labels: List[str]
567-
taxonomy_name: str
568569
reference_id: str
570+
taxonomy_name: Optional[str] = None
569571
metadata: Optional[Dict] = None
570572

571573
def __post_init__(self):
@@ -575,20 +577,22 @@ def __post_init__(self):
575577
def from_json(cls, payload: dict):
576578
return cls(
577579
labels=payload[LABELS_KEY],
578-
taxonomy_name=payload[TAXONOMY_NAME_KEY],
579580
reference_id=payload[REFERENCE_ID_KEY],
581+
taxonomy_name=payload.get(TAXONOMY_NAME_KEY, None),
580582
metadata=payload.get(METADATA_KEY, {}),
581583
)
582584

583585
def to_payload(self) -> dict:
584-
return {
586+
payload = {
585587
LABELS_KEY: self.labels,
586-
TAXONOMY_NAME_KEY: self.taxonomy_name,
587588
TYPE_KEY: MULTICATEGORY_TYPE,
588589
GEOMETRY_KEY: {},
589590
REFERENCE_ID_KEY: self.reference_id,
590591
METADATA_KEY: self.metadata,
591592
}
593+
if self.taxonomy_name is not None:
594+
payload[TAXONOMY_NAME_KEY] = self.taxonomy_name
595+
return payload
592596

593597

594598
def is_local_path(path: str) -> bool:

nucleus/prediction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,9 @@ class CategoryPrediction(CategoryAnnotation):
348348
349349
Parameters:
350350
label: The label for this annotation (e.g. car, pedestrian, bicycle).
351+
reference_id: The reference ID of the image you wish to apply this annotation to.
351352
taxonomy_name: The name of the taxonomy this annotation conforms to.
352353
See :meth:`Dataset.add_taxonomy`.
353-
reference_id: The reference ID of the image you wish to apply this annotation to.
354354
confidence: 0-1 indicating the confidence of the prediction.
355355
class_pdf: An optional complete class probability distribution on this
356356
prediction. Each value should be between 0 and 1 (inclusive), and sum up to
@@ -365,8 +365,8 @@ class CategoryPrediction(CategoryAnnotation):
365365
def __init__(
366366
self,
367367
label: str,
368-
taxonomy_name: str,
369368
reference_id: str,
369+
taxonomy_name: Optional[str] = None,
370370
confidence: Optional[float] = None,
371371
metadata: Optional[Dict] = None,
372372
class_pdf: Optional[Dict] = None,

tests/helpers.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,14 @@ def reference_id_from_url(url):
168168
for i in range(len(TEST_IMG_URLS))
169169
]
170170

171+
TEST_DEFAULT_CATEGORY_ANNOTATIONS = [
172+
{
173+
"label": f"[Pytest] Category Label ${i}",
174+
"reference_id": reference_id_from_url(TEST_IMG_URLS[i]),
175+
}
176+
for i in range(len(TEST_IMG_URLS))
177+
]
178+
171179
TEST_MULTICATEGORY_ANNOTATIONS = [
172180
{
173181
"labels": [
@@ -180,6 +188,17 @@ def reference_id_from_url(url):
180188
for i in range(len(TEST_IMG_URLS))
181189
]
182190

191+
TEST_DEFAULT_MULTICATEGORY_ANNOTATIONS = [
192+
{
193+
"labels": [
194+
f"[Pytest] MultiCategory Label ${i}",
195+
f"[Pytest] MultiCategory Label ${i+1}",
196+
],
197+
"reference_id": reference_id_from_url(TEST_IMG_URLS[i]),
198+
}
199+
for i in range(len(TEST_IMG_URLS))
200+
]
201+
183202
TEST_MASK_URL = "https://raw.githubusercontent.com/scaleapi/nucleus-python-client/master/tests/testdata/000000000285.png"
184203

185204
TEST_SEGMENTATION_ANNOTATIONS = [
@@ -253,6 +272,20 @@ def reference_id_from_url(url):
253272
for i in range(len(TEST_CATEGORY_ANNOTATIONS))
254273
]
255274

275+
TEST_DEFAULT_CATEGORY_PREDICTIONS = [
276+
{
277+
**TEST_DEFAULT_CATEGORY_ANNOTATIONS[i],
278+
"confidence": 0.10 * i,
279+
"class_pdf": TEST_CATEGORY_MODEL_PDF,
280+
}
281+
if i != 0
282+
else {
283+
**TEST_DEFAULT_CATEGORY_ANNOTATIONS[i],
284+
"confidence": 0.10 * i,
285+
}
286+
for i in range(len(TEST_DEFAULT_CATEGORY_ANNOTATIONS))
287+
]
288+
256289
TEST_INDEX_EMBEDDINGS_FILE = "https://raw.githubusercontent.com/scaleapi/nucleus-python-client/master/tests/testdata/pytest_embeddings_payload.json"
257290

258291

@@ -310,18 +343,20 @@ def assert_category_annotation_matches_dict(
310343
annotation_instance, annotation_dict
311344
):
312345
assert annotation_instance.label == annotation_dict["label"]
313-
assert (
314-
annotation_instance.taxonomy_name == annotation_dict["taxonomy_name"]
315-
)
346+
if annotation_instance.taxonomy_name:
347+
assert annotation_instance.taxonomy_name == annotation_dict.get(
348+
"taxonomy_name", None
349+
)
316350

317351

318352
def assert_multicategory_annotation_matches_dict(
319353
annotation_instance, annotation_dict
320354
):
321355
assert set(annotation_instance.labels) == set(annotation_dict["labels"])
322-
assert (
323-
annotation_instance.taxonomy_name == annotation_dict["taxonomy_name"]
324-
)
356+
if annotation_instance.taxonomy_name:
357+
assert annotation_instance.taxonomy_name == annotation_dict.get(
358+
"taxonomy_name", None
359+
)
325360

326361

327362
def assert_segmentation_annotation_matches_dict(

0 commit comments

Comments
 (0)