|
1 | 1 | import asyncio
|
2 | 2 | import time
|
3 | 3 | from dataclasses import dataclass
|
4 |
| -from typing import TYPE_CHECKING, BinaryIO, Callable, Sequence, Tuple |
| 4 | +from typing import ( |
| 5 | + TYPE_CHECKING, |
| 6 | + BinaryIO, |
| 7 | + Callable, |
| 8 | + Optional, |
| 9 | + Sequence, |
| 10 | + Tuple, |
| 11 | + Union, |
| 12 | +) |
5 | 13 |
|
6 | 14 | import aiohttp
|
7 | 15 | import nest_asyncio
|
@@ -84,56 +92,71 @@ def get_event_loop():
|
84 | 92 | return loop
|
85 | 93 |
|
86 | 94 |
|
87 |
| -def make_many_form_data_requests_concurrently( |
| 95 | +def make_multiple_requests_concurrently( |
88 | 96 | client: "NucleusClient",
|
89 |
| - requests: Sequence[FormDataContextHandler], |
90 |
| - route: str, |
| 97 | + requests: Sequence[Union[FormDataContextHandler, str]], |
| 98 | + route: Optional[str], |
91 | 99 | progressbar: tqdm,
|
92 | 100 | ):
|
93 | 101 | """
|
94 | 102 | Makes an async post request with form data to a Nucleus endpoint.
|
95 | 103 |
|
96 | 104 | Args:
|
97 | 105 | client: The client to use for the request.
|
98 |
| - requests: Each requst should be a FormDataContextHandler object which will |
99 |
| - handle generating form data, and opening/closing files for each request. |
100 |
| - route: route for the request. |
| 106 | + requests: a list of requests to make. This list either comprises a string of endpoints to request, |
| 107 | + or a list of FormDataContextHandler object which will handle generating form data, and opening/closing files for each request. |
| 108 | + route: A route is required when requests are for Form Data Post requests |
101 | 109 | progressbar: A tqdm progress bar to use for showing progress to the user.
|
102 | 110 | """
|
103 | 111 | loop = get_event_loop()
|
104 | 112 | return loop.run_until_complete(
|
105 |
| - form_data_request_helper(client, requests, route, progressbar) |
| 113 | + _request_helper(client, requests, route, progressbar) |
106 | 114 | )
|
107 | 115 |
|
108 | 116 |
|
109 |
| -async def form_data_request_helper( |
| 117 | +async def _request_helper( |
110 | 118 | client: "NucleusClient",
|
111 |
| - requests: Sequence[FormDataContextHandler], |
112 |
| - route: str, |
| 119 | + requests: Sequence[Union[FormDataContextHandler, str]], |
| 120 | + route: Optional[str], |
113 | 121 | progressbar: tqdm,
|
114 | 122 | ):
|
115 | 123 | """
|
116 |
| - Makes an async post request with files to a Nucleus endpoint. |
| 124 | + Makes an async requests to a Nucleus endpoint. |
117 | 125 |
|
118 | 126 | Args:
|
119 | 127 | client: The client to use for the request.
|
120 |
| - requests: Each request should be a FormDataContextHandler object which will |
121 |
| - handle generating form data, and opening/closing files for each request. |
| 128 | + requests: a list of requests to make. This list either comprises a string of endpoints to request, |
| 129 | + or a list of FormDataContextHandler object which will handle generating form data, and opening/closing files for each request. |
122 | 130 | route: route for the request.
|
123 | 131 | """
|
124 | 132 | async with aiohttp.ClientSession() as session:
|
125 |
| - tasks = [ |
126 |
| - asyncio.ensure_future( |
127 |
| - _post_form_data( |
128 |
| - client=client, |
129 |
| - request=request, |
130 |
| - route=route, |
131 |
| - session=session, |
132 |
| - progressbar=progressbar, |
| 133 | + tasks = [] |
| 134 | + for request in requests: |
| 135 | + if isinstance(request, FormDataContextHandler): |
| 136 | + assert ( |
| 137 | + route |
| 138 | + ), "A route must be specified for FormDataContextHandler requests" |
| 139 | + req = asyncio.ensure_future( |
| 140 | + _post_form_data( |
| 141 | + client=client, |
| 142 | + request=request, |
| 143 | + route=route, |
| 144 | + session=session, |
| 145 | + progressbar=progressbar, |
| 146 | + ) |
133 | 147 | )
|
134 |
| - ) |
135 |
| - for request in requests |
136 |
| - ] |
| 148 | + tasks.append(req) |
| 149 | + else: |
| 150 | + req = asyncio.ensure_future( |
| 151 | + _make_request( |
| 152 | + client=client, |
| 153 | + request=request, |
| 154 | + session=session, |
| 155 | + progressbar=progressbar, |
| 156 | + ) |
| 157 | + ) |
| 158 | + tasks.append(req) |
| 159 | + |
137 | 160 | return await asyncio.gather(*tasks)
|
138 | 161 |
|
139 | 162 |
|
@@ -165,36 +188,80 @@ async def _post_form_data(
|
165 | 188 | auth=aiohttp.BasicAuth(client.api_key, ""),
|
166 | 189 | timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
|
167 | 190 | ) as response:
|
168 |
| - logger.info( |
169 |
| - "API request has response code %s", response.status |
| 191 | + data = await _parse_async_response( |
| 192 | + endpoint, session, response, sleep_time |
170 | 193 | )
|
171 |
| - |
172 |
| - try: |
173 |
| - data = await response.json() |
174 |
| - except aiohttp.client_exceptions.ContentTypeError: |
175 |
| - # In case of 404, the server returns text |
176 |
| - data = await response.text() |
177 |
| - if ( |
178 |
| - response.status in RetryStrategy.statuses |
179 |
| - and sleep_time != -1 |
180 |
| - ): |
181 |
| - time.sleep(sleep_time) |
| 194 | + if data is None: |
182 | 195 | continue
|
183 | 196 |
|
184 |
| - if response.status == 503: |
185 |
| - raise TimeoutError( |
186 |
| - "The request to upload your max is timing out, please lower local_files_per_upload_request in your api call." |
187 |
| - ) |
188 |
| - |
189 |
| - if not response.ok: |
190 |
| - raise NucleusAPIError( |
191 |
| - endpoint, |
192 |
| - session.post, |
193 |
| - aiohttp_response=( |
194 |
| - response.status, |
195 |
| - response.reason, |
196 |
| - data, |
197 |
| - ), |
198 |
| - ) |
199 | 197 | progressbar.update(1)
|
200 | 198 | return data
|
| 199 | + |
| 200 | + |
| 201 | +async def _make_request( |
| 202 | + client: "NucleusClient", |
| 203 | + request: str, |
| 204 | + session: aiohttp.ClientSession, |
| 205 | + progressbar: tqdm, |
| 206 | +): |
| 207 | + """ |
| 208 | + Makes an async post request with files to a Nucleus endpoint. |
| 209 | +
|
| 210 | + Args: |
| 211 | + client: The client to use for the request. |
| 212 | + request: The request to make (See FormDataContextHandler for more details.) |
| 213 | + route: route for the request. |
| 214 | + session: The session to use for the request. |
| 215 | +
|
| 216 | + Returns: |
| 217 | + A tuple (endpoint request string, response from endpoint) |
| 218 | + """ |
| 219 | + endpoint = f"{client.endpoint}/{request}" |
| 220 | + logger.info("GET %s", endpoint) |
| 221 | + |
| 222 | + async with UPLOAD_SEMAPHORE: |
| 223 | + for sleep_time in RetryStrategy.sleep_times() + [-1]: |
| 224 | + async with session.get( |
| 225 | + endpoint, |
| 226 | + auth=aiohttp.BasicAuth(client.api_key, ""), |
| 227 | + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, |
| 228 | + ) as response: |
| 229 | + data = await _parse_async_response( |
| 230 | + endpoint, session, response, sleep_time |
| 231 | + ) |
| 232 | + if data is None: |
| 233 | + continue |
| 234 | + |
| 235 | + progressbar.update(1) |
| 236 | + return (request, data) |
| 237 | + |
| 238 | + |
| 239 | +async def _parse_async_response(endpoint, session, response, sleep_time): |
| 240 | + logger.info("API request has response code %s", response.status) |
| 241 | + |
| 242 | + try: |
| 243 | + data = await response.json() |
| 244 | + except aiohttp.client_exceptions.ContentTypeError: |
| 245 | + # In case of 404, the server returns text |
| 246 | + data = await response.text() |
| 247 | + if response.status in RetryStrategy.statuses and sleep_time != -1: |
| 248 | + time.sleep(sleep_time) |
| 249 | + return None |
| 250 | + |
| 251 | + if response.status == 503: |
| 252 | + raise TimeoutError( |
| 253 | + "The request to upload your max is timing out, please lower local_files_per_upload_request in your api call." |
| 254 | + ) |
| 255 | + |
| 256 | + if not response.ok: |
| 257 | + raise NucleusAPIError( |
| 258 | + endpoint, |
| 259 | + session.get, |
| 260 | + aiohttp_response=( |
| 261 | + response.status, |
| 262 | + response.reason, |
| 263 | + data, |
| 264 | + ), |
| 265 | + ) |
| 266 | + |
| 267 | + return data |
0 commit comments