diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index 432ac687d..4d7e78489 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -127,7 +127,7 @@ def __init__( total=_attempts_remaining, respect_retry_after_header=True, backoff_factor=self.delay_min, - allowed_methods=["POST"], + allowed_methods=["POST", "DELETE", "GET"], status_forcelist=[429, 503, *self.force_dangerous_codes], ) diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index f0daae162..d2a62d73f 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -1,16 +1,18 @@ import base64 +import json import logging import urllib.parse -from typing import Dict, Union, Optional +from typing import Dict, Union, Optional, Any import six -import thrift +import thrift.transport.THttpClient import ssl import warnings from http.client import HTTPResponse from io import BytesIO +import urllib3 from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager from urllib3.util import make_headers from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy @@ -222,3 +224,105 @@ def set_retry_command_type(self, value: CommandType): logger.warning( "DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set." ) + + def make_rest_request( + self, + method: str, + endpoint_path: str, + data: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """ + Make a REST API request using the existing connection pool. + + Args: + method (str): HTTP method (GET, POST, DELETE, etc.) + endpoint_path (str): API endpoint path (e.g., "sessions" or "statements/123") + data (dict, optional): Request payload data + params (dict, optional): Query parameters + headers (dict, optional): Additional headers + + Returns: + dict: Response data parsed from JSON + + Raises: + RequestError: If the request fails + """ + # Ensure the transport is open + if not self.isOpen(): + self.open() + + # Prepare headers + request_headers = { + "Content-Type": "application/json", + } + + # Add authentication headers + auth_headers: Dict[str, str] = {} + self.__auth_provider.add_headers(auth_headers) + request_headers.update(auth_headers) + + # Add custom headers if provided + if headers: + request_headers.update(headers) + + # Prepare request body + body = json.dumps(data).encode("utf-8") if data else None + + # Build query string for params + query_string = "" + if params: + query_string = "?" + urllib.parse.urlencode(params) + + # Determine full path + full_path = ( + self.path.rstrip("/") + "/" + endpoint_path.lstrip("/") + query_string + ) + + # Log request details (debug level) + logger.debug(f"Making {method} request to {full_path}") + + # Make request using the connection pool - let urllib3 exceptions propagate + logger.debug(f"making request to {full_path}") + logger.debug(f"\trequest headers: {request_headers}") + logger.debug(f"\trequest body: {body.decode('utf-8') if body else None}") + logger.debug(f"\trequest params: {params}") + logger.debug(f"\trequest full path: {full_path}") + self.__resp = self.__pool.request( + method, + url=full_path, + body=body, + headers=request_headers, + preload_content=False, + timeout=self.__timeout, + retries=self.retry_policy, + ) + logger.debug(f"Response: {self.__resp}") + + # Store response status and headers + if self.__resp is not None: + self.code = self.__resp.status + self.message = self.__resp.reason + self.headers = self.__resp.headers + + # Log response status + logger.debug(f"Response status: {self.code}, message: {self.message}") + + # Read and parse response data + # Note: urllib3's HTTPResponse has a data attribute, but it's not in the type stubs + response_data = getattr(self.__resp, "data", None) + + # Parse JSON response if there is content + if response_data: + result = json.loads(response_data.decode("utf-8")) + + # Log response content (truncated for large responses) + content_str = json.dumps(result) + logger.debug(f"Response content: {content_str}") + + return result + + return {} + else: + raise ValueError("No response received from server") diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index 468fb4d4c..b578ba30c 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -13,10 +13,12 @@ cast, ) -from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.backend.types import ExecuteResponse, CommandId +from databricks.sql.backend.sea.models.base import ResultData -from databricks.sql.result_set import ResultSet, SeaResultSet +if TYPE_CHECKING: + from databricks.sql.backend.sea.backend import SeaDatabricksClient + from databricks.sql.result_set import ResultSet, SeaResultSet logger = logging.getLogger(__name__) @@ -70,6 +72,8 @@ def _filter_sea_result_set( from databricks.sql.result_set import SeaResultSet # Create a new SeaResultSet with the filtered data + from databricks.sql.backend.sea.backend import SeaDatabricksClient + filtered_result_set = SeaResultSet( connection=result_set.connection, execute_response=execute_response, diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index cc188f917..90fd13bc4 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,8 +1,16 @@ +import errno import logging +import math +import threading +import uuid import time import re from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set +import urllib3 + +import databricks +from databricks.sql.auth.retry import CommandType from databricks.sql.backend.sea.models.base import ExternalLink, ResultManifest from databricks.sql.backend.sea.utils.constants import ( ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, @@ -13,6 +21,17 @@ MetadataCommands, ) +from databricks.sql.backend.thrift_backend import ( + DATABRICKS_ERROR_OR_REDIRECT_HEADER, + DATABRICKS_REASON_HEADER, + THRIFT_ERROR_MESSAGE_HEADER, + DEFAULT_SOCKET_TIMEOUT, +) +from databricks.sql.utils import NoRetryReason, RequestErrorInfo, _bound +from databricks.sql.thrift_api.TCLIService.TCLIService import ( + Client as TCLIServiceClient, +) + if TYPE_CHECKING: from databricks.sql.client import Cursor from databricks.sql.result_set import ResultSet @@ -25,8 +44,10 @@ BackendType, ExecuteResponse, ) -from databricks.sql.exc import DatabaseError, ServerOperationError -from databricks.sql.backend.sea.utils.http_client import SeaHttpClient +from databricks.sql.exc import DatabaseError, RequestError, ServerOperationError +from databricks.sql.auth.thrift_http_client import THttpClient +from databricks.sql.backend.sea.utils.http_client_adapter import SeaHttpClientAdapter +from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( @@ -45,6 +66,27 @@ logger = logging.getLogger(__name__) +unsafe_logger = logging.getLogger("databricks.sql.unsafe") +unsafe_logger.setLevel(logging.DEBUG) + +# To capture these logs in client code, add a non-NullHandler. +# See our e2e test suite for an example with logging.FileHandler +unsafe_logger.addHandler(logging.NullHandler()) + +# Disable propagation so that handlers for `databricks.sql` don't pick up these messages +unsafe_logger.propagate = False + +# see Connection.__init__ for parameter descriptions. +# - Min/Max avoids unsustainable configs (sane values are far more constrained) +# - 900s attempts-duration lines up w ODBC/JDBC drivers (for cluster startup > 10 mins) +_retry_policy = { # (type, default, min, max) + "_retry_delay_min": (float, 1, 0.1, 60), + "_retry_delay_max": (float, 60, 5, 3600), + "_retry_stop_after_attempts_count": (int, 30, 1, 60), + "_retry_stop_after_attempts_duration": (float, 900, 1, 86400), + "_retry_delay_default": (float, 5, 1, 60), +} + def _filter_session_configuration( session_configuration: Optional[Dict[str, str]] @@ -88,8 +130,11 @@ class SeaDatabricksClient(DatabricksClient): CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel" CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}" - # SEA constants - POLL_INTERVAL_SECONDS = 0.2 + _retry_delay_min: float + _retry_delay_max: float + _retry_stop_after_attempts_count: int + _retry_stop_after_attempts_duration: float + _retry_delay_default: float def __init__( self, @@ -126,17 +171,337 @@ def __init__( # Extract warehouse ID from http_path self.warehouse_id = self._extract_warehouse_id(http_path) - # Initialize HTTP client - self.http_client = SeaHttpClient( - server_hostname=server_hostname, - port=port, - http_path=http_path, - http_headers=http_headers, - auth_provider=auth_provider, + self._ssl_options = ssl_options + + # Extract retry policy parameters + self._initialize_retry_args(kwargs) + self._auth_provider = auth_provider + + self.enable_v3_retries = kwargs.get("_enable_v3_retries", True) + if not self.enable_v3_retries: + logger.warning( + "Legacy retry behavior is enabled for this connection." + " This behaviour is deprecated and will be removed in a future release." + ) + + self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", []) + + additional_transport_args = {} + _max_redirects: Union[None, int] = kwargs.get("_retry_max_redirects") + if _max_redirects: + if _max_redirects > self._retry_stop_after_attempts_count: + logger.warn( + "_retry_max_redirects > _retry_stop_after_attempts_count so it will have no affect!" + ) + urllib3_kwargs = {"redirect": _max_redirects} + else: + urllib3_kwargs = {} + + if self.enable_v3_retries: + self.retry_policy = databricks.sql.auth.thrift_http_client.DatabricksRetryPolicy( + delay_min=self._retry_delay_min, + delay_max=self._retry_delay_max, + stop_after_attempts_count=self._retry_stop_after_attempts_count, + stop_after_attempts_duration=self._retry_stop_after_attempts_duration, + delay_default=self._retry_delay_default, + force_dangerous_codes=self.force_dangerous_codes, + urllib3_kwargs=urllib3_kwargs, + ) + + additional_transport_args["retry_policy"] = self.retry_policy + + # Initialize ThriftHttpClient + self._transport = databricks.sql.auth.thrift_http_client.THttpClient( + auth_provider=self._auth_provider, + uri_or_host=f"https://{server_hostname}:{port}", ssl_options=self._ssl_options, - **kwargs, + **additional_transport_args, # type: ignore ) + timeout = kwargs.get("_socket_timeout", DEFAULT_SOCKET_TIMEOUT) + # setTimeout defaults to 15 minutes and is expected in ms + self._transport.setTimeout(timeout and (float(timeout) * 1000.0)) + + self._transport.setCustomHeaders(dict(http_headers)) + + # protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport) + # self._client = TCLIServiceClient(protocol) + + try: + self._transport.open() + except: + self._transport.close() + raise + + self._request_lock = threading.RLock() + + # Initialize HTTP client adapter + self.http_client = SeaHttpClientAdapter(thrift_client=self._transport) + + # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918) + def _initialize_retry_args(self, kwargs): + # Configure retries & timing: use user-settings or defaults, and bound + # by policy. Log.warn when given param gets restricted. + for key, (type_, default, min, max) in _retry_policy.items(): + given_or_default = type_(kwargs.get(key, default)) + bound = _bound(min, max, given_or_default) + setattr(self, key, bound) + logger.debug( + "retry parameter: {} given_or_default {}".format(key, given_or_default) + ) + if bound != given_or_default: + logger.warning( + "Override out of policy retry parameter: " + + "{} given {}, restricted to {}".format( + key, given_or_default, bound + ) + ) + + # Fail on retry delay min > max; consider later adding fail on min > duration? + if ( + self._retry_stop_after_attempts_count > 1 + and self._retry_delay_min > self._retry_delay_max + ): + raise ValueError( + "Invalid configuration enables retries with retry delay min(={}) > max(={})".format( + self._retry_delay_min, self._retry_delay_max + ) + ) + + @staticmethod + def _check_response_for_error(response): + pass # TODO: implement + + @staticmethod + def _extract_error_message_from_headers(headers): + err_msg = "" + if THRIFT_ERROR_MESSAGE_HEADER in headers: + err_msg = headers[THRIFT_ERROR_MESSAGE_HEADER] + if DATABRICKS_ERROR_OR_REDIRECT_HEADER in headers: + if ( + err_msg + ): # We don't expect both to be set, but log both here just in case + err_msg = "Thriftserver error: {}, Databricks error: {}".format( + err_msg, headers[DATABRICKS_ERROR_OR_REDIRECT_HEADER] + ) + else: + err_msg = headers[DATABRICKS_ERROR_OR_REDIRECT_HEADER] + if DATABRICKS_REASON_HEADER in headers: + err_msg += ": " + headers[DATABRICKS_REASON_HEADER] + + if not err_msg: + # if authentication token is invalid we need this branch + if DATABRICKS_REASON_HEADER in headers: + err_msg += ": " + headers[DATABRICKS_REASON_HEADER] + + return err_msg + + def _handle_request_error(self, error_info, attempt, elapsed): + max_attempts = self._retry_stop_after_attempts_count + max_duration_s = self._retry_stop_after_attempts_duration + + if ( + error_info.retry_delay is not None + and elapsed + error_info.retry_delay > max_duration_s + ): + no_retry_reason = NoRetryReason.OUT_OF_TIME + elif error_info.retry_delay is not None and attempt >= max_attempts: + no_retry_reason = NoRetryReason.OUT_OF_ATTEMPTS + elif error_info.retry_delay is None: + no_retry_reason = NoRetryReason.NOT_RETRYABLE + else: + no_retry_reason = None + + full_error_info_context = error_info.full_info_logging_context( + no_retry_reason, attempt, max_attempts, elapsed, max_duration_s + ) + + if no_retry_reason is not None: + user_friendly_error_message = error_info.user_friendly_error_message( + no_retry_reason, attempt, elapsed + ) + logger.info(f"User friendly error message: {user_friendly_error_message}") + network_request_error = RequestError( + user_friendly_error_message, full_error_info_context, error_info.error + ) + logger.info(network_request_error.message_with_context()) + + raise network_request_error + + logger.info( + "Retrying request after error in {} seconds: {}".format( + error_info.retry_delay, full_error_info_context + ) + ) + time.sleep(error_info.retry_delay) + + # FUTURE: Consider moving to https://github.com/litl/backoff or + # https://github.com/jd/tenacity for retry logic. + def make_request(self, method_name, path, data, params, headers, retryable=True): + """Execute given request, attempting retries when + 1. Receiving HTTP 429/503 from server + 2. OSError is raised during a GetOperationStatus + + For delay between attempts, honor the given Retry-After header, but with bounds. + Use lower bound of expontial-backoff based on _retry_delay_min, + and upper bound of _retry_delay_max. + Will stop retry attempts if total elapsed time + next retry delay would exceed + _retry_stop_after_attempts_duration. + """ + + # basic strategy: build range iterator rep'ing number of available + # retries. bounds can be computed from there. iterate over it with + # retries until success or final failure achieved. + + t0 = time.time() + + def get_elapsed(): + return time.time() - t0 + + def bound_retry_delay(attempt, proposed_delay): + """bound delay (seconds) by [min_delay*1.5^(attempt-1), max_delay]""" + delay = int(proposed_delay) + delay = max(delay, self._retry_delay_min * math.pow(1.5, attempt - 1)) + delay = min(delay, self._retry_delay_max) + return delay + + def extract_retry_delay(attempt): + # encapsulate retry checks, returns None || delay-in-secs + # Retry IFF 429/503 code + Retry-After header set + http_code = getattr(self._transport, "code", None) + retry_after = getattr(self._transport, "headers", {}).get("Retry-After", 1) + if http_code in [429, 503]: + # bound delay (seconds) by [min_delay*1.5^(attempt-1), max_delay] + return bound_retry_delay(attempt, int(retry_after)) + return None + + def attempt_request(attempt): + # splits out lockable attempt, from delay & retry loop + # returns tuple: (method_return, delay_fn(), error, error_message) + # - non-None method_return -> success, return and be done + # - non-None retry_delay -> sleep delay before retry + # - error, error_message always set when available + + error, error_message, retry_delay = None, None, None + try: + logger.debug("Sending request: {}()".format(method_name)) + unsafe_logger.debug("Sending request: {}".format(path)) + + # These three lines are no-ops if the v3 retry policy is not in use + if self.enable_v3_retries: + command_type = self.http_client._determine_command_type( + path, method_name, data + ) + self.http_client.thrift_client.set_retry_command_type(command_type) + self.http_client.thrift_client.startRetryTimer() + + if method_name == "GET": + response = self.http_client.get(path, params, headers) + elif method_name == "POST": + response = self.http_client.post(path, data, params, headers) + elif method_name == "DELETE": + response = self.http_client.delete(path, data, params, headers) + else: + raise ValueError(f"Unsupported method: {method_name}") + + return response + + except urllib3.exceptions.HTTPError as err: + # retry on timeout. Happens a lot in Azure and it is safe as data has not been sent to server yet + + # TODO: don't use exception handling for GOS polling... + + logger.error("ThriftBackend.attempt_request: HTTPError: %s", err) + + if command_type == CommandType.GET_OPERATION_STATUS: + delay_default = ( + self.enable_v3_retries + and self.retry_policy.delay_default + or self._retry_delay_default + ) + retry_delay = bound_retry_delay(attempt, delay_default) + logger.info( + f"GetOperationStatus failed with HTTP error and will be retried: {str(err)}" + ) + else: + raise err + except OSError as err: + error = err + error_message = str(err) + # fmt: off + # The built-in errno package encapsulates OSError codes, which are OS-specific. + # log.info for errors we believe are not unusual or unexpected. log.warn for + # for others like EEXIST, EBADF, ERANGE which are not expected in this context. + # + # I manually tested this retry behaviour using mitmweb and confirmed that + # GetOperationStatus requests are retried when I forced network connection + # interruptions / timeouts / reconnects. See #24 for more info. + # | Debian | Darwin | + info_errs = [ # |--------|--------| + errno.ESHUTDOWN, # | 32 | 32 | + errno.EAFNOSUPPORT, # | 97 | 47 | + errno.ECONNRESET, # | 104 | 54 | + errno.ETIMEDOUT, # | 110 | 60 | + ] + + # retry on timeout. Happens a lot in Azure and it is safe as data has not been sent to server yet + if command_type == CommandType.GET_OPERATION_STATUS or err.errno == errno.ETIMEDOUT: + retry_delay = bound_retry_delay(attempt, self._retry_delay_default) + + # fmt: on + log_string = f"{command_type} failed with code {err.errno} and will attempt to retry" + if err.errno in info_errs: + logger.info(log_string) + else: + logger.warning(log_string) + except Exception as err: + logger.error("ThriftBackend.attempt_request: Exception: %s", err) + error = err + retry_delay = extract_retry_delay(attempt) + error_message = SeaDatabricksClient._extract_error_message_from_headers( + getattr(self._transport, "headers", {}) + ) + finally: + # Calling `close()` here releases the active HTTP connection back to the pool + self._transport.close() + + return RequestErrorInfo( + error=error, + error_message=error_message, + retry_delay=retry_delay, + http_code=getattr(self._transport, "code", None), + method=method_name, + request=data, + ) + + # The real work: + # - for each available attempt: + # lock-and-attempt + # return on success + # if available: bounded delay and retry + # if not: raise error + max_attempts = self._retry_stop_after_attempts_count if retryable else 1 + + # use index-1 counting for logging/human consistency + for attempt in range(1, max_attempts + 1): + # We have a lock here because .cancel can be called from a separate thread. + # We do not want threads to be simultaneously sharing the Thrift Transport + # because we use its state to determine retries + with self._request_lock: + response_or_error_info = attempt_request(attempt) + elapsed = get_elapsed() + + # conditions: success, non-retry-able, no-attempts-left, no-time-left, delay+retry + if not isinstance(response_or_error_info, RequestErrorInfo): + # log nothing here, presume that main request logging covers + response = response_or_error_info + SeaDatabricksClient._check_response_for_error(response) + return response + + error_info = response_or_error_info + # The error handler will either sleep or throw an exception + self._handle_request_error(error_info, attempt, elapsed) + def _extract_warehouse_id(self, http_path: str) -> str: """ Extract the warehouse ID from the HTTP path. @@ -219,22 +584,22 @@ def open_session( schema=schema, ) - response = self.http_client._make_request( - method="POST", path=self.SESSION_PATH, data=request_data.to_dict() - ) - - session_response = CreateSessionResponse.from_dict(response) - session_id = session_response.session_id - if not session_id: - raise ServerOperationError( - "Failed to create session: No session ID returned", - { - "operation-id": None, - "diagnostic-info": None, - }, + try: + response = self.make_request( + method_name="POST", + path=self.SESSION_PATH, + data=request_data.to_dict(), + params=None, + headers=None, ) - return SessionId.from_sea_session_id(session_id) + session_response = CreateSessionResponse.from_dict(response) + session_id = session_response.session_id + + return SessionId.from_sea_session_id(session_id) + except Exception as e: + logger.error("SeaDatabricksClient.open_session: Exception: %s", e) + raise def close_session(self, session_id: SessionId) -> None: """ @@ -259,11 +624,17 @@ def close_session(self, session_id: SessionId) -> None: session_id=sea_session_id, ) - self.http_client._make_request( - method="DELETE", - path=self.SESSION_PATH_WITH_ID.format(sea_session_id), - data=request_data.to_dict(), - ) + try: + self.make_request( + method_name="DELETE", + path=self.SESSION_PATH_WITH_ID.format(sea_session_id), + data=request_data.to_dict(), + params=None, + headers=None, + ) + except Exception as e: + logger.error("SeaDatabricksClient.close_session: Exception: %s", e) + raise @staticmethod def get_default_session_configuration_value(name: str) -> Optional[str]: @@ -325,8 +696,38 @@ def _extract_description_from_manifest( return columns if columns else None + def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink: + """ + Get links for chunks starting from the specified index. + + Args: + statement_id: The statement ID + chunk_index: The starting chunk index + + Returns: + ExternalLink: External link for the chunk + """ + + response_data = self.http_client.get( + path=self.CHUNK_PATH_WITH_ID_AND_INDEX.format(statement_id, chunk_index), + ) + response = GetChunksResponse.from_dict(response_data) + + links = response.external_links + link = next((l for l in links if l.chunk_index == chunk_index), None) + if not link: + raise ServerOperationError( + f"No link found for chunk index {chunk_index}", + { + "operation-id": statement_id, + "diagnostic-info": None, + }, + ) + + return link + def _results_message_to_execute_response( - self, response: GetStatementResponse + self, response: GetStatementResponse, command_id: CommandId ) -> ExecuteResponse: """ Convert a SEA response to an ExecuteResponse and extract result data. @@ -405,7 +806,7 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[Dict[str, Any]], + parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: @@ -439,9 +840,9 @@ def execute_command( for param in parameters: sea_parameters.append( StatementParameter( - name=param["name"], - value=param["value"], - type=param["type"] if "type" in param else None, + name=param.name, + value=param.value.stringValue, + type=param.type, ) ) @@ -470,31 +871,56 @@ def execute_command( result_compression=result_compression, ) - response_data = self.http_client._make_request( - method="POST", path=self.STATEMENT_PATH, data=request.to_dict() - ) - response = ExecuteStatementResponse.from_dict(response_data) - statement_id = response.statement_id - if not statement_id: - raise ServerOperationError( - "Failed to execute command: No statement ID returned", - { - "operation-id": None, - "diagnostic-info": None, - }, + try: + response_data = self.make_request( + method_name="POST", + path=self.STATEMENT_PATH, + data=request.to_dict(), + params=None, + headers=None, ) + response = ExecuteStatementResponse.from_dict(response_data) + statement_id = response.statement_id + if not statement_id: + raise ServerOperationError( + "Failed to execute command: No statement ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) - command_id = CommandId.from_sea_statement_id(statement_id) + command_id = CommandId.from_sea_statement_id(statement_id) - # Store the command ID in the cursor - cursor.active_command_id = command_id + # Store the command ID in the cursor + cursor.active_command_id = command_id - # If async operation, return and let the client poll for results - if async_op: - return None + # If async operation, return and let the client poll for results + if async_op: + return None - self._wait_until_command_done(response) - return self.get_execution_result(command_id, cursor) + # For synchronous operation, wait for the statement to complete + status = response.status + state = status.state + + # Keep polling until we reach a terminal state + while state in [CommandState.PENDING, CommandState.RUNNING]: + time.sleep(0.5) # add a small delay to avoid excessive API calls + state = self.get_query_state(command_id) + + if state in [CommandState.FAILED, CommandState.CLOSED]: + raise DatabaseError( + f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", + { + "operation-id": command_id.to_sea_statement_id(), + "diagnostic-info": None, + }, + ) + + return self.get_execution_result(command_id, cursor) + except Exception as e: + logger.error("SeaDatabricksClient.execute_command: Exception: %s", e) + raise def cancel_command(self, command_id: CommandId) -> None: """ @@ -513,11 +939,17 @@ def cancel_command(self, command_id: CommandId) -> None: sea_statement_id = command_id.to_sea_statement_id() request = CancelStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="POST", - path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), - ) + try: + self.make_request( + method_name="POST", + path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + params=None, + headers=None, + ) + except Exception as e: + logger.error("SeaDatabricksClient.cancel_command: Exception: %s", e) + raise def close_command(self, command_id: CommandId) -> None: """ @@ -536,11 +968,17 @@ def close_command(self, command_id: CommandId) -> None: sea_statement_id = command_id.to_sea_statement_id() request = CloseStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="DELETE", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), - ) + try: + self.make_request( + method_name="DELETE", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + params=None, + headers=None, + ) + except Exception as e: + logger.error("SeaDatabricksClient.close_command: Exception: %s", e) + raise def get_query_state(self, command_id: CommandId) -> CommandState: """ @@ -562,15 +1000,21 @@ def get_query_state(self, command_id: CommandId) -> CommandState: sea_statement_id = command_id.to_sea_statement_id() request = GetStatementRequest(statement_id=sea_statement_id) - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), - ) + try: + response_data = self.make_request( + method_name="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=None, + params=None, + headers=None, + ) - # Parse the response - response = GetStatementResponse.from_dict(response_data) - return response.status.state + # Parse the response + response = GetStatementResponse.from_dict(response_data) + return response.status.state + except Exception as e: + logger.error("SeaDatabricksClient.get_query_state: Exception: %s", e) + raise def get_execution_result( self, @@ -599,28 +1043,37 @@ def get_execution_result( # Create the request model request = GetStatementRequest(statement_id=sea_statement_id) - # Get the statement result - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), - ) - response = GetStatementResponse.from_dict(response_data) + try: + # Get the statement result + response_data = self.make_request( + method_name="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=None, + params=None, + headers=None, + ) + response = GetStatementResponse.from_dict(response_data) - # Create and return a SeaResultSet - from databricks.sql.result_set import SeaResultSet + # Create and return a SeaResultSet + from databricks.sql.result_set import SeaResultSet - execute_response = self._results_message_to_execute_response(response) + # Convert the response to an ExecuteResponse and extract result data + execute_response = self._results_message_to_execute_response( + response, command_id + ) - return SeaResultSet( - connection=cursor.connection, - execute_response=execute_response, - sea_client=self, - buffer_size_bytes=cursor.buffer_size_bytes, - arraysize=cursor.arraysize, - result_data=response.result, - manifest=response.manifest, - ) + return SeaResultSet( + connection=cursor.connection, + execute_response=execute_response, + sea_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, + result_data=response.result, + manifest=response.manifest, + ) + except Exception as e: + logger.error("SeaDatabricksClient.get_execution_result: Exception: %s", e) + raise def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink: """ @@ -632,9 +1085,12 @@ def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink: ExternalLink: External link for the chunk """ - response_data = self.http_client._make_request( - method="GET", + response_data = self.make_request( + method_name="GET", path=self.CHUNK_PATH_WITH_ID_AND_INDEX.format(statement_id, chunk_index), + data=None, + params=None, + headers=None, ) response = GetChunksResponse.from_dict(response_data) @@ -721,9 +1177,6 @@ def get_tables( table_types: Optional[List[str]] = None, ) -> "ResultSet": """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_tables") - operation = ( MetadataCommands.SHOW_TABLES_ALL_CATALOGS.value if catalog_name in [None, "*", "%"] diff --git a/src/databricks/sql/backend/sea/utils/__init__.py b/src/databricks/sql/backend/sea/utils/__init__.py new file mode 100644 index 000000000..6e83dfe25 --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/__init__.py @@ -0,0 +1,21 @@ +""" +Utility modules for the Statement Execution API (SEA) backend. +""" + +from databricks.sql.backend.sea.utils.http_client_adapter import SeaHttpClientAdapter +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, + ResultFormat, + ResultDisposition, + ResultCompression, + WaitTimeout, +) + +__all__ = [ + "SeaHttpClientAdapter", + "ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP", + "ResultFormat", + "ResultDisposition", + "ResultCompression", + "WaitTimeout", +] diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py deleted file mode 100644 index fe292919c..000000000 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ /dev/null @@ -1,186 +0,0 @@ -import json -import logging -import requests -from typing import Callable, Dict, Any, Optional, List, Tuple -from urllib.parse import urljoin - -from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.types import SSLOptions - -logger = logging.getLogger(__name__) - - -class SeaHttpClient: - """ - HTTP client for Statement Execution API (SEA). - - This client handles the HTTP communication with the SEA endpoints, - including authentication, request formatting, and response parsing. - """ - - def __init__( - self, - server_hostname: str, - port: int, - http_path: str, - http_headers: List[Tuple[str, str]], - auth_provider: AuthProvider, - ssl_options: SSLOptions, - **kwargs, - ): - """ - Initialize the SEA HTTP client. - - Args: - server_hostname: Hostname of the Databricks server - port: Port number for the connection - http_path: HTTP path for the connection - http_headers: List of HTTP headers to include in requests - auth_provider: Authentication provider - ssl_options: SSL configuration options - **kwargs: Additional keyword arguments - """ - - self.server_hostname = server_hostname - self.port = port - self.http_path = http_path - self.auth_provider = auth_provider - self.ssl_options = ssl_options - - self.base_url = f"https://{server_hostname}:{port}" - - self.headers: Dict[str, str] = dict(http_headers) - self.headers.update({"Content-Type": "application/json"}) - - self.max_retries = kwargs.get("_retry_stop_after_attempts_count", 30) - - # Create a session for connection pooling - self.session = requests.Session() - - # Configure SSL verification - if ssl_options.tls_verify: - self.session.verify = ssl_options.tls_trusted_ca_file or True - else: - self.session.verify = False - - # Configure client certificates if provided - if ssl_options.tls_client_cert_file: - client_cert = ssl_options.tls_client_cert_file - client_key = ssl_options.tls_client_cert_key_file - client_key_password = ssl_options.tls_client_cert_key_password - - if client_key: - self.session.cert = (client_cert, client_key) - else: - self.session.cert = client_cert - - if client_key_password: - # Note: requests doesn't directly support key passwords - # This would require more complex handling with libraries like pyOpenSSL - logger.warning( - "Client key password provided but not supported by requests library" - ) - - def _get_auth_headers(self) -> Dict[str, str]: - """Get authentication headers from the auth provider.""" - headers: Dict[str, str] = {} - self.auth_provider.add_headers(headers) - return headers - - def _get_call(self, method: str) -> Callable: - """Get the appropriate HTTP method function.""" - method = method.upper() - if method == "GET": - return self.session.get - if method == "POST": - return self.session.post - if method == "DELETE": - return self.session.delete - raise ValueError(f"Unsupported HTTP method: {method}") - - def _make_request( - self, - method: str, - path: str, - data: Optional[Dict[str, Any]] = None, - params: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - """ - Make an HTTP request to the SEA endpoint. - - Args: - method: HTTP method (GET, POST, DELETE) - path: API endpoint path - data: Request payload data - params: Query parameters - - Returns: - Dict[str, Any]: Response data parsed from JSON - - Raises: - RequestError: If the request fails - """ - - url = urljoin(self.base_url, path) - headers: Dict[str, str] = {**self.headers, **self._get_auth_headers()} - - logger.debug(f"making {method} request to {url}") - - try: - call = self._get_call(method) - response = call( - url=url, - headers=headers, - json=data, - params=params, - ) - - # Check for HTTP errors - response.raise_for_status() - - # Log response details - logger.debug(f"Response status: {response.status_code}") - - # Parse JSON response - if response.content: - result = response.json() - # Log response content (but limit it for large responses) - content_str = json.dumps(result) - if len(content_str) > 1000: - logger.debug( - f"Response content (truncated): {content_str[:1000]}..." - ) - else: - logger.debug(f"Response content: {content_str}") - return result - return {} - - except requests.exceptions.RequestException as e: - # Handle request errors and extract details from response if available - error_message = f"SEA HTTP request failed: {str(e)}" - - if hasattr(e, "response") and e.response is not None: - status_code = e.response.status_code - try: - error_details = e.response.json() - error_message = ( - f"{error_message}: {error_details.get('message', '')}" - ) - logger.error( - f"Request failed (status {status_code}): {error_details}" - ) - except (ValueError, KeyError): - # If we can't parse JSON, log raw content - content = ( - e.response.content.decode("utf-8", errors="replace") - if isinstance(e.response.content, bytes) - else str(e.response.content) - ) - logger.error(f"Request failed (status {status_code}): {content}") - else: - logger.error(error_message) - - # Re-raise as a RequestError - from databricks.sql.exc import RequestError - - raise RequestError(error_message, e) diff --git a/src/databricks/sql/backend/sea/utils/http_client_adapter.py b/src/databricks/sql/backend/sea/utils/http_client_adapter.py new file mode 100644 index 000000000..43ec5e27c --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/http_client_adapter.py @@ -0,0 +1,144 @@ +""" +HTTP client adapter for the Statement Execution API (SEA). + +This module provides an adapter that uses the ThriftHttpClient for HTTP operations +but provides a simplified interface for HTTP methods. +""" + +import logging +from typing import Dict, Optional, Any + +from databricks.sql.auth.thrift_http_client import THttpClient +from databricks.sql.auth.retry import CommandType + +logger = logging.getLogger(__name__) + + +class SeaHttpClientAdapter: + """ + Adapter for using ThriftHttpClient with SEA API. + + This class provides a simplified interface for HTTP methods while using + ThriftHttpClient for the actual HTTP operations. + """ + + # SEA API paths + BASE_PATH = "/api/2.0/sql/" + + def __init__( + self, + thrift_client: THttpClient, + ): + """ + Initialize the SEA HTTP client adapter. + + Args: + thrift_client: ThriftHttpClient instance to use for HTTP operations + """ + self.thrift_client = thrift_client + + def _determine_command_type( + self, path: str, method: str, data: Optional[Dict[str, Any]] = None + ) -> CommandType: + """ + Determine the CommandType based on the request path and method. + + Args: + path: API endpoint path + method: HTTP method (GET, POST, DELETE) + data: Request payload data + + Returns: + CommandType: The appropriate CommandType enum value + """ + # Extract the base path component (e.g., "sessions", "statements") + path_parts = path.strip("/").split("/") + base_path = path_parts[-1] if path_parts else "" + + # Check for specific operations based on path and method + if "statements" in path: + if method == "POST" and "cancel" in path: + return CommandType.CLOSE_OPERATION + elif method == "POST" and "cancel" not in path: + return CommandType.EXECUTE_STATEMENT + elif method == "GET": + return CommandType.GET_OPERATION_STATUS + elif method == "DELETE": + return CommandType.CLOSE_OPERATION + elif "sessions" in path: + if method == "POST": + # Creating a new session + return CommandType.OTHER + elif method == "DELETE": + return CommandType.CLOSE_SESSION + + # Default for any other operations + return CommandType.OTHER + + def get( + self, + path: str, + params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """ + Convenience method for GET requests with retry support. + + Args: + path: API endpoint path + params: Query parameters + headers: Additional headers + + Returns: + Response data parsed from JSON + """ + return self.thrift_client.make_rest_request( + "GET", path, params=params, headers=headers + ) + + def post( + self, + path: str, + data: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """ + Convenience method for POST requests with retry support. + + Args: + path: API endpoint path + data: Request payload data + params: Query parameters + headers: Additional headers + + Returns: + Response data parsed from JSON + """ + response = self.thrift_client.make_rest_request( + "POST", path, data=data, params=params, headers=headers + ) + return response + + def delete( + self, + path: str, + data: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """ + Convenience method for DELETE requests with retry support. + + Args: + path: API endpoint path + data: Request payload data + params: Query parameters + headers: Additional headers + + Returns: + Response data parsed from JSON + """ + return self.thrift_client.make_rest_request( + "DELETE", path, data=data, params=params, headers=headers + ) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index c6e5f621b..dd019a3ea 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -6,7 +6,6 @@ import time import pandas -from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.sea.models.base import ( ExternalLink, ResultData, @@ -20,6 +19,7 @@ pyarrow = None if TYPE_CHECKING: + from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.client import Connection from databricks.sql.backend.databricks_client import DatabricksClient diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 7880db338..c0db3e93b 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -15,7 +15,8 @@ import dateutil import lz4.frame -from databricks.sql.backend.sea.backend import SeaDatabricksClient +if TYPE_CHECKING: + from databricks.sql.backend.sea.backend import SeaDatabricksClient try: import pyarrow diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py index ed8ac4574..f57377da4 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/tests/e2e/common/large_queries_mixin.py @@ -87,7 +87,7 @@ def test_long_running_query(self): and asserts that the query completes successfully. """ minutes = 60 - min_duration = 5 * minutes + min_duration = 3 * minutes duration = -1 scale0 = 10000 diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index dd509c062..169885adc 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -1,4 +1,5 @@ from contextlib import contextmanager +import json import time from typing import Optional, List from unittest.mock import MagicMock, PropertyMock, patch @@ -13,6 +14,7 @@ RequestError, SessionAlreadyClosedError, UnsafeToRetryError, + DatabaseError, ) @@ -75,6 +77,42 @@ def mocked_server_response( False if redirect_location is None else redirect_location ) + # For the SEA backend, we need to provide proper JSON response data + if status >= 400: + # For error responses, provide proper SEA error structure + # This could be either a session creation error or statement execution error + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": f"HTTP {status} error - Server unavailable", + "error_code": f"HTTP_{status}_ERROR", + }, + }, + } + mock_response.data = json.dumps(error_response).encode("utf-8") + else: + # For success responses, provide proper SEA session response + success_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + mock_response.data = json.dumps(success_response).encode("utf-8") + with patch("urllib3.connectionpool.HTTPSConnectionPool._get_conn") as getconn_mock: getconn_mock.return_value.getresponse.return_value = mock_response try: @@ -105,6 +143,43 @@ def mock_sequential_server_responses(responses: List[dict]): _mock.get_redirect_location.return_value = ( False if resp["redirect_location"] is None else resp["redirect_location"] ) + + # Add proper SEA response data based on the status code + status = resp["status"] + if status >= 400: + # For error responses, provide proper SEA error structure + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": f"HTTP {status} error - Server unavailable", + "error_code": f"HTTP_{status}_ERROR", + }, + }, + } + _mock.data = json.dumps(error_response).encode("utf-8") + else: + # For success responses, provide proper SEA session response + success_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + _mock.data = json.dumps(success_response).encode("utf-8") + mock_responses.append(_mock) with patch("urllib3.connectionpool.HTTPSConnectionPool._get_conn") as getconn_mock: @@ -257,9 +332,13 @@ def test_retry_dangerous_codes(self): with conn.cursor() as cursor: for dangerous_code in DANGEROUS_CODES: with mocked_server_response(status=dangerous_code): - with pytest.raises(RequestError) as cm: + with pytest.raises((RequestError, DatabaseError)) as cm: cursor.execute("Not a real query") - assert isinstance(cm.value.args[1], UnsafeToRetryError) + # For SEA backend, dangerous codes result in DatabaseError + # when the statement execution fails with proper error response + # For Thrift backend, it should raise RequestError with UnsafeToRetryError + if isinstance(cm.value, RequestError): + assert isinstance(cm.value.args[1], UnsafeToRetryError) # Prove that these codes are retried if forced by the user with self.connection( diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 8cfed7c28..80d7d05cc 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -90,6 +90,7 @@ def connection_params(self): params = { "server_hostname": self.arguments["host"], "http_path": self.arguments["http_path"], + "use_sea": True, **self.auth_params(), } @@ -116,7 +117,7 @@ def connection(self, extra_params=()): def cursor(self, extra_params=()): with self.connection(extra_params) as conn: cursor = conn.cursor( - arraysize=self.arraysize, buffer_size_bytes=self.buffer_size_bytes + arraysize=None, buffer_size_bytes=self.buffer_size_bytes ) try: yield cursor diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index d75359f2f..65171f123 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -30,9 +30,12 @@ class TestSeaBackend: def mock_http_client(self): """Create a mock HTTP client.""" with patch( - "databricks.sql.backend.sea.backend.SeaHttpClient" + "databricks.sql.backend.sea.backend.SeaHttpClientAdapter" ) as mock_client_class: mock_client = mock_client_class.return_value + mock_client.get.return_value = {} + mock_client.post.return_value = {} + mock_client.delete.return_value = {} yield mock_client @pytest.fixture @@ -143,18 +146,18 @@ def test_initialization(self, mock_http_client): def test_session_management(self, sea_client, mock_http_client, thrift_session_id): """Test session management methods.""" # Test open_session with minimal parameters - mock_http_client._make_request.return_value = {"session_id": "test-session-123"} + mock_http_client.post.return_value = {"session_id": "test-session-123"} session_id = sea_client.open_session(None, None, None) assert isinstance(session_id, SessionId) assert session_id.backend_type == BackendType.SEA assert session_id.guid == "test-session-123" - mock_http_client._make_request.assert_called_with( - method="POST", path=sea_client.SESSION_PATH, data={"warehouse_id": "abc123"} + mock_http_client.post.assert_called_with( + path=sea_client.SESSION_PATH, data={"warehouse_id": "abc123"} ) # Test open_session with all parameters mock_http_client.reset_mock() - mock_http_client._make_request.return_value = {"session_id": "test-session-456"} + mock_http_client.post.return_value = {"session_id": "test-session-456"} session_config = { "ANSI_MODE": "FALSE", # Supported parameter "STATEMENT_TIMEOUT": "3600", # Supported parameter @@ -173,13 +176,13 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i "catalog": catalog, "schema": schema, } - mock_http_client._make_request.assert_called_with( - method="POST", path=sea_client.SESSION_PATH, data=expected_data + mock_http_client.post.assert_called_with( + path=sea_client.SESSION_PATH, data=expected_data ) # Test open_session error handling mock_http_client.reset_mock() - mock_http_client._make_request.return_value = {} + mock_http_client.post.return_value = {} with pytest.raises(Error) as excinfo: sea_client.open_session(None, None, None) assert "Failed to create session" in str(excinfo.value) @@ -188,8 +191,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i mock_http_client.reset_mock() session_id = SessionId.from_sea_session_id("test-session-789") sea_client.close_session(session_id) - mock_http_client._make_request.assert_called_with( - method="DELETE", + mock_http_client.delete.assert_called_with( path=sea_client.SESSION_PATH_WITH_ID.format("test-session-789"), data={"session_id": "test-session-789", "warehouse_id": "abc123"}, ) @@ -221,7 +223,7 @@ def test_command_execution_sync( }, "result": {"data": [["value1"]]}, } - mock_http_client._make_request.return_value = execute_response + mock_http_client.post.return_value = execute_response with patch.object( sea_client, "get_execution_result", return_value="mock_result_set" @@ -273,7 +275,7 @@ def test_command_execution_async( "statement_id": "test-statement-456", "status": {"state": "PENDING"}, } - mock_http_client._make_request.return_value = execute_response + mock_http_client.post.return_value = execute_response result = sea_client.execute_command( operation="SELECT 1", @@ -293,7 +295,7 @@ def test_command_execution_async( # Test async with missing statement ID mock_http_client.reset_mock() - mock_http_client._make_request.return_value = {"status": {"state": "PENDING"}} + mock_http_client.post.return_value = {"status": {"state": "PENDING"}} with pytest.raises(ServerOperationError) as excinfo: sea_client.execute_command( operation="SELECT 1", @@ -326,7 +328,8 @@ def test_command_execution_advanced( "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, "result": {"data": []}, } - mock_http_client._make_request.side_effect = [initial_response, poll_response] + mock_http_client.post.side_effect = [initial_response] + mock_http_client.get.side_effect = [poll_response] with patch.object( sea_client, "get_execution_result", return_value="mock_result_set" @@ -348,13 +351,16 @@ def test_command_execution_advanced( # Test with parameters mock_http_client.reset_mock() - mock_http_client._make_request.side_effect = None # Reset side_effect + mock_http_client.post.side_effect = None # Reset side_effect execute_response = { "statement_id": "test-statement-123", "status": {"state": "SUCCEEDED"}, } - mock_http_client._make_request.return_value = execute_response - param = {"name": "param1", "value": "value1", "type": "STRING"} + mock_http_client.post.return_value = execute_response + param = MagicMock() + param.name = "param1" + param.value = "value1" + param.type = "STRING" with patch.object(sea_client, "get_execution_result"): sea_client.execute_command( @@ -369,7 +375,7 @@ def test_command_execution_advanced( async_op=False, enforce_embedded_schema_correctness=False, ) - args, kwargs = mock_http_client._make_request.call_args + args, kwargs = mock_http_client.post.call_args assert "parameters" in kwargs["data"] assert len(kwargs["data"]["parameters"]) == 1 assert kwargs["data"]["parameters"][0]["name"] == "param1" @@ -388,7 +394,7 @@ def test_command_execution_advanced( }, }, } - mock_http_client._make_request.return_value = error_response + mock_http_client.post.return_value = error_response with patch("time.sleep"): with patch.object( @@ -411,7 +417,7 @@ def test_command_execution_advanced( # Test missing statement ID mock_http_client.reset_mock() - mock_http_client._make_request.return_value = {"status": {"state": "SUCCEEDED"}} + mock_http_client.post.return_value = {"status": {"state": "SUCCEEDED"}} with pytest.raises(ServerOperationError) as excinfo: sea_client.execute_command( operation="SELECT 1", @@ -439,10 +445,9 @@ def test_command_management( ): """Test command management methods.""" # Test cancel_command - mock_http_client._make_request.return_value = {} + mock_http_client.post.return_value = {} sea_client.cancel_command(sea_command_id) - mock_http_client._make_request.assert_called_with( - method="POST", + mock_http_client.post.assert_called_with( path=sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format("test-statement-123"), data={"statement_id": "test-statement-123"}, ) @@ -455,8 +460,7 @@ def test_command_management( # Test close_command mock_http_client.reset_mock() sea_client.close_command(sea_command_id) - mock_http_client._make_request.assert_called_with( - method="DELETE", + mock_http_client.delete.assert_called_with( path=sea_client.STATEMENT_PATH_WITH_ID.format("test-statement-123"), data={"statement_id": "test-statement-123"}, ) @@ -468,16 +472,14 @@ def test_command_management( # Test get_query_state mock_http_client.reset_mock() - mock_http_client._make_request.return_value = { + mock_http_client.get.return_value = { "statement_id": "test-statement-123", "status": {"state": "RUNNING"}, } state = sea_client.get_query_state(sea_command_id) assert state == CommandState.RUNNING - mock_http_client._make_request.assert_called_with( - method="GET", + mock_http_client.get.assert_called_with( path=sea_client.STATEMENT_PATH_WITH_ID.format("test-statement-123"), - data={"statement_id": "test-statement-123"}, ) # Test get_query_state with invalid ID @@ -515,7 +517,7 @@ def test_command_management( "data_array": [["1"]], }, } - mock_http_client._make_request.return_value = sea_response + mock_http_client.get.return_value = sea_response result = sea_client.get_execution_result(sea_command_id, mock_cursor) assert result.command_id.to_sea_statement_id() == "test-statement-123" assert result.status == CommandState.SUCCEEDED diff --git a/tests/unit/test_sea_http_client_adapter.py b/tests/unit/test_sea_http_client_adapter.py new file mode 100644 index 000000000..c7cc9da78 --- /dev/null +++ b/tests/unit/test_sea_http_client_adapter.py @@ -0,0 +1,151 @@ +""" +Tests for the SEA HTTP client adapter. + +This module contains tests for the SeaHttpClientAdapter class, which provides +an adapter for using ThriftHttpClient with the SEA API. +""" + +import pytest +from unittest.mock import patch, MagicMock + +from databricks.sql.backend.sea.utils.http_client_adapter import SeaHttpClientAdapter +from databricks.sql.auth.retry import CommandType + + +class TestSeaHttpClientAdapter: + """Test suite for the SeaHttpClientAdapter class.""" + + @pytest.fixture + def mock_thrift_client(self): + """Create a mock ThriftHttpClient.""" + mock_client = MagicMock() + mock_client.make_rest_request.return_value = {} + return mock_client + + @pytest.fixture + def adapter(self, mock_thrift_client): + """Create a SeaHttpClientAdapter instance with a mock ThriftHttpClient.""" + return SeaHttpClientAdapter(thrift_client=mock_thrift_client) + + def test_determine_command_type(self, adapter): + """Test the command type determination logic.""" + # Test statement execution + assert ( + adapter._determine_command_type("/api/2.0/sql/statements", "POST") + == CommandType.EXECUTE_STATEMENT + ) + + # Test get operation status + assert ( + adapter._determine_command_type("/api/2.0/sql/statements/123", "GET") + == CommandType.GET_OPERATION_STATUS + ) + + # Test cancel operation + assert ( + adapter._determine_command_type( + "/api/2.0/sql/statements/123/cancel", "POST" + ) + == CommandType.CLOSE_OPERATION + ) + + # Test close operation + assert ( + adapter._determine_command_type("/api/2.0/sql/statements/123", "DELETE") + == CommandType.CLOSE_OPERATION + ) + + # Test close session + assert ( + adapter._determine_command_type("/api/2.0/sql/sessions/123", "DELETE") + == CommandType.CLOSE_SESSION + ) + + # Test other operations + assert ( + adapter._determine_command_type("/api/2.0/sql/sessions", "POST") + == CommandType.OTHER + ) + assert ( + adapter._determine_command_type("/api/2.0/sql/other", "GET") + == CommandType.OTHER + ) + + def test_http_methods_set_command_type(self, adapter, mock_thrift_client): + """Test that HTTP methods set the command type and start the retry timer.""" + # Test GET method + adapter.get("/api/2.0/sql/statements/123") + mock_thrift_client.set_retry_command_type.assert_called_with( + CommandType.GET_OPERATION_STATUS + ) + mock_thrift_client.startRetryTimer.assert_called_once() + mock_thrift_client.make_rest_request.assert_called_with( + "GET", "/api/2.0/sql/statements/123", params=None, headers=None + ) + + # Reset mocks + mock_thrift_client.reset_mock() + + # Test POST method + adapter.post("/api/2.0/sql/statements") + mock_thrift_client.set_retry_command_type.assert_called_with( + CommandType.EXECUTE_STATEMENT + ) + mock_thrift_client.startRetryTimer.assert_called_once() + mock_thrift_client.make_rest_request.assert_called_with( + "POST", "/api/2.0/sql/statements", data=None, params=None, headers=None + ) + + # Reset mocks + mock_thrift_client.reset_mock() + + # Test DELETE method + adapter.delete("/api/2.0/sql/statements/123") + mock_thrift_client.set_retry_command_type.assert_called_with( + CommandType.CLOSE_OPERATION + ) + mock_thrift_client.startRetryTimer.assert_called_once() + mock_thrift_client.make_rest_request.assert_called_with( + "DELETE", + "/api/2.0/sql/statements/123", + data=None, + params=None, + headers=None, + ) + + def test_http_methods_with_parameters(self, adapter, mock_thrift_client): + """Test HTTP methods with parameters.""" + # Test GET with parameters + params = {"param1": "value1"} + headers = {"header1": "value1"} + adapter.get("/api/2.0/sql/statements/123", params=params, headers=headers) + mock_thrift_client.make_rest_request.assert_called_with( + "GET", "/api/2.0/sql/statements/123", params=params, headers=headers + ) + + # Reset mocks + mock_thrift_client.reset_mock() + + # Test POST with data and parameters + data = {"key": "value"} + adapter.post( + "/api/2.0/sql/statements", data=data, params=params, headers=headers + ) + mock_thrift_client.make_rest_request.assert_called_with( + "POST", "/api/2.0/sql/statements", data=data, params=params, headers=headers + ) + + # Reset mocks + mock_thrift_client.reset_mock() + + # Test DELETE with data and parameters + adapter.delete( + "/api/2.0/sql/statements/123", data=data, params=params, headers=headers + ) + mock_thrift_client.make_rest_request.assert_called_with( + "DELETE", + "/api/2.0/sql/statements/123", + data=data, + params=params, + headers=headers, + ) diff --git a/tests/unit/test_thrift_http_client.py b/tests/unit/test_thrift_http_client.py new file mode 100644 index 000000000..a881716bc --- /dev/null +++ b/tests/unit/test_thrift_http_client.py @@ -0,0 +1,393 @@ +import unittest +import json +import urllib +from unittest.mock import patch, Mock, MagicMock, PropertyMock, call +import urllib3 +from http.client import HTTPResponse +from io import BytesIO + +from databricks.sql.auth.thrift_http_client import THttpClient +from databricks.sql.exc import RequestError +from databricks.sql.auth.retry import DatabricksRetryPolicy +from databricks.sql.types import SSLOptions + + +class TestTHttpClient(unittest.TestCase): + """Unit tests for the THttpClient class.""" + + @patch("urllib.request.getproxies") + @patch("urllib.request.proxy_bypass") + def setUp(self, mock_proxy_bypass, mock_getproxies): + """Set up test fixtures.""" + # Mock proxy functions + mock_getproxies.return_value = {} + mock_proxy_bypass.return_value = True + + # Create auth provider mock + self.mock_auth_provider = Mock() + self.mock_auth_provider.add_headers = Mock() + + # Create HTTP client + self.uri = "https://example.com/path" + self.http_client = THttpClient( + auth_provider=self.mock_auth_provider, + uri_or_host=self.uri, + ssl_options=SSLOptions(), + ) + + # Mock the connection pool + self.mock_pool = Mock() + self.http_client._THttpClient__pool = self.mock_pool + + # Set custom headers to include User-Agent (required by the class) + self.http_client._headers = {"User-Agent": "test-agent"} + self.http_client.__custom_headers = {"User-Agent": "test-agent"} + + # Set timeout + self.http_client._THttpClient__timeout = None + + def test_check_rest_response_for_error_success(self): + """Test _check_rest_response_for_error with success status.""" + # No exception should be raised for status codes < 400 + self.http_client._check_rest_response_for_error(200, None) + self.http_client._check_rest_response_for_error(201, None) + self.http_client._check_rest_response_for_error(302, None) + # No assertion needed - test passes if no exception is raised + + def test_check_rest_response_for_error_client_error(self): + """Test _check_rest_response_for_error with client error status.""" + # Setup response data with error message + response_data = json.dumps({"message": "Bad request"}).encode("utf-8") + + # Check that exception is raised for client error + with self.assertRaises(RequestError) as context: + self.http_client._check_rest_response_for_error(400, response_data) + + # Verify the exception message + self.assertIn( + "REST HTTP request failed with status 400", str(context.exception) + ) + self.assertIn("Bad request", str(context.exception)) + + def test_check_rest_response_for_error_server_error(self): + """Test _check_rest_response_for_error with server error status.""" + # Setup response data with error message + response_data = json.dumps({"message": "Internal server error"}).encode("utf-8") + + # Check that exception is raised for server error + with self.assertRaises(RequestError) as context: + self.http_client._check_rest_response_for_error(500, response_data) + + # Verify the exception message + self.assertIn( + "REST HTTP request failed with status 500", str(context.exception) + ) + self.assertIn("Internal server error", str(context.exception)) + + def test_check_rest_response_for_error_no_message(self): + """Test _check_rest_response_for_error with error but no message.""" + # Check that exception is raised with generic message + with self.assertRaises(RequestError) as context: + self.http_client._check_rest_response_for_error(404, None) + + # Verify the exception message + self.assertIn( + "REST HTTP request failed with status 404", str(context.exception) + ) + + def test_check_rest_response_for_error_invalid_json(self): + """Test _check_rest_response_for_error with invalid JSON response.""" + # Setup invalid JSON response + response_data = "Not a JSON response".encode("utf-8") + + # Check that exception is raised with generic message + with self.assertRaises(RequestError) as context: + self.http_client._check_rest_response_for_error(500, response_data) + + # Verify the exception message + self.assertIn( + "REST HTTP request failed with status 500", str(context.exception) + ) + + @patch("databricks.sql.auth.thrift_http_client.THttpClient.isOpen") + @patch("databricks.sql.auth.thrift_http_client.THttpClient.open") + @patch( + "databricks.sql.auth.thrift_http_client.THttpClient._check_rest_response_for_error" + ) + def test_make_rest_request_success(self, mock_check_error, mock_open, mock_is_open): + """Test the make_rest_request method with a successful response.""" + # Setup mocks + mock_is_open.return_value = False # To trigger open() call + + # Create a mock response + mock_response = Mock() + mock_response.status = 200 + mock_response.reason = "OK" + mock_response.headers = {"Content-Type": "application/json"} + mock_response.data = json.dumps({"result": "success"}).encode("utf-8") + + # Configure the mock pool to return our mock response + self.mock_pool.request.return_value = mock_response + + # Call the method under test + result = self.http_client.make_rest_request( + method="GET", endpoint_path="test/endpoint", params={"param": "value"} + ) + + # Verify the result + self.assertEqual(result, {"result": "success"}) + + # Verify open was called + mock_open.assert_called_once() + + # Verify the request was made with correct parameters + self.mock_pool.request.assert_called_once() + + # Check URL contains the parameters + args, kwargs = self.mock_pool.request.call_args + self.assertIn("test/endpoint?param=value", kwargs["url"]) + + # Verify error check was called + mock_check_error.assert_called_once_with(200, mock_response.data) + + # Verify auth headers were added + self.mock_auth_provider.add_headers.assert_called_once() + + @patch("databricks.sql.auth.thrift_http_client.THttpClient.isOpen") + @patch("databricks.sql.auth.thrift_http_client.THttpClient.open") + @patch( + "databricks.sql.auth.thrift_http_client.THttpClient._check_rest_response_for_error" + ) + def test_make_rest_request_with_data( + self, mock_check_error, mock_open, mock_is_open + ): + """Test the make_rest_request method with data payload.""" + # Setup mocks + mock_is_open.return_value = True # Connection is already open + + # Create a mock response + mock_response = Mock() + mock_response.status = 200 + mock_response.reason = "OK" + mock_response.headers = {"Content-Type": "application/json"} + mock_response.data = json.dumps({"result": "success"}).encode("utf-8") + + # Configure the mock pool to return our mock response + self.mock_pool.request.return_value = mock_response + + # Call the method under test with data + data = {"key": "value"} + result = self.http_client.make_rest_request( + method="POST", endpoint_path="test/endpoint", data=data + ) + + # Verify the result + self.assertEqual(result, {"result": "success"}) + + # Verify open was not called (connection already open) + mock_open.assert_not_called() + + # Verify the request was made with correct parameters + self.mock_pool.request.assert_called_once() + + # Check body contains the JSON data + args, kwargs = self.mock_pool.request.call_args + self.assertEqual(kwargs["body"], json.dumps(data).encode("utf-8")) + + # Verify error check was called + mock_check_error.assert_called_once_with(200, mock_response.data) + + @patch("databricks.sql.auth.thrift_http_client.THttpClient.isOpen") + @patch("databricks.sql.auth.thrift_http_client.THttpClient.open") + @patch( + "databricks.sql.auth.thrift_http_client.THttpClient._check_rest_response_for_error" + ) + def test_make_rest_request_with_custom_headers( + self, mock_check_error, mock_open, mock_is_open + ): + """Test the make_rest_request method with custom headers.""" + # Setup mocks + mock_is_open.return_value = True # Connection is already open + + # Create a mock response + mock_response = Mock() + mock_response.status = 200 + mock_response.reason = "OK" + mock_response.headers = {"Content-Type": "application/json"} + mock_response.data = json.dumps({"result": "success"}).encode("utf-8") + + # Configure the mock pool to return our mock response + self.mock_pool.request.return_value = mock_response + + # Call the method under test with custom headers + custom_headers = {"X-Custom-Header": "custom-value"} + result = self.http_client.make_rest_request( + method="GET", endpoint_path="test/endpoint", headers=custom_headers + ) + + # Verify the result + self.assertEqual(result, {"result": "success"}) + + # Verify the request was made with correct headers + self.mock_pool.request.assert_called_once() + + # Check headers contain the custom header + args, kwargs = self.mock_pool.request.call_args + headers = kwargs["headers"] + self.assertIn("X-Custom-Header", headers) + self.assertEqual(headers["X-Custom-Header"], "custom-value") + self.assertEqual(headers["Content-Type"], "application/json") + + # Verify error check was called + mock_check_error.assert_called_once_with(200, mock_response.data) + + @patch("databricks.sql.auth.thrift_http_client.THttpClient.isOpen") + @patch("databricks.sql.auth.thrift_http_client.THttpClient.open") + def test_make_rest_request_http_error(self, mock_open, mock_is_open): + """Test the make_rest_request method with an HTTP error.""" + # Setup mocks + mock_is_open.return_value = True # Connection is already open + + # Configure the mock pool to raise an HTTP error + http_error = urllib3.exceptions.HTTPError("HTTP Error") + self.mock_pool.request.side_effect = http_error + + # Call the method under test and expect an exception + with self.assertRaises(RequestError) as context: + self.http_client.make_rest_request( + method="GET", endpoint_path="test/endpoint" + ) + + # Verify the exception message + self.assertIn("REST HTTP request failed", str(context.exception)) + self.assertIn("HTTP Error", str(context.exception)) + + @patch("databricks.sql.auth.thrift_http_client.THttpClient.isOpen") + @patch("databricks.sql.auth.thrift_http_client.THttpClient.open") + @patch( + "databricks.sql.auth.thrift_http_client.THttpClient._check_rest_response_for_error" + ) + def test_make_rest_request_empty_response( + self, mock_check_error, mock_open, mock_is_open + ): + """Test the make_rest_request method with an empty response.""" + # Setup mocks + mock_is_open.return_value = True # Connection is already open + + # Create a mock response with empty data + mock_response = Mock() + mock_response.status = 204 # No Content + mock_response.reason = "No Content" + mock_response.headers = {"Content-Type": "application/json"} + mock_response.data = None # Empty response + + # Configure the mock pool to return our mock response + self.mock_pool.request.return_value = mock_response + + # Call the method under test + result = self.http_client.make_rest_request( + method="DELETE", endpoint_path="test/endpoint/123" + ) + + # Verify the result is an empty dict + self.assertEqual(result, {}) + + # Verify error check was called with None data + mock_check_error.assert_called_once_with(204, None) + + @patch("databricks.sql.auth.thrift_http_client.THttpClient.isOpen") + @patch("databricks.sql.auth.thrift_http_client.THttpClient.open") + def test_make_rest_request_no_response(self, mock_open, mock_is_open): + """Test the make_rest_request method with no response.""" + # Setup mocks + mock_is_open.return_value = True # Connection is already open + + # Configure the mock pool to return None + self.mock_pool.request.return_value = None + + # Call the method under test and expect an exception + with self.assertRaises(ValueError) as context: + self.http_client.make_rest_request( + method="GET", endpoint_path="test/endpoint" + ) + + # Verify the exception message + self.assertEqual(str(context.exception), "No response received from server") + + @patch("databricks.sql.auth.thrift_http_client.THttpClient.isOpen") + @patch("databricks.sql.auth.thrift_http_client.THttpClient.open") + @patch( + "databricks.sql.auth.thrift_http_client.THttpClient._check_rest_response_for_error" + ) + def test_make_rest_request_with_retry_policy( + self, mock_check_error, mock_open, mock_is_open + ): + """Test the make_rest_request method with a retry policy.""" + # Setup mocks + mock_is_open.return_value = True # Connection is already open + + # Create a mock response + mock_response = Mock() + mock_response.status = 200 + mock_response.reason = "OK" + mock_response.headers = {"Content-Type": "application/json"} + mock_response.data = json.dumps({"result": "success"}).encode("utf-8") + + # Configure the mock pool to return our mock response + self.mock_pool.request.return_value = mock_response + + # Create a retry policy mock + mock_retry_policy = Mock(spec=DatabricksRetryPolicy) + + # Set the retry policy on the client + self.http_client.retry_policy = mock_retry_policy + + # Call the method under test + result = self.http_client.make_rest_request( + method="GET", endpoint_path="test/endpoint" + ) + + # Verify the result + self.assertEqual(result, {"result": "success"}) + + # Verify the request was made with the retry policy + self.mock_pool.request.assert_called_once() + + # Check retries parameter + args, kwargs = self.mock_pool.request.call_args + self.assertEqual(kwargs["retries"], mock_retry_policy) + + @patch("databricks.sql.auth.thrift_http_client.THttpClient.isOpen") + @patch("databricks.sql.auth.thrift_http_client.THttpClient.open") + @patch( + "databricks.sql.auth.thrift_http_client.THttpClient._check_rest_response_for_error" + ) + def test_make_rest_request_invalid_json_response( + self, mock_check_error, mock_open, mock_is_open + ): + """Test the make_rest_request method with invalid JSON response.""" + # Setup mocks + mock_is_open.return_value = True # Connection is already open + + # Create a mock response with invalid JSON + mock_response = Mock() + mock_response.status = 200 + mock_response.reason = "OK" + mock_response.headers = {"Content-Type": "application/json"} + mock_response.data = "Not a valid JSON".encode("utf-8") + + # Configure the mock pool to return our mock response + self.mock_pool.request.return_value = mock_response + + # Call the method under test and expect a JSON decode error + with self.assertRaises(json.JSONDecodeError): + self.http_client.make_rest_request( + method="GET", endpoint_path="test/endpoint" + ) + + # Verify error check was called before the JSON parsing + mock_check_error.assert_called_once_with(200, mock_response.data) + + +if __name__ == "__main__": + unittest.main()