Skip to content

Commit 58b83ab

Browse files
authored
Merge pull request #130 from scaleapi/da-get-scene
Get Scene API
2 parents 0e89988 + 9594b37 commit 58b83ab

File tree

5 files changed

+49
-16
lines changed

5 files changed

+49
-16
lines changed

nucleus/dataset.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
check_all_paths_remote,
3333
check_for_duplicate_reference_ids,
3434
)
35-
from .scene import LidarScene, check_all_scene_paths_remote
35+
from .scene import LidarScene, Scene, check_all_scene_paths_remote
3636
from .payload_constructor import (
3737
construct_append_scenes_payload,
3838
construct_model_run_creation_payload,
@@ -511,3 +511,17 @@ def delete_annotations(
511511
self.id, reference_ids, keep_history
512512
)
513513
return AsyncJob.from_json(response, self._client)
514+
515+
def get_scene(self, reference_id) -> Scene:
516+
"""Returns a scene by reference id
517+
518+
Returns:
519+
a Scene object representing all dataset items organized into frames
520+
"""
521+
return Scene.from_json(
522+
self._client.make_request(
523+
payload=None,
524+
route=f"dataset/{self.id}/scene/{reference_id}",
525+
requests_command=requests.get,
526+
)
527+
)

nucleus/dataset_item.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,24 +122,16 @@ def __post_init__(self):
122122
)
123123

124124
@classmethod
125-
def from_json(cls, payload: dict, is_scene=False):
125+
def from_json(cls, payload: dict):
126126
image_url = payload.get(IMAGE_URL_KEY, None) or payload.get(
127127
ORIGINAL_IMAGE_URL_KEY, None
128128
)
129-
130-
if is_scene:
131-
return cls(
132-
image_location=image_url,
133-
pointcloud_location=payload.get(POINTCLOUD_URL_KEY, None),
134-
reference_id=payload.get(REFERENCE_ID_KEY, None),
135-
metadata=payload.get(METADATA_KEY, {}),
136-
)
137-
138129
return cls(
139130
image_location=image_url,
131+
pointcloud_location=payload.get(POINTCLOUD_URL_KEY, None),
140132
reference_id=payload.get(REFERENCE_ID_KEY, None),
141133
metadata=payload.get(METADATA_KEY, {}),
142-
upload_to_scale=payload.get(UPLOAD_TO_SCALE_KEY, None),
134+
upload_to_scale=payload.get(UPLOAD_TO_SCALE_KEY, True),
143135
)
144136

145137
def local_file_exists(self):

nucleus/scene.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ def __post_init__(self):
3131
def __repr__(self) -> str:
3232
return f"Frame(items={self.items})"
3333

34+
def __eq__(self, other):
35+
for key, value in self.items.items():
36+
if key not in other.items:
37+
return False
38+
if value != other.items[key]:
39+
return False
40+
return True
41+
3442
def add_item(self, item: DatasetItem, sensor_name: str):
3543
self.items[sensor_name] = item
3644

@@ -50,7 +58,7 @@ def get_sensors(self):
5058
@classmethod
5159
def from_json(cls, payload: dict):
5260
items = {
53-
sensor: DatasetItem.from_json(item, is_scene=True)
61+
sensor: DatasetItem.from_json(item)
5462
for sensor, item in payload.items()
5563
}
5664
return cls(**items)
@@ -66,13 +74,24 @@ def to_payload(self) -> dict:
6674
class Scene(ABC):
6775
reference_id: str
6876
frames: List[Frame] = field(default_factory=list)
69-
metadata: Optional[dict] = None
77+
metadata: Optional[dict] = field(default_factory=dict)
7078

7179
def __post_init__(self):
7280
self.sensors = set(
7381
flatten([frame.get_sensors() for frame in self.frames])
7482
)
7583
self.frames_dict = dict(enumerate(self.frames))
84+
if self.metadata is None:
85+
self.metadata = {}
86+
87+
def __eq__(self, other):
88+
return all(
89+
[
90+
self.reference_id == other.reference_id,
91+
self.frames == other.frames,
92+
self.metadata == other.metadata,
93+
]
94+
)
7695

7796
@property
7897
def length(self) -> int:
@@ -168,7 +187,7 @@ def from_json(cls, payload: dict):
168187
return cls(
169188
reference_id=payload[REFERENCE_ID_KEY],
170189
frames=frames,
171-
metadata=payload.get(METADATA_KEY, None),
190+
metadata=payload.get(METADATA_KEY, {}),
172191
)
173192

174193
def to_payload(self) -> dict:

tests/test_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def test_dataset_append_async_with_1_bad_url(dataset: Dataset):
293293
"PayloadUrl": "",
294294
"final_error": (
295295
"One or more of the images you attempted to upload did not process"
296-
" correctly. Please see the status for an overview and the errors for "
296+
" correctly. Please see the status for an overview and the errors (job.errors()) for "
297297
"more detailed messages."
298298
),
299299
"image_upload_step": {"errored": 1, "pending": 0, "completed": 4},

tests/test_scene.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import pytest
23
from nucleus.constants import (
34
ANNOTATIONS_KEY,
@@ -260,6 +261,13 @@ def test_scene_upload_sync(dataset):
260261

261262
response = dataset.append(scenes, update=update)
262263

264+
first_scene = dataset.get_scene(scenes[0].reference_id)
265+
266+
assert first_scene == scenes[0]
267+
first_scene_modified = copy.deepcopy(first_scene)
268+
first_scene_modified.reference_id = "WRONG!"
269+
assert first_scene_modified != scenes[0]
270+
263271
assert response["dataset_id"] == dataset.id
264272
assert response["new_scenes"] == len(scenes)
265273

0 commit comments

Comments
 (0)