1
+ import json
1
2
from dataclasses import dataclass
2
3
from typing import Optional , Dict , List , Set
3
4
from enum import Enum
5
+ from nucleus .constants import (
6
+ CAMERA_PARAMS_KEY ,
7
+ METADATA_KEY ,
8
+ REFERENCE_ID_KEY ,
9
+ TYPE_KEY ,
10
+ URL_KEY ,
11
+ )
4
12
from .annotation import Point3D
5
13
from .utils import flatten
6
14
7
15
8
16
class DatasetItemType (Enum ):
9
17
IMAGE = "image"
10
18
POINTCLOUD = "pointcloud"
11
- VIDEO = "video"
12
19
13
20
14
21
@dataclass
@@ -37,6 +44,28 @@ class SceneDatasetItem:
37
44
metadata : Optional [dict ] = None
38
45
camera_params : Optional [CameraParams ] = None
39
46
47
+ @classmethod
48
+ def from_json (cls , payload : dict ):
49
+ return cls (
50
+ url = payload .get (URL_KEY , "" ),
51
+ type = payload .get (TYPE_KEY , "" ),
52
+ reference_id = payload .get (REFERENCE_ID_KEY , None ),
53
+ metadata = payload .get (METADATA_KEY , None ),
54
+ camera_params = payload .get (CAMERA_PARAMS_KEY , None ),
55
+ )
56
+
57
+ def to_payload (self ) -> dict :
58
+ return {
59
+ URL_KEY : self .url ,
60
+ TYPE_KEY : self .type ,
61
+ REFERENCE_ID_KEY : self .reference_id ,
62
+ METADATA_KEY : self .metadata ,
63
+ CAMERA_PARAMS_KEY : self .camera_params ,
64
+ }
65
+
66
+ def to_json (self ) -> str :
67
+ return json .dumps (self .to_payload (), allow_nan = False )
68
+
40
69
41
70
@dataclass
42
71
class Frame :
@@ -49,18 +78,31 @@ def __post_init__(self):
49
78
value , SceneDatasetItem
50
79
), "All values must be SceneDatasetItems"
51
80
81
+ def add_item (self , item : SceneDatasetItem , sensor_name : str ):
82
+ self .items [sensor_name ] = item
83
+
52
84
53
85
@dataclass
54
86
class Scene :
55
87
frames : List [Frame ]
56
88
reference_id : str
57
89
metadata : Optional [dict ] = None
58
90
91
+ def __post_init__ (self ):
92
+ assert isinstance (self .frames , List ), "frames must be a list"
93
+ for frame in self .frames :
94
+ assert isinstance (
95
+ frame , Frame
96
+ ), "each element of frames must be a Frame object"
97
+ assert len (self .frames ) > 0 , "frames must have length of at least 1"
98
+ assert isinstance (
99
+ self .reference_id , str
100
+ ), "reference_id must be a string"
101
+
59
102
60
103
@dataclass
61
104
class LidarScene (Scene ):
62
- def __post_init__ (self ):
63
- # do validation here for lidar scene
105
+ def validate (self ):
64
106
lidar_sources = flatten (
65
107
[
66
108
[
@@ -75,4 +117,13 @@ def __post_init__(self):
75
117
len (Set (lidar_sources )) == 1
76
118
), "Each lidar scene must have exactly one lidar source"
77
119
78
- # TODO: check single pointcloud per frame
120
+ for frame in self .frames :
121
+ num_pointclouds = sum (
122
+ [
123
+ int (item .type == DatasetItemType .POINTCLOUD )
124
+ for item in frame .values ()
125
+ ]
126
+ )
127
+ assert (
128
+ num_pointclouds == 1
129
+ ), "Each frame of a lidar scene must have exactly 1 pointcloud"
0 commit comments