|
1 | 1 | import base64
|
| 2 | +import json |
2 | 3 | import logging
|
3 | 4 | import urllib.parse
|
4 |
| -from typing import Dict, Union, Optional |
| 5 | +from typing import Dict, Union, Optional, Any |
5 | 6 |
|
6 | 7 | import six
|
7 |
| -import thrift |
| 8 | +import thrift.transport.THttpClient |
8 | 9 |
|
9 | 10 | import ssl
|
10 | 11 | import warnings
|
11 | 12 | from http.client import HTTPResponse
|
12 | 13 | from io import BytesIO
|
13 | 14 |
|
| 15 | +import urllib3 |
14 | 16 | from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager
|
15 | 17 | from urllib3.util import make_headers
|
16 | 18 | from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy
|
@@ -222,3 +224,151 @@ def set_retry_command_type(self, value: CommandType):
|
222 | 224 | logger.warning(
|
223 | 225 | "DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set."
|
224 | 226 | )
|
| 227 | + |
| 228 | + def make_rest_request( |
| 229 | + self, |
| 230 | + method: str, |
| 231 | + endpoint_path: str, |
| 232 | + data: Optional[Dict[str, Any]] = None, |
| 233 | + params: Optional[Dict[str, Any]] = None, |
| 234 | + headers: Optional[Dict[str, str]] = None, |
| 235 | + ) -> Dict[str, Any]: |
| 236 | + """ |
| 237 | + Make a REST API request using the existing connection pool. |
| 238 | +
|
| 239 | + Args: |
| 240 | + method (str): HTTP method (GET, POST, DELETE, etc.) |
| 241 | + endpoint_path (str): API endpoint path (e.g., "sessions" or "statements/123") |
| 242 | + data (dict, optional): Request payload data |
| 243 | + params (dict, optional): Query parameters |
| 244 | + headers (dict, optional): Additional headers |
| 245 | +
|
| 246 | + Returns: |
| 247 | + dict: Response data parsed from JSON |
| 248 | +
|
| 249 | + Raises: |
| 250 | + RequestError: If the request fails |
| 251 | + """ |
| 252 | + # Ensure the transport is open |
| 253 | + if not self.isOpen(): |
| 254 | + self.open() |
| 255 | + |
| 256 | + # Prepare headers |
| 257 | + request_headers = { |
| 258 | + "Content-Type": "application/json", |
| 259 | + } |
| 260 | + |
| 261 | + # Add authentication headers |
| 262 | + auth_headers: Dict[str, str] = {} |
| 263 | + self.__auth_provider.add_headers(auth_headers) |
| 264 | + request_headers.update(auth_headers) |
| 265 | + |
| 266 | + # Add custom headers if provided |
| 267 | + if headers: |
| 268 | + request_headers.update(headers) |
| 269 | + |
| 270 | + # Prepare request body |
| 271 | + body = json.dumps(data).encode("utf-8") if data else None |
| 272 | + |
| 273 | + # Build query string for params |
| 274 | + query_string = "" |
| 275 | + if params: |
| 276 | + query_string = "?" + urllib.parse.urlencode(params) |
| 277 | + |
| 278 | + # Determine full path |
| 279 | + full_path = ( |
| 280 | + self.path.rstrip("/") + "/" + endpoint_path.lstrip("/") + query_string |
| 281 | + ) |
| 282 | + |
| 283 | + # Log request details (debug level) |
| 284 | + logger.debug(f"Making {method} request to {full_path}") |
| 285 | + |
| 286 | + try: |
| 287 | + # Make request using the connection pool |
| 288 | + self.__resp = self.__pool.request( |
| 289 | + method, |
| 290 | + url=full_path, |
| 291 | + body=body, |
| 292 | + headers=request_headers, |
| 293 | + preload_content=False, |
| 294 | + timeout=self.__timeout, |
| 295 | + retries=self.retry_policy, |
| 296 | + ) |
| 297 | + |
| 298 | + # Store response status and headers |
| 299 | + if self.__resp is not None: |
| 300 | + self.code = self.__resp.status |
| 301 | + self.message = self.__resp.reason |
| 302 | + self.headers = self.__resp.headers |
| 303 | + |
| 304 | + # Log response status |
| 305 | + logger.debug(f"Response status: {self.code}, message: {self.message}") |
| 306 | + |
| 307 | + # Read and parse response data |
| 308 | + # Note: urllib3's HTTPResponse has a data attribute, but it's not in the type stubs |
| 309 | + response_data = getattr(self.__resp, "data", None) |
| 310 | + |
| 311 | + # Check for HTTP errors |
| 312 | + self._check_rest_response_for_error(self.code, response_data) |
| 313 | + |
| 314 | + # Parse JSON response if there is content |
| 315 | + if response_data: |
| 316 | + result = json.loads(response_data.decode("utf-8")) |
| 317 | + |
| 318 | + # Log response content (truncated for large responses) |
| 319 | + content_str = json.dumps(result) |
| 320 | + if len(content_str) > 1000: |
| 321 | + logger.debug( |
| 322 | + f"Response content (truncated): {content_str[:1000]}..." |
| 323 | + ) |
| 324 | + else: |
| 325 | + logger.debug(f"Response content: {content_str}") |
| 326 | + |
| 327 | + return result |
| 328 | + |
| 329 | + return {} |
| 330 | + else: |
| 331 | + raise ValueError("No response received from server") |
| 332 | + |
| 333 | + except urllib3.exceptions.HTTPError as e: |
| 334 | + error_message = f"REST HTTP request failed: {str(e)}" |
| 335 | + logger.error(error_message) |
| 336 | + from databricks.sql.exc import RequestError |
| 337 | + |
| 338 | + raise RequestError(error_message, e) |
| 339 | + |
| 340 | + def _check_rest_response_for_error( |
| 341 | + self, status_code: int, response_data: Optional[bytes] |
| 342 | + ) -> None: |
| 343 | + """ |
| 344 | + Check if the REST response indicates an error and raise an appropriate exception. |
| 345 | +
|
| 346 | + Args: |
| 347 | + status_code: HTTP status code |
| 348 | + response_data: Raw response data |
| 349 | +
|
| 350 | + Raises: |
| 351 | + RequestError: If the response indicates an error |
| 352 | + """ |
| 353 | + if status_code >= 400: |
| 354 | + error_message = f"REST HTTP request failed with status {status_code}" |
| 355 | + |
| 356 | + # Try to extract error details from JSON response |
| 357 | + if response_data: |
| 358 | + try: |
| 359 | + error_details = json.loads(response_data.decode("utf-8")) |
| 360 | + if isinstance(error_details, dict) and "message" in error_details: |
| 361 | + error_message = f"{error_message}: {error_details['message']}" |
| 362 | + logger.error( |
| 363 | + f"Request failed (status {status_code}): {error_details}" |
| 364 | + ) |
| 365 | + except (ValueError, KeyError): |
| 366 | + # If we can't parse JSON, log raw content |
| 367 | + content = response_data.decode("utf-8", errors="replace") |
| 368 | + logger.error(f"Request failed (status {status_code}): {content}") |
| 369 | + else: |
| 370 | + logger.error(f"Request failed (status {status_code}): No response data") |
| 371 | + |
| 372 | + from databricks.sql.exc import RequestError |
| 373 | + |
| 374 | + raise RequestError(error_message) |
0 commit comments