|
5 | 5 | PolygonAnnotation,
|
6 | 6 | Segment,
|
7 | 7 | SegmentationAnnotation,
|
| 8 | + CuboidAnnotation, |
| 9 | + Point3D, |
8 | 10 | )
|
9 | 11 | from .constants import (
|
10 | 12 | ANNOTATION_ID_KEY,
|
|
23 | 25 | ANNOTATIONS_KEY,
|
24 | 26 | ITEM_ID_KEY,
|
25 | 27 | MASK_URL_KEY,
|
| 28 | + POSITION_KEY, |
| 29 | + DIMENSIONS_KEY, |
| 30 | + YAW_KEY, |
26 | 31 | )
|
27 | 32 |
|
28 | 33 |
|
@@ -146,3 +151,56 @@ def from_json(cls, payload: dict):
|
146 | 151 | metadata=payload.get(METADATA_KEY, {}),
|
147 | 152 | class_pdf=payload.get(CLASS_PDF_KEY, None),
|
148 | 153 | )
|
| 154 | + |
| 155 | + |
| 156 | +class CuboidPrediction(CuboidAnnotation): |
| 157 | + def __init__( |
| 158 | + self, |
| 159 | + label: str, |
| 160 | + position: Point3D, |
| 161 | + dimensions: Point3D, |
| 162 | + yaw: float, |
| 163 | + reference_id: Optional[str] = None, |
| 164 | + item_id: Optional[str] = None, |
| 165 | + confidence: Optional[float] = None, |
| 166 | + annotation_id: Optional[str] = None, |
| 167 | + metadata: Optional[Dict] = None, |
| 168 | + class_pdf: Optional[Dict] = None, |
| 169 | + ): |
| 170 | + super().__init__( |
| 171 | + label=label, |
| 172 | + position=position, |
| 173 | + dimensions=dimensions, |
| 174 | + yaw=yaw, |
| 175 | + reference_id=reference_id, |
| 176 | + item_id=item_id, |
| 177 | + annotation_id=annotation_id, |
| 178 | + metadata=metadata, |
| 179 | + ) |
| 180 | + self.confidence = confidence |
| 181 | + self.class_pdf = class_pdf |
| 182 | + |
| 183 | + def to_payload(self) -> dict: |
| 184 | + payload = super().to_payload() |
| 185 | + if self.confidence is not None: |
| 186 | + payload[CONFIDENCE_KEY] = self.confidence |
| 187 | + if self.class_pdf is not None: |
| 188 | + payload[CLASS_PDF_KEY] = self.class_pdf |
| 189 | + |
| 190 | + return payload |
| 191 | + |
| 192 | + @classmethod |
| 193 | + def from_json(cls, payload: dict): |
| 194 | + geometry = payload.get(GEOMETRY_KEY, {}) |
| 195 | + return cls( |
| 196 | + label=payload.get(LABEL_KEY, 0), |
| 197 | + position=Point3D.from_json(geometry.get(POSITION_KEY, {})), |
| 198 | + dimensions=Point3D.from_json(geometry.get(DIMENSIONS_KEY, {})), |
| 199 | + yaw=payload.get(YAW_KEY, 0), |
| 200 | + reference_id=payload.get(REFERENCE_ID_KEY, None), |
| 201 | + item_id=payload.get(DATASET_ITEM_ID_KEY, None), |
| 202 | + confidence=payload.get(CONFIDENCE_KEY, None), |
| 203 | + annotation_id=payload.get(ANNOTATION_ID_KEY, None), |
| 204 | + metadata=payload.get(METADATA_KEY, {}), |
| 205 | + class_pdf=payload.get(CLASS_PDF_KEY, None), |
| 206 | + ) |
0 commit comments