1
1
import json
2
+ from abc import ABC
2
3
from dataclasses import dataclass , field
3
4
from typing import Optional , Union , Dict , List , Set
4
5
from enum import Enum
@@ -96,7 +97,7 @@ def to_payload(self) -> dict:
96
97
97
98
98
99
@dataclass
99
- class Scene :
100
+ class Scene ( ABC ) :
100
101
reference_id : str
101
102
frames : List [Frame ] = field (default_factory = list )
102
103
metadata : Optional [dict ] = None
@@ -106,7 +107,11 @@ def __post_init__(self):
106
107
if all ((frame .index is not None for frame in self .frames )):
107
108
self .frames_dict = {frame .index : frame for frame in self .frames }
108
109
else :
109
- self .frames_dict = dict (enumerate (self .frames ))
110
+ indexed_frames = [
111
+ Frame (index = i , items = frame .items )
112
+ for i , frame in enumerate (self .frames )
113
+ ]
114
+ self .frames_dict = dict (enumerate (indexed_frames ))
110
115
111
116
def check_valid_frame_indices (self ):
112
117
infer_from_list_position = all (
@@ -117,21 +122,18 @@ def check_valid_frame_indices(self):
117
122
)
118
123
assert (
119
124
infer_from_list_position or explicit_frame_order
120
- ), "Must specify index explicitly for all frames or implicitly for all frames (inferred from list position) "
121
-
122
- # TODO: move validation to scene upload
123
- def validate_scene ( self ):
124
- assert isinstance (self .frames , List ), "frames must be a list"
125
- assert len ( self . frames ) > 0 , "frames must have length of at least 1"
126
- for frame in self .frames :
125
+ ), "Must specify index explicitly for all frames or infer from list position for all frames "
126
+
127
+ def validate ( self ):
128
+ assert (
129
+ len (self .frames_dict ) > 0
130
+ ) , "Must have at least 1 frame in a scene "
131
+ for frame in self .frames_dict . values () :
127
132
assert isinstance (
128
133
frame , Frame
129
- ), "each element of frames must be a Frame object"
130
- assert isinstance (
131
- self .reference_id , str
132
- ), "reference_id must be a string"
134
+ ), "Each frame in a scene must be a Frame object"
133
135
134
- def add_item (self , item : SceneDatasetItem , index : int , sensor_name : str ):
136
+ def add_item (self , index : int , sensor_name : str , item : SceneDatasetItem ):
135
137
if index not in self .frames_dict :
136
138
new_frame = Frame (index , {sensor_name : item })
137
139
self .frames_dict [index ] = new_frame
@@ -150,13 +152,13 @@ def add_frame(self, frame: Frame, update: bool = False):
150
152
self .frames_dict [frame .index ] = frame
151
153
152
154
def to_payload (self ) -> dict :
153
- frames_payload = [frame . to_payload () for frame in self . frames ]
154
- if len ( frames_payload ) > 0 and frames_payload [ 0 ][ INDEX_KEY ] is None :
155
- for i , _ in enumerate ( frames_payload ):
156
- frames_payload [ i ][ INDEX_KEY ] = i
157
- else :
158
- frames_payload . sort ( key = lambda x : x [ INDEX_KEY ])
159
-
155
+ ordered_frames = [
156
+ frame
157
+ for _ , frame in sorted (
158
+ self . frames_dict . items (), key = lambda x : x [ 0 ]
159
+ )
160
+ ]
161
+ frames_payload = [ frame . to_payload () for frame in ordered_frames ]
160
162
return {
161
163
REFERENCE_ID_KEY : self .reference_id ,
162
164
FRAMES_KEY : frames_payload ,
@@ -166,22 +168,24 @@ def to_payload(self) -> dict:
166
168
167
169
@dataclass
168
170
class LidarScene (Scene ):
171
+ # TODO: call validate in scene upload
169
172
def validate (self ):
173
+ super ().validate ()
170
174
lidar_sources = flatten (
171
175
[
172
176
[
173
177
source
174
178
for source in frame .items .keys ()
175
179
if frame .items [source ].type == DatasetItemType .POINTCLOUD
176
180
]
177
- for frame in self .frames
181
+ for frame in self .frames_dict . values ()
178
182
]
179
183
)
180
184
assert (
181
185
len (Set (lidar_sources )) == 1
182
186
), "Each lidar scene must have exactly one lidar source"
183
187
184
- for frame in self .frames :
188
+ for frame in self .frames_dict . values () :
185
189
num_pointclouds = sum (
186
190
[
187
191
int (item .type == DatasetItemType .POINTCLOUD )
0 commit comments