Skip to content

Commit c93dbf2

Browse files
Drew KaulDrew Kaul
authored andcommitted
added CuboidPrediction dataclass
1 parent fb4671f commit c93dbf2

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

nucleus/prediction.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
PolygonAnnotation,
66
Segment,
77
SegmentationAnnotation,
8+
CuboidAnnotation,
9+
Point3D,
810
)
911
from .constants import (
1012
ANNOTATION_ID_KEY,
@@ -23,6 +25,9 @@
2325
ANNOTATIONS_KEY,
2426
ITEM_ID_KEY,
2527
MASK_URL_KEY,
28+
POSITION_KEY,
29+
DIMENSIONS_KEY,
30+
YAW_KEY,
2631
)
2732

2833

@@ -146,3 +151,56 @@ def from_json(cls, payload: dict):
146151
metadata=payload.get(METADATA_KEY, {}),
147152
class_pdf=payload.get(CLASS_PDF_KEY, None),
148153
)
154+
155+
156+
class CuboidPrediction(CuboidAnnotation):
157+
def __init__(
158+
self,
159+
label: str,
160+
position: Point3D,
161+
dimensions: Point3D,
162+
yaw: float,
163+
reference_id: Optional[str] = None,
164+
item_id: Optional[str] = None,
165+
confidence: Optional[float] = None,
166+
annotation_id: Optional[str] = None,
167+
metadata: Optional[Dict] = None,
168+
class_pdf: Optional[Dict] = None,
169+
):
170+
super().__init__(
171+
label=label,
172+
position=position,
173+
dimensions=dimensions,
174+
yaw=yaw,
175+
reference_id=reference_id,
176+
item_id=item_id,
177+
annotation_id=annotation_id,
178+
metadata=metadata,
179+
)
180+
self.confidence = confidence
181+
self.class_pdf = class_pdf
182+
183+
def to_payload(self) -> dict:
184+
payload = super().to_payload()
185+
if self.confidence is not None:
186+
payload[CONFIDENCE_KEY] = self.confidence
187+
if self.class_pdf is not None:
188+
payload[CLASS_PDF_KEY] = self.class_pdf
189+
190+
return payload
191+
192+
@classmethod
193+
def from_json(cls, payload: dict):
194+
geometry = payload.get(GEOMETRY_KEY, {})
195+
return cls(
196+
label=payload.get(LABEL_KEY, 0),
197+
position=Point3D.from_json(geometry.get(POSITION_KEY, {})),
198+
dimensions=Point3D.from_json(geometry.get(DIMENSIONS_KEY, {})),
199+
yaw=payload.get(YAW_KEY, 0),
200+
reference_id=payload.get(REFERENCE_ID_KEY, None),
201+
item_id=payload.get(DATASET_ITEM_ID_KEY, None),
202+
confidence=payload.get(CONFIDENCE_KEY, None),
203+
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
204+
metadata=payload.get(METADATA_KEY, {}),
205+
class_pdf=payload.get(CLASS_PDF_KEY, None),
206+
)

0 commit comments

Comments
 (0)