Skip to content

Commit 70b04f4

Browse files
authored
Merge pull request #104 from scaleapi/nucleus_frame_args
Small tweaks to improve Frame UX
2 parents e52488e + c72a4dd commit 70b04f4

File tree

2 files changed

+27
-50
lines changed

2 files changed

+27
-50
lines changed

nucleus/scene.py

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from abc import ABC
33
from dataclasses import dataclass, field
4-
from typing import Optional, Union, Any, Dict, List
4+
from typing import Optional, Any, Dict, List
55
from nucleus.constants import (
66
FRAMES_KEY,
77
LENGTH_KEY,
@@ -15,10 +15,11 @@
1515
from .dataset_item import DatasetItemType, DatasetItem
1616

1717

18-
@dataclass
1918
class Frame:
20-
items: Dict[str, DatasetItem] = field(default_factory=dict)
21-
index: Union[int, None] = None
19+
def __init__(self, **kwargs):
20+
self.items = {}
21+
for key, value in kwargs.items():
22+
self.items[key] = value
2223

2324
def __post_init__(self):
2425
for key, value in self.items.items():
@@ -28,7 +29,7 @@ def __post_init__(self):
2829
), "All values must be DatasetItems"
2930

3031
def __repr__(self) -> str:
31-
return f"Frame(index={self.index}, items={self.items})"
32+
return f"Frame(items={self.items})"
3233

3334
def add_item(self, item: DatasetItem, sensor_name: str):
3435
self.items[sensor_name] = item
@@ -46,16 +47,13 @@ def get_items(self):
4647
def get_sensors(self):
4748
return list(self.items.keys())
4849

49-
def get_index(self):
50-
return self.index
51-
5250
@classmethod
5351
def from_json(cls, payload: dict):
5452
items = {
5553
sensor: DatasetItem.from_json(item, is_scene=True)
5654
for sensor, item in payload.items()
5755
}
58-
return cls(items=items)
56+
return cls(**items)
5957

6058
def to_payload(self) -> dict:
6159
return {
@@ -71,18 +69,10 @@ class Scene(ABC):
7169
metadata: Optional[dict] = None
7270

7371
def __post_init__(self):
74-
self.check_valid_frame_indices()
7572
self.sensors = set(
7673
flatten([frame.get_sensors() for frame in self.frames])
7774
)
78-
if all((frame.index is not None for frame in self.frames)):
79-
self.frames_dict = {frame.index: frame for frame in self.frames}
80-
else:
81-
indexed_frames = [
82-
Frame(index=i, items=frame.items)
83-
for i, frame in enumerate(self.frames)
84-
]
85-
self.frames_dict = dict(enumerate(indexed_frames))
75+
self.frames_dict = dict(enumerate(self.frames))
8676

8777
@property
8878
def length(self) -> int:
@@ -92,17 +82,6 @@ def length(self) -> int:
9282
def num_sensors(self) -> int:
9383
return len(self.get_sensors())
9484

95-
def check_valid_frame_indices(self):
96-
infer_from_list_position = all(
97-
(frame.index is None for frame in self.frames)
98-
)
99-
explicit_frame_order = all(
100-
(frame.index is not None for frame in self.frames)
101-
)
102-
assert (
103-
infer_from_list_position or explicit_frame_order
104-
), "Must specify index explicitly for all frames or infer from list position for all frames"
105-
10685
def validate(self):
10786
assert self.length > 0, "Must have at least 1 frame in a scene"
10887
for frame in self.frames_dict.values():
@@ -113,21 +92,18 @@ def validate(self):
11392
def add_item(self, index: int, sensor_name: str, item: DatasetItem):
11493
self.sensors.add(sensor_name)
11594
if index not in self.frames_dict:
116-
new_frame = Frame(index=index, items={sensor_name: item})
95+
new_frame = Frame(**{sensor_name: item})
11796
self.frames_dict[index] = new_frame
11897
else:
11998
self.frames_dict[index].items[sensor_name] = item
12099

121-
def add_frame(self, frame: Frame, update: bool = False):
122-
assert (
123-
frame.index is not None
124-
), "Must specify index explicitly when calling add_frame"
100+
def add_frame(self, frame: Frame, index: int, update: bool = False):
125101
if (
126-
frame.index not in self.frames_dict
127-
or frame.index in self.frames_dict
102+
index not in self.frames_dict
103+
or index in self.frames_dict
128104
and update
129105
):
130-
self.frames_dict[frame.index] = frame
106+
self.frames_dict[index] = frame
131107
self.sensors.update(frame.get_sensors())
132108

133109
def get_frame(self, index: int):

tests/test_scene.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,10 @@ def dataset(CLIENT):
4747

4848

4949
def test_frame_add_item():
50-
frame = Frame(index=0)
50+
frame = Frame()
5151
frame.add_item(TEST_DATASET_ITEMS[0], "camera")
5252
frame.add_item(TEST_LIDAR_ITEMS[0], "lidar")
5353

54-
assert frame.get_index() == 0
5554
assert frame.get_sensors() == ["camera", "lidar"]
5655
for item in frame.get_items():
5756
assert item in [TEST_DATASET_ITEMS[0], TEST_LIDAR_ITEMS[0]]
@@ -104,14 +103,19 @@ def test_scene_from_json():
104103
"lidar": lidar_item_f2,
105104
}
106105

