Skip to content

Commit b587628

Browse files
authored
Merge pull request #126 from scaleapi/da-retry
Add retries and make custom indexing return a proper object
2 parents 019817a + 0abbdf7 commit b587628

File tree

5 files changed

+82
-61
lines changed

5 files changed

+82
-61
lines changed

nucleus/__init__.py

Lines changed: 59 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
import json
88
import logging
99
import os
10-
import urllib.request
11-
from asyncio.tasks import Task
10+
import time
1211
from typing import Any, Dict, List, Optional, Union
1312

1413
import aiohttp
@@ -41,15 +40,15 @@
4140
ERROR_ITEMS,
4241
ERROR_PAYLOAD,
4342
ERRORS_KEY,
44-
JOB_ID_KEY,
45-
JOB_LAST_KNOWN_STATUS_KEY,
46-
JOB_TYPE_KEY,
47-
JOB_CREATION_TIME_KEY,
4843
IMAGE_KEY,
4944
IMAGE_URL_KEY,
5045
INDEX_CONTINUOUS_ENABLE_KEY,
5146
ITEM_METADATA_SCHEMA_KEY,
5247
ITEMS_KEY,
48+
JOB_CREATION_TIME_KEY,
49+
JOB_ID_KEY,
50+
JOB_LAST_KNOWN_STATUS_KEY,
51+
JOB_TYPE_KEY,
5352
KEEP_HISTORY_KEY,
5453
MESSAGE_KEY,
5554
MODEL_RUN_ID_KEY,
@@ -63,7 +62,7 @@
6362
UPDATE_KEY,
6463
)
6564
from .dataset import Dataset
66-
from .dataset_item import DatasetItem, CameraParams, Quaternion
65+
from .dataset_item import CameraParams, DatasetItem, Quaternion
6766
from .errors import (
6867
DatasetItemRetrievalError,
6968
ModelCreationError,
@@ -87,9 +86,9 @@
8786
PolygonPrediction,
8887
SegmentationPrediction,
8988
)
89+
from .scene import Frame, LidarScene
9090
from .slice import Slice
9191
from .upload_response import UploadResponse
92-
from .scene import Frame, LidarScene
9392

9493
# pylint: disable=E1101
9594
# TODO: refactor to reduce this file to under 1000 lines.
@@ -105,6 +104,11 @@
105104
)
106105

107106

