Skip to content

Commit a578917

Browse files
Drew KaulDrew Kaul
authored andcommitted
add getters to scene class
1 parent 2413465 commit a578917

File tree

2 files changed

+98
-30
lines changed

2 files changed

+98
-30
lines changed

nucleus/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
ITEM_METADATA_SCHEMA_KEY = "item_metadata_schema"
5252
JOB_ID_KEY = "job_id"
5353
KEEP_HISTORY_KEY = "keep_history"
54+
LENGTH_KEY = "length"
5455
JOB_STATUS_KEY = "job_status"
5556
JOB_LAST_KNOWN_STATUS_KEY = "job_last_known_status"
5657
JOB_TYPE_KEY = "job_type"
@@ -64,6 +65,7 @@
6465
NAME_KEY = "name"
6566
NEW_ITEMS = "new_items"
6667
NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
68+
NUM_SENSORS_KEY = "num_sensors"
6769
ORIGINAL_IMAGE_URL_KEY = "original_image_url"
6870
POINTCLOUD_LOCATION_KEY = "pointcloud_location"
6971
POINTCLOUD_URL_KEY = "pointcloud_url"

nucleus/scene.py

Lines changed: 96 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from typing import Optional, Union, Any, Dict, List
55
from nucleus.constants import (
66
FRAMES_KEY,
7+
LENGTH_KEY,
78
METADATA_KEY,
9+
NUM_SENSORS_KEY,
810
REFERENCE_ID_KEY,
911
POINTCLOUD_LOCATION_KEY,
1012
IMAGE_LOCATION_KEY,
@@ -28,6 +30,22 @@ def __post_init__(self):
2830
def add_item(self, item: DatasetItem, sensor_name: str):
2931
self.items[sensor_name] = item
3032

33+
def get_item(self, sensor_name: str):
34+
if sensor_name not in self.items:
35+
raise ValueError(
36+
f"This frame does not have a {sensor_name} sensor"
37+
)
38+
return self.items[sensor_name]
39+
40+
def get_items(self):
41+
return self.items.values()
42+
43+
def get_sensors(self):
44+
return self.items.keys()
45+
46+
def get_index(self):
47+
return self.index
48+
3149
@classmethod
3250
def from_json(cls, payload: dict):
3351
items = {
@@ -51,6 +69,9 @@ class Scene(ABC):
5169

5270
def __post_init__(self):
5371
self.check_valid_frame_indices()
72+
self.sensors = set(
73+
flatten([frame.get_sensors() for frame in self.frames])
74+
)
5475
if all((frame.index is not None for frame in self.frames)):
5576
self.frames_dict = {frame.index: frame for frame in self.frames}
5677
else:
@@ -60,6 +81,14 @@ def __post_init__(self):
6081
]
6182
self.frames_dict = dict(enumerate(indexed_frames))
6283

84+
@property
85+
def length(self) -> int:
86+
return len(self.frames_dict)
87+
88+
@property
89+
def num_sensors(self) -> int:
90+
return len(self.get_sensors())
91+
6392
def check_valid_frame_indices(self):
6493
infer_from_list_position = all(
6594
(frame.index is None for frame in self.frames)
@@ -72,15 +101,14 @@ def check_valid_frame_indices(self):
72101
), "Must specify index explicitly for all frames or infer from list position for all frames"
73102

74103
def validate(self):
75-
assert (
76-
len(self.frames_dict) > 0
77-
), "Must have at least 1 frame in a scene"
104+
assert self.length() > 0, "Must have at least 1 frame in a scene"
78105
for frame in self.frames_dict.values():
79106
assert isinstance(
80107
frame, Frame
81108
), "Each frame in a scene must be a Frame object"
82109

83110
def add_item(self, index: int, sensor_name: str, item: DatasetItem):
111+
self.sensors.add(sensor_name)
84112
if index not in self.frames_dict:
85113
new_frame = Frame(index=index, items={sensor_name: item})
86114
self.frames_dict[index] = new_frame
@@ -97,6 +125,50 @@ def add_frame(self, frame: Frame, update: bool = False):
97125
and update
98126
):
99127
self.frames_dict[frame.index] = frame
128+
self.sensors.update(frame.get_sensors())
129+
130+
def get_frame(self, index: int):
131+
if index not in self.frames_dict:
132+
raise ValueError(
133+
f"This scene does not have a frame at index {index}"
134+
)
135+
return self.frames_dict[index]
136+
137+
def get_frames(self):
138+
return [
139+
frame
140+
for _, frame in sorted(
141+
self.frames_dict.items(), key=lambda x: x[0]
142+
)
143+
]
144+
145+
def get_sensors(self):
146+
return list(self.sensors)
147+
148+
def get_items_from_sensor(self, sensor_name: str):
149+
if sensor_name not in self.sensors:
150+
raise ValueError(
151+
f"This scene does not have a {sensor_name} sensor"
152+
)
153+
items_from_sensor = []
154+
for frame in self.frames_dict.values():
155+
try:
156+
sensor_item = frame.get_item(sensor_name)
157+
items_from_sensor.append(sensor_item)
158+
except ValueError:
159+
# This sensor is not present at current frame
160+
items_from_sensor.append(None)
161+
return items_from_sensor
162+
163+
def get_items(self):
164+
return flatten([frame.get_items() for frame in self.get_frames()])
165+
166+
def info(self):
167+
return {
168+
REFERENCE_ID_KEY: self.reference_id,
169+
LENGTH_KEY: self.length(),
170+
NUM_SENSORS_KEY: self.num_sensors(),
171+
}
100172

