Skip to content

Commit 1a2eaa8

Browse files
authored
Merge pull request #77 from scaleapi/cuboid_upload
added CuboidAnnotation dataclass
2 parents 947be07 + ca69891 commit 1a2eaa8

File tree

3 files changed

+130
-1
lines changed

3 files changed

+130
-1
lines changed

nucleus/annotation.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
ANNOTATION_ID_KEY,
99
ANNOTATIONS_KEY,
1010
BOX_TYPE,
11+
CUBOID_TYPE,
1112
DATASET_ITEM_ID_KEY,
13+
DIMENSIONS_KEY,
1214
GEOMETRY_KEY,
1315
HEIGHT_KEY,
1416
INDEX_KEY,
@@ -18,12 +20,15 @@
1820
MASK_URL_KEY,
1921
METADATA_KEY,
2022
POLYGON_TYPE,
23+
POSITION_KEY,
2124
REFERENCE_ID_KEY,
2225
TYPE_KEY,
2326
VERTICES_KEY,
2427
WIDTH_KEY,
2528
X_KEY,
29+
YAW_KEY,
2630
Y_KEY,
31+
Z_KEY,
2732
)
2833

2934

@@ -43,6 +48,8 @@ def from_json(cls, payload: dict):
4348
return BoxAnnotation.from_json(payload)
4449
elif payload.get(TYPE_KEY, None) == POLYGON_TYPE:
4550
return PolygonAnnotation.from_json(payload)
51+
elif payload.get(TYPE_KEY, None) == CUBOID_TYPE:
52+
return CuboidAnnotation.from_json(payload)
4653
else:
4754
return SegmentationAnnotation.from_json(payload)
4855

@@ -125,6 +132,7 @@ def to_payload(self) -> dict:
125132
class AnnotationTypes(Enum):
126133
BOX = BOX_TYPE
127134
POLYGON = POLYGON_TYPE
135+
CUBOID = CUBOID_TYPE
128136

129137

130138
@dataclass # pylint: disable=R0902
@@ -187,6 +195,20 @@ def to_payload(self) -> dict:
187195
return {X_KEY: self.x, Y_KEY: self.y}
188196

189197

198+
@dataclass
199+
class Point3D:
200+
x: float
201+
y: float
202+
z: float
203+
204+
@classmethod
205+
def from_json(cls, payload: Dict[str, float]):
206+
return cls(payload[X_KEY], payload[Y_KEY], payload[Z_KEY])
207+
208+
def to_payload(self) -> dict:
209+
return {X_KEY: self.x, Y_KEY: self.y, Z_KEY: self.z}
210+
211+
190212
@dataclass
191213
class PolygonAnnotation(Annotation):
192214
label: str
@@ -241,6 +263,50 @@ def to_payload(self) -> dict:
241263
return payload
242264

243265

266+
@dataclass # pylint: disable=R0902
267+
class CuboidAnnotation(Annotation): # pylint: disable=R0902
268+
label: str
269+
position: Point3D
270+
dimensions: Point3D
271+
yaw: float
272+
reference_id: Optional[str] = None
273+
item_id: Optional[str] = None
274+
annotation_id: Optional[str] = None
275+
metadata: Optional[Dict] = None
276+
277+
def __post_init__(self):
278+
self._check_ids()
279+
self.metadata = self.metadata if self.metadata else {}
280+
281+
@classmethod
282+
def from_json(cls, payload: dict):
283+
geometry = payload.get(GEOMETRY_KEY, {})
284+
return cls(
285+
label=payload.get(LABEL_KEY, 0),
286+
position=Point3D.from_json(geometry.get(POSITION_KEY, {})),
287+
dimensions=Point3D.from_json(geometry.get(DIMENSIONS_KEY, {})),
288+
yaw=payload.get(YAW_KEY, 0),
289+
reference_id=payload.get(REFERENCE_ID_KEY, None),
290+
item_id=payload.get(DATASET_ITEM_ID_KEY, None),
291+
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
292+
metadata=payload.get(METADATA_KEY, {}),
293+
)
294+
295+
def to_payload(self) -> dict:
296+
return {
297+
LABEL_KEY: self.label,
298+
TYPE_KEY: CUBOID_TYPE,
299+
GEOMETRY_KEY: {
300+
POSITION_KEY: self.position,
301+
DIMENSIONS_KEY: self.dimensions,
302+
YAW_KEY: self.yaw,
303+
},
304+
REFERENCE_ID_KEY: self.reference_id,
305+
ANNOTATION_ID_KEY: self.annotation_id,
306+
METADATA_KEY: self.metadata,
307+
}
308+
309+
244310
def check_all_annotation_paths_remote(
245311
annotations: Sequence[Union[Annotation]],
246312
):

