Skip to content

Commit e18f313

Browse files
authored
return raw response (#340)
1 parent 636282f commit e18f313

File tree

3 files changed

+21
-20
lines changed

3 files changed

+21
-20
lines changed

nucleus/__init__.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,7 @@ def make_request(
10211021
payload: Optional[dict],
10221022
route: str,
10231023
requests_command=requests.post,
1024+
return_raw_response: bool = False,
10241025
) -> dict:
10251026
"""Makes a request to a Nucleus API endpoint.
10261027
@@ -1030,9 +1031,10 @@ def make_request(
10301031
payload: Given request payload.
10311032
route: Route for the request.
10321033
Requests command: ``requests.post``, ``requests.get``, or ``requests.delete``.
1034+
return_raw_response: return the request's response object entirely
10331035
10341036
Returns:
1035-
Response payload as JSON dict.
1037+
Response payload as JSON dict or request object.
10361038
"""
10371039
if payload is None:
10381040
payload = {}
@@ -1042,18 +1044,7 @@ def make_request(
10421044
"Received defined payload with GET request! Will ignore payload"
10431045
)
10441046
payload = None
1045-
return self._connection.make_request(payload, route, requests_command) # type: ignore
1046-
1047-
def handle_bad_response(
1048-
self,
1049-
endpoint,
1050-
requests_command,
1051-
requests_response=None,
1052-
aiohttp_response=None,
1053-
):
1054-
self._connection.handle_bad_response(
1055-
endpoint, requests_command, requests_response, aiohttp_response
1056-
)
1047+
return self._connection.make_request(payload, route, requests_command, return_raw_response) # type: ignore
10571048

10581049
def _set_api_key(self, api_key):
10591050
"""Fetch API key from environment variable NUCLEUS_API_KEY if not set"""

nucleus/connection.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ def put(self, payload: dict, route: str):
4040
return self.make_request(payload, route, requests_command=requests.put)
4141

4242
def make_request(
43-
self, payload: dict, route: str, requests_command=requests.post
43+
self,
44+
payload: dict,
45+
route: str,
46+
requests_command=requests.post,
47+
return_raw_response: bool = False,
4448
) -> dict:
4549
"""
4650
Makes a request to Nucleus endpoint and logs a warning if not
@@ -49,6 +53,7 @@ def make_request(
4953
:param payload: given payload
5054
:param route: route for the request
5155
:param requests_command: requests.post, requests.get, requests.delete
56+
:param return_raw_response: return the request's response object entirely
5257
:return: response JSON
5358
"""
5459
endpoint = f"{self.endpoint}/{route}"
@@ -73,6 +78,9 @@ def make_request(
7378
if not response.ok:
7479
self.handle_bad_response(endpoint, requests_command, response)
7580

81+
if return_raw_response:
82+
return response
83+
7684
return response.json()
7785

7886
def handle_bad_response(

nucleus/model.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,16 +234,17 @@ def add_tags(self, tags: List[str]):
234234
Args:
235235
tags: list of tag names
236236
"""
237-
response = self._client.make_request(
237+
response: requests.Response = self._client.make_request(
238238
{MODEL_TAGS_KEY: tags},
239239
f"model/{self.id}/tag",
240240
requests_command=requests.post,
241+
return_raw_response=True,
241242
)
242243

243-
if response.get("msg", False):
244+
if response.ok:
244245
self.tags.extend(tags)
245246

246-
return response
247+
return response.json()
247248

248249
def remove_tags(self, tags: List[str]):
249250
"""Remove tag(s) from the model. ::
@@ -257,13 +258,14 @@ def remove_tags(self, tags: List[str]):
257258
Args:
258259
tags: list of tag names to remove
259260
"""
260-
response = self._client.make_request(
261+
response: requests.Response = self._client.make_request(
261262
{MODEL_TAGS_KEY: tags},
262263
f"model/{self.id}/tag",
263264
requests_command=requests.delete,
265+
return_raw_response=True,
264266
)
265267

266-
if response.get("msg", False):
268+
if response.ok:
267269
self.tags = list(filter(lambda t: t not in tags, self.tags))
268270

269-
return response
271+
return response.json()

0 commit comments

Comments
 (0)