Skip to content

Commit 0a8f921

Browse files
author
Val Brodsky
committed
Add support for custom error handler
1 parent fed3d3a commit 0a8f921

File tree

3 files changed

+82
-36
lines changed

3 files changed

+82
-36
lines changed

libs/lbox-clients/src/lbox/exceptions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ def __init__(self, message, cause=None):
1616
self.cause = cause
1717

1818
def __str__(self):
19-
return self.message + str(self.args)
19+
exception_message = self.message
20+
if self.cause is not None:
21+
exception_message += " (caused by: %s)" % self.cause
22+
return exception_message
2023

2124

2225
class AuthenticationError(LabelboxError):

libs/lbox-clients/src/lbox/request_client.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import sys
55
from datetime import datetime, timezone
66
from types import MappingProxyType
7+
from typing import Callable, Dict, Optional
78

89
import requests
910
import requests.exceptions
@@ -52,9 +53,7 @@ def __init__(
5253
"""
5354
if api_key is None:
5455
if _LABELBOX_API_KEY not in os.environ:
55-
raise exceptions.AuthenticationError(
56-
"Labelbox API key not provided"
57-
)
56+
raise exceptions.AuthenticationError("Labelbox API key not provided")
5857
api_key = os.environ[_LABELBOX_API_KEY]
5958
self.api_key = api_key
6059

@@ -70,9 +69,7 @@ def __init__(
7069
self._connection: requests.Session = self._init_connection()
7170

7271
def _init_connection(self) -> requests.Session:
73-
connection = (
74-
requests.Session()
75-
) # using default connection pool size of 10
72+
connection = requests.Session() # using default connection pool size of 10
7673
connection.headers.update(self._default_headers())
7774

7875
return connection
@@ -106,6 +103,9 @@ def execute(
106103
experimental=False,
107104
error_log_key="message",
108105
raise_return_resource_not_found=False,
106+
error_handlers: Optional[
107+
Dict[str, Callable[[requests.models.Response], None]]
108+
] = None,
109109
):
110110
"""Sends a request to the server for the execution of the
111111
given query.
@@ -120,6 +120,27 @@ def execute(
120120
files (dict): file arguments for request
121121
timeout (float): Max allowed time for query execution,
122122
in seconds.
123+
raise_return_resource_not_found: By default the client relies on the caller to raise the correct exception when a resource is not found.
124+
If this is set to True, the client will raise a ResourceNotFoundError exception automatically.
125+
This simplifies processing.
126+
We recommend to use it only of api returns a clear and well-formed error when a resource not found for a given query.
127+
error_handlers (dict): A dictionary mapping graphql error code to handler functions.
128+
Allows a caller to handle specific errors reporting in a custom way or produce more user-friendly readable messages.
129+
130+
Example - custom error handler:
131+
>>> def _raise_readable_errors(self, response):
132+
>>> errors = response.json().get('errors', [])
133+
>>> if errors:
134+
>>> message = errors[0].get(
135+
>>> 'message', json.dumps([{
136+
>>> "errorMessage": "Unknown error"
137+
>>> }]))
138+
>>> errors = json.loads(message)
139+
>>> error_messages = [error['errorMessage'] for error in errors]
140+
>>> else:
141+
>>> error_messages = ["Uknown error"]
142+
>>> raise LabelboxError(". ".join(error_messages))
143+
123144
Returns:
124145
dict, parsed JSON response.
125146
Raises:
@@ -149,12 +170,8 @@ def convert_value(value):
149170

150171
if query is not None:
151172
if params is not None:
152-
params = {
153-
key: convert_value(value) for key, value in params.items()
154-
}
155-
data = json.dumps({"query": query, "variables": params}).encode(
156-
"utf-8"
157-
)
173+
params = {key: convert_value(value) for key, value in params.items()}
174+
data = json.dumps({"query": query, "variables": params}).encode("utf-8")
158175
elif data is None:
159176
raise ValueError("query and data cannot both be none")
160177

@@ -207,9 +224,7 @@ def convert_value(value):
207224
"upstream connect error or disconnect/reset before headers"
208225
in response.text
209226
):
210-
raise exceptions.InternalServerError(
211-
"Connection reset"
212-
)
227+
raise exceptions.InternalServerError("Connection reset")
213228
elif response.status_code == 502:
214229
error_502 = "502 Bad Gateway"
215230
raise exceptions.InternalServerError(error_502)
@@ -237,19 +252,14 @@ def get_error_status_code(error: dict) -> int:
237252
except:
238253
return 500
239254

240-
if (
241-
check_errors(["AUTHENTICATION_ERROR"], "extensions", "code")
242-
is not None
243-
):
255+
if check_errors(["AUTHENTICATION_ERROR"], "extensions", "code") is not None:
244256
raise exceptions.AuthenticationError("Invalid API key")
245257

246258
authorization_error = check_errors(
247259
["AUTHORIZATION_ERROR"], "extensions", "code"
248260
)
249261
if authorization_error is not None:
250-
raise exceptions.AuthorizationError(
251-
authorization_error["message"]
252-
)
262+
raise exceptions.AuthorizationError(authorization_error["message"])
253263

254264
validation_error = check_errors(
255265
["GRAPHQL_VALIDATION_FAILED"], "extensions", "code"
@@ -262,13 +272,9 @@ def get_error_status_code(error: dict) -> int:
262272
else:
263273
raise exceptions.InvalidQueryError(message)
264274

265-
graphql_error = check_errors(
266-
["GRAPHQL_PARSE_FAILED"], "extensions", "code"
267-
)
275+
graphql_error = check_errors(["GRAPHQL_PARSE_FAILED"], "extensions", "code")
268276
if graphql_error is not None:
269-
raise exceptions.InvalidQueryError(
270-
graphql_error["message"]
271-
)
277+
raise exceptions.InvalidQueryError(graphql_error["message"])
272278

273279
# Check if API limit was exceeded
274280
response_msg = r_json.get("message", "")
@@ -293,9 +299,7 @@ def get_error_status_code(error: dict) -> int:
293299
["RESOURCE_CONFLICT"], "extensions", "code"
294300
)
295301
if resource_conflict_error is not None:
296-
raise exceptions.ResourceConflict(
297-
resource_conflict_error["message"]
298-
)
302+
raise exceptions.ResourceConflict(resource_conflict_error["message"])
299303

300304
malformed_request_error = check_errors(
301305
["MALFORMED_REQUEST"], "extensions", "code"
@@ -311,7 +315,13 @@ def get_error_status_code(error: dict) -> int:
311315
internal_server_error = check_errors(
312316
["INTERNAL_SERVER_ERROR"], "extensions", "code"
313317
)
318+
error_code = "INTERNAL_SERVER_ERROR"
319+
314320
if internal_server_error is not None:
321+
if error_handlers and error_code in error_handlers:
322+
handler = error_handlers[error_code]
323+
handler(response)
324+
return None
315325
message = internal_server_error.get("message")
316326
error_status_code = get_error_status_code(internal_server_error)
317327
if error_status_code == 400:
@@ -343,9 +353,7 @@ def get_error_status_code(error: dict) -> int:
343353
errors,
344354
)
345355
)
346-
raise exceptions.LabelboxError(
347-
"Unknown error: %s" % str(messages)
348-
)
356+
raise exceptions.LabelboxError("Unknown error: %s" % str(messages))
349357

