Skip to content

Commit 05240ec

Browse files
ardilaUbuntudrakejwong
authored
Da item pagination (#263)
* Unit test and large scale test pass * specific tests pass * remove unwanted prints * missing type annotation * Update nucleus/constants.py Co-authored-by: Drake Wong <40375132+drakejwong@users.noreply.github.com> * fix slice tests Co-authored-by: Ubuntu <diego.ardila@scale.com> Co-authored-by: Drake Wong <40375132+drakejwong@users.noreply.github.com>
1 parent b2a97a4 commit 05240ec

File tree

9 files changed

+223
-177
lines changed

9 files changed

+223
-177
lines changed

conftest.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import os
2+
from typing import TYPE_CHECKING
23

34
import pytest
45

56
import nucleus
67
from tests.helpers import TEST_DATASET_ITEMS, TEST_DATASET_NAME
78

9+
if TYPE_CHECKING:
10+
from nucleus import NucleusClient
11+
812
assert "NUCLEUS_PYTEST_API_KEY" in os.environ, (
913
"You must set the 'NUCLEUS_PYTEST_API_KEY' environment variable to a valid "
1014
"Nucleus API key to run the test suite"
@@ -20,12 +24,12 @@ def CLIENT():
2024

2125

2226
@pytest.fixture()
23-
def dataset(CLIENT):
24-
ds = CLIENT.create_dataset(TEST_DATASET_NAME)
25-
ds.append(TEST_DATASET_ITEMS)
26-
yield ds
27+
def dataset(CLIENT: "NucleusClient"):
28+
test_dataset = CLIENT.create_dataset(TEST_DATASET_NAME, is_scene=False)
29+
test_dataset.append(TEST_DATASET_ITEMS)
30+
yield test_dataset
2731

28-
CLIENT.delete_dataset(ds.id)
32+
CLIENT.delete_dataset(test_dataset.id)
2933

3034

3135
@pytest.fixture()

nucleus/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
JOB_LAST_KNOWN_STATUS_KEY = "job_last_known_status"
7676
JOB_TYPE_KEY = "job_type"
7777
JOB_CREATION_TIME_KEY = "job_creation_time"
78+
LAST_PAGE = "lastPage"
7879
LABEL_KEY = "label"
7980
LABELS_KEY = "labels"
8081
MASK_URL_KEY = "mask_url"
@@ -87,6 +88,8 @@
8788
NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
8889
NUM_SENSORS_KEY = "num_sensors"
8990
ORIGINAL_IMAGE_URL_KEY = "original_image_url"
91+
PAGE_SIZE = "pageSize"
92+
PAGE_TOKEN = "pageToken"
9093
P1_KEY = "p1"
9194
P2_KEY = "p2"
9295
POINTCLOUD_KEY = "pointcloud"
@@ -97,6 +100,7 @@
97100
PREDICTIONS_PROCESSED_KEY = "predictions_processed"
98101
REFERENCE_IDS_KEY = "reference_ids"
99102
REFERENCE_ID_KEY = "reference_id"
103+
BACKEND_REFERENCE_ID_KEY = "ref_id" # TODO(355762): Our backend returns this instead of the "proper" key sometimes.
100104
REQUEST_ID_KEY = "requestId"
101105
SCENES_KEY = "scenes"
102106
SERIALIZED_REQUEST_KEY = "serialized_request"

nucleus/dataset.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Any, Dict, List, Optional, Sequence, Union
2+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union
33

44
import requests
55

@@ -18,6 +18,7 @@
1818
convert_export_payload,
1919
format_dataset_item_response,
2020
format_prediction_response,
21+
paginate_generator,
2122
serialize_and_write_to_presigned_url,
2223
)
2324

@@ -32,6 +33,7 @@
3233
EMBEDDING_DIMENSION_KEY,
3334
EMBEDDINGS_URL_KEY,
3435
EXPORTED_ROWS,
36+
ITEMS_KEY,
3537
KEEP_HISTORY_KEY,
3638
MESSAGE_KEY,
3739
NAME_KEY,
@@ -51,7 +53,7 @@
5153
)
5254
from .dataset_item_uploader import DatasetItemUploader
5355
from .deprecation_warning import deprecated
54-
from .errors import DatasetItemRetrievalError
56+
from .errors import NucleusAPIError
5557
from .metadata_manager import ExportMetadataType, MetadataManager
5658
from .payload_constructor import (
5759
construct_append_scenes_payload,
@@ -160,25 +162,51 @@ def size(self) -> int:
160162
dataset_size = DatasetSize.parse_obj(response)
161163
return dataset_size.count
162164

165+
def items_generator(self, page_size=100000) -> Iterable[DatasetItem]:
166+
"""Generator yielding all dataset items in the dataset.
167+
168+
169+
::
170+
sum_example_field = 0
171+
for item in dataset.items_generator():
172+
sum += item.metadata["example_field"]
173+
174+
Args:
175+
page_size (int, optional): Number of items to return per page. If you are
176+
experiencing timeouts while using this generator, you can try lowering
177+
the page size.
178+
179+
Yields:
180+
an iterable of DatasetItem objects.
181+
"""
182+
json_generator = paginate_generator(
183+
client=self._client,
184+
endpoint=f"dataset/{self.id}/itemsPage",
185+
result_key=ITEMS_KEY,
186+
page_size=page_size,
187+
)
188+
for item_json in json_generator:
189+
yield DatasetItem.from_json(item_json)
190+
163191
@property
164192
def items(self) -> List[DatasetItem]:
165-
"""List of all DatasetItem objects in the Dataset."""
166-
response = self._client.make_request(
167-
{}, f"dataset/{self.id}/datasetItems", requests.get
168-
)
169-
dataset_items = response.get("dataset_items", None)
170-
error = response.get("error", None)
171-
constructed_dataset_items = []
172-
if dataset_items:
173-
for item in dataset_items:
174-
image_url = item.get("original_image_url")
175-
metadata = item.get("metadata", None)
176-
ref_id = item.get("ref_id", None)
177-
dataset_item = DatasetItem(image_url, ref_id, metadata)
178-
constructed_dataset_items.append(dataset_item)
179-
elif error:
180-
raise DatasetItemRetrievalError(message=error)
181-
return constructed_dataset_items
193+
"""List of all DatasetItem objects in the Dataset.
194+
195+
For fetching more than 200k items see :meth:`NucleusDataset.items_generator`.
196+
"""
197+
try:
198+
response = self._client.make_request(
199+
{}, f"dataset/{self.id}/datasetItems", requests.get
200+
)
201+
except NucleusAPIError as e:
202+
if e.status_code == 503:
203+
e.message += "\nThe server timed out while trying to load your items. Please try iterating over dataset.items_generator() instead."
204+
raise e
205+
dataset_item_jsons = response.get("dataset_items", None)
206+
return [
207+
DatasetItem.from_json(item_json)
208+
for item_json in dataset_item_jsons
209+
]
182210

183211
@property
184212
def scenes(self) -> List[ScenesListEntry]:

nucleus/dataset_item.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from .annotation import Point3D, is_local_path
99
from .constants import (
10+
BACKEND_REFERENCE_ID_KEY,
1011
CAMERA_MODEL_KEY,
1112
CAMERA_PARAMS_KEY,
1213
CX_KEY,
@@ -290,6 +291,8 @@ def from_json(cls, payload: dict):
290291
image_url = payload.get(IMAGE_URL_KEY, None) or payload.get(
291292
ORIGINAL_IMAGE_URL_KEY, None
292293
)
294+
if BACKEND_REFERENCE_ID_KEY in payload:
295+
payload[REFERENCE_ID_KEY] = payload[BACKEND_REFERENCE_ID_KEY]
293296
return cls(
294297
image_location=image_url,
295298
pointcloud_location=payload.get(POINTCLOUD_URL_KEY, None),

nucleus/errors.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,28 +40,29 @@ class NucleusAPIError(Exception):
4040
def __init__(
4141
self, endpoint, command, requests_response=None, aiohttp_response=None
4242
):
43-
message = f"Your client is on version {nucleus_client_version}. If you have not recently done so, please make sure you have updated to the latest version of the client by running pip install --upgrade scale-nucleus\n"
43+
self.message = f"Your client is on version {nucleus_client_version}. If you have not recently done so, please make sure you have updated to the latest version of the client by running pip install --upgrade scale-nucleus\n"
4444
if requests_response is not None:
45-
message += f"Tried to {command.__name__} {endpoint}, but received {requests_response.status_code}: {requests_response.reason}."
45+
self.status_code = requests_response.status_code
46+
self.message += f"Tried to {command.__name__} {endpoint}, but received {requests_response.status_code}: {requests_response.reason}."
4647
if hasattr(requests_response, "text"):
4748
if requests_response.text:
48-
message += (
49+
self.message += (
4950
f"\nThe detailed error is:\n{requests_response.text}"
5051
)
5152

5253
if aiohttp_response is not None:
5354
status, reason, data = aiohttp_response
54-
message += f"Tried to {command.__name__} {endpoint}, but received {status}: {reason}."
55+
self.status_code = status
56+
self.message += f"Tried to {command.__name__} {endpoint}, but received {status}: {reason}."
5557
if data:
56-
message += f"\nThe detailed error is:\n{data}"
58+
self.message += f"\nThe detailed error is:\n{data}"
5759

5860
if any(
59-
infra_flake_message in message
61+
infra_flake_message in self.message
6062
for infra_flake_message in INFRA_FLAKE_MESSAGES
6163
):
62-
message += "\n This likely indicates temporary downtime of the API, please try again in a minute or two"
63-
64-
super().__init__(message)
64+
self.message += "\n This likely indicates temporary downtime of the API, please try again in a minute or two"
65+
super().__init__(self.message)
6566

6667

6768
class NoAPIKey(Exception):

nucleus/slice.py

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
import requests
55

66
from nucleus.annotation import Annotation
7-
from nucleus.constants import EXPORTED_ROWS
7+
from nucleus.constants import EXPORTED_ROWS, ITEMS_KEY
88
from nucleus.dataset_item import DatasetItem
9+
from nucleus.errors import NucleusAPIError
910
from nucleus.job import AsyncJob
1011
from nucleus.utils import (
1112
KeyErrorDict,
1213
convert_export_payload,
1314
format_dataset_item_response,
15+
paginate_generator,
1416
)
1517

1618

@@ -57,30 +59,6 @@ def __eq__(self, other):
5759
return True
5860
return False
5961

60-
def _fetch_all(self) -> dict:
61-
"""Retrieves info and all items of the Slice.
62-
63-
Returns:
64-
A dict mapping keys to the corresponding info retrieved.
65-
::
66-
67-
{
68-
"name": Union[str, int],
69-
"slice_id": str,
70-
"dataset_id": str,
71-
"dataset_items": List[{
72-
"id": str,
73-
"metadata": Dict[str, Union[str, int, float]],
74-
"ref_id": str,
75-
"original_image_url": str
76-
}]
77-
}
78-
"""
79-
response = self._client.make_request(
80-
{}, f"slice/{self.id}", requests_command=requests.get
81-
)
82-
return response
83-
8462
@property
8563
def slice_id(self):
8664
warnings.warn(
@@ -103,10 +81,52 @@ def dataset_id(self):
10381
self._dataset_id = self.info()["dataset_id"]
10482
return self._dataset_id
10583

84+
def items_generator(self, page_size=100000):
85+
"""Generator yielding all dataset items in the dataset.
86+
87+
::
88+
sum_example_field = 0
89+
for item in slice.items_generator():
90+
sum += item.metadata["example_field"]
91+
92+
Args:
93+
page_size (int, optional): Number of items to return per page. If you are
94+
experiencing timeouts while using this generator, you can try lowering
95+
the page size.
96+
97+
Yields:
98+
an iterable of DatasetItem objects.
99+
"""
100+
json_generator = paginate_generator(
101+
client=self._client,
102+
endpoint=f"slice/{self.id}/itemsPage",
103+
result_key=ITEMS_KEY,
104+
page_size=page_size,
105+
)
106+
for item_json in json_generator:
107+
yield DatasetItem.from_json(item_json)
108+
106109
@property
107110
def items(self):
108-
"""All DatasetItems contained in the Slice."""
109-
return self._fetch_all()["dataset_items"]
111+
"""All DatasetItems contained in the Slice.
112+
113+
For fetching more than 200k items see :meth:`Slice.items_generator`.
114+
115+
"""
116+
try:
117+
dataset_item_jsons = self._client.make_request(
118+
{}, f"slice/{self.id}", requests_command=requests.get
119+
)[
120+
"dataset_items"
121+
] # Unfortunately, we didn't use a standard value here, so not using a constant for the key
122+
return [
123+
DatasetItem.from_json(dataset_item_json)
124+
for dataset_item_json in dataset_item_jsons
125+
]
126+
except NucleusAPIError as e:
127+
if e.status_code == 503:
128+
e.message += "/n Your request timed out while trying to get all the items in the slice. Please try slice.items_generator() instead."
129+
raise e
110130

111131
def info(self) -> dict:
112132
"""Retrieves the name, slice_id, and dataset_id of the Slice.

nucleus/utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
import uuid
66
from collections import defaultdict
7-
from typing import IO, Dict, List, Sequence, Type, Union
7+
from typing import IO, TYPE_CHECKING, Dict, List, Sequence, Type, Union
88

99
import requests
1010
from requests.models import HTTPError
@@ -19,6 +19,7 @@
1919
PolygonAnnotation,
2020
SegmentationAnnotation,
2121
)
22+
from nucleus.errors import NucleusAPIError
2223

2324
from .constants import (
2425
ANNOTATION_TYPES,
@@ -27,8 +28,11 @@
2728
CATEGORY_TYPE,
2829
CUBOID_TYPE,
2930
ITEM_KEY,
31+
LAST_PAGE,
3032
LINE_TYPE,
3133
MULTICATEGORY_TYPE,
34+
PAGE_SIZE,
35+
PAGE_TOKEN,
3236
POLYGON_TYPE,
3337
REFERENCE_ID_KEY,
3438
SEGMENTATION_TYPE,
@@ -50,6 +54,9 @@
5054
'\\\\"': '"',
5155
}
5256

57+
if TYPE_CHECKING:
58+
from . import NucleusClient
59+
5360

5461
class KeyErrorDict(dict):
5562
"""Wrapper for response dicts with deprecated keys.
@@ -292,3 +299,27 @@ def replace_double_slashes(s: str) -> str:
292299
for key, val in STRING_REPLACEMENTS.items():
293300
s = s.replace(key, val)
294301
return s
302+
303+
304+
def paginate_generator(
305+
client: "NucleusClient",
306+
endpoint: str,
307+
result_key: str,
308+
page_size: int = 100000,
309+
):
310+
last_page = False
311+
page_token = None
312+
while not last_page:
313+
try:
314+
response = client.make_request(
315+
{PAGE_TOKEN: page_token, PAGE_SIZE: page_size},
316+
endpoint,
317+
requests.post,
318+
)
319+
except NucleusAPIError as e:
320+
if e.status_code == 503:
321+
e.message += f"/n Your request timed out while trying to get a page size of {page_size}. Try lowering the page_size."
322+
raise e
323+
page_token, last_page = response[PAGE_TOKEN], response[LAST_PAGE]
324+
for json_value in response[result_key]:
325+
yield json_value

0 commit comments

Comments
 (0)