Skip to content

Commit a8469ee

Browse files
author
Diego Ardila
committed
simplified error handling to surface more errors, added test that checks errors
1 parent e6e4d85 commit a8469ee

File tree

4 files changed

+54
-44
lines changed

4 files changed

+54
-44
lines changed

conftest.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,34 +29,8 @@
2929

3030

3131
@pytest.fixture(scope="session")
32-
def monkeypatch_session(request):
33-
"""This workaround is needed to allow monkeypatching in session-scoped fixtures.
34-
35-
See https://github.com/pytest-dev/pytest/issues/363
36-
"""
37-
from _pytest.monkeypatch import MonkeyPatch
38-
39-
mpatch = MonkeyPatch()
40-
yield mpatch
41-
mpatch.undo()
42-
43-
44-
@pytest.fixture(scope="session")
45-
def CLIENT(monkeypatch_session):
32+
def CLIENT():
4633
client = nucleus.NucleusClient(API_KEY)
47-
48-
# Change _make_request to raise AsssertionErrors when the
49-
# HTTP status code is not successful, so that tests fail if
50-
# the request was unsuccessful.
51-
def _make_request_patch(
52-
payload: dict, route: str, requests_command=requests.post
53-
) -> dict:
54-
response = client._make_request_raw(payload, route, requests_command)
55-
if response.status_code not in SUCCESS_STATUS_CODES:
56-
response.raise_for_status()
57-
return response.json()
58-
59-
monkeypatch_session.setattr(client, "make_request", _make_request_patch)
6034
return client
6135

6236

nucleus/__init__.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132
ModelRunCreationError,
133133
DatasetItemRetrievalError,
134134
NotFoundError,
135+
NucleusAPIError,
135136
)
136137

137138
logger = logging.getLogger(__name__)
@@ -146,15 +147,21 @@ class NucleusClient:
146147
Nucleus client.
147148
"""
148149

149-
def __init__(self, api_key: str, use_notebook: bool = False):
150+
def __init__(
151+
self,
152+
api_key: str,
153+
use_notebook: bool = False,
154+
endpoint=NUCLEUS_ENDPOINT,
155+
):
150156
self.api_key = api_key
151157
self.tqdm_bar = tqdm.tqdm
158+
self.endpoint = endpoint
152159
self._use_notebook = use_notebook
153160
if use_notebook:
154161
self.tqdm_bar = tqdm_notebook.tqdm
155162

156163
def __repr__(self):
157-
return f"NucleusClient(api_key='{self.api_key}', use_notebook={self._use_notebook})"
164+
return f"NucleusClient(api_key='{self.api_key}', use_notebook={self._use_notebook}, endpoint='{self.endpoint}'')"
158165

159166
def __eq__(self, other):
160167
if self.api_key == other.api_key:
@@ -1080,7 +1087,7 @@ def _make_grequest(
10801087
sess.mount("https://", adapter)
10811088
sess.mount("http://", adapter)
10821089

1083-
endpoint = f"{NUCLEUS_ENDPOINT}/{route}"
1090+
endpoint = f"{self.endpoint}/{route}"
10841091
logger.info("Posting to %s", endpoint)
10851092

10861093
if local:
@@ -1103,18 +1110,17 @@ def _make_grequest(
11031110
return post
11041111

11051112
def _make_request_raw(
1106-
self, payload: dict, route: str, requests_command=requests.post
1113+
self, payload: dict, endpoint: str, requests_command=requests.post
11071114
):
11081115
"""
11091116
Makes a request to Nucleus endpoint. This method returns the raw
11101117
requests.Response object which is useful for unit testing.
11111118
11121119
:param payload: given payload
1113-
:param route: route for the request
1120+
:param endpoint: endpoint + route for the request
11141121
:param requests_command: requests.post, requests.get, requests.delete
11151122
:return: response
11161123
"""
1117-
endpoint = f"{NUCLEUS_ENDPOINT}/{route}"
11181124
logger.info("Posting to %s", endpoint)
11191125

11201126
response = requests_command(
@@ -1140,12 +1146,14 @@ def make_request(
11401146
:param requests_command: requests.post, requests.get, requests.delete
11411147
:return: response JSON
11421148
"""
1143-
response = self._make_request_raw(payload, route, requests_command)
1149+
endpoint = f"{self.endpoint}/{route}"
11441150

1145-
if getattr(response, "status_code") not in SUCCESS_STATUS_CODES:
1146-
logger.warning(response)
1151+
response = self._make_request_raw(payload, endpoint, requests_command)
11471152

1148-
if response.status_code == 404:
1149-
raise response.raise_for_status()
1153+
if not response.ok:
1154+
self.handle_bad_response(endpoint, requests_command, response)
11501155

11511156
return response.json()
1157+
1158+
def handle_bad_response(self, endpoint, requests_command, response):
1159+
raise NucleusAPIError(endpoint, requests_command, response)

nucleus/errors.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,12 @@ class DatasetItemRetrievalError(Exception):
2222
def __init__(self, message="Could not retrieve dataset items"):
2323
self.message = message
2424
super().__init__(self.message)
25+
26+
27+
class NucleusAPIError(Exception):
28+
def __init__(self, endpoint, command, response):
29+
message = f"Tried to {command.__name__} {endpoint}, but received {response.status_code}: {response.reason}."
30+
if hasattr(response, "text"):
31+
if response.text:
32+
message += f"\nThe detailed error is:\n{response.text}"
33+
super().__init__(message)

tests/test_dataset.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import os
23

34
from helpers import (
45
TEST_SLICE_NAME,
@@ -8,7 +9,13 @@
89
reference_id_from_url,
910
)
1011

11-
from nucleus import Dataset, DatasetItem, UploadResponse, NucleusClient
12+
from nucleus import (
13+
Dataset,
14+
DatasetItem,
15+
UploadResponse,
16+
NucleusClient,
17+
NucleusAPIError,
18+
)
1219
from nucleus.constants import (
1320
NEW_ITEMS,
1421
UPDATED_ITEMS,
@@ -18,6 +25,8 @@
1825
DATASET_ID_KEY,
1926
)
2027

28+
TEST_AUTOTAG_DATASET = "ds_bz43jm2jwm70060b3890"
29+
2130

2231
def test_reprs():
2332
# Have to define here in order to have access to all relevant objects
@@ -130,10 +139,20 @@ def test_dataset_list_autotags(CLIENT, dataset):
130139

131140

132141
def test_dataset_export_autotag_scores(CLIENT):
133-
# Pandoc dataset.
134-
client.get_dataset("ds_bwhjbyfb8mjj0ykagxf0")
135-
142+
# This test can only run for the test user who has an indexed dataset.
136143
# TODO: if/when we can create autotags via api, create one instead.
137-
dataset.autotag_scores(autotag_name="red_car_v2")
144+
if os.environ.get("HAS_ACCESS_TO_TEST_DATA", False):
145+
dataset = CLIENT.get_dataset(TEST_AUTOTAG_DATASET)
146+
147+
with pytest.raises(NucleusAPIError) as api_error:
148+
dataset.autotag_scores(autotag_name="NONSENSE_GARBAGE")
149+
assert (
150+
f"The autotag NONSENSE_GARBAGE was not found in dataset {TEST_AUTOTAG_DATASET}"
151+
in str(api_error.value)
152+
)
153+
154+
scores = dataset.autotag_scores(autotag_name="TestTag")
138155

139-
# TODO: add some asserts?
156+
for column in ["dataset_item_ids", "ref_ids", "scores"]:
157+
assert column in scores
158+
assert len(scores[column]) > 0

0 commit comments

Comments
 (0)