nucleus/constants.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
POLYGON_TYPE = "polygon"
88
MASK_TYPE = "mask"
99
SEGMENTATION_TYPE = "segmentation"
10-
ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE)
10+
CUBOID_TYPE = "cuboid"
11+
ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE, CUBOID_TYPE)
1112
ANNOTATION_UPDATE_KEY = "update"
1213
AUTOTAGS_KEY = "autotags"
1314
EXPORTED_ROWS = "exportedRows"
@@ -22,6 +23,7 @@
2223
DATASET_SLICES_KEY = "slice_ids"
2324
DEFAULT_ANNOTATION_UPDATE_MODE = False
2425
DEFAULT_NETWORK_TIMEOUT_SEC = 120
26+
DIMENSIONS_KEY = "dimensions"
2527
EMBEDDINGS_URL_KEY = "embeddings_url"
2628
ERRORS_KEY = "errors"
2729
ERROR_CODES = "error_codes"
@@ -48,6 +50,7 @@
4850
NEW_ITEMS = "new_items"
4951
NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
5052
ORIGINAL_IMAGE_URL_KEY = "original_image_url"
53+
POSITION_KEY = "position"
5154
PREDICTIONS_IGNORED_KEY = "predictions_ignored"
5255
PREDICTIONS_PROCESSED_KEY = "predictions_processed"
5356
REFERENCE_IDS_KEY = "reference_ids"
@@ -63,5 +66,7 @@
6366
UPDATE_KEY = "update"
6467
VERTICES_KEY = "vertices"
6568
WIDTH_KEY = "width"
69+
YAW_KEY = "yaw"
6670
X_KEY = "x"
6771
Y_KEY = "y"
72+
Z_KEY = "z"

nucleus/prediction.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
PolygonAnnotation,
66
Segment,
77
SegmentationAnnotation,
8+
CuboidAnnotation,
9+
Point3D,
810
)
911
from .constants import (
1012
ANNOTATION_ID_KEY,
@@ -23,6 +25,9 @@
2325
ANNOTATIONS_KEY,
2426
ITEM_ID_KEY,
2527
MASK_URL_KEY,
28+
POSITION_KEY,
29+
DIMENSIONS_KEY,
30+
YAW_KEY,
2631
)
2732

2833

@@ -146,3 +151,56 @@ def from_json(cls, payload: dict):
146151
metadata=payload.get(METADATA_KEY, {}),
147152
class_pdf=payload.get(CLASS_PDF_KEY, None),
148153
)
154+
155+
156+
class CuboidPrediction(CuboidAnnotation):
157+
def __init__(
158+
self,
159+
label: str,
160+
position: Point3D,
161+
dimensions: Point3D,
162+
yaw: float,
163+
reference_id: Optional[str] = None,
164+
item_id: Optional[str] = None,
165+
confidence: Optional[float] = None,
166+
annotation_id: Optional[str] = None,
167+
metadata: Optional[Dict] = None,
168+
class_pdf: Optional[Dict] = None,
169+
):
170+
super().__init__(
171+
label=label,
172+
position=position,
173+
dimensions=dimensions,
174+
yaw=yaw,
175+
reference_id=reference_id,
176+
item_id=item_id,
177+
annotation_id=annotation_id,
178+
metadata=metadata,
179+
)
180+
self.confidence = confidence
181+
self.class_pdf = class_pdf
182+
183+
def to_payload(self) -> dict:
184+
payload = super().to_payload()
185+
if self.confidence is not None:
186+
payload[CONFIDENCE_KEY] = self.confidence
187+
if self.class_pdf is not None:
188+
payload[CLASS_PDF_KEY] = self.class_pdf
189+
190+
return payload
191+
192+
@classmethod
193+
def from_json(cls, payload: dict):
194+
geometry = payload.get(GEOMETRY_KEY, {})
195+
return cls(
196+
label=payload.get(LABEL_KEY, 0),
197+
position=Point3D.from_json(geometry.get(POSITION_KEY, {})),
198+
dimensions=Point3D.from_json(geometry.get(DIMENSIONS_KEY, {})),
199+
yaw=payload.get(YAW_KEY, 0),
200+
reference_id=payload.get(REFERENCE_ID_KEY, None),
201+
item_id=payload.get(DATASET_ITEM_ID_KEY, None),
202+
confidence=payload.get(CONFIDENCE_KEY, None),
203+
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
204+
metadata=payload.get(METADATA_KEY, {}),
205+
class_pdf=payload.get(CLASS_PDF_KEY, None),
206+
)

0 commit comments

Comments
 (0)