Skip to content

Commit f2dbd31

Browse files
authored
allow for concurrent task fetches (#411)
1 parent 2044e72 commit f2dbd31

File tree

6 files changed

+188
-59
lines changed

6 files changed

+188
-59
lines changed

CHANGELOG.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,23 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

88

9+
## [0.16.11](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.11) - 2023-11-22
10+
11+
### Added
12+
13+
- Method to allow for concurrent task fetches for pointcloud data
14+
15+
Example:
16+
```python
17+
>>> task_ids = ['task_1', 'task_2']
18+
>>> resp = client.download_pointcloud_tasks(task_ids=task_ids, frame_num=1)
19+
>>> resp
20+
{
21+
'task_1': [Point3D(x=5, y=10.7, z=-2.3), ...],
22+
'task_2': [Point3D(x=1.3 y=11.1, z=1.5), ...],
23+
}
24+
```
25+
926
## [0.16.10](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.10) - 2023-11-22
1027

1128
Allow creating a dataset by crawling all images in a directory, recursively. Also supports privacy mode datasets.
@@ -53,7 +70,7 @@ This would create a dataset `my-dataset`, and when opened in Nucleus, the images
5370

5471
### Fixes
5572

56-
- Minor fixes to video scene upload on privacy moce
73+
- Minor fixes to video scene upload on privacy mode
5774

5875
## [0.16.8](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.8) - 2023-11-16
5976

nucleus/__init__.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
SegmentationAnnotation,
7171
)
7272
from .async_job import AsyncJob, EmbeddingsExportJob
73+
from .async_utils import make_multiple_requests_concurrently
7374
from .camera_params import CameraParams
7475
from .connection import Connection
7576
from .constants import (
@@ -1065,6 +1066,50 @@ def download_pointcloud_task(
10651066

10661067
return [Point3D.from_json(pt) for pt in points]
10671068

1069+
def download_pointcloud_tasks(
1070+
self, task_ids: List[str], frame_num: int
1071+
) -> Dict[str, List[Union[Point3D, LidarPoint]]]:
1072+
"""
1073+
Download the lidar point cloud data for a given set of tasks and frame number.
1074+
1075+
Parameters:
1076+
task_ids: list of task ids to fetch data from
1077+
frame_num: download point cloud for this particular frame
1078+
1079+
Returns:
1080+
A dictionary from task_id to list of Point3D objects
1081+
1082+
"""
1083+
endpoints = [
1084+
f"task/{task_id}/frame/{frame_num}" for task_id in task_ids
1085+
]
1086+
progressbar = self.tqdm_bar(
1087+
total=len(endpoints),
1088+
desc="Downloading pointcloud tasks",
1089+
)
1090+
results = make_multiple_requests_concurrently(
1091+
client=self,
1092+
requests=endpoints,
1093+
route=None,
1094+
progressbar=progressbar,
1095+
)
1096+
resp = {}
1097+
1098+
for result in results:
1099+
req, data = result
1100+
task_id = req.split("/")[1] # task/<task id>/frame/1 => task_id
1101+
points = data.get(POINTS_KEY, None)
1102+
if points is None or len(points) == 0:
1103+
raise Exception("Response has invalid payload")
1104+
1105+
sample_point = points[0]
1106+
if I_KEY in sample_point.keys():
1107+
resp[task_id] = [LidarPoint.from_json(pt) for pt in points]
1108+
1109+
resp[task_id] = [Point3D.from_json(pt) for pt in points]
1110+
1111+
return resp
1112+
10681113
@deprecated("Prefer calling Dataset.create_custom_index instead.")
10691114
def create_custom_index(
10701115
self, dataset_id: str, embeddings_urls: list, embedding_dim: int

nucleus/annotation_uploader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from nucleus.async_utils import (
77
FileFormField,
88
FormDataContextHandler,
9-
make_many_form_data_requests_concurrently,
9+
make_multiple_requests_concurrently,
1010
)
1111
from nucleus.constants import MASK_TYPE, SERIALIZED_REQUEST_KEY
1212
from nucleus.errors import DuplicateIDError
@@ -150,7 +150,7 @@ def make_batched_file_form_data_requests(
150150
desc="Local segmentation mask file batches",
151151
)
152152

153-
return make_many_form_data_requests_concurrently(
153+
return make_multiple_requests_concurrently(
154154
client=self._client,
155155
requests=requests,
156156
route=self._route,

nucleus/async_utils.py

Lines changed: 120 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
import asyncio
22
import time
33
from dataclasses import dataclass
4-
from typing import TYPE_CHECKING, BinaryIO, Callable, Sequence, Tuple
4+
from typing import (
5+
TYPE_CHECKING,
6+
BinaryIO,
7+
Callable,
8+
Optional,
9+
Sequence,
10+
Tuple,
11+
Union,
12+
)
513

614
import aiohttp
715
import nest_asyncio
@@ -84,56 +92,71 @@ def get_event_loop():
8492
return loop
8593

8694

87-
def make_many_form_data_requests_concurrently(
95+
def make_multiple_requests_concurrently(
8896
client: "NucleusClient",
89-
requests: Sequence[FormDataContextHandler],
90-
route: str,
97+
requests: Sequence[Union[FormDataContextHandler, str]],
98+
route: Optional[str],
9199
progressbar: tqdm,
92100
):
93101
"""
94102
Makes an async post request with form data to a Nucleus endpoint.
95103
96104
Args:
97105
client: The client to use for the request.
98-
requests: Each requst should be a FormDataContextHandler object which will
99-
handle generating form data, and opening/closing files for each request.
100-
route: route for the request.
106+
requests: a list of requests to make. This list either comprises a string of endpoints to request,
107+
or a list of FormDataContextHandler object which will handle generating form data, and opening/closing files for each request.
108+
route: A route is required when requests are for Form Data Post requests
101109
progressbar: A tqdm progress bar to use for showing progress to the user.
102110
"""
103111
loop = get_event_loop()
104112
return loop.run_until_complete(
105-
form_data_request_helper(client, requests, route, progressbar)
113+
_request_helper(client, requests, route, progressbar)
106114
)
107115

108116

109-
async def form_data_request_helper(
117+
async def _request_helper(
110118
client: "NucleusClient",
111-
requests: Sequence[FormDataContextHandler],
112-
route: str,
119+
requests: Sequence[Union[FormDataContextHandler, str]],
120+
route: Optional[str],
113121
progressbar: tqdm,
114122
):
115123
"""
116-
Makes an async post request with files to a Nucleus endpoint.
124+
Makes an async requests to a Nucleus endpoint.
117125
118126
Args:
119127
client: The client to use for the request.
120-
requests: Each request should be a FormDataContextHandler object which will
121-
handle generating form data, and opening/closing files for each request.
128+
requests: a list of requests to make. This list either comprises a string of endpoints to request,
129+
or a list of FormDataContextHandler object which will handle generating form data, and opening/closing files for each request.
122130
route: route for the request.
123131
"""
124132
async with aiohttp.ClientSession() as session:
125-
tasks = [
126-
asyncio.ensure_future(
127-
_post_form_data(
128-
client=client,
129-
request=request,
130-
route=route,
131-
session=session,
132-
progressbar=progressbar,
133+
tasks = []
134+
for request in requests:
135+
if isinstance(request, FormDataContextHandler):
136+
assert (
137+
route
138+
), "A route must be specified for FormDataContextHandler requests"
139+
req = asyncio.ensure_future(
140+
_post_form_data(
141+
client=client,
142+
request=request,
143+
route=route,
144+
session=session,
145+
progressbar=progressbar,
146+
)
133147
)
134-
)
135-
for request in requests
136-
]
148+
tasks.append(req)
149+
else:
150+
req = asyncio.ensure_future(
151+
_make_request(
152+
client=client,
153+
request=request,
154+
session=session,
155+
progressbar=progressbar,
156+
)
157+
)
158+
tasks.append(req)
159+
137160
return await asyncio.gather(*tasks)
138161

139162

@@ -165,36 +188,80 @@ async def _post_form_data(
165188
auth=aiohttp.BasicAuth(client.api_key, ""),
166189
timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
167190
) as response:
168-
logger.info(
169-
"API request has response code %s", response.status
191+
data = await _parse_async_response(
192+
endpoint, session, response, sleep_time
170193
)
171-
172-
try:
173-
data = await response.json()
174-
except aiohttp.client_exceptions.ContentTypeError:
175-
# In case of 404, the server returns text
176-
data = await response.text()
177-
if (
178-
response.status in RetryStrategy.statuses
179-
and sleep_time != -1
180-
):
181-
time.sleep(sleep_time)
194+
if data is None:
182195
continue
183196

184-
if response.status == 503:
185-
raise TimeoutError(
186-
"The request to upload your max is timing out, please lower local_files_per_upload_request in your api call."
187-
)
188-
189-
if not response.ok:
190-
raise NucleusAPIError(
191-
endpoint,
192-
session.post,
193-
aiohttp_response=(
194-
response.status,
195-
response.reason,
196-
data,
197-
),
198-
)
199197
progressbar.update(1)
200198
return data
199+
200+
201+
async def _make_request(
202+
client: "NucleusClient",
203+
request: str,
204+
session: aiohttp.ClientSession,
205+
progressbar: tqdm,
206+
):
207+
"""
208+
Makes an async post request with files to a Nucleus endpoint.
209+
210+
Args:
211+
client: The client to use for the request.
212+
request: The request to make (See FormDataContextHandler for more details.)
213+
route: route for the request.
214+
session: The session to use for the request.
215+
216+
Returns:
217+
A tuple (endpoint request string, response from endpoint)
218+
"""
219+
endpoint = f"{client.endpoint}/{request}"
220+
logger.info("GET %s", endpoint)
221+
222+
async with UPLOAD_SEMAPHORE:
223+
for sleep_time in RetryStrategy.sleep_times() + [-1]:
224+
async with session.get(
225+
endpoint,
226+
auth=aiohttp.BasicAuth(client.api_key, ""),
227+
timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
228+
) as response:
229+
data = await _parse_async_response(
230+
endpoint, session, response, sleep_time
231+
)
232+
if data is None:
233+
continue
234+
235+
progressbar.update(1)
236+
return (request, data)
237+
238+
239+
async def _parse_async_response(endpoint, session, response, sleep_time):
240+
logger.info("API request has response code %s", response.status)
241+
242+
try:
243+
data = await response.json()
244+
except aiohttp.client_exceptions.ContentTypeError:
245+
# In case of 404, the server returns text
246+
data = await response.text()
247+
if response.status in RetryStrategy.statuses and sleep_time != -1:
248+
time.sleep(sleep_time)
249+
return None
250+
251+
if response.status == 503:
252+
raise TimeoutError(
253+
"The request to upload your max is timing out, please lower local_files_per_upload_request in your api call."
254+
)
255+
256+
if not response.ok:
257+
raise NucleusAPIError(
258+
endpoint,
259+
session.get,
260+
aiohttp_response=(
261+
response.status,
262+
response.reason,
263+
data,
264+
),
265+
)
266+
267+
return data

nucleus/dataset_item_uploader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
FileFormData,
1515
FileFormField,
1616
FormDataContextHandler,
17-
make_many_form_data_requests_concurrently,
17+
make_multiple_requests_concurrently,
1818
)
1919

2020
from .constants import DATASET_ID_KEY, IMAGE_KEY, ITEMS_KEY, UPDATE_KEY
@@ -125,7 +125,7 @@ def _process_append_requests_local(
125125
desc=f"Uploading {len(items)} items in {len(requests)} batches",
126126
)
127127

128-
return make_many_form_data_requests_concurrently(
128+
return make_multiple_requests_concurrently(
129129
self._client,
130130
requests,
131131
f"dataset/{dataset_id}/append",

tests/test_video_scene.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def test_video_scene_upload_async(dataset_video_scene):
158158

159159

160160
@pytest.mark.integration
161-
@pytest.mark.xfail(reason="SFN doesn't throw on validation error - 05.10.2023")
161+
@pytest.mark.skip(reason="SFN doesn't throw on validation error - 05.10.2023")
162162
def test_repeat_refid_video_scene_upload_async(dataset_video_scene):
163163
payload = TEST_VIDEO_SCENES_REPEAT_REF_IDS
164164
scenes = [

0 commit comments

Comments
 (0)