Skip to content

Commit 29d0f3e

Browse files
authored
Add LineAnnotation and LinePrediction support (#224)
* Add LineAnnotation and LinePrediction support * Fix top level imports * Update test case to check right thing * Add Line types to utils.py
1 parent ed0a796 commit 29d0f3e

File tree

8 files changed

+352
-26
lines changed

8 files changed

+352
-26
lines changed

nucleus/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
"DatasetItemRetrievalError",
1616
"Frame",
1717
"LidarScene",
18-
"VideoScene",
18+
"LineAnnotation",
19+
"LinePrediction",
1920
"Model",
2021
"ModelCreationError",
2122
# "MultiCategoryAnnotation", # coming soon!
@@ -31,6 +32,7 @@
3132
"SegmentationAnnotation",
3233
"SegmentationPrediction",
3334
"Slice",
35+
"VideoScene",
3436
]
3537

3638
import os
@@ -49,6 +51,7 @@
4951
BoxAnnotation,
5052
CategoryAnnotation,
5153
CuboidAnnotation,
54+
LineAnnotation,
5255
MultiCategoryAnnotation,
5356
Point,
5457
Point3D,
@@ -119,6 +122,7 @@
119122
BoxPrediction,
120123
CategoryPrediction,
121124
CuboidPrediction,
125+
LinePrediction,
122126
PolygonPrediction,
123127
SegmentationPrediction,
124128
)

nucleus/annotation.py

Lines changed: 100 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from dataclasses import dataclass, field
33
from enum import Enum
4-
from typing import Dict, List, Optional, Sequence, Union
4+
from typing import Dict, List, Optional, Sequence, Type, Union
55
from urllib.parse import urlparse
66

