Skip to content

Commit 722ea60

Browse files
author
Ubuntu
committed
Add retries and make custom indexing return a proper object
1 parent 019817a commit 722ea60

File tree

2 files changed

+61
-35
lines changed

2 files changed

+61
-35
lines changed

nucleus/__init__.py

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import json
88
import logging
99
import os
10-
import urllib.request
11-
from asyncio.tasks import Task
1210
from typing import Any, Dict, List, Optional, Union
1311

1412
import aiohttp
@@ -17,6 +15,7 @@
1715
import requests
1816
import tqdm
1917
import tqdm.notebook as tqdm_notebook
18+
import time
2019

2120
from nucleus.url_utils import sanitize_string_args
2221

@@ -105,6 +104,11 @@
105104
)
106105

107106

107+
class RetryStrategy:
108+
statuses = {503, 504}
109+
sleep_times = [1, 3, 9]
110+
111+
108112
class NucleusClient:
109113
"""
110114
Nucleus client.
@@ -511,28 +515,41 @@ async def _make_files_request(
511515
content_type=file[1][2],
512516
)
513517

514-
async with session.post(
515-
endpoint,
516-
data=form,
517-
auth=aiohttp.BasicAuth(self.api_key, ""),
518-
timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
519-
) as response:
520-
logger.info("API request has response code %s", response.status)
521-
522-
try:
523-
data = await response.json()
524-
except aiohttp.client_exceptions.ContentTypeError:
525-
# In case of 404, the server returns text
526-
data = await response.text()
527-
528-
if not response.ok:
529-
self.handle_bad_response(
530-
endpoint,
531-
session.post,
532-
aiohttp_response=(response.status, response.reason, data),
518+
for sleep_time in RetryStrategy.sleep_times + [""]:
519+
async with session.post(
520+
endpoint,
521+
data=form,
522+
auth=aiohttp.BasicAuth(self.api_key, ""),
523+
timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
524+
) as response:
525+
logger.info(
526+
"API request has response code %s", response.status
533527
)
534528

535-
return data
529+
try:
530+
data = await response.json()
531+
except aiohttp.client_exceptions.ContentTypeError:
532+
# In case of 404, the server returns text
533+
data = await response.text()
534+
if (
535+
response.status in RetryStrategy.statuses
536+
and sleep_time != ""
537+
):
538+
time.sleep(sleep_time)
539+
continue
540+
541+
if not response.ok:
542+
self.handle_bad_response(
543+
endpoint,
544+
session.post,
545+
aiohttp_response=(
546+
response.status,
547+
response.reason,
548+
data,
549+
),
550+
)
551+
552+
return data
536553

537554
def _process_append_requests(
538555
self,
@@ -1191,14 +1208,20 @@ def make_request(
11911208

11921209
logger.info("Posting to %s", endpoint)
11931210

1194-
response = requests_command(
1195-
endpoint,
1196-
json=payload,
1197-
headers={"Content-Type": "application/json"},
1198-
auth=(self.api_key, ""),
1199-
timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
1200-
)
1201-
logger.info("API request has response code %s", response.status_code)
1211+
for retry_wait_time in RetryStrategy.sleep_times:
1212+
response = requests_command(
1213+
endpoint,
1214+
json=payload,
1215+
headers={"Content-Type": "application/json"},
1216+
auth=(self.api_key, ""),
1217+
timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
1218+
)
1219+
logger.info(
1220+
"API request has response code %s", response.status_code
1221+
)
1222+
if response.status_code not in RetryStrategy.statuses:
1223+
break
1224+
time.sleep(retry_wait_time)
12021225

12031226
if not response.ok:
12041227
self.handle_bad_response(endpoint, requests_command, response)
@@ -1214,4 +1237,4 @@ def handle_bad_response(
12141237
):
12151238
raise NucleusAPIError(
12161239
endpoint, requests_command, requests_response, aiohttp_response
1217-
)
1240+
)

nucleus/dataset.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,10 +428,13 @@ def list_autotags(self):
428428
return self._client.list_autotags(self.id)
429429

430430
def create_custom_index(self, embeddings_urls: list, embedding_dim: int):
431-
return self._client.create_custom_index(
432-
self.id,
433-
embeddings_urls,
434-
embedding_dim,
431+
return AsyncJob.from_json(
432+
self._client.create_custom_index(
433+
self.id,
434+
embeddings_urls,
435+
embedding_dim,
436+
),
437+
self._client,
435438
)
436439

437440
def delete_custom_index(self):

0 commit comments

Comments
 (0)