Skip to content

Commit bb4e865

Browse files
Drew KaulDrew Kaul
authored andcommitted
finish Scene class and make abstract
1 parent 7371bb6 commit bb4e865

File tree

1 file changed

+27
-23
lines changed

1 file changed

+27
-23
lines changed

nucleus/scene.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from abc import ABC
23
from dataclasses import dataclass, field
34
from typing import Optional, Union, Dict, List, Set
45
from enum import Enum
@@ -96,7 +97,7 @@ def to_payload(self) -> dict:
9697

9798

9899
@dataclass
99-
class Scene:
100+
class Scene(ABC):
100101
reference_id: str
101102
frames: List[Frame] = field(default_factory=list)
102103
metadata: Optional[dict] = None
@@ -106,7 +107,11 @@ def __post_init__(self):
106107
if all((frame.index is not None for frame in self.frames)):
107108
self.frames_dict = {frame.index: frame for frame in self.frames}
108109
else:
109-
self.frames_dict = dict(enumerate(self.frames))
110+
indexed_frames = [
111+
Frame(index=i, items=frame.items)
112+
for i, frame in enumerate(self.frames)
113+
]
114+
self.frames_dict = dict(enumerate(indexed_frames))
110115

111116
def check_valid_frame_indices(self):
112117
infer_from_list_position = all(
@@ -117,21 +122,18 @@ def check_valid_frame_indices(self):
117122
)
118123
assert (
119124
infer_from_list_position or explicit_frame_order
120-
), "Must specify index explicitly for all frames or implicitly for all frames (inferred from list position)"
121-
122-
# TODO: move validation to scene upload
123-
def validate_scene(self):
124-
assert isinstance(self.frames, List), "frames must be a list"
125-
assert len(self.frames) > 0, "frames must have length of at least 1"
126-
for frame in self.frames:
125+
), "Must specify index explicitly for all frames or infer from list position for all frames"
126+
127+
def validate(self):
128+
assert (
129+
len(self.frames_dict) > 0
130+
), "Must have at least 1 frame in a scene"
131+
for frame in self.frames_dict.values():
127132
assert isinstance(
128133
frame, Frame
129-
), "each element of frames must be a Frame object"
130-
assert isinstance(
131-
self.reference_id, str
132-
), "reference_id must be a string"
134+
), "Each frame in a scene must be a Frame object"
133135

134-
def add_item(self, item: SceneDatasetItem, index: int, sensor_name: str):
136+
def add_item(self, index: int, sensor_name: str, item: SceneDatasetItem):
135137
if index not in self.frames_dict:
136138
new_frame = Frame(index, {sensor_name: item})
137139
self.frames_dict[index] = new_frame
@@ -150,13 +152,13 @@ def add_frame(self, frame: Frame, update: bool = False):
150152
self.frames_dict[frame.index] = frame
151153

152154
def to_payload(self) -> dict:
153-
frames_payload = [frame.to_payload() for frame in self.frames]
154-
if len(frames_payload) > 0 and frames_payload[0][INDEX_KEY] is None:
155-
for i, _ in enumerate(frames_payload):
156-
frames_payload[i][INDEX_KEY] = i
157-
else:
158-
frames_payload.sort(key=lambda x: x[INDEX_KEY])
159-
155+
ordered_frames = [
156+
frame
157+
for _, frame in sorted(
158+
self.frames_dict.items(), key=lambda x: x[0]
159+
)
160+
]
161+
frames_payload = [frame.to_payload() for frame in ordered_frames]
160162
return {
161163
REFERENCE_ID_KEY: self.reference_id,
162164
FRAMES_KEY: frames_payload,
@@ -166,22 +168,24 @@ def to_payload(self) -> dict:
166168

167169
@dataclass
168170
class LidarScene(Scene):
171+
# TODO: call validate in scene upload
169172
def validate(self):
173+
super().validate()
170174
lidar_sources = flatten(
171175
[
172176
[
173177
source
174178
for source in frame.items.keys()
175179
if frame.items[source].type == DatasetItemType.POINTCLOUD
176180
]
177-
for frame in self.frames
181+
for frame in self.frames_dict.values()
178182
]
179183
)
180184
assert (
181185
len(Set(lidar_sources)) == 1
182186
), "Each lidar scene must have exactly one lidar source"
183187

184-
for frame in self.frames:
188+
for frame in self.frames_dict.values():
185189
num_pointclouds = sum(
186190
[
187191
int(item.type == DatasetItemType.POINTCLOUD)

0 commit comments

Comments
 (0)