107+
class RetryStrategy:
108+
statuses = {503, 504}
109+
sleep_times = [1, 3, 9]
110+
111+
108112
class NucleusClient:
109113
"""
110114
Nucleus client.
@@ -511,28 +515,41 @@ async def _make_files_request(
511515
content_type=file[1][2],
512516
)
513517

514-
async with session.post(
515-
endpoint,
516-
data=form,
517-
auth=aiohttp.BasicAuth(self.api_key, ""),
518-
timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
519-
) as response:
520-
logger.info("API request has response code %s", response.status)
521-
522-
try:
523-
data = await response.json()
524-
except aiohttp.client_exceptions.ContentTypeError:
525-
# In case of 404, the server returns text
526-
data = await response.text()
527-
528-
if not response.ok:
529-
self.handle_bad_response(
530-
endpoint,
531-
session.post,
532-
aiohttp_response=(response.status, response.reason, data),
518+
for sleep_time in RetryStrategy.sleep_times + [-1]:
519+
async with session.post(
520+
endpoint,
521+
data=form,
522+
auth=aiohttp.BasicAuth(self.api_key, ""),
523+
timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
524+
) as response:
525+
logger.info(
526+
"API request has response code %s", response.status
533527
)
534528

535-
return data
529+
try:
530+
data = await response.json()
531+
except aiohttp.client_exceptions.ContentTypeError:
532+
# In case of 404, the server returns text
533+
data = await response.text()
534+
if (
535+
response.status in RetryStrategy.statuses
536+
and sleep_time != -1
537+
):
538+
time.sleep(sleep_time)
539+
continue
540+
541+
if not response.ok:
542+
self.handle_bad_response(
543+
endpoint,
544+
session.post,
545+
aiohttp_response=(
546+
response.status,
547+
response.reason,
548+
data,
549+
),
550+
)
551+
552+
return data
536553

537554
def _process_append_requests(
538555
self,
@@ -1130,13 +1147,6 @@ def create_custom_index(
11301147
requests_command=requests.post,
11311148
)
11321149

1133-
def check_index_status(self, job_id: str):
1134-
return self.make_request(
1135-
{},
1136-
f"indexing/{job_id}",
1137-
requests_command=requests.get,
1138-
)
1139-
11401150
def delete_custom_index(self, dataset_id: str):
11411151
return self.make_request(
11421152
{},
@@ -1191,14 +1201,20 @@ def make_request(
11911201

11921202
logger.info("Posting to %s", endpoint)
11931203

1194-
response = requests_command(
1195-
endpoint,
1196-
json=payload,
1197-
headers={"Content-Type": "application/json"},
1198-
auth=(self.api_key, ""),
1199-
timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
1200-
)
1201-
logger.info("API request has response code %s", response.status_code)
1204+
for retry_wait_time in RetryStrategy.sleep_times:
1205+
response = requests_command(
1206+
endpoint,
1207+
json=payload,
1208+
headers={"Content-Type": "application/json"},
1209+
auth=(self.api_key, ""),
1210+
timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
1211+
)
1212+
logger.info(
1213+
"API request has response code %s", response.status_code
1214+
)
1215+
if response.status_code not in RetryStrategy.statuses:
1216+
break
1217+
time.sleep(retry_wait_time)
12021218

12031219
if not response.ok:
12041220
self.handle_bad_response(endpoint, requests_command, response)

nucleus/dataset.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -428,10 +428,13 @@ def list_autotags(self):
428428
return self._client.list_autotags(self.id)
429429

430430
def create_custom_index(self, embeddings_urls: list, embedding_dim: int):
431-
return self._client.create_custom_index(
432-
self.id,
433-
embeddings_urls,
434-
embedding_dim,
431+
return AsyncJob.from_json(
432+
self._client.create_custom_index(
433+
self.id,
434+
embeddings_urls,
435+
embedding_dim,
436+
),
437+
self._client,
435438
)
436439

437440
def delete_custom_index(self):
@@ -463,9 +466,6 @@ def add_taxonomy(
463466
requests_command=requests.post,
464467
)
465468

466-
def check_index_status(self, job_id: str):
467-
return self._client.check_index_status(job_id)
468-
469469
def items_and_annotations(
470470
self,
471471
) -> List[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]:

nucleus/errors.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
"scale-nucleus"
55
).version
66

7+
INFRA_FLAKE_MESSAGES = [
8+
"downstream duration timeout",
9+
"upstream connect error or disconnect/reset before headers. reset reason: local reset",
10+
]
11+
712

813
class ModelCreationError(Exception):
914
def __init__(self, message="Could not create the model"):
@@ -35,7 +40,7 @@ class NucleusAPIError(Exception):
3540
def __init__(
3641
self, endpoint, command, requests_response=None, aiohttp_response=None
3742
):
38-
message = f"Your client is on version {nucleus_client_version}. Before reporting this error, please make sure you update to the latest version of the client by running pip install --upgrade scale-nucleus\n"
43+
message = f"Your client is on version {nucleus_client_version}. If you have not recently done so, please make sure you have updated to the latest version of the client by running pip install --upgrade scale-nucleus\n"
3944
if requests_response is not None:
4045
message += f"Tried to {command.__name__} {endpoint}, but received {requests_response.status_code}: {requests_response.reason}."
4146
if hasattr(requests_response, "text"):
@@ -50,4 +55,10 @@ def __init__(
5055
if data:
5156
message += f"\nThe detailed error is:\n{data}"
5257

58+
if any(
59+
infra_flake_message in message
60+
for infra_flake_message in INFRA_FLAKE_MESSAGES
61+
):
62+
message += "\n This likely indicates temporary downtime of the API, please try again in a minute or two"
63+
5364
super().__init__(message)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.1.22"
24+
version = "0.1.23"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

tests/test_indexing.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,14 @@ def dataset(CLIENT):
4141
@pytest.mark.integration
4242
def test_index_integration(dataset):
4343
signed_embeddings_url = TEST_INDEX_EMBEDDINGS_FILE
44-
create_response = dataset.create_custom_index(
45-
[signed_embeddings_url], embedding_dim=3
46-
)
47-
job = AsyncJob.from_json(create_response, client="Nucleus Client")
44+
job = dataset.create_custom_index([signed_embeddings_url], embedding_dim=3)
4845
assert job.job_id
4946
assert job.job_last_known_status
5047
assert job.job_type
5148
assert job.job_creation_time
52-
assert MESSAGE_KEY in create_response
53-
job_id = create_response[JOB_ID_KEY]
49+
job.sleep_until_complete()
5450

55-
# Job can error because pytest dataset fixture gets deleted
56-
# As a workaround, we'll just check htat we got some response
57-
job_status_response = dataset.check_index_status(job_id)
51+
job_status_response = job.status()
5852
assert STATUS_KEY in job_status_response
5953
assert JOB_ID_KEY in job_status_response
6054
assert MESSAGE_KEY in job_status_response

0 commit comments

Comments
 (0)