107-
expected_frames = [Frame(expected_items_1), Frame(expected_items_2)]
106+
expected_frames = [Frame(**expected_items_1), Frame(**expected_items_2)]
108107
expected_scene = LidarScene(
109108
scene_json[REFERENCE_ID_KEY], expected_frames, metadata={}
110109
)
110+
111111
assert sorted(
112112
scene.get_items(), key=lambda item: item.reference_id
113113
) == sorted(expected_scene.get_items(), key=lambda item: item.reference_id)
114-
assert scene.get_frames() == expected_scene.get_frames()
114+
scene_frames = [frame.to_payload() for frame in scene.get_frames()]
115+
expected_scene_frames = [
116+
frame.to_payload() for frame in expected_scene.get_frames()
117+
]
118+
assert scene_frames == expected_scene_frames
115119
assert set(scene.get_sensors()) == set(expected_scene.get_sensors())
116120
assert scene.to_payload() == expected_scene.to_payload()
117121

@@ -195,29 +199,26 @@ def test_scene_add_frame():
195199
frames = [frame_1]
196200
scene = LidarScene(scene_ref_id, frames=frames)
197201

198-
frame_2 = Frame(index=1)
202+
frame_2 = Frame()
199203
frame_2.add_item(TEST_LIDAR_ITEMS[1], "lidar")
200-
scene.add_frame(frame_2)
204+
scene.add_frame(frame_2, index=1)
201205
frames.append(frame_2)
202206

203207
assert scene.length == len(frames)
204208
assert set(scene.get_sensors()) == set(["camera", "lidar"])
205209
expected_frame_1 = Frame(
206-
index=0,
207-
items={
210+
**{
208211
"camera": TEST_DATASET_ITEMS[0],
209212
"lidar": TEST_LIDAR_ITEMS[0],
210213
},
211214
)
212-
assert scene.get_frame(0) == expected_frame_1
215+
assert scene.get_frame(0).to_payload() == expected_frame_1.to_payload()
213216
expected_frame_2 = Frame(
214-
index=1,
215-
items={
217+
**{
216218
"lidar": TEST_LIDAR_ITEMS[1],
217219
},
218220
)
219-
expected_frames = [expected_frame_1, expected_frame_2]
220-
assert scene.get_frames() == expected_frames
221+
assert scene.get_frame(1).to_payload() == expected_frame_2.to_payload()
221222
for item in scene.get_items_from_sensor("lidar"):
222223
assert item in [TEST_LIDAR_ITEMS[0], TEST_LIDAR_ITEMS[1]]
223224

0 commit comments

Comments
 (0)