Skip to content

Commit 203e22a

Browse files
authored
Merge pull request #101 from scaleapi/add_getters_to_scene_class
Add getters to scene class
2 parents 16dda47 + 54fca37 commit 203e22a

File tree

6 files changed

+405
-86
lines changed

6 files changed

+405
-86
lines changed

nucleus/annotation.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,19 +295,23 @@ def from_json(cls, payload: dict):
295295
)
296296

297297
def to_payload(self) -> dict:
298-
return {
298+
payload = {
299299
LABEL_KEY: self.label,
300300
TYPE_KEY: CUBOID_TYPE,
301301
GEOMETRY_KEY: {
302302
POSITION_KEY: self.position.to_payload(),
303303
DIMENSIONS_KEY: self.dimensions.to_payload(),
304304
YAW_KEY: self.yaw,
305305
},
306-
REFERENCE_ID_KEY: self.reference_id,
307-
ITEM_ID_KEY: self.item_id,
308-
ANNOTATION_ID_KEY: self.annotation_id,
309-
METADATA_KEY: self.metadata,
310306
}
307+
if self.reference_id:
308+
payload[REFERENCE_ID_KEY] = self.reference_id
309+
if self.annotation_id:
310+
payload[ANNOTATION_ID_KEY] = self.annotation_id
311+
if self.metadata:
312+
payload[METADATA_KEY] = self.metadata
313+
314+
return payload
311315

312316

313317
def is_local_path(path: str) -> bool:

nucleus/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
ITEM_METADATA_SCHEMA_KEY = "item_metadata_schema"
5353
JOB_ID_KEY = "job_id"
5454
KEEP_HISTORY_KEY = "keep_history"
55+
LENGTH_KEY = "length"
5556
JOB_STATUS_KEY = "job_status"
5657
JOB_LAST_KNOWN_STATUS_KEY = "job_last_known_status"
5758
JOB_TYPE_KEY = "job_type"
@@ -65,7 +66,9 @@
6566
NAME_KEY = "name"
6667
NEW_ITEMS = "new_items"
6768
NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
69+
NUM_SENSORS_KEY = "num_sensors"
6870
ORIGINAL_IMAGE_URL_KEY = "original_image_url"
71+
POINTCLOUD_KEY = "pointcloud"
6972
POINTCLOUD_LOCATION_KEY = "pointcloud_location"
7073
POINTCLOUD_URL_KEY = "pointcloud_url"
7174
POSITION_KEY = "position"

nucleus/dataset_item.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,12 @@ class DatasetItem: # pylint: disable=R0902
9696
pointcloud_location: Optional[str] = None
9797

9898
def __post_init__(self):
99-
self.local = is_local_path(self.image_location)
10099
assert bool(self.image_location) != bool(
101100
self.pointcloud_location
102101
), "Must specify exactly one of the image_location, pointcloud_location parameters"
102+
self.local = (
103+
is_local_path(self.image_location) if self.image_location else None
104+
)
103105
self.type = (
104106
DatasetItemType.IMAGE
105107
if self.image_location
@@ -116,14 +118,14 @@ def __post_init__(self):
116118

117119
@classmethod
118120
def from_json(cls, payload: dict, is_scene=False):
119-
image_url = payload.get(IMAGE_URL_KEY, "") or payload.get(
120-
ORIGINAL_IMAGE_URL_KEY, ""
121+
image_url = payload.get(IMAGE_URL_KEY, None) or payload.get(
122+
ORIGINAL_IMAGE_URL_KEY, None
121123
)
122124

123125
if is_scene:
124126
return cls(
125127
image_location=image_url,
126-
pointcloud_location=payload.get(POINTCLOUD_URL_KEY, ""),
128+
pointcloud_location=payload.get(POINTCLOUD_URL_KEY, None),
127129
reference_id=payload.get(REFERENCE_ID_KEY, None),
128130
item_id=payload.get(DATASET_ITEM_ID_KEY, None),
129131
metadata=payload.get(METADATA_KEY, {}),
@@ -143,26 +145,25 @@ def to_payload(self, is_scene=False) -> dict:
143145
payload: Dict[str, Any] = {
144146
METADATA_KEY: self.metadata or {},
145147
}
148+
if self.reference_id:
149+
payload[REFERENCE_ID_KEY] = self.reference_id
150+
if self.item_id:
151+
payload[DATASET_ITEM_ID_KEY] = self.item_id
146152

147153
if is_scene:
148154
if self.image_location:
149155
payload[URL_KEY] = self.image_location
150156
elif self.pointcloud_location:
151157
payload[URL_KEY] = self.pointcloud_location
152158
payload[TYPE_KEY] = self.type.value
159+
if self.camera_params:
160+
payload[CAMERA_PARAMS_KEY] = self.camera_params.to_payload()
153161
else:
154162
assert (
155163
self.image_location
156-
), "Must specify image_location for DatasetItems not in a Scene"
164+
), "Must specify image_location for DatasetItems not in a LidarScene"
157165
payload[IMAGE_URL_KEY] = self.image_location
158166

159-
if self.reference_id:
160-
payload[REFERENCE_ID_KEY] = self.reference_id
161-
if self.item_id:
162-
payload[DATASET_ITEM_ID_KEY] = self.item_id
163-
if self.camera_params:
164-
payload[CAMERA_PARAMS_KEY] = self.camera_params.to_payload()
165-
166167
return payload
167168

168169
def to_json(self) -> str:

nucleus/scene.py

Lines changed: 106 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from typing import Optional, Union, Any, Dict, List
55
from nucleus.constants import (
66
FRAMES_KEY,
7+
LENGTH_KEY,
78
METADATA_KEY,
9+
NUM_SENSORS_KEY,
810
REFERENCE_ID_KEY,
911
POINTCLOUD_LOCATION_KEY,
1012
IMAGE_LOCATION_KEY,
@@ -25,9 +27,28 @@ def __post_init__(self):
2527
value, DatasetItem
2628
), "All values must be DatasetItems"
2729

30+
def __repr__(self) -> str:
31+
return f"Frame(index={self.index}, items={self.items})"
32+
2833
def add_item(self, item: DatasetItem, sensor_name: str):
2934
self.items[sensor_name] = item
3035

36+
def get_item(self, sensor_name: str):
37+
if sensor_name not in self.items:
38+
raise ValueError(
39+
f"This frame does not have a {sensor_name} sensor"
40+
)
41+
return self.items[sensor_name]
42+
43+
def get_items(self):
44+
return list(self.items.values())
45+
46+
def get_sensors(self):
47+
return list(self.items.keys())
48+
49+
def get_index(self):
50+
return self.index
51+
3152
@classmethod
3253
def from_json(cls, payload: dict):
3354
items = {
@@ -51,6 +72,9 @@ class Scene(ABC):
5172

5273
def __post_init__(self):
5374
self.check_valid_frame_indices()
75+
self.sensors = set(
76+
flatten([frame.get_sensors() for frame in self.frames])
77+
)
5478
if all((frame.index is not None for frame in self.frames)):
5579
self.frames_dict = {frame.index: frame for frame in self.frames}
5680
else:
@@ -60,6 +84,14 @@ def __post_init__(self):
6084
]
6185
self.frames_dict = dict(enumerate(indexed_frames))
6286

87+
@property
88+
def length(self) -> int:
89+
return len(self.frames_dict)
90+
91+
@property
92+
def num_sensors(self) -> int:
93+
return len(self.get_sensors())
94+
6395
def check_valid_frame_indices(self):
6496
infer_from_list_position = all(
6597
(frame.index is None for frame in self.frames)
@@ -72,15 +104,14 @@ def check_valid_frame_indices(self):
72104
), "Must specify index explicitly for all frames or infer from list position for all frames"
73105

74106
def validate(self):
75-
assert (
76-
len(self.frames_dict) > 0
77-
), "Must have at least 1 frame in a scene"
107+
assert self.length > 0, "Must have at least 1 frame in a scene"
78108
for frame in self.frames_dict.values():
79109
assert isinstance(
80110
frame, Frame
81111
), "Each frame in a scene must be a Frame object"
82112

83113
def add_item(self, index: int, sensor_name: str, item: DatasetItem):
114+
self.sensors.add(sensor_name)
84115
if index not in self.frames_dict:
85116
new_frame = Frame(index=index, items={sensor_name: item})
86117
self.frames_dict[index] = new_frame
@@ -97,6 +128,54 @@ def add_frame(self, frame: Frame, update: bool = False):
97128
and update
98129
):
99130
self.frames_dict[frame.index] = frame
131+
self.sensors.update(frame.get_sensors())
132+
133+
def get_frame(self, index: int):
134+
if index not in self.frames_dict:
135+
raise ValueError(
136+
f"This scene does not have a frame at index {index}"
137+
)
138+
return self.frames_dict[index]
139+
140+
def get_frames(self):
141+
return [
142+
frame
143+
for _, frame in sorted(
144+
self.frames_dict.items(), key=lambda x: x[0]
145+
)
146+
]
147+
148+
def get_sensors(self):
149+
return list(self.sensors)
150+
151+
def get_item(self, index: int, sensor_name: str):
152+
frame = self.get_frame(index)
153+
return frame.get_item(sensor_name)
154+
155+
def get_items_from_sensor(self, sensor_name: str):
156+
if sensor_name not in self.sensors:
157+
raise ValueError(
158+
f"This scene does not have a {sensor_name} sensor"
159+
)
160+
items_from_sensor = []
161+
for frame in self.frames_dict.values():
162+
try:
163+
sensor_item = frame.get_item(sensor_name)
164+
items_from_sensor.append(sensor_item)
165+
except ValueError:
166+
# This sensor is not present at current frame
167+
items_from_sensor.append(None)
168+
return items_from_sensor
169+
170+
def get_items(self):
171+
return flatten([frame.get_items() for frame in self.get_frames()])
172+
173+
def info(self):
174+
return {
175+
REFERENCE_ID_KEY: self.reference_id,
176+
LENGTH_KEY: self.length,
177+
NUM_SENSORS_KEY: self.num_sensors,
178+
}
100179

101180
def validate_frames_dict(self):
102181
is_continuous = set(list(range(len(self.frames_dict)))) == set(
@@ -118,12 +197,7 @@ def from_json(cls, payload: dict):
118197

119198
def to_payload(self) -> dict:
120199
self.validate_frames_dict()
121-
ordered_frames = [
122-
frame
123-
for _, frame in sorted(
124-
self.frames_dict.items(), key=lambda x: x[0]
125-
)
126-
]
200+
ordered_frames = self.get_frames()
127201
frames_payload = [frame.to_payload() for frame in ordered_frames]
128202
payload: Dict[str, Any] = {
129203
REFERENCE_ID_KEY: self.reference_id,
@@ -139,27 +213,30 @@ def to_json(self) -> str:
139213

140214
@dataclass
141215
class LidarScene(Scene):
216+
def __repr__(self) -> str:
217+
return f"LidarScene(reference_id='{self.reference_id}', frames={self.get_frames()}, metadata={self.metadata})"
218+
142219
def validate(self):
143220
super().validate()
144-
lidar_sources = flatten(
221+
lidar_sensors = flatten(
145222
[
146223
[
147-
source
148-
for source in frame.items.keys()
149-
if frame.items[source].type == DatasetItemType.POINTCLOUD
224+
sensor
225+
for sensor in frame.items.keys()
226+
if frame.items[sensor].type == DatasetItemType.POINTCLOUD
150227
]
151228
for frame in self.frames_dict.values()
152229
]
153230
)
154231
assert (
155-
len(set(lidar_sources)) == 1
156-
), "Each lidar scene must have exactly one lidar source"
232+
len(set(lidar_sensors)) == 1
233+
), "Each lidar scene must have exactly one lidar sensor"
157234

158235
for frame in self.frames_dict.values():
159236
num_pointclouds = sum(
160237
[
161238
int(item.type == DatasetItemType.POINTCLOUD)
162-
for item in frame.items.values()
239+
for item in frame.get_items()
163240
]
164241
)
165242
assert (
@@ -173,17 +250,16 @@ def flatten(t):
173250

174251
def check_all_scene_paths_remote(scenes: List[LidarScene]):
175252
for scene in scenes:
176-
for frame in scene.frames_dict.values():
177-
for item in frame.items.values():
178-
pointcloud_location = getattr(item, POINTCLOUD_LOCATION_KEY)
179-
if pointcloud_location and is_local_path(pointcloud_location):
180-
raise ValueError(
181-
f"All paths for DatasetItems in a Scene must be remote, but {item.pointcloud_location} is either "
182-
"local, or a remote URL type that is not supported."
183-
)
184-
image_location = getattr(item, IMAGE_LOCATION_KEY)
185-
if image_location and is_local_path(image_location):
186-
raise ValueError(
187-
f"All paths for DatasetItems in a Scene must be remote, but {item.image_location} is either "
188-
"local, or a remote URL type that is not supported."
189-
)
253+
for item in scene.get_items():
254+
pointcloud_location = getattr(item, POINTCLOUD_LOCATION_KEY)
255+
if pointcloud_location and is_local_path(pointcloud_location):
256+
raise ValueError(
257+
f"All paths for DatasetItems in a Scene must be remote, but {item.pointcloud_location} is either "
258+
"local, or a remote URL type that is not supported."
259+
)
260+
image_location = getattr(item, IMAGE_LOCATION_KEY)
261+
if image_location and is_local_path(image_location):
262+
raise ValueError(
263+
f"All paths for DatasetItems in a Scene must be remote, but {item.image_location} is either "
264+
"local, or a remote URL type that is not supported."
265+
)

tests/helpers.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424

2525
TEST_POINTCLOUD_URLS = [
2626
"https://scaleapi-cust-lidar.s3.us-west-1.amazonaws.com/test-scale/frame-0.json",
27+
"https://scaleapi-cust-lidar.s3.us-west-1.amazonaws.com/test-scale/frame-1.json",
28+
"https://scaleapi-cust-lidar.s3.us-west-1.amazonaws.com/test-scale/frame-2.json",
29+
"https://scaleapi-cust-lidar.s3.us-west-1.amazonaws.com/test-scale/frame-3.json",
30+
"https://scaleapi-cust-lidar.s3.us-west-1.amazonaws.com/test-scale/frame-4.json",
2731
]
2832

2933
TEST_LIDAR_SCENES = {
@@ -56,7 +60,14 @@
5660
}
5761
},
5862
},
59-
}
63+
},
64+
{
65+
"lidar": {
66+
"pointcloud_url": TEST_POINTCLOUD_URLS[0],
67+
"reference_id": "lidar_frame_2",
68+
"metadata": {},
69+
},
70+
},
6071
],
6172
"metadata": {},
6273
},
@@ -71,6 +82,14 @@
7182
DatasetItem(TEST_IMG_URLS[3], "4"),
7283
]
7384

85+
TEST_LIDAR_ITEMS = [
86+
DatasetItem(pointcloud_location=TEST_POINTCLOUD_URLS[0], reference_id="1"),
87+
DatasetItem(pointcloud_location=TEST_POINTCLOUD_URLS[1], reference_id="2"),
88+
DatasetItem(pointcloud_location=TEST_POINTCLOUD_URLS[2], reference_id="3"),
89+
DatasetItem(pointcloud_location=TEST_POINTCLOUD_URLS[3], reference_id="4"),
90+
DatasetItem(pointcloud_location=TEST_POINTCLOUD_URLS[4], reference_id="5"),
91+
]
92+
7493
LOCAL_FILENAME = "tests/test_img.jpg"
7594
TEST_PREDS = [
7695
BoxPrediction("[Pytest Box Prediction 1]", 0, 0, 100, 100, "1"),

0 commit comments

Comments
 (0)