|
9 | 9 | import urllib.parse
|
10 | 10 | from collections import defaultdict
|
11 | 11 | from datetime import datetime, timezone
|
12 |
| -from typing import Any, List, Dict, Union, Optional, overload |
| 12 | +from typing import Any, List, Dict, Union, Optional, overload, Callable |
13 | 13 |
|
14 | 14 | import requests
|
15 | 15 | import requests.exceptions
|
@@ -138,15 +138,19 @@ def _default_headers(self):
|
138 | 138 | @retry.Retry(predicate=retry.if_exception_type(
|
139 | 139 | labelbox.exceptions.InternalServerError,
|
140 | 140 | labelbox.exceptions.TimeoutError))
|
141 |
| - def execute(self, |
142 |
| - query=None, |
143 |
| - params=None, |
144 |
| - data=None, |
145 |
| - files=None, |
146 |
| - timeout=60.0, |
147 |
| - experimental=False, |
148 |
| - error_log_key="message", |
149 |
| - raise_return_resource_not_found=False): |
| 141 | + def execute( |
| 142 | + self, |
| 143 | + query=None, |
| 144 | + params=None, |
| 145 | + data=None, |
| 146 | + files=None, |
| 147 | + timeout=60.0, |
| 148 | + experimental=False, |
| 149 | + error_log_key="message", |
| 150 | + raise_return_resource_not_found=False, |
| 151 | + error_handlers: Optional[Dict[str, Callable[[Dict[str, Any]], |
| 152 | + None]]] = None |
| 153 | + ) -> Dict[str, Any]: |
150 | 154 | """ Sends a request to the server for the execution of the
|
151 | 155 | given query.
|
152 | 156 |
|
@@ -323,7 +327,12 @@ def get_error_status_code(error: dict) -> int:
|
323 | 327 | # TODO: fix this in the server API
|
324 | 328 | internal_server_error = check_errors(["INTERNAL_SERVER_ERROR"],
|
325 | 329 | "extensions", "code")
|
| 330 | + error_code = "INTERNAL_SERVER_ERROR" |
| 331 | + |
326 | 332 | if internal_server_error is not None:
|
| 333 | + if error_handlers and error_code in error_handlers: |
| 334 | + handler = error_handlers[error_code] |
| 335 | + handler(response) |
327 | 336 | message = internal_server_error.get("message")
|
328 | 337 | error_status_code = get_error_status_code(internal_server_error)
|
329 | 338 | if error_status_code == 400:
|
|
0 commit comments