4
4
import sys
5
5
from datetime import datetime , timezone
6
6
from types import MappingProxyType
7
+ from typing import Callable , Dict , Optional
7
8
8
9
import requests
9
10
import requests .exceptions
10
11
from google .api_core import retry
11
- from lbox import exceptions
12
+ from lbox import exceptions # type: ignore
12
13
13
14
logger = logging .getLogger (__name__ )
14
15
@@ -52,9 +53,7 @@ def __init__(
52
53
"""
53
54
if api_key is None :
54
55
if _LABELBOX_API_KEY not in os .environ :
55
- raise exceptions .AuthenticationError (
56
- "Labelbox API key not provided"
57
- )
56
+ raise exceptions .AuthenticationError ("Labelbox API key not provided" )
58
57
api_key = os .environ [_LABELBOX_API_KEY ]
59
58
self .api_key = api_key
60
59
@@ -70,9 +69,7 @@ def __init__(
70
69
self ._connection : requests .Session = self ._init_connection ()
71
70
72
71
def _init_connection (self ) -> requests .Session :
73
- connection = (
74
- requests .Session ()
75
- ) # using default connection pool size of 10
72
+ connection = requests .Session () # using default connection pool size of 10
76
73
connection .headers .update (self ._default_headers ())
77
74
78
75
return connection
@@ -106,6 +103,9 @@ def execute(
106
103
experimental = False ,
107
104
error_log_key = "message" ,
108
105
raise_return_resource_not_found = False ,
106
+ error_handlers : Optional [
107
+ Dict [str , Callable [[requests .models .Response ], None ]]
108
+ ] = None ,
109
109
):
110
110
"""Sends a request to the server for the execution of the
111
111
given query.
@@ -120,6 +120,27 @@ def execute(
120
120
files (dict): file arguments for request
121
121
timeout (float): Max allowed time for query execution,
122
122
in seconds.
123
+ raise_return_resource_not_found: By default the client relies on the caller to raise the correct exception when a resource is not found.
124
+ If this is set to True, the client will raise a ResourceNotFoundError exception automatically.
125
+ This simplifies processing.
126
+ We recommend to use it only of api returns a clear and well-formed error when a resource not found for a given query.
127
+ error_handlers (dict): A dictionary mapping graphql error code to handler functions.
128
+ Allows a caller to handle specific errors reporting in a custom way or produce more user-friendly readable messages.
129
+
130
+ Example - custom error handler:
131
+ >>> def _raise_readable_errors(self, response):
132
+ >>> errors = response.json().get('errors', [])
133
+ >>> if errors:
134
+ >>> message = errors[0].get(
135
+ >>> 'message', json.dumps([{
136
+ >>> "errorMessage": "Unknown error"
137
+ >>> }]))
138
+ >>> errors = json.loads(message)
139
+ >>> error_messages = [error['errorMessage'] for error in errors]
140
+ >>> else:
141
+ >>> error_messages = ["Uknown error"]
142
+ >>> raise LabelboxError(". ".join(error_messages))
143
+
123
144
Returns:
124
145
dict, parsed JSON response.
125
146
Raises:
@@ -149,12 +170,8 @@ def convert_value(value):
149
170
150
171
if query is not None :
151
172
if params is not None :
152
- params = {
153
- key : convert_value (value ) for key , value in params .items ()
154
- }
155
- data = json .dumps ({"query" : query , "variables" : params }).encode (
156
- "utf-8"
157
- )
173
+ params = {key : convert_value (value ) for key , value in params .items ()}
174
+ data = json .dumps ({"query" : query , "variables" : params }).encode ("utf-8" )
158
175
elif data is None :
159
176
raise ValueError ("query and data cannot both be none" )
160
177
@@ -207,9 +224,7 @@ def convert_value(value):
207
224
"upstream connect error or disconnect/reset before headers"
208
225
in response .text
209
226
):
210
- raise exceptions .InternalServerError (
211
- "Connection reset"
212
- )
227
+ raise exceptions .InternalServerError ("Connection reset" )
213
228
elif response .status_code == 502 :
214
229
error_502 = "502 Bad Gateway"
215
230
raise exceptions .InternalServerError (error_502 )
@@ -234,22 +249,17 @@ def check_errors(keywords, *path):
234
249
def get_error_status_code (error : dict ) -> int :
235
250
try :
236
251
return int (error ["extensions" ].get ("exception" ).get ("status" ))
237
- except :
252
+ except Exception :
238
253
return 500
239
254
240
- if (
241
- check_errors (["AUTHENTICATION_ERROR" ], "extensions" , "code" )
242
- is not None
243
- ):
255
+ if check_errors (["AUTHENTICATION_ERROR" ], "extensions" , "code" ) is not None :
244
256
raise exceptions .AuthenticationError ("Invalid API key" )
245
257
246
258
authorization_error = check_errors (
247
259
["AUTHORIZATION_ERROR" ], "extensions" , "code"
248
260
)
249
261
if authorization_error is not None :
250
- raise exceptions .AuthorizationError (
251
- authorization_error ["message" ]
252
- )
262
+ raise exceptions .AuthorizationError (authorization_error ["message" ])
253
263
254
264
validation_error = check_errors (
255
265
["GRAPHQL_VALIDATION_FAILED" ], "extensions" , "code"
@@ -262,13 +272,9 @@ def get_error_status_code(error: dict) -> int:
262
272
else :
263
273
raise exceptions .InvalidQueryError (message )
264
274
265
- graphql_error = check_errors (
266
- ["GRAPHQL_PARSE_FAILED" ], "extensions" , "code"
267
- )
275
+ graphql_error = check_errors (["GRAPHQL_PARSE_FAILED" ], "extensions" , "code" )
268
276
if graphql_error is not None :
269
- raise exceptions .InvalidQueryError (
270
- graphql_error ["message" ]
271
- )
277
+ raise exceptions .InvalidQueryError (graphql_error ["message" ])
272
278
273
279
# Check if API limit was exceeded
274
280
response_msg = r_json .get ("message" , "" )
@@ -293,9 +299,7 @@ def get_error_status_code(error: dict) -> int:
293
299
["RESOURCE_CONFLICT" ], "extensions" , "code"
294
300
)
295
301
if resource_conflict_error is not None :
296
- raise exceptions .ResourceConflict (
297
- resource_conflict_error ["message" ]
298
- )
302
+ raise exceptions .ResourceConflict (resource_conflict_error ["message" ])
299
303
300
304
malformed_request_error = check_errors (
301
305
["MALFORMED_REQUEST" ], "extensions" , "code"
@@ -311,7 +315,13 @@ def get_error_status_code(error: dict) -> int:
311
315
internal_server_error = check_errors (
312
316
["INTERNAL_SERVER_ERROR" ], "extensions" , "code"
313
317
)
318
+ error_code = "INTERNAL_SERVER_ERROR"
319
+
314
320
if internal_server_error is not None :
321
+ if error_handlers and error_code in error_handlers :
322
+ handler = error_handlers [error_code ]
323
+ handler (response )
324
+ return None
315
325
message = internal_server_error .get ("message" )
316
326
error_status_code = get_error_status_code (internal_server_error )
317
327
if error_status_code == 400 :
@@ -343,9 +353,7 @@ def get_error_status_code(error: dict) -> int:
343
353
errors ,
344
354
)
345
355
)
346
- raise exceptions .LabelboxError (
347
- "Unknown error: %s" % str (messages )
348
- )
356
+ raise exceptions .LabelboxError ("Unknown error: %s" % str (messages ))
349
357
350
358
# if we do return a proper error code, and didn't catch this above
351
359
# reraise
0 commit comments