101173
def validate_frames_dict(self):
102174
is_continuous = set(list(range(len(self.frames_dict)))) == set(
@@ -118,12 +190,7 @@ def from_json(cls, payload: dict):
118190

119191
def to_payload(self) -> dict:
120192
self.validate_frames_dict()
121-
ordered_frames = [
122-
frame
123-
for _, frame in sorted(
124-
self.frames_dict.items(), key=lambda x: x[0]
125-
)
126-
]
193+
ordered_frames = self.get_frames()
127194
frames_payload = [frame.to_payload() for frame in ordered_frames]
128195
payload: Dict[str, Any] = {
129196
REFERENCE_ID_KEY: self.reference_id,
@@ -141,25 +208,25 @@ def to_json(self) -> str:
141208
class LidarScene(Scene):
142209
def validate(self):
143210
super().validate()
144-
lidar_sources = flatten(
211+
lidar_sensors = flatten(
145212
[
146213
[
147-
source
148-
for source in frame.items.keys()
149-
if frame.items[source].type == DatasetItemType.POINTCLOUD
214+
sensor
215+
for sensor in frame.items.keys()
216+
if frame.items[sensor].type == DatasetItemType.POINTCLOUD
150217
]
151218
for frame in self.frames_dict.values()
152219
]
153220
)
154221
assert (
155-
len(set(lidar_sources)) == 1
156-
), "Each lidar scene must have exactly one lidar source"
222+
len(set(lidar_sensors)) == 1
223+
), "Each lidar scene must have exactly one lidar sensor"
157224

158225
for frame in self.frames_dict.values():
159226
num_pointclouds = sum(
160227
[
161228
int(item.type == DatasetItemType.POINTCLOUD)
162-
for item in frame.items.values()
229+
for item in frame.get_items()
163230
]
164231
)
165232
assert (
@@ -173,17 +240,16 @@ def flatten(t):
173240

174241
def check_all_scene_paths_remote(scenes: List[LidarScene]):
175242
for scene in scenes:
176-
for frame in scene.frames_dict.values():
177-
for item in frame.items.values():
178-
pointcloud_location = getattr(item, POINTCLOUD_LOCATION_KEY)
179-
if pointcloud_location and is_local_path(pointcloud_location):
180-
raise ValueError(
181-
f"All paths for DatasetItems in a Scene must be remote, but {item.pointcloud_location} is either "
182-
"local, or a remote URL type that is not supported."
183-
)
184-
image_location = getattr(item, IMAGE_LOCATION_KEY)
185-
if image_location and is_local_path(image_location):
186-
raise ValueError(
187-
f"All paths for DatasetItems in a Scene must be remote, but {item.image_location} is either "
188-
"local, or a remote URL type that is not supported."
189-
)
243+
for item in scene.get_items():
244+
pointcloud_location = getattr(item, POINTCLOUD_LOCATION_KEY)
245+
if pointcloud_location and is_local_path(pointcloud_location):
246+
raise ValueError(
247+
f"All paths for DatasetItems in a Scene must be remote, but {item.pointcloud_location} is either "
248+
"local, or a remote URL type that is not supported."
249+
)
250+
image_location = getattr(item, IMAGE_LOCATION_KEY)
251+
if image_location and is_local_path(image_location):
252+
raise ValueError(
253+
f"All paths for DatasetItems in a Scene must be remote, but {item.image_location} is either "
254+
"local, or a remote URL type that is not supported."
255+
)

0 commit comments

Comments
 (0)