Skip to content

Commit 729cd7f

Browse files
Drew KaulDrew Kaul
authored andcommitted
fix mutable dataclass defaults
1 parent af158c1 commit 729cd7f

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

nucleus/scene.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from dataclasses import dataclass
2+
from dataclasses import dataclass, field
33
from typing import Optional, Union, Dict, List, Set
44
from enum import Enum
55
from nucleus.constants import (
@@ -73,7 +73,7 @@ def to_json(self) -> str:
7373
@dataclass
7474
class Frame:
7575
index: Union[int, None] = None
76-
items: Dict[str, SceneDatasetItem] = {}
76+
items: Dict[str, SceneDatasetItem] = field(default_factory=dict)
7777

7878
def __post_init__(self):
7979
for key, value in self.items.items():
@@ -98,15 +98,15 @@ def to_payload(self) -> dict:
9898
@dataclass
9999
class Scene:
100100
reference_id: str
101-
frames: List[Frame] = []
101+
frames: List[Frame] = field(default_factory=list)
102102
metadata: Optional[dict] = None
103103

104104
def __post_init__(self):
105105
self.check_valid_frame_indices()
106-
if(all([frame.index is not None for frame in self.frames])):
106+
if all((frame.index is not None for frame in self.frames)):
107107
self.frames_dict = {frame.index: frame for frame in self.frames}
108108
else:
109-
self.frames_dict = {i: frame for i, frame in enumerate(self.frames)}
109+
self.frames_dict = dict(enumerate(self.frames))
110110

111111
# TODO: move validation to scene upload
112112
assert isinstance(self.frames, List), "frames must be a list"
@@ -120,9 +120,15 @@ def __post_init__(self):
120120
), "reference_id must be a string"
121121

122122
def check_valid_frame_indices(self):
123-
infer_from_list_position = all([frame.index is None for frame in self.frames])
124-
explicit_frame_order = all([frame.index is not None for frame in self.frames])
125-
assert infer_from_list_position or explicit_frame_order, "Must specify index explicitly for all frames or implicitly for all frames (inferred from list position)"
123+
infer_from_list_position = all(
124+
(frame.index is None for frame in self.frames)
125+
)
126+
explicit_frame_order = all(
127+
(frame.index is not None for frame in self.frames)
128+
)
129+
assert (
130+
infer_from_list_position or explicit_frame_order
131+
), "Must specify index explicitly for all frames or implicitly for all frames (inferred from list position)"
126132

127133
def add_item(self, item: SceneDatasetItem, index: int, sensor_name: str):
128134
if index not in self.frames_dict:
@@ -132,17 +138,23 @@ def add_item(self, item: SceneDatasetItem, index: int, sensor_name: str):
132138
self.frames_dict[index].items[sensor_name] = item
133139

134140
def add_frame(self, frame: Frame, update: bool = False):
135-
assert frame.index is not None, "Must specify index explicitly when calling add_frame"
136-
if frame.index not in self.frames_dict or frame.index in self.frames_dict and update:
141+
assert (
142+
frame.index is not None
143+
), "Must specify index explicitly when calling add_frame"
144+
if (
145+
frame.index not in self.frames_dict
146+
or frame.index in self.frames_dict
147+
and update
148+
):
137149
self.frames_dict[frame.index] = frame
138150

139151
def to_payload(self) -> dict:
140152
frames_payload = [frame.to_payload() for frame in self.frames]
141-
if len(frames_payload) > 0 and frames_payload[0].index is None:
142-
for i in range(len(frames_payload)):
143-
frames_payload[INDEX_KEY] = i
153+
if len(frames_payload) > 0 and frames_payload[0][INDEX_KEY] is None:
154+
for i, _ in enumerate(frames_payload):
155+
frames_payload[i][INDEX_KEY] = i
144156
else:
145-
frames_payload.sort(lambda x: x[INDEX_KEY])
157+
frames_payload.sort(key=lambda x: x[INDEX_KEY])
146158

147159
return {
148160
REFERENCE_ID_KEY: self.reference_id,

0 commit comments

Comments
 (0)