Skip to content

Commit 89c6b12

Browse files
Drew KaulDrew Kaul
authored andcommitted
refactor SceneDatasetItem to DatasetItem
1 parent 8d1f2d1 commit 89c6b12

File tree

6 files changed

+150
-152
lines changed

6 files changed

+150
-152
lines changed

nucleus/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
UPDATE_KEY,
110110
)
111111
from .dataset import Dataset
112-
from .dataset_item import DatasetItem
112+
from .dataset_item import DatasetItem, CameraParams
113113
from .errors import (
114114
DatasetItemRetrievalError,
115115
ModelCreationError,
@@ -135,7 +135,7 @@
135135
)
136136
from .slice import Slice
137137
from .upload_response import UploadResponse
138-
from .scene import Scene, LidarScene, Frame, SceneDatasetItem, CameraParams
138+
from .scene import Frame, Scene, LidarScene
139139

140140
# pylint: disable=E1101
141141
# TODO: refactor to reduce this file to under 1000 lines.

nucleus/annotation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import dataclass
33
from enum import Enum
44
from typing import Dict, List, Optional, Sequence, Union
5-
from nucleus.dataset_item import is_local_path
5+
from urllib.parse import urlparse
66

77
from .constants import (
88
ANNOTATION_ID_KEY,
@@ -310,6 +310,10 @@ def to_payload(self) -> dict:
310310
}
311311

312312

313+
def is_local_path(path: str) -> bool:
314+
return urlparse(path).scheme not in {"https", "http", "s3", "gs"}
315+
316+
313317
def check_all_mask_paths_remote(
314318
annotations: Sequence[Union[Annotation]],
315319
):

nucleus/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@
6464
NEW_ITEMS = "new_items"
6565
NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
6666
ORIGINAL_IMAGE_URL_KEY = "original_image_url"
67+
POINTCLOUD_LOCATION_KEY = "pointcloud_location"
68+
POINTCLOUD_URL_KEY = "pointcloud_url"
6769
POSITION_KEY = "position"
6870
PREDICTIONS_IGNORED_KEY = "predictions_ignored"
6971
PREDICTIONS_PROCESSED_KEY = "predictions_processed"

nucleus/dataset_item.py

Lines changed: 125 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,135 @@
22
import json
33
import os.path
44
from dataclasses import dataclass
5-
from typing import Optional, Sequence
6-
from urllib.parse import urlparse
5+
from typing import Optional, Sequence, Dict, Any
6+
from enum import Enum
77

8+
from .annotation import is_local_path, Point3D
89
from .constants import (
910
DATASET_ITEM_ID_KEY,
1011
IMAGE_URL_KEY,
1112
METADATA_KEY,
1213
ORIGINAL_IMAGE_URL_KEY,
1314
REFERENCE_ID_KEY,
15+
TYPE_KEY,
16+
URL_KEY,
17+
CAMERA_PARAMS_KEY,
18+
POINTCLOUD_URL_KEY,
19+
X_KEY,
20+
Y_KEY,
21+
Z_KEY,
22+
W_KEY,
23+
POSITION_KEY,
24+
HEADING_KEY,
25+
FX_KEY,
26+
FY_KEY,
27+
CX_KEY,
28+
CY_KEY,
1429
)
1530

1631

1732
@dataclass
18-
class DatasetItem:
33+
class Quaternion:
34+
x: float
35+
y: float
36+
z: float
37+
w: float
1938

20-
image_location: str
39+
@classmethod
40+
def from_json(cls, payload: Dict[str, float]):
41+
return cls(
42+
payload[X_KEY], payload[Y_KEY], payload[Z_KEY], payload[W_KEY]
43+
)
44+
45+
def to_payload(self) -> dict:
46+
return {
47+
X_KEY: self.x,
48+
Y_KEY: self.y,
49+
Z_KEY: self.z,
50+
W_KEY: self.w,
51+
}
52+
53+
54+
@dataclass
55+
class CameraParams:
56+
position: Point3D
57+
heading: Quaternion
58+
fx: float
59+
fy: float
60+
cx: float
61+
cy: float
62+
63+
@classmethod
64+
def from_json(cls, payload: Dict[str, Any]):
65+
return cls(
66+
Point3D.from_json(payload[POSITION_KEY]),
67+
Quaternion.from_json(payload[HEADING_KEY]),
68+
payload[FX_KEY],
69+
payload[FY_KEY],
70+
payload[CX_KEY],
71+
payload[CY_KEY],
72+
)
73+
74+
def to_payload(self) -> dict:
75+
return {
76+
POSITION_KEY: self.position.to_payload(),
77+
HEADING_KEY: self.heading.to_payload(),
78+
FX_KEY: self.fx,
79+
FY_KEY: self.fy,
80+
CX_KEY: self.cx,
81+
CY_KEY: self.cy,
82+
}
83+
84+
85+
class DatasetItemType(Enum):
86+
IMAGE = "image"
87+
POINTCLOUD = "pointcloud"
88+
89+
90+
@dataclass # pylint: disable=R0902
91+
class DatasetItem: # pylint: disable=R0902
92+
image_location: Optional[str] = None
2193
reference_id: Optional[str] = None
2294
item_id: Optional[str] = None
2395
metadata: Optional[dict] = None
96+
pointcloud_location: Optional[str] = None
2497

