Skip to content

Commit 5ba43f2

Browse files
author
Ubuntu
committed
Added test for new endpoint, and some eq methods to make that test easier
1 parent 0e89988 commit 5ba43f2

File tree

4 files changed

+44
-15
lines changed

4 files changed

+44
-15
lines changed

nucleus/dataset.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
check_all_paths_remote,
3333
check_for_duplicate_reference_ids,
3434
)
35-
from .scene import LidarScene, check_all_scene_paths_remote
35+
from .scene import LidarScene, Scene, check_all_scene_paths_remote
3636
from .payload_constructor import (
3737
construct_append_scenes_payload,
3838
construct_model_run_creation_payload,
@@ -511,3 +511,17 @@ def delete_annotations(
511511
self.id, reference_ids, keep_history
512512
)
513513
return AsyncJob.from_json(response, self._client)
514+
515+
def get_scene(self, reference_id) -> Scene:
516+
"""Returns a scene by reference id
517+
518+
Returns:
519+
a Scene object representing all dataset items organized into frames
520+
"""
521+
return Scene.from_json(
522+
self._client.make_request(
523+
payload=None,
524+
route=f"dataset/{self.id}/scene/{reference_id}",
525+
requests_command=requests.get,
526+
)
527+
)

nucleus/dataset_item.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,24 +122,16 @@ def __post_init__(self):
122122
)
123123

124124
@classmethod
125-
def from_json(cls, payload: dict, is_scene=False):
125+
def from_json(cls, payload: dict):
126126
image_url = payload.get(IMAGE_URL_KEY, None) or payload.get(
127127
ORIGINAL_IMAGE_URL_KEY, None
128128
)
129-
130-
if is_scene:
131-
return cls(
132-
image_location=image_url,
133-
pointcloud_location=payload.get(POINTCLOUD_URL_KEY, None),
134-
reference_id=payload.get(REFERENCE_ID_KEY, None),
135-
metadata=payload.get(METADATA_KEY, {}),
136-
)
137-
138129
return cls(
139130
image_location=image_url,
131+
pointcloud_location=payload.get(POINTCLOUD_URL_KEY, None),
140132
reference_id=payload.get(REFERENCE_ID_KEY, None),
141133
metadata=payload.get(METADATA_KEY, {}),
142-
upload_to_scale=payload.get(UPLOAD_TO_SCALE_KEY, None),
134+
upload_to_scale=payload.get(UPLOAD_TO_SCALE_KEY, True),
143135
)
144136

145137
def local_file_exists(self):

nucleus/scene.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ def __post_init__(self):
3131
def __repr__(self) -> str:
3232
return f"Frame(items={self.items})"
3333

34+
def __eq__(self, other):
35+
for key, value in self.items.items():
36+
if key not in other.items:
37+
return False
38+
if value != other.items[key]:
39+
return False
40+
return True
41+
3442
def add_item(self, item: DatasetItem, sensor_name: str):
3543
self.items[sensor_name] = item
3644

@@ -50,7 +58,7 @@ def get_sensors(self):
5058
@classmethod
5159
def from_json(cls, payload: dict):
5260
items = {
53-
sensor: DatasetItem.from_json(item, is_scene=True)
61+
sensor: DatasetItem.from_json(item)
5462
for sensor, item in payload.items()
5563
}
5664
return cls(**items)
@@ -66,13 +74,24 @@ def to_payload(self) -> dict:
6674
class Scene(ABC):
6775
reference_id: str
6876
frames: List[Frame] = field(default_factory=list)
69-
metadata: Optional[dict] = None
77+
metadata: Optional[dict] = field(default_factory=dict)
7078

7179
def __post_init__(self):
7280
self.sensors = set(
7381
flatten([frame.get_sensors() for frame in self.frames])
7482
)
7583
self.frames_dict = dict(enumerate(self.frames))
84+
if self.metadata is None:
85+
self.metadata = {}
86+
87+
def __eq__(self, other):
88+
return all(
89+
[
90+
self.reference_id == other.reference_id,
91+
self.frames == other.frames,
92+
self.metadata == other.metadata,
93+
]
94+
)
7695

7796
@property
7897
def length(self) -> int:
@@ -168,7 +187,7 @@ def from_json(cls, payload: dict):
168187
return cls(
169188
reference_id=payload[REFERENCE_ID_KEY],
170189
frames=frames,
171-
metadata=payload.get(METADATA_KEY, None),
190+
metadata=payload.get(METADATA_KEY, {}),
172191
)
173192

174193
def to_payload(self) -> dict:

tests/test_scene.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,10 @@ def test_scene_upload_sync(dataset):
260260

261261
response = dataset.append(scenes, update=update)
262262

263+
first_scene = dataset.get_scene(scenes[0].reference_id)
264+
265+
assert first_scene == scenes[0]
266+
263267
assert response["dataset_id"] == dataset.id
264268
assert response["new_scenes"] == len(scenes)
265269

0 commit comments

Comments
 (0)