77
from .constants import (
@@ -16,6 +16,7 @@
1616
INDEX_KEY,
1717
LABEL_KEY,
1818
LABELS_KEY,
19+
LINE_TYPE,
1920
MASK_TYPE,
2021
MASK_URL_KEY,
2122
METADATA_KEY,
@@ -46,18 +47,17 @@ class Annotation:
4647
@classmethod
4748
def from_json(cls, payload: dict):
4849
"""Instantiates annotation object from schematized JSON dict payload."""
49-
if payload.get(TYPE_KEY, None) == BOX_TYPE:
50-
return BoxAnnotation.from_json(payload)
51-
elif payload.get(TYPE_KEY, None) == POLYGON_TYPE:
52-
return PolygonAnnotation.from_json(payload)
53-
elif payload.get(TYPE_KEY, None) == CUBOID_TYPE:
54-
return CuboidAnnotation.from_json(payload)
55-
elif payload.get(TYPE_KEY, None) == CATEGORY_TYPE:
56-
return CategoryAnnotation.from_json(payload)
57-
elif payload.get(TYPE_KEY, None) == MULTICATEGORY_TYPE:
58-
return MultiCategoryAnnotation.from_json(payload)
59-
else:
60-
return SegmentationAnnotation.from_json(payload)
50+
type_key_to_type: Dict[str, Type[Annotation]] = {
51+
BOX_TYPE: BoxAnnotation,
52+
LINE_TYPE: LineAnnotation,
53+
POLYGON_TYPE: PolygonAnnotation,
54+
CUBOID_TYPE: CuboidAnnotation,
55+
CATEGORY_TYPE: CategoryAnnotation,
56+
MULTICATEGORY_TYPE: MultiCategoryAnnotation,
57+
}
58+
type_key = payload.get(TYPE_KEY, None)
59+
AnnotationCls = type_key_to_type.get(type_key, SegmentationAnnotation)
60+
return AnnotationCls.from_json(payload)
6161

6262
def to_payload(self) -> dict:
6363
"""Serializes annotation object to schematized JSON dict."""
@@ -177,6 +177,88 @@ def to_payload(self) -> dict:
177177
return {X_KEY: self.x, Y_KEY: self.y}
178178

179179

180+
@dataclass
181+
class LineAnnotation(Annotation):
182+
"""A polyline annotation consisting of an ordered list of 2D points.
183+
A LineAnnotation differs from a PolygonAnnotation by not forming a closed
184+
loop, and by having zero area.
185+
186+
::
187+
188+
from nucleus import LineAnnotation
189+
190+
line = LineAnnotation(
191+
label="face",
192+
vertices=[Point(100, 100), Point(200, 300), Point(300, 200)],
193+
reference_id="person_image_1",
194+
annotation_id="person_image_1_line_1",
195+
metadata={"camera_mode": "portrait"},
196+
)
197+
198+
Parameters:
199+
label (str): The label for this annotation.
200+
vertices (List[:class:`Point`]): The list of points making up the line.
201+
reference_id (str): User-defined ID of the image to which to apply this
202+
annotation.
203+
annotation_id (Optional[str]): The annotation ID that uniquely identifies
204+
this annotation within its target dataset item. Upon ingest, a matching
205+
annotation id will be ignored by default, and updated if update=True
206+
for dataset.annotate.
207+
metadata (Optional[Dict]): Arbitrary key/value dictionary of info to
208+
attach to this annotation. Strings, floats and ints are supported best
209+
by querying and insights features within Nucleus. For more details see
210+
our `metadata guide <https://nucleus.scale.com/docs/upload-metadata>`_.
211+
"""
212+
213+
label: str
214+
vertices: List[Point]
215+
reference_id: str
216+
annotation_id: Optional[str] = None
217+
metadata: Optional[Dict] = None
218+
219+
def __post_init__(self):
220+
self.metadata = self.metadata if self.metadata else {}
221+
if len(self.vertices) > 0:
222+
if not hasattr(self.vertices[0], X_KEY) or not hasattr(
223+
self.vertices[0], "to_payload"
224+
):
225+
try:
226+
self.vertices = [
227+
Point(x=vertex[X_KEY], y=vertex[Y_KEY])
228+
for vertex in self.vertices
229+
]
230+
except KeyError as ke:
231+
raise ValueError(
232+
"Use a point object to pass in vertices. For example, vertices=[nucleus.Point(x=1, y=2)]"
233+
) from ke
234+
235+
@classmethod
236+
def from_json(cls, payload: dict):
237+
geometry = payload.get(GEOMETRY_KEY, {})
238+
return cls(
239+
label=payload.get(LABEL_KEY, 0),
240+
vertices=[
241+
Point.from_json(_) for _ in geometry.get(VERTICES_KEY, [])
242+
],
243+
reference_id=payload[REFERENCE_ID_KEY],
244+
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
245+
metadata=payload.get(METADATA_KEY, {}),
246+
)
247+
248+
def to_payload(self) -> dict:
249+
payload = {
250+
LABEL_KEY: self.label,
251+
TYPE_KEY: LINE_TYPE,
252+
GEOMETRY_KEY: {
253+
VERTICES_KEY: [_.to_payload() for _ in self.vertices]
254+
},
255+
REFERENCE_ID_KEY: self.reference_id,
256+
ANNOTATION_ID_KEY: self.annotation_id,
257+
METADATA_KEY: self.metadata,
258+
}
259+
return payload
260+
261+
180262
@dataclass
181263
class PolygonAnnotation(Annotation):
182264
"""A polygon annotation consisting of an ordered list of 2D points.
@@ -499,6 +581,7 @@ def to_payload(self) -> dict:
499581

500582
class AnnotationTypes(Enum):
501583
BOX = BOX_TYPE
584+
LINE = LINE_TYPE
502585
POLYGON = POLYGON_TYPE
503586
CUBOID = CUBOID_TYPE
504587
CATEGORY = CATEGORY_TYPE
@@ -600,6 +683,7 @@ class AnnotationList:
600683
"""Wrapper class separating a list of annotations by type."""
601684

602685
box_annotations: List[BoxAnnotation] = field(default_factory=list)
686+
line_annotations: List[LineAnnotation] = field(default_factory=list)
603687
polygon_annotations: List[PolygonAnnotation] = field(default_factory=list)
604688
cuboid_annotations: List[CuboidAnnotation] = field(default_factory=list)
605689
category_annotations: List[CategoryAnnotation] = field(
@@ -620,6 +704,8 @@ def add_annotations(self, annotations: List[Annotation]):
620704

621705
if isinstance(annotation, BoxAnnotation):
622706
self.box_annotations.append(annotation)
707+
elif isinstance(annotation, LineAnnotation):
708+
self.line_annotations.append(annotation)
623709
elif isinstance(annotation, PolygonAnnotation):
624710
self.polygon_annotations.append(annotation)
625711
elif isinstance(annotation, CuboidAnnotation):
@@ -637,6 +723,7 @@ def add_annotations(self, annotations: List[Annotation]):
637723
def __len__(self):
638724
return (
639725
len(self.box_annotations)
726+
+ len(self.line_annotations)
640727
+ len(self.polygon_annotations)
641728
+ len(self.cuboid_annotations)
642729
+ len(self.category_annotations)

nucleus/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
ANNOTATION_METADATA_SCHEMA_KEY = "annotation_metadata_schema"
66
BACKFILL_JOB_KEY = "backfill_job"
77
BOX_TYPE = "box"
8+
LINE_TYPE = "line"
89
POLYGON_TYPE = "polygon"
910
MASK_TYPE = "mask"
1011
SEGMENTATION_TYPE = "segmentation"
@@ -13,6 +14,7 @@
1314
MULTICATEGORY_TYPE = "multicategory"
1415
ANNOTATION_TYPES = (
1516
BOX_TYPE,
17+
LINE_TYPE,
1618
POLYGON_TYPE,
1719
SEGMENTATION_TYPE,
1820
CUBOID_TYPE,

nucleus/prediction.py

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
such as confidence or probability distributions.
55
"""
66
from dataclasses import dataclass, field
7-
from typing import Dict, List, Optional, Union
7+
from typing import Dict, List, Optional, Type, Union
88

99
from .annotation import (
1010
BoxAnnotation,
1111
CategoryAnnotation,
1212
CuboidAnnotation,
13+
LineAnnotation,
1314
Point,
1415
Point3D,
1516
PolygonAnnotation,
@@ -28,6 +29,7 @@
2829
GEOMETRY_KEY,
2930
HEIGHT_KEY,
3031
LABEL_KEY,
32+
LINE_TYPE,
3133
MASK_URL_KEY,
3234
METADATA_KEY,
3335
POLYGON_TYPE,
@@ -45,16 +47,16 @@
4547

4648
def from_json(payload: dict):
4749
"""Instantiates prediction object from schematized JSON dict payload."""
48-
if payload.get(TYPE_KEY, None) == BOX_TYPE:
49-
return BoxPrediction.from_json(payload)
50-
elif payload.get(TYPE_KEY, None) == POLYGON_TYPE:
51-
return PolygonPrediction.from_json(payload)
52-
elif payload.get(TYPE_KEY, None) == CUBOID_TYPE:
53-
return CuboidPrediction.from_json(payload)
54-
elif payload.get(TYPE_KEY, None) == CATEGORY_TYPE:
55-
return CategoryPrediction.from_json(payload)
56-
else:
57-
return SegmentationPrediction.from_json(payload)
50+
type_key_to_type: Dict[str, Type[Prediction]] = {
51+
BOX_TYPE: BoxPrediction,
52+
LINE_TYPE: LinePrediction,
53+
POLYGON_TYPE: PolygonPrediction,
54+
CUBOID_TYPE: CuboidPrediction,
55+
CATEGORY_TYPE: CategoryPrediction,
56+
}
57+
type_key = payload.get(TYPE_KEY, None)
58+
PredictionCls = type_key_to_type.get(type_key, SegmentationPrediction)
59+
return PredictionCls.from_json(payload)
5860

5961

6062
class SegmentationPrediction(SegmentationAnnotation):
@@ -203,6 +205,74 @@ def from_json(cls, payload: dict):
203205
)
204206

205207

208+
class LinePrediction(LineAnnotation):
209+
"""Prediction of a line.
210+
211+
Parameters:
212+
label (str): The label for this prediction (e.g. car, pedestrian, bicycle).
213+
vertices List[:class:`Point`]: The list of points making up the line.
214+
reference_id (str): User-defined ID of the image to which to apply this
215+
annotation.
216+
confidence: 0-1 indicating the confidence of the prediction.
217+
annotation_id (Optional[str]): The annotation ID that uniquely identifies
218+
this annotation within its target dataset item. Upon ingest, a matching
219+
annotation id will be ignored by default, and updated if update=True
220+
for dataset.annotate.
221+
metadata (Optional[Dict]): Arbitrary key/value dictionary of info to
222+
attach to this prediction. Strings, floats and ints are supported best
223+
by querying and insights features within Nucleus. For more details see
224+
our `metadata guide <https://nucleus.scale.com/docs/upload-metadata>`_.
225+
class_pdf: An optional complete class probability distribution on this
226+
annotation. Each value should be between 0 and 1 (inclusive), and sum up to
227+
1 as a complete distribution. This can be useful for computing entropy to
228+
surface places where the model is most uncertain.
229+
"""
230+
231+
def __init__(
232+
self,
233+
label: str,
234+
vertices: List[Point],
235+
reference_id: str,
236+
confidence: Optional[float] = None,
237+
annotation_id: Optional[str] = None,
238+
metadata: Optional[Dict] = None,
239+
class_pdf: Optional[Dict] = None,
240+
):
241+
super().__init__(
242+
label=label,
243+
vertices=vertices,
244+
reference_id=reference_id,
245+
annotation_id=annotation_id,
246+
metadata=metadata,
247+
)
248+
self.confidence = confidence
249+
self.class_pdf = class_pdf
250+
251+
def to_payload(self) -> dict:
252+
payload = super().to_payload()
253+
if self.confidence is not None:
254+
payload[CONFIDENCE_KEY] = self.confidence
255+
if self.class_pdf is not None:
256+
payload[CLASS_PDF_KEY] = self.class_pdf
257+
258+
return payload
259+
260+
@classmethod
261+
def from_json(cls, payload: dict):
262+
geometry = payload.get(GEOMETRY_KEY, {})
263+
return cls(
264+
label=payload.get(LABEL_KEY, 0),
265+
vertices=[
266+
Point.from_json(_) for _ in geometry.get(VERTICES_KEY, [])
267+
],
268+
reference_id=payload[REFERENCE_ID_KEY],
269+
confidence=payload.get(CONFIDENCE_KEY, None),
270+
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
271+
metadata=payload.get(METADATA_KEY, {}),
272+
class_pdf=payload.get(CLASS_PDF_KEY, None),
273+
)
274+
275+
206276
class PolygonPrediction(PolygonAnnotation):
207277
"""Prediction of a polygon.
208278
@@ -404,6 +474,7 @@ def from_json(cls, payload: dict):
404474

405475
Prediction = Union[
406476
BoxPrediction,
477+
LinePrediction,
407478
PolygonPrediction,
408479
CuboidPrediction,
409480
CategoryPrediction,
@@ -416,6 +487,7 @@ class PredictionList:
416487
"""Wrapper class separating a list of predictions by type."""
417488

418489
box_predictions: List[BoxPrediction] = field(default_factory=list)
490+
line_predictions: List[LinePrediction] = field(default_factory=list)
419491
polygon_predictions: List[PolygonPrediction] = field(default_factory=list)
420492
cuboid_predictions: List[CuboidPrediction] = field(default_factory=list)
421493
category_predictions: List[CategoryPrediction] = field(
@@ -429,6 +501,8 @@ def add_predictions(self, predictions: List[Prediction]):
429501
for prediction in predictions:
430502
if isinstance(prediction, BoxPrediction):
431503
self.box_predictions.append(prediction)
504+
elif isinstance(prediction, LinePrediction):
505+
self.line_predictions.append(prediction)
432506
elif isinstance(prediction, PolygonPrediction):
433507
self.polygon_predictions.append(prediction)
434508
elif isinstance(prediction, CuboidPrediction):
@@ -444,6 +518,7 @@ def add_predictions(self, predictions: List[Prediction]):
444518
def __len__(self):
445519
return (
446520
len(self.box_predictions)
521+
+ len(self.line_predictions)
447522
+ len(self.polygon_predictions)
448523
+ len(self.cuboid_predictions)
449524
+ len(self.category_predictions)

0 commit comments

Comments
 (0)