Skip to content

Commit 2169a0e

Browse files
committed
init frame from named parameters
1 parent e52488e commit 2169a0e

File tree

1 file changed

+11
-35
lines changed

1 file changed

+11
-35
lines changed

nucleus/scene.py

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
from .dataset_item import DatasetItemType, DatasetItem
1616

1717

18-
@dataclass
1918
class Frame:
20-
items: Dict[str, DatasetItem] = field(default_factory=dict)
21-
index: Union[int, None] = None
19+
def __init__(self, **kwargs):
20+
self.items = {}
21+
for key, value in kwargs.items():
22+
self.items[key] = value
2223

2324
def __post_init__(self):
2425
for key, value in self.items.items():
@@ -28,7 +29,7 @@ def __post_init__(self):
2829
), "All values must be DatasetItems"
2930

3031
def __repr__(self) -> str:
31-
return f"Frame(index={self.index}, items={self.items})"
32+
return f"Frame(items={self.items})"
3233

3334
def add_item(self, item: DatasetItem, sensor_name: str):
3435
self.items[sensor_name] = item
@@ -46,9 +47,6 @@ def get_items(self):
4647
def get_sensors(self):
4748
return list(self.items.keys())
4849

49-
def get_index(self):
50-
return self.index
51-
5250
@classmethod
5351
def from_json(cls, payload: dict):
5452
items = {
@@ -71,18 +69,10 @@ class Scene(ABC):
7169
metadata: Optional[dict] = None
7270

7371
def __post_init__(self):
74-
self.check_valid_frame_indices()
7572
self.sensors = set(
7673
flatten([frame.get_sensors() for frame in self.frames])
7774
)
78-
if all((frame.index is not None for frame in self.frames)):
79-
self.frames_dict = {frame.index: frame for frame in self.frames}
80-
else:
81-
indexed_frames = [
82-
Frame(index=i, items=frame.items)
83-
for i, frame in enumerate(self.frames)
84-
]
85-
self.frames_dict = dict(enumerate(indexed_frames))
75+
self.frames_dict = dict(enumerate(self.frames))
8676

8777
@property
8878
def length(self) -> int:
@@ -92,17 +82,6 @@ def length(self) -> int:
9282
def num_sensors(self) -> int:
9383
return len(self.get_sensors())
9484

95-
def check_valid_frame_indices(self):
96-
infer_from_list_position = all(
97-
(frame.index is None for frame in self.frames)
98-
)
99-
explicit_frame_order = all(
100-
(frame.index is not None for frame in self.frames)
101-
)
102-
assert (
103-
infer_from_list_position or explicit_frame_order
104-
), "Must specify index explicitly for all frames or infer from list position for all frames"
105-
10685
def validate(self):
10786
assert self.length > 0, "Must have at least 1 frame in a scene"
10887
for frame in self.frames_dict.values():
@@ -113,21 +92,18 @@ def validate(self):
11392
def add_item(self, index: int, sensor_name: str, item: DatasetItem):
11493
self.sensors.add(sensor_name)
11594
if index not in self.frames_dict:
116-
new_frame = Frame(index=index, items={sensor_name: item})
95+
new_frame = Frame(items={sensor_name: item})
11796
self.frames_dict[index] = new_frame
11897
else:
11998
self.frames_dict[index].items[sensor_name] = item
12099

121-
def add_frame(self, frame: Frame, update: bool = False):
122-
assert (
123-
frame.index is not None
124-
), "Must specify index explicitly when calling add_frame"
100+
def add_frame(self, frame: Frame, index: int, update: bool = False):
125101
if (
126-
frame.index not in self.frames_dict
127-
or frame.index in self.frames_dict
102+
index not in self.frames_dict
103+
or index in self.frames_dict
128104
and update
129105
):
130-
self.frames_dict[frame.index] = frame
106+
self.frames_dict[index] = frame
131107
self.sensors.update(frame.get_sensors())
132108

133109
def get_frame(self, index: int):

0 commit comments

Comments
 (0)