350358
# if we do return a proper error code, and didn't catch this above
351359
# reraise
Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,46 @@
1+
from unittest.mock import MagicMock
2+
13
from lbox.request_client import RequestClient
24

35

46
# @patch.dict(os.environ, {'LABELBOX_API_KEY': 'bar'})
57
def test_headers():
6-
client = RequestClient(sdk_version="foo", api_key="api_key", endpoint="http://localhost:8080/_gql")
8+
client = RequestClient(
9+
sdk_version="foo", api_key="api_key", endpoint="http://localhost:8080/_gql"
10+
)
711
assert client.headers
812
assert client.headers["Authorization"] == "Bearer api_key"
913
assert client.headers["Content-Type"] == "application/json"
1014
assert client.headers["User-Agent"]
1115
assert client.headers["X-Python-Version"]
16+
17+
18+
def test_custom_error_handling():
19+
mock_raise_error = MagicMock()
20+
21+
response_dict = {
22+
"errors": [
23+
{
24+
"message": "Internal server error",
25+
"extensions": {"code": "INTERNAL_SERVER_ERROR"},
26+
}
27+
],
28+
}
29+
response = MagicMock()
30+
response.json.return_value = response_dict
31+
response.status_code = 200
32+
33+
client = RequestClient(
34+
sdk_version="foo", api_key="api_key", endpoint="http://localhost:8080/_gql"
35+
)
36+
connection_mock = MagicMock()
37+
connection_mock.send.return_value = response
38+
client._connection = connection_mock
39+
40+
client.execute(
41+
"query_str",
42+
{"projectId": "project_id"},
43+
raise_return_resource_not_found=True,
44+
error_handlers={"INTERNAL_SERVER_ERROR": mock_raise_error},
45+
)
46+
mock_raise_error.assert_called_once_with(response)

0 commit comments

Comments
 (0)