Skip to content

Commit cf9d74d

Browse files
Drew KaulDrew Kaul
authored andcommitted
finish SceneDatasetItem class
1 parent 5ce34d6 commit cf9d74d

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

nucleus/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
ANNOTATION_UPDATE_KEY = "update"
1313
AUTOTAGS_KEY = "autotags"
1414
EXPORTED_ROWS = "exportedRows"
15+
CAMERA_PARAMS_KEY = "camera_params"
1516
CLASS_PDF_KEY = "class_pdf"
1617
CONFIDENCE_KEY = "confidence"
1718
DATASET_ID_KEY = "dataset_id"
@@ -70,6 +71,7 @@
7071
TYPE_KEY = "type"
7172
UPDATED_ITEMS = "updated_items"
7273
UPDATE_KEY = "update"
74+
URL_KEY = "url"
7375
VERTICES_KEY = "vertices"
7476
WIDTH_KEY = "width"
7577
YAW_KEY = "yaw"

nucleus/scene.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
1+
import json
12
from dataclasses import dataclass
23
from typing import Optional, Dict, List, Set
34
from enum import Enum
5+
from nucleus.constants import (
6+
CAMERA_PARAMS_KEY,
7+
METADATA_KEY,
8+
REFERENCE_ID_KEY,
9+
TYPE_KEY,
10+
URL_KEY,
11+
)
412
from .annotation import Point3D
513
from .utils import flatten
614

715

816
class DatasetItemType(Enum):
917
IMAGE = "image"
1018
POINTCLOUD = "pointcloud"
11-
VIDEO = "video"
1219

1320

1421
@dataclass
@@ -37,6 +44,28 @@ class SceneDatasetItem:
3744
metadata: Optional[dict] = None
3845
camera_params: Optional[CameraParams] = None
3946

47+
@classmethod
48+
def from_json(cls, payload: dict):
49+
return cls(
50+
url=payload.get(URL_KEY, ""),
51+
type=payload.get(TYPE_KEY, ""),
52+
reference_id=payload.get(REFERENCE_ID_KEY, None),
53+
metadata=payload.get(METADATA_KEY, None),
54+
camera_params=payload.get(CAMERA_PARAMS_KEY, None),
55+
)
56+
57+
def to_payload(self) -> dict:
58+
return {
59+
URL_KEY: self.url,
60+
TYPE_KEY: self.type,
61+
REFERENCE_ID_KEY: self.reference_id,
62+
METADATA_KEY: self.metadata,
63+
CAMERA_PARAMS_KEY: self.camera_params,
64+
}
65+
66+
def to_json(self) -> str:
67+
return json.dumps(self.to_payload(), allow_nan=False)
68+
4069

4170
@dataclass
4271
class Frame:
@@ -49,18 +78,31 @@ def __post_init__(self):
4978
value, SceneDatasetItem
5079
), "All values must be SceneDatasetItems"
5180

81+
def add_item(self, item: SceneDatasetItem, sensor_name: str):
82+
self.items[sensor_name] = item
83+
5284

5385
@dataclass
5486
class Scene:
5587
frames: List[Frame]
5688
reference_id: str
5789
metadata: Optional[dict] = None
5890

91+
def __post_init__(self):
92+
assert isinstance(self.frames, List), "frames must be a list"
93+
for frame in self.frames:
94+
assert isinstance(
95+
frame, Frame
96+
), "each element of frames must be a Frame object"
97+
assert len(self.frames) > 0, "frames must have length of at least 1"
98+
assert isinstance(
99+
self.reference_id, str
100+
), "reference_id must be a string"
101+
59102

60103
@dataclass
61104
class LidarScene(Scene):
62-
def __post_init__(self):
63-
# do validation here for lidar scene
105+
def validate(self):
64106
lidar_sources = flatten(
65107
[
66108
[
@@ -75,4 +117,13 @@ def __post_init__(self):
75117
len(Set(lidar_sources)) == 1
76118
), "Each lidar scene must have exactly one lidar source"
77119

78-
# TODO: check single pointcloud per frame
120+
for frame in self.frames:
121+
num_pointclouds = sum(
122+
[
123+
int(item.type == DatasetItemType.POINTCLOUD)
124+
for item in frame.values()
125+
]
126+
)
127+
assert (
128+
num_pointclouds == 1
129+
), "Each frame of a lidar scene must have exactly 1 pointcloud"

0 commit comments

Comments
 (0)