4
4
from typing import Optional , Union , Any , Dict , List
5
5
from nucleus .constants import (
6
6
FRAMES_KEY ,
7
+ LENGTH_KEY ,
7
8
METADATA_KEY ,
9
+ NUM_SENSORS_KEY ,
8
10
REFERENCE_ID_KEY ,
9
11
POINTCLOUD_LOCATION_KEY ,
10
12
IMAGE_LOCATION_KEY ,
@@ -25,9 +27,28 @@ def __post_init__(self):
25
27
value , DatasetItem
26
28
), "All values must be DatasetItems"
27
29
30
+ def __repr__ (self ) -> str :
31
+ return f"Frame(index={ self .index } , items={ self .items } )"
32
+
28
33
def add_item (self , item : DatasetItem , sensor_name : str ):
29
34
self .items [sensor_name ] = item
30
35
36
+ def get_item (self , sensor_name : str ):
37
+ if sensor_name not in self .items :
38
+ raise ValueError (
39
+ f"This frame does not have a { sensor_name } sensor"
40
+ )
41
+ return self .items [sensor_name ]
42
+
43
+ def get_items (self ):
44
+ return list (self .items .values ())
45
+
46
+ def get_sensors (self ):
47
+ return list (self .items .keys ())
48
+
49
+ def get_index (self ):
50
+ return self .index
51
+
31
52
@classmethod
32
53
def from_json (cls , payload : dict ):
33
54
items = {
@@ -51,6 +72,9 @@ class Scene(ABC):
51
72
52
73
def __post_init__ (self ):
53
74
self .check_valid_frame_indices ()
75
+ self .sensors = set (
76
+ flatten ([frame .get_sensors () for frame in self .frames ])
77
+ )
54
78
if all ((frame .index is not None for frame in self .frames )):
55
79
self .frames_dict = {frame .index : frame for frame in self .frames }
56
80
else :
@@ -60,6 +84,14 @@ def __post_init__(self):
60
84
]
61
85
self .frames_dict = dict (enumerate (indexed_frames ))
62
86
87
+ @property
88
+ def length (self ) -> int :
89
+ return len (self .frames_dict )
90
+
91
+ @property
92
+ def num_sensors (self ) -> int :
93
+ return len (self .get_sensors ())
94
+
63
95
def check_valid_frame_indices (self ):
64
96
infer_from_list_position = all (
65
97
(frame .index is None for frame in self .frames )
@@ -72,15 +104,14 @@ def check_valid_frame_indices(self):
72
104
), "Must specify index explicitly for all frames or infer from list position for all frames"
73
105
74
106
def validate (self ):
75
- assert (
76
- len (self .frames_dict ) > 0
77
- ), "Must have at least 1 frame in a scene"
107
+ assert self .length > 0 , "Must have at least 1 frame in a scene"
78
108
for frame in self .frames_dict .values ():
79
109
assert isinstance (
80
110
frame , Frame
81
111
), "Each frame in a scene must be a Frame object"
82
112
83
113
def add_item (self , index : int , sensor_name : str , item : DatasetItem ):
114
+ self .sensors .add (sensor_name )
84
115
if index not in self .frames_dict :
85
116
new_frame = Frame (index = index , items = {sensor_name : item })
86
117
self .frames_dict [index ] = new_frame
@@ -97,6 +128,54 @@ def add_frame(self, frame: Frame, update: bool = False):
97
128
and update
98
129
):
99
130
self .frames_dict [frame .index ] = frame
131
+ self .sensors .update (frame .get_sensors ())
132
+
133
+ def get_frame (self , index : int ):
134
+ if index not in self .frames_dict :
135
+ raise ValueError (
136
+ f"This scene does not have a frame at index { index } "
137
+ )
138
+ return self .frames_dict [index ]
139
+
140
+ def get_frames (self ):
141
+ return [
142
+ frame
143
+ for _ , frame in sorted (
144
+ self .frames_dict .items (), key = lambda x : x [0 ]
145
+ )
146
+ ]
147
+
148
+ def get_sensors (self ):
149
+ return list (self .sensors )
150
+
151
+ def get_item (self , index : int , sensor_name : str ):
152
+ frame = self .get_frame (index )
153
+ return frame .get_item (sensor_name )
154
+
155
+ def get_items_from_sensor (self , sensor_name : str ):
156
+ if sensor_name not in self .sensors :
157
+ raise ValueError (
158
+ f"This scene does not have a { sensor_name } sensor"
159
+ )
160
+ items_from_sensor = []
161
+ for frame in self .frames_dict .values ():
162
+ try :
163
+ sensor_item = frame .get_item (sensor_name )
164
+ items_from_sensor .append (sensor_item )
165
+ except ValueError :
166
+ # This sensor is not present at current frame
167
+ items_from_sensor .append (None )
168
+ return items_from_sensor
169
+
170
+ def get_items (self ):
171
+ return flatten ([frame .get_items () for frame in self .get_frames ()])
172
+
173
+ def info (self ):
174
+ return {
175
+ REFERENCE_ID_KEY : self .reference_id ,
176
+ LENGTH_KEY : self .length ,
177
+ NUM_SENSORS_KEY : self .num_sensors ,
178
+ }
100
179
101
180
def validate_frames_dict (self ):
102
181
is_continuous = set (list (range (len (self .frames_dict )))) == set (
@@ -118,12 +197,7 @@ def from_json(cls, payload: dict):
118
197
119
198
def to_payload (self ) -> dict :
120
199
self .validate_frames_dict ()
121
- ordered_frames = [
122
- frame
123
- for _ , frame in sorted (
124
- self .frames_dict .items (), key = lambda x : x [0 ]
125
- )
126
- ]
200
+ ordered_frames = self .get_frames ()
127
201
frames_payload = [frame .to_payload () for frame in ordered_frames ]
128
202
payload : Dict [str , Any ] = {
129
203
REFERENCE_ID_KEY : self .reference_id ,
@@ -139,27 +213,30 @@ def to_json(self) -> str:
139
213
140
214
@dataclass
141
215
class LidarScene (Scene ):
216
+ def __repr__ (self ) -> str :
217
+ return f"LidarScene(reference_id='{ self .reference_id } ', frames={ self .get_frames ()} , metadata={ self .metadata } )"
218
+
142
219
def validate (self ):
143
220
super ().validate ()
144
- lidar_sources = flatten (
221
+ lidar_sensors = flatten (
145
222
[
146
223
[
147
- source
148
- for source in frame .items .keys ()
149
- if frame .items [source ].type == DatasetItemType .POINTCLOUD
224
+ sensor
225
+ for sensor in frame .items .keys ()
226
+ if frame .items [sensor ].type == DatasetItemType .POINTCLOUD
150
227
]
151
228
for frame in self .frames_dict .values ()
152
229
]
153
230
)
154
231
assert (
155
- len (set (lidar_sources )) == 1
156
- ), "Each lidar scene must have exactly one lidar source "
232
+ len (set (lidar_sensors )) == 1
233
+ ), "Each lidar scene must have exactly one lidar sensor "
157
234
158
235
for frame in self .frames_dict .values ():
159
236
num_pointclouds = sum (
160
237
[
161
238
int (item .type == DatasetItemType .POINTCLOUD )
162
- for item in frame .items . values ()
239
+ for item in frame .get_items ()
163
240
]
164
241
)
165
242
assert (
@@ -173,17 +250,16 @@ def flatten(t):
173
250
174
251
def check_all_scene_paths_remote (scenes : List [LidarScene ]):
175
252
for scene in scenes :
176
- for frame in scene .frames_dict .values ():
177
- for item in frame .items .values ():
178
- pointcloud_location = getattr (item , POINTCLOUD_LOCATION_KEY )
179
- if pointcloud_location and is_local_path (pointcloud_location ):
180
- raise ValueError (
181
- f"All paths for DatasetItems in a Scene must be remote, but { item .pointcloud_location } is either "
182
- "local, or a remote URL type that is not supported."
183
- )
184
- image_location = getattr (item , IMAGE_LOCATION_KEY )
185
- if image_location and is_local_path (image_location ):
186
- raise ValueError (
187
- f"All paths for DatasetItems in a Scene must be remote, but { item .image_location } is either "
188
- "local, or a remote URL type that is not supported."
189
- )
253
+ for item in scene .get_items ():
254
+ pointcloud_location = getattr (item , POINTCLOUD_LOCATION_KEY )
255
+ if pointcloud_location and is_local_path (pointcloud_location ):
256
+ raise ValueError (
257
+ f"All paths for DatasetItems in a Scene must be remote, but { item .pointcloud_location } is either "
258
+ "local, or a remote URL type that is not supported."
259
+ )
260
+ image_location = getattr (item , IMAGE_LOCATION_KEY )
261
+ if image_location and is_local_path (image_location ):
262
+ raise ValueError (
263
+ f"All paths for DatasetItems in a Scene must be remote, but { item .image_location } is either "
264
+ "local, or a remote URL type that is not supported."
265
+ )
0 commit comments