7
7
import json
8
8
import logging
9
9
import os
10
- import urllib .request
11
- from asyncio .tasks import Task
12
10
from typing import Any , Dict , List , Optional , Union
13
11
14
12
import aiohttp
17
15
import requests
18
16
import tqdm
19
17
import tqdm .notebook as tqdm_notebook
18
+ import time
20
19
21
20
from nucleus .url_utils import sanitize_string_args
22
21
105
104
)
106
105
107
106
107
+ class RetryStrategy :
108
+ statuses = {503 , 504 }
109
+ sleep_times = [1 , 3 , 9 ]
110
+
111
+
108
112
class NucleusClient :
109
113
"""
110
114
Nucleus client.
@@ -511,28 +515,41 @@ async def _make_files_request(
511
515
content_type = file [1 ][2 ],
512
516
)
513
517
514
- async with session .post (
515
- endpoint ,
516
- data = form ,
517
- auth = aiohttp .BasicAuth (self .api_key , "" ),
518
- timeout = DEFAULT_NETWORK_TIMEOUT_SEC ,
519
- ) as response :
520
- logger .info ("API request has response code %s" , response .status )
521
-
522
- try :
523
- data = await response .json ()
524
- except aiohttp .client_exceptions .ContentTypeError :
525
- # In case of 404, the server returns text
526
- data = await response .text ()
527
-
528
- if not response .ok :
529
- self .handle_bad_response (
530
- endpoint ,
531
- session .post ,
532
- aiohttp_response = (response .status , response .reason , data ),
518
+ for sleep_time in RetryStrategy .sleep_times + ["" ]:
519
+ async with session .post (
520
+ endpoint ,
521
+ data = form ,
522
+ auth = aiohttp .BasicAuth (self .api_key , "" ),
523
+ timeout = DEFAULT_NETWORK_TIMEOUT_SEC ,
524
+ ) as response :
525
+ logger .info (
526
+ "API request has response code %s" , response .status
533
527
)
534
528
535
- return data
529
+ try :
530
+ data = await response .json ()
531
+ except aiohttp .client_exceptions .ContentTypeError :
532
+ # In case of 404, the server returns text
533
+ data = await response .text ()
534
+ if (
535
+ response .status in RetryStrategy .statuses
536
+ and sleep_time != ""
537
+ ):
538
+ time .sleep (sleep_time )
539
+ continue
540
+
541
+ if not response .ok :
542
+ self .handle_bad_response (
543
+ endpoint ,
544
+ session .post ,
545
+ aiohttp_response = (
546
+ response .status ,
547
+ response .reason ,
548
+ data ,
549
+ ),
550
+ )
551
+
552
+ return data
536
553
537
554
def _process_append_requests (
538
555
self ,
@@ -1191,14 +1208,20 @@ def make_request(
1191
1208
1192
1209
logger .info ("Posting to %s" , endpoint )
1193
1210
1194
- response = requests_command (
1195
- endpoint ,
1196
- json = payload ,
1197
- headers = {"Content-Type" : "application/json" },
1198
- auth = (self .api_key , "" ),
1199
- timeout = DEFAULT_NETWORK_TIMEOUT_SEC ,
1200
- )
1201
- logger .info ("API request has response code %s" , response .status_code )
1211
+ for retry_wait_time in RetryStrategy .sleep_times :
1212
+ response = requests_command (
1213
+ endpoint ,
1214
+ json = payload ,
1215
+ headers = {"Content-Type" : "application/json" },
1216
+ auth = (self .api_key , "" ),
1217
+ timeout = DEFAULT_NETWORK_TIMEOUT_SEC ,
1218
+ )
1219
+ logger .info (
1220
+ "API request has response code %s" , response .status_code
1221
+ )
1222
+ if response .status_code not in RetryStrategy .statuses :
1223
+ break
1224
+ time .sleep (retry_wait_time )
1202
1225
1203
1226
if not response .ok :
1204
1227
self .handle_bad_response (endpoint , requests_command , response )
@@ -1214,4 +1237,4 @@ def handle_bad_response(
1214
1237
):
1215
1238
raise NucleusAPIError (
1216
1239
endpoint , requests_command , requests_response , aiohttp_response
1217
- )
1240
+ )
0 commit comments