2598
def __post_init__(self):
2699
self.local = is_local_path(self.image_location)
100+
assert bool(self.image_location) != bool(
101+
self.pointcloud_location
102+
), "Must specify exactly one of the image_location, pointcloud_location parameters"
103+
self.type = (
104+
DatasetItemType.IMAGE
105+
if self.image_location
106+
else DatasetItemType.POINTCLOUD
107+
)
108+
camera_params = (
109+
self.metadata.get(CAMERA_PARAMS_KEY, None)
110+
if self.metadata
111+
else None
112+
)
113+
self.camera_params = (
114+
CameraParams.from_json(camera_params) if camera_params else None
115+
)
27116

28117
@classmethod
29-
def from_json(cls, payload: dict):
30-
url = payload.get(IMAGE_URL_KEY, "") or payload.get(
118+
def from_json(cls, payload: dict, is_scene=False):
119+
image_url = payload.get(IMAGE_URL_KEY, "") or payload.get(
31120
ORIGINAL_IMAGE_URL_KEY, ""
32121
)
122+
123+
if is_scene:
124+
return cls(
125+
image_location=image_url,
126+
pointcloud_location=payload.get(POINTCLOUD_URL_KEY, ""),
127+
reference_id=payload.get(REFERENCE_ID_KEY, None),
128+
item_id=payload.get(DATASET_ITEM_ID_KEY, None),
129+
metadata=payload.get(METADATA_KEY, {}),
130+
)
131+
33132
return cls(
34-
image_location=url,
133+
image_location=image_url,
35134
reference_id=payload.get(REFERENCE_ID_KEY, None),
36135
item_id=payload.get(DATASET_ITEM_ID_KEY, None),
37136
metadata=payload.get(METADATA_KEY, {}),
@@ -40,28 +139,39 @@ def from_json(cls, payload: dict):
40139
def local_file_exists(self):
41140
return os.path.isfile(self.image_location)
42141

43-
def to_payload(self) -> dict:
44-
payload = {
45-
IMAGE_URL_KEY: self.image_location,
142+
def to_payload(self, is_scene=False) -> dict:
143+
payload: Dict[str, Any] = {
46144
METADATA_KEY: self.metadata or {},
47145
}
146+
147+
if is_scene:
148+
if self.image_location:
149+
payload[URL_KEY] = self.image_location
150+
elif self.pointcloud_location:
151+
payload[URL_KEY] = self.pointcloud_location
152+
payload[TYPE_KEY] = self.type.value
153+
else:
154+
assert (
155+
self.image_location
156+
), "Must specify image_location for DatasetItems not in a Scene"
157+
payload[IMAGE_URL_KEY] = self.image_location
158+
48159
if self.reference_id:
49160
payload[REFERENCE_ID_KEY] = self.reference_id
50161
if self.item_id:
51162
payload[DATASET_ITEM_ID_KEY] = self.item_id
163+
if self.camera_params:
164+
payload[CAMERA_PARAMS_KEY] = self.camera_params
165+
52166
return payload
53167

54168
def to_json(self) -> str:
55169
return json.dumps(self.to_payload(), allow_nan=False)
56170

57171

58-
def is_local_path(path: str) -> bool:
59-
return urlparse(path).scheme not in {"https", "http", "s3", "gs"}
60-
61-
62172
def check_all_paths_remote(dataset_items: Sequence[DatasetItem]):
63173
for item in dataset_items:
64-
if is_local_path(item.image_location):
174+
if item.image_location and is_local_path(item.image_location):
65175
raise ValueError(
66176
f"All paths must be remote, but {item.image_location} is either "
67177
"local, or a remote URL type that is not supported."

0 commit comments

Comments
 (0)