Skip to content

Commit 830e225

Browse files
authored
Update results returned by slice.items() (#368)
1 parent adda9a9 commit 830e225

File tree

8 files changed

+97
-23
lines changed

8 files changed

+97
-23
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.14.25](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.25) - 2022-10-20
9+
10+
### Updated
11+
- Items of a slice can be retrieved by Slice property `.item`
12+
- The type of items returned from `.items` is based on the slice `type`:
13+
- `slice.type == 'dataset_item'` => list of `DatasetItem` objects
14+
- `slice.type == 'object'` => list of `Annotation`/`Prediction` objects
15+
- `slice.type == 'scene'` => list of `Scene` objects
16+
817
## [0.14.24](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.24) - 2022-10-19
918

1019
### Fixed

nucleus/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
DATASET_ID_KEY = "dataset_id"
3939
DATASET_IS_SCENE_KEY = "is_scene"
4040
DATASET_ITEM_ID_KEY = "dataset_item_id"
41+
DATASET_ITEMS_KEY = "dataset_items"
4142
DATASET_LENGTH_KEY = "length"
4243
DATASET_MODEL_RUNS_KEY = "model_run_ids"
4344
DATASET_NAME_KEY = "name"

nucleus/dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
BACKFILL_JOB_KEY,
3232
DATASET_ID_KEY,
3333
DATASET_IS_SCENE_KEY,
34+
DATASET_ITEMS_KEY,
3435
DEFAULT_ANNOTATION_UPDATE_MODE,
3536
EMBEDDING_DIMENSION_KEY,
3637
EMBEDDINGS_URL_KEY,
@@ -212,7 +213,8 @@ def items(self) -> List[DatasetItem]:
212213
if e.status_code == 503:
213214
e.message += "\nThe server timed out while trying to load your items. Please try iterating over dataset.items_generator() instead."
214215
raise e
215-
dataset_item_jsons = response.get("dataset_items", None)
216+
dataset_item_jsons = response.get(DATASET_ITEMS_KEY, None)
217+
216218
return [
217219
DatasetItem.from_json(item_json)
218220
for item_json in dataset_item_jsons

nucleus/scene.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ class Scene(ABC):
124124
reference_id: str
125125
frames: List[Frame] = field(default_factory=list)
126126
metadata: Optional[dict] = field(default_factory=dict)
127+
skip_validate: Optional[bool] = False
127128

128129
def __post_init__(self):
129130
self.sensors = set(
@@ -133,7 +134,8 @@ def __post_init__(self):
133134
if self.metadata is None:
134135
self.metadata = {}
135136

136-
self.validate()
137+
if not self.skip_validate:
138+
self.validate()
137139

138140
def __eq__(self, other):
139141
return all(
@@ -310,14 +312,15 @@ def validate_frames_dict(self):
310312
), "frames must be 0-indexed and continuous (no missing frames)"
311313

312314
@classmethod
313-
def from_json(cls, payload: dict):
315+
def from_json(cls, payload: dict, skip_validate: Optional[bool] = False):
314316
"""Instantiates scene object from schematized JSON dict payload."""
315317
frames_payload = payload.get(FRAMES_KEY, [])
316318
frames = [Frame.from_json(frame) for frame in frames_payload]
317319
return cls(
318320
reference_id=payload[REFERENCE_ID_KEY],
319321
frames=frames,
320322
metadata=payload.get(METADATA_KEY, {}),
323+
skip_validate=skip_validate,
321324
)
322325

323326
def to_payload(self) -> dict:

nucleus/slice.py

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from nucleus.dataset_item import DatasetItem
1212
from nucleus.errors import NucleusAPIError
1313
from nucleus.job import AsyncJob
14+
from nucleus.prediction import from_json as prediction_from_json
15+
from nucleus.scene import Scene
1416
from nucleus.utils import (
1517
KeyErrorDict,
1618
convert_export_payload,
@@ -113,6 +115,7 @@ def __init__(self, slice_id: str, client):
113115
self._dataset_id = None
114116
self._created_at = None
115117
self._pending_job_count = None
118+
self._type = None
116119

117120
def __repr__(self):
118121
return f"Slice(slice_id='{self.id}', name={self._name}, dataset_id={self._dataset_id})"
@@ -182,6 +185,13 @@ def dataset_id(self):
182185
self._dataset_id = self.info()["dataset_id"]
183186
return self._dataset_id
184187

188+
@property
189+
def type(self):
190+
"""The type of the Slice."""
191+
if self._type is None:
192+
self._type = self.info()["type"]
193+
return self._type
194+
185195
def items_generator(self, page_size=100000):
186196
"""Generator yielding all dataset items in the dataset.
187197
@@ -209,28 +219,69 @@ def items_generator(self, page_size=100000):
209219
for item_json in json_generator:
210220
yield DatasetItem.from_json(item_json)
211221

222+
def dataset_items(self):
223+
"""Fetch all DatasetItems contained in the Slice.
224+
225+
We recommend using :meth:`Slice.items_generator` if the Slice has more than 200k items.
226+
227+
Returns: list of DatasetItem objects
228+
229+
"""
230+
try:
231+
response = self._client.make_request(
232+
{}, f"slice/{self.id}", requests_command=requests.get
233+
)
234+
except NucleusAPIError as e:
235+
if e.status_code == 503:
236+
e.message += "/n Your request timed out while trying to get all the items in the slice. Please try slice.items_generator() instead."
237+
raise e
238+
239+
dataset_item_jsons = response.get(ITEMS_KEY, [])
240+
return [
241+
DatasetItem.from_json(dataset_item_json)
242+
for dataset_item_json in dataset_item_jsons
243+
]
244+
212245
@property
213246
def items(self):
214-
"""All DatasetItems contained in the Slice.
247+
"""Fetch all items belonging to this slice, the type of items returned depends on the type of the slice.
248+
The type of the slice can be one of { dataset_item, object, scene }.
215249
216-
We recommend using :meth:`Slice.items_generator` if the Slice has more than 200k items.
217250
251+
Returns: List of DatasetItems for a `dataset_item` slice,
252+
list of Annotations/Predictions for an `object` slice,
253+
or a list of Scenes for a `scene` slice.
218254
"""
219255
try:
220-
dataset_item_jsons = self._client.make_request(
256+
response = self._client.make_request(
221257
{}, f"slice/{self.id}", requests_command=requests.get
222-
)[
223-
"dataset_items"
224-
] # Unfortunately, we didn't use a standard value here, so not using a constant for the key
225-
return [
226-
DatasetItem.from_json(dataset_item_json)
227-
for dataset_item_json in dataset_item_jsons
228-
]
258+
)
229259
except NucleusAPIError as e:
230260
if e.status_code == 503:
231261
e.message += "/n Your request timed out while trying to get all the items in the slice. Please try slice.items_generator() instead."
232262
raise e
233263

264+
items = response.get(ITEMS_KEY, [])
265+
266+
formatted_items = []
267+
for item in items:
268+
item_id_prefix = item["id"].split("_")[0]
269+
if item_id_prefix == "di":
270+
formatted_items.append(DatasetItem.from_json(item))
271+
elif item_id_prefix == "ann":
272+
formatted_items.append(Annotation.from_json(item))
273+
elif item_id_prefix == "pred":
274+
formatted_items.append(prediction_from_json(item))
275+
elif item_id_prefix == "scn":
276+
# here we skip validate since no frames for the scene is fetched
277+
formatted_items.append(
278+
Scene.from_json(item, skip_validate=True)
279+
)
280+
else:
281+
raise ValueError("Unknown prefix", item_id_prefix)
282+
283+
return formatted_items
284+
234285
def info(self) -> dict:
235286
"""Retrieves the name, slice_id, and dataset_id of the Slice.
236287
@@ -251,6 +302,11 @@ def info(self) -> dict:
251302
{}, f"slice/{self.id}/info", requests_command=requests.get
252303
)
253304
info.update(res)
305+
self._name = info["name"]
306+
self._dataset_id = info["dataset_id"]
307+
self._created_at = info["created_at"]
308+
self._pending_job_count = info["pending_job_count"]
309+
self._type = info["type"]
254310
return info
255311

256312
def append(
@@ -552,7 +608,10 @@ def check_annotations_are_in_slice(
552608
for annotation in annotations
553609
if annotation.reference_id is not None
554610
}.difference(
555-
{item_metadata["ref_id"] for item_metadata in slice_to_check.items}
611+
{
612+
item_metadata["ref_id"]
613+
for item_metadata in slice_to_check.dataset_items()
614+
}
556615
)
557616
if reference_ids_not_found_in_slice:
558617
annotations_are_in_slice = False

nucleus/validate/scenario_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import List, Optional
99

1010
from ..connection import Connection
11-
from ..constants import NAME_KEY, SLICE_ID_KEY
11+
from ..constants import DATASET_ITEMS_KEY, NAME_KEY, SLICE_ID_KEY
1212
from ..dataset_item import DatasetItem
1313
from .constants import (
1414
EVAL_FUNCTION_ID_KEY,
@@ -27,8 +27,6 @@
2727
from .scenario_test_evaluation import ScenarioTestEvaluation
2828
from .scenario_test_metric import ScenarioTestMetric
2929

30-
DATASET_ITEMS_KEY = "dataset_items"
31-
3230

3331
@dataclass
3432
class ScenarioTest:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.14.24"
24+
version = "0.14.25"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

tests/test_slice.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_slice_create_and_delete_and_list(dataset: Dataset):
5353
assert slc.name == TEST_SLICE_NAME
5454
assert slc.dataset_id == dataset.id
5555

56-
assert {item.reference_id for item in slc.items} == {
56+
assert {item.reference_id for item in slc.dataset_items()} == {
5757
item.reference_id for item in ds_items[:2]
5858
}
5959

@@ -122,7 +122,7 @@ def test_slice_create_and_prediction_export(dataset, slc, model):
122122

123123
assert response
124124

125-
slice_reference_ids = [item.reference_id for item in slc.items]
125+
slice_reference_ids = [item.reference_id for item in slc.dataset_items()]
126126

127127
def get_expected_box_prediction(reference_id):
128128
for prediction in predictions:
@@ -156,7 +156,7 @@ def test_slice_append(dataset):
156156

157157
# Insert duplicate first item
158158
slc.append(reference_ids=[item.reference_id for item in ds_items[:3]])
159-
slice_items = slc.items
159+
slice_items = slc.dataset_items()
160160

161161
assert len(slice_items) == 3
162162

@@ -176,7 +176,7 @@ def test_slice_send_to_labeling(dataset):
176176
reference_ids=[ds_items[0].reference_id, ds_items[1].reference_id],
177177
)
178178

179-
items = slc.items
179+
items = slc.dataset_items()
180180
assert len(items) == 2
181181

182182
response = slc.send_to_labeling(TEST_PROJECT_ID)
@@ -210,7 +210,9 @@ def test_slice_dataset_item_iterator(dataset):
210210
name=TEST_SLICE_NAME + get_uuid(),
211211
reference_ids=[item.reference_id for item in all_items[:1]],
212212
)
213-
expected_items = {item.reference_id: item for item in test_slice.items}
213+
expected_items = {
214+
item.reference_id: item for item in test_slice.dataset_items()
215+
}
214216
actual_items = {
215217
item.reference_id: item
216218
for item in test_slice.items_generator(page_size=1)

0 commit comments

Comments
 (0)