4
4
from typing import Optional , Union , Any , Dict , List
5
5
from nucleus .constants import (
6
6
FRAMES_KEY ,
7
+ LENGTH_KEY ,
7
8
METADATA_KEY ,
9
+ NUM_SENSORS_KEY ,
8
10
REFERENCE_ID_KEY ,
9
11
POINTCLOUD_LOCATION_KEY ,
10
12
IMAGE_LOCATION_KEY ,
@@ -28,6 +30,22 @@ def __post_init__(self):
28
30
def add_item (self , item : DatasetItem , sensor_name : str ):
29
31
self .items [sensor_name ] = item
30
32
33
+ def get_item (self , sensor_name : str ):
34
+ if sensor_name not in self .items :
35
+ raise ValueError (
36
+ f"This frame does not have a { sensor_name } sensor"
37
+ )
38
+ return self .items [sensor_name ]
39
+
40
+ def get_items (self ):
41
+ return self .items .values ()
42
+
43
+ def get_sensors (self ):
44
+ return self .items .keys ()
45
+
46
+ def get_index (self ):
47
+ return self .index
48
+
31
49
@classmethod
32
50
def from_json (cls , payload : dict ):
33
51
items = {
@@ -51,6 +69,9 @@ class Scene(ABC):
51
69
52
70
def __post_init__ (self ):
53
71
self .check_valid_frame_indices ()
72
+ self .sensors = set (
73
+ flatten ([frame .get_sensors () for frame in self .frames ])
74
+ )
54
75
if all ((frame .index is not None for frame in self .frames )):
55
76
self .frames_dict = {frame .index : frame for frame in self .frames }
56
77
else :
@@ -60,6 +81,14 @@ def __post_init__(self):
60
81
]
61
82
self .frames_dict = dict (enumerate (indexed_frames ))
62
83
84
+ @property
85
+ def length (self ) -> int :
86
+ return len (self .frames_dict )
87
+
88
+ @property
89
+ def num_sensors (self ) -> int :
90
+ return len (self .get_sensors ())
91
+
63
92
def check_valid_frame_indices (self ):
64
93
infer_from_list_position = all (
65
94
(frame .index is None for frame in self .frames )
@@ -72,15 +101,14 @@ def check_valid_frame_indices(self):
72
101
), "Must specify index explicitly for all frames or infer from list position for all frames"
73
102
74
103
def validate (self ):
75
- assert (
76
- len (self .frames_dict ) > 0
77
- ), "Must have at least 1 frame in a scene"
104
+ assert self .length () > 0 , "Must have at least 1 frame in a scene"
78
105
for frame in self .frames_dict .values ():
79
106
assert isinstance (
80
107
frame , Frame
81
108
), "Each frame in a scene must be a Frame object"
82
109
83
110
def add_item (self , index : int , sensor_name : str , item : DatasetItem ):
111
+ self .sensors .add (sensor_name )
84
112
if index not in self .frames_dict :
85
113
new_frame = Frame (index = index , items = {sensor_name : item })
86
114
self .frames_dict [index ] = new_frame
@@ -97,6 +125,50 @@ def add_frame(self, frame: Frame, update: bool = False):
97
125
and update
98
126
):
99
127
self .frames_dict [frame .index ] = frame
128
+ self .sensors .update (frame .get_sensors ())
129
+
130
+ def get_frame (self , index : int ):
131
+ if index not in self .frames_dict :
132
+ raise ValueError (
133
+ f"This scene does not have a frame at index { index } "
134
+ )
135
+ return self .frames_dict [index ]
136
+
137
+ def get_frames (self ):
138
+ return [
139
+ frame
140
+ for _ , frame in sorted (
141
+ self .frames_dict .items (), key = lambda x : x [0 ]
142
+ )
143
+ ]
144
+
145
+ def get_sensors (self ):
146
+ return list (self .sensors )
147
+
148
+ def get_items_from_sensor (self , sensor_name : str ):
149
+ if sensor_name not in self .sensors :
150
+ raise ValueError (
151
+ f"This scene does not have a { sensor_name } sensor"
152
+ )
153
+ items_from_sensor = []
154
+ for frame in self .frames_dict .values ():
155
+ try :
156
+ sensor_item = frame .get_item (sensor_name )
157
+ items_from_sensor .append (sensor_item )
158
+ except ValueError :
159
+ # This sensor is not present at current frame
160
+ items_from_sensor .append (None )
161
+ return items_from_sensor
162
+
163
+ def get_items (self ):
164
+ return flatten ([frame .get_items () for frame in self .get_frames ()])
165
+
166
+ def info (self ):
167
+ return {
168
+ REFERENCE_ID_KEY : self .reference_id ,
169
+ LENGTH_KEY : self .length (),
170
+ NUM_SENSORS_KEY : self .num_sensors (),
171
+ }
100
172
101
173
def validate_frames_dict (self ):
102
174
is_continuous = set (list (range (len (self .frames_dict )))) == set (
@@ -118,12 +190,7 @@ def from_json(cls, payload: dict):
118
190
119
191
def to_payload (self ) -> dict :
120
192
self .validate_frames_dict ()
121
- ordered_frames = [
122
- frame
123
- for _ , frame in sorted (
124
- self .frames_dict .items (), key = lambda x : x [0 ]
125
- )
126
- ]
193
+ ordered_frames = self .get_frames ()
127
194
frames_payload = [frame .to_payload () for frame in ordered_frames ]
128
195
payload : Dict [str , Any ] = {
129
196
REFERENCE_ID_KEY : self .reference_id ,
@@ -141,25 +208,25 @@ def to_json(self) -> str:
141
208
class LidarScene (Scene ):
142
209
def validate (self ):
143
210
super ().validate ()
144
- lidar_sources = flatten (
211
+ lidar_sensors = flatten (
145
212
[
146
213
[
147
- source
148
- for source in frame .items .keys ()
149
- if frame .items [source ].type == DatasetItemType .POINTCLOUD
214
+ sensor
215
+ for sensor in frame .items .keys ()
216
+ if frame .items [sensor ].type == DatasetItemType .POINTCLOUD
150
217
]
151
218
for frame in self .frames_dict .values ()
152
219
]
153
220
)
154
221
assert (
155
- len (set (lidar_sources )) == 1
156
- ), "Each lidar scene must have exactly one lidar source "
222
+ len (set (lidar_sensors )) == 1
223
+ ), "Each lidar scene must have exactly one lidar sensor "
157
224
158
225
for frame in self .frames_dict .values ():
159
226
num_pointclouds = sum (
160
227
[
161
228
int (item .type == DatasetItemType .POINTCLOUD )
162
- for item in frame .items . values ()
229
+ for item in frame .get_items ()
163
230
]
164
231
)
165
232
assert (
@@ -173,17 +240,16 @@ def flatten(t):
173
240
174
241
def check_all_scene_paths_remote (scenes : List [LidarScene ]):
175
242
for scene in scenes :
176
- for frame in scene .frames_dict .values ():
177
- for item in frame .items .values ():
178
- pointcloud_location = getattr (item , POINTCLOUD_LOCATION_KEY )
179
- if pointcloud_location and is_local_path (pointcloud_location ):
180
- raise ValueError (
181
- f"All paths for DatasetItems in a Scene must be remote, but { item .pointcloud_location } is either "
182
- "local, or a remote URL type that is not supported."
183
- )
184
- image_location = getattr (item , IMAGE_LOCATION_KEY )
185
- if image_location and is_local_path (image_location ):
186
- raise ValueError (
187
- f"All paths for DatasetItems in a Scene must be remote, but { item .image_location } is either "
188
- "local, or a remote URL type that is not supported."
189
- )
243
+ for item in scene .get_items ():
244
+ pointcloud_location = getattr (item , POINTCLOUD_LOCATION_KEY )
245
+ if pointcloud_location and is_local_path (pointcloud_location ):
246
+ raise ValueError (
247
+ f"All paths for DatasetItems in a Scene must be remote, but { item .pointcloud_location } is either "
248
+ "local, or a remote URL type that is not supported."
249
+ )
250
+ image_location = getattr (item , IMAGE_LOCATION_KEY )
251
+ if image_location and is_local_path (image_location ):
252
+ raise ValueError (
253
+ f"All paths for DatasetItems in a Scene must be remote, but { item .image_location } is either "
254
+ "local, or a remote URL type that is not supported."
255
+ )
0 commit comments