diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py new file mode 100644 index 000000000..0337d8d06 --- /dev/null +++ b/src/databricks/sql/backend/databricks_client.py @@ -0,0 +1,342 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Any, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet + +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.backend.types import SessionId, CommandId, CommandState + + +class DatabricksClient(ABC): + """ + Abstract client interface for interacting with Databricks SQL services. + + Implementations of this class are responsible for: + - Managing connections to Databricks SQL services + - Executing SQL queries and commands + - Retrieving query results + - Fetching metadata about catalogs, schemas, tables, and columns + """ + + # == Connection and Session Management == + @abstractmethod + def open_session( + self, + session_configuration: Optional[Dict[str, Any]], + catalog: Optional[str], + schema: Optional[str], + ) -> SessionId: + """ + Opens a new session with the Databricks SQL service. + + This method establishes a new session with the server and returns a session + identifier that can be used for subsequent operations. + + Args: + session_configuration: Optional dictionary of configuration parameters for the session + catalog: Optional catalog name to use as the initial catalog for the session + schema: Optional schema name to use as the initial schema for the session + + Returns: + SessionId: A session identifier object that can be used for subsequent operations + + Raises: + Error: If the session configuration is invalid + OperationalError: If there's an error establishing the session + InvalidServerResponseError: If the server response is invalid or unexpected + """ + pass + + @abstractmethod + def close_session(self, session_id: SessionId) -> None: + """ + Closes an existing session with the Databricks SQL service. + + This method terminates the session identified by the given session ID and + releases any resources associated with it. + + Args: + session_id: The session identifier returned by open_session() + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error closing the session + """ + pass + + # == Query Execution, Command Management == + @abstractmethod + def execute_command( + self, + operation: str, + session_id: SessionId, + max_rows: int, + max_bytes: int, + lz4_compression: bool, + cursor: Cursor, + use_cloud_fetch: bool, + parameters: List[ttypes.TSparkParameter], + async_op: bool, + enforce_embedded_schema_correctness: bool, + ) -> Union[ResultSet, None]: + """ + Executes a SQL command or query within the specified session. + + This method sends a SQL command to the server for execution and handles + the response. It can operate in both synchronous and asynchronous modes. + + Args: + operation: The SQL command or query to execute + session_id: The session identifier in which to execute the command + max_rows: Maximum number of rows to fetch in a single fetch batch + max_bytes: Maximum number of bytes to fetch in a single fetch batch + lz4_compression: Whether to use LZ4 compression for result data + cursor: The cursor object that will handle the results + use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets + parameters: List of parameters to bind to the query + async_op: Whether to execute the command asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + If async_op is False, returns a ResultSet object containing the + query results and metadata. If async_op is True, returns None and the + results must be fetched later using get_execution_result(). + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error executing the command + ServerOperationError: If the server encounters an error during execution + """ + pass + + @abstractmethod + def cancel_command(self, command_id: CommandId) -> None: + """ + Cancels a running command or query. + + This method attempts to cancel a command that is currently being executed. + It can be called from a different thread than the one executing the command. + + Args: + command_id: The command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + OperationalError: If there's an error canceling the command + """ + pass + + @abstractmethod + def close_command(self, command_id: CommandId) -> None: + """ + Closes a command and releases associated resources. + + This method informs the server that the client is done with the command + and any resources associated with it can be released. + + Args: + command_id: The command identifier to close + + Raises: + ValueError: If the command ID is invalid + OperationalError: If there's an error closing the command + """ + pass + + @abstractmethod + def get_query_state(self, command_id: CommandId) -> CommandState: + """ + Gets the current state of a query or command. + + This method retrieves the current execution state of a command from the server. + + Args: + command_id: The command identifier to check + + Returns: + CommandState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + OperationalError: If there's an error retrieving the state + ServerOperationError: If the command is in an error state + DatabaseError: If the command has been closed unexpectedly + """ + pass + + @abstractmethod + def get_execution_result( + self, + command_id: CommandId, + cursor: Cursor, + ) -> ResultSet: + """ + Retrieves the results of a previously executed command. + + This method fetches the results of a command that was executed asynchronously + or retrieves additional results from a command that has more rows available. + + Args: + command_id: The command identifier for which to retrieve results + cursor: The cursor object that will handle the results + + Returns: + ResultSet: An object containing the query results and metadata + + Raises: + ValueError: If the command ID is invalid + OperationalError: If there's an error retrieving the results + """ + pass + + # == Metadata Operations == + @abstractmethod + def get_catalogs( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: Cursor, + ) -> ResultSet: + """ + Retrieves a list of available catalogs. + + This method fetches metadata about all catalogs available in the current + session's context. + + Args: + session_id: The session identifier + max_rows: Maximum number of rows to fetch in a single batch + max_bytes: Maximum number of bytes to fetch in a single batch + cursor: The cursor object that will handle the results + + Returns: + ResultSet: An object containing the catalog metadata + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error retrieving the catalogs + """ + pass + + @abstractmethod + def get_schemas( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: Cursor, + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + ) -> ResultSet: + """ + Retrieves a list of schemas, optionally filtered by catalog and schema name patterns. + + This method fetches metadata about schemas available in the specified catalog + or all catalogs if no catalog is specified. + + Args: + session_id: The session identifier + max_rows: Maximum number of rows to fetch in a single batch + max_bytes: Maximum number of bytes to fetch in a single batch + cursor: The cursor object that will handle the results + catalog_name: Optional catalog name pattern to filter by + schema_name: Optional schema name pattern to filter by + + Returns: + ResultSet: An object containing the schema metadata + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error retrieving the schemas + """ + pass + + @abstractmethod + def get_tables( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: Cursor, + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + table_types: Optional[List[str]] = None, + ) -> ResultSet: + """ + Retrieves a list of tables, optionally filtered by catalog, schema, table name, and table types. + + This method fetches metadata about tables available in the specified catalog + and schema, or all catalogs and schemas if not specified. + + Args: + session_id: The session identifier + max_rows: Maximum number of rows to fetch in a single batch + max_bytes: Maximum number of bytes to fetch in a single batch + cursor: The cursor object that will handle the results + catalog_name: Optional catalog name pattern to filter by + schema_name: Optional schema name pattern to filter by + table_name: Optional table name pattern to filter by + table_types: Optional list of table types to filter by (e.g., ['TABLE', 'VIEW']) + + Returns: + ResultSet: An object containing the table metadata + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error retrieving the tables + """ + pass + + @abstractmethod + def get_columns( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: Cursor, + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + column_name: Optional[str] = None, + ) -> ResultSet: + """ + Retrieves a list of columns, optionally filtered by catalog, schema, table, and column name patterns. + + This method fetches metadata about columns available in the specified table, + or all tables if not specified. + + Args: + session_id: The session identifier + max_rows: Maximum number of rows to fetch in a single batch + max_bytes: Maximum number of bytes to fetch in a single batch + cursor: The cursor object that will handle the results + catalog_name: Optional catalog name pattern to filter by + schema_name: Optional schema name pattern to filter by + table_name: Optional table name pattern to filter by + column_name: Optional column name pattern to filter by + + Returns: + ResultSet: An object containing the column metadata + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error retrieving the columns + """ + pass + + @property + @abstractmethod + def max_download_threads(self) -> int: + """ + Gets the maximum number of download threads for cloud fetch operations. + + Returns: + int: The maximum number of download threads + """ + pass diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py similarity index 81% rename from src/databricks/sql/thrift_backend.py rename to src/databricks/sql/backend/thrift_backend.py index e3dc38ad5..514d937d8 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -1,13 +1,24 @@ -from decimal import Decimal +from __future__ import annotations + import errno import logging import math import time -import uuid import threading -from typing import List, Union +from typing import Union, TYPE_CHECKING + +from databricks.sql.result_set import ThriftResultSet -from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState +if TYPE_CHECKING: + from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet + +from databricks.sql.backend.types import ( + CommandState, + SessionId, + CommandId, +) +from databricks.sql.backend.utils.guid_utils import guid_to_hex_id try: import pyarrow @@ -41,6 +52,7 @@ convert_column_based_set_to_arrow_table, ) from databricks.sql.types import SSLOptions +from databricks.sql.backend.databricks_client import DatabricksClient logger = logging.getLogger(__name__) @@ -73,9 +85,9 @@ } -class ThriftBackend: - CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE - ERROR_OP_STATE = ttypes.TOperationState.ERROR_STATE +class ThriftDatabricksClient(DatabricksClient): + CLOSED_OP_STATE = CommandState.CLOSED + ERROR_OP_STATE = CommandState.FAILED _retry_delay_min: float _retry_delay_max: float @@ -91,7 +103,6 @@ def __init__( http_headers, auth_provider: AuthProvider, ssl_options: SSLOptions, - staging_allowed_local_path: Union[None, str, List[str]] = None, **kwargs, ): # Internal arguments in **kwargs: @@ -150,7 +161,6 @@ def __init__( else: raise ValueError("No valid connection settings.") - self.staging_allowed_local_path = staging_allowed_local_path self._initialize_retry_args(kwargs) self._use_arrow_native_complex_types = kwargs.get( "_use_arrow_native_complex_types", True @@ -161,7 +171,7 @@ def __init__( ) # Cloud fetch - self.max_download_threads = kwargs.get("max_download_threads", 10) + self._max_download_threads = kwargs.get("max_download_threads", 10) self._ssl_options = ssl_options @@ -224,6 +234,10 @@ def __init__( self._request_lock = threading.RLock() + @property + def max_download_threads(self) -> int: + return self._max_download_threads + # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918) def _initialize_retry_args(self, kwargs): # Configure retries & timing: use user-settings or defaults, and bound @@ -337,6 +351,7 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ + # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -446,8 +461,10 @@ def attempt_request(attempt): logger.error("ThriftBackend.attempt_request: Exception: %s", err) error = err retry_delay = extract_retry_delay(attempt) - error_message = ThriftBackend._extract_error_message_from_headers( - getattr(self._transport, "headers", {}) + error_message = ( + ThriftDatabricksClient._extract_error_message_from_headers( + getattr(self._transport, "headers", {}) + ) ) finally: # Calling `close()` here releases the active HTTP connection back to the pool @@ -483,7 +500,7 @@ def attempt_request(attempt): if not isinstance(response_or_error_info, RequestErrorInfo): # log nothing here, presume that main request logging covers response = response_or_error_info - ThriftBackend._check_response_for_error(response) + ThriftDatabricksClient._check_response_for_error(response) return response error_info = response_or_error_info @@ -534,7 +551,7 @@ def _check_session_configuration(self, session_configuration): ) ) - def open_session(self, session_configuration, catalog, schema): + def open_session(self, session_configuration, catalog, schema) -> SessionId: try: self._transport.open() session_configuration = { @@ -562,13 +579,22 @@ def open_session(self, session_configuration, catalog, schema): response = self.make_request(self._client.OpenSession, open_session_req) self._check_initial_namespace(catalog, schema, response) self._check_protocol_version(response) - return response + properties = ( + {"serverProtocolVersion": response.serverProtocolVersion} + if response.serverProtocolVersion + else {} + ) + return SessionId.from_thrift_handle(response.sessionHandle, properties) except: self._transport.close() raise - def close_session(self, session_handle) -> None: - req = ttypes.TCloseSessionReq(sessionHandle=session_handle) + def close_session(self, session_id: SessionId) -> None: + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") + + req = ttypes.TCloseSessionReq(sessionHandle=thrift_handle) try: self.make_request(self._client.CloseSession, req) finally: @@ -583,7 +609,7 @@ def _check_command_not_in_error_or_closed_state( get_operations_resp.displayMessage, { "operation-id": op_handle - and self.guid_to_hex_id(op_handle.operationId.guid), + and guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": get_operations_resp.diagnosticInfo, }, ) @@ -592,18 +618,18 @@ def _check_command_not_in_error_or_closed_state( get_operations_resp.errorMessage, { "operation-id": op_handle - and self.guid_to_hex_id(op_handle.operationId.guid), + and guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": None, }, ) elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE: raise DatabaseError( "Command {} unexpectedly closed server side".format( - op_handle and self.guid_to_hex_id(op_handle.operationId.guid) + op_handle and guid_to_hex_id(op_handle.operationId.guid) ), { "operation-id": op_handle - and self.guid_to_hex_id(op_handle.operationId.guid) + and guid_to_hex_id(op_handle.operationId.guid) }, ) @@ -707,7 +733,8 @@ def _col_to_description(col): @staticmethod def _hive_schema_to_description(t_table_schema): return [ - ThriftBackend._col_to_description(col) for col in t_table_schema.columns + ThriftDatabricksClient._col_to_description(col) + for col in t_table_schema.columns ] def _results_message_to_execute_response(self, resp, operation_state): @@ -767,28 +794,36 @@ def _results_message_to_execute_response(self, resp, operation_state): ) else: arrow_queue_opt = None + + command_id = CommandId.from_thrift_handle(resp.operationHandle) + if command_id is None: + raise ValueError(f"Invalid Thrift handle: {resp.operationHandle}") + return ExecuteResponse( arrow_queue=arrow_queue_opt, - status=operation_state, + status=CommandState.from_thrift_state(operation_state), has_been_closed_server_side=has_been_closed_server_side, has_more_rows=has_more_rows, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - command_handle=resp.operationHandle, + command_id=command_id, description=description, arrow_schema_bytes=schema_bytes, ) - def get_execution_result(self, op_handle, cursor): - - assert op_handle is not None + def get_execution_result( + self, command_id: CommandId, cursor: Cursor + ) -> "ResultSet": + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") req = ttypes.TFetchResultsReq( operationHandle=ttypes.TOperationHandle( - op_handle.operationId, - op_handle.operationType, + thrift_handle.operationId, + thrift_handle.operationType, False, - op_handle.modifiedRowCount, + thrift_handle.modifiedRowCount, ), maxRows=cursor.arraysize, maxBytes=cursor.buffer_size_bytes, @@ -827,18 +862,27 @@ def get_execution_result(self, op_handle, cursor): ssl_options=self._ssl_options, ) - return ExecuteResponse( + execute_response = ExecuteResponse( arrow_queue=queue, - status=resp.status, + status=CommandState.from_thrift_state(resp.status), has_been_closed_server_side=False, has_more_rows=has_more_rows, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - command_handle=op_handle, + command_id=command_id, description=description, arrow_schema_bytes=schema_bytes, ) + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) + def _wait_until_command_done(self, op_handle, initial_operation_status_resp): if initial_operation_status_resp: self._check_command_not_in_error_or_closed_state( @@ -857,51 +901,60 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp): self._check_command_not_in_error_or_closed_state(op_handle, poll_resp) return operation_state - def get_query_state(self, op_handle) -> "TOperationState": - poll_resp = self._poll_for_status(op_handle) + def get_query_state(self, command_id: CommandId) -> CommandState: + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") + + poll_resp = self._poll_for_status(thrift_handle) operation_state = poll_resp.operationState - self._check_command_not_in_error_or_closed_state(op_handle, poll_resp) - return operation_state + self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) + state = CommandState.from_thrift_state(operation_state) + if state is None: + raise ValueError(f"Unknown command state: {operation_state}") + return state @staticmethod def _check_direct_results_for_error(t_spark_direct_results): if t_spark_direct_results: if t_spark_direct_results.operationStatus: - ThriftBackend._check_response_for_error( + ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.operationStatus ) if t_spark_direct_results.resultSetMetadata: - ThriftBackend._check_response_for_error( + ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.resultSetMetadata ) if t_spark_direct_results.resultSet: - ThriftBackend._check_response_for_error( + ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.resultSet ) if t_spark_direct_results.closeOperation: - ThriftBackend._check_response_for_error( + ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.closeOperation ) def execute_command( self, - operation, - session_handle, - max_rows, - max_bytes, - lz4_compression, - cursor, + operation: str, + session_id: SessionId, + max_rows: int, + max_bytes: int, + lz4_compression: bool, + cursor: Cursor, use_cloud_fetch=True, parameters=[], async_op=False, enforce_embedded_schema_correctness=False, - ): - assert session_handle is not None + ) -> Union["ResultSet", None]: + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") logger.debug( "ThriftBackend.execute_command(operation=%s, session_handle=%s)", operation, - session_handle, + thrift_handle, ) spark_arrow_types = ttypes.TSparkArrowTypes( @@ -913,7 +966,7 @@ def execute_command( intervalTypesAsArrow=False, ) req = ttypes.TExecuteStatementReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, statement=operation, runAsync=True, # For async operation we don't want the direct results @@ -938,34 +991,68 @@ def execute_command( if async_op: self._handle_execute_response_async(resp, cursor) + return None else: - return self._handle_execute_response(resp, cursor) + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=use_cloud_fetch, + ) + + def get_catalogs( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: Cursor, + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet - def get_catalogs(self, session_handle, max_rows, max_bytes, cursor): - assert session_handle is not None + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") req = ttypes.TGetCatalogsReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), ) resp = self.make_request(self._client.GetCatalogs, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def get_schemas( self, - session_handle, - max_rows, - max_bytes, - cursor, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: Cursor, catalog_name=None, schema_name=None, - ): - assert session_handle is not None + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") req = ttypes.TGetSchemasReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), @@ -973,23 +1060,37 @@ def get_schemas( schemaName=schema_name, ) resp = self.make_request(self._client.GetSchemas, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def get_tables( self, - session_handle, - max_rows, - max_bytes, - cursor, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: Cursor, catalog_name=None, schema_name=None, table_name=None, table_types=None, - ): - assert session_handle is not None + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") req = ttypes.TGetTablesReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), @@ -999,23 +1100,37 @@ def get_tables( tableTypes=table_types, ) resp = self.make_request(self._client.GetTables, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def get_columns( self, - session_handle, - max_rows, - max_bytes, - cursor, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: Cursor, catalog_name=None, schema_name=None, table_name=None, column_name=None, - ): - assert session_handle is not None + ) -> "ResultSet": + from databricks.sql.result_set import ThriftResultSet + + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") req = ttypes.TGetColumnsReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), @@ -1025,10 +1140,24 @@ def get_columns( columnName=column_name, ) resp = self.make_request(self._client.GetColumns, req) - return self._handle_execute_response(resp, cursor) + + execute_response = self._handle_execute_response(resp, cursor) + + return ThriftResultSet( + connection=cursor.connection, + execute_response=execute_response, + thrift_client=self, + buffer_size_bytes=max_bytes, + arraysize=max_rows, + use_cloud_fetch=cursor.connection.use_cloud_fetch, + ) def _handle_execute_response(self, resp, cursor): - cursor.active_op_handle = resp.operationHandle + command_id = CommandId.from_thrift_handle(resp.operationHandle) + if command_id is None: + raise ValueError(f"Invalid Thrift handle: {resp.operationHandle}") + + cursor.active_command_id = command_id self._check_direct_results_for_error(resp.directResults) final_operation_state = self._wait_until_command_done( @@ -1039,28 +1168,34 @@ def _handle_execute_response(self, resp, cursor): return self._results_message_to_execute_response(resp, final_operation_state) def _handle_execute_response_async(self, resp, cursor): - cursor.active_op_handle = resp.operationHandle + command_id = CommandId.from_thrift_handle(resp.operationHandle) + if command_id is None: + raise ValueError(f"Invalid Thrift handle: {resp.operationHandle}") + + cursor.active_command_id = command_id self._check_direct_results_for_error(resp.directResults) def fetch_results( self, - op_handle, - max_rows, - max_bytes, - expected_row_start_offset, - lz4_compressed, + command_id: CommandId, + max_rows: int, + max_bytes: int, + expected_row_start_offset: int, + lz4_compressed: bool, arrow_schema_bytes, description, use_cloud_fetch=True, ): - assert op_handle is not None + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") req = ttypes.TFetchResultsReq( operationHandle=ttypes.TOperationHandle( - op_handle.operationId, - op_handle.operationType, + thrift_handle.operationId, + thrift_handle.operationType, False, - op_handle.modifiedRowCount, + thrift_handle.modifiedRowCount, ), maxRows=max_rows, maxBytes=max_bytes, @@ -1089,46 +1224,20 @@ def fetch_results( return queue, resp.hasMoreRows - def close_command(self, op_handle): - logger.debug("ThriftBackend.close_command(op_handle=%s)", op_handle) - req = ttypes.TCloseOperationReq(operationHandle=op_handle) - resp = self.make_request(self._client.CloseOperation, req) - return resp.status + def cancel_command(self, command_id: CommandId) -> None: + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") - def cancel_command(self, active_op_handle): - logger.debug( - "Cancelling command {}".format( - self.guid_to_hex_id(active_op_handle.operationId.guid) - ) - ) - req = ttypes.TCancelOperationReq(active_op_handle) + logger.debug("Cancelling command %s", command_id.to_hex_guid()) + req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) - @staticmethod - def handle_to_id(session_handle): - return session_handle.sessionId.guid + def close_command(self, command_id: CommandId) -> None: + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") - @staticmethod - def handle_to_hex_id(session_handle: TCLIService.TSessionHandle): - this_uuid = uuid.UUID(bytes=session_handle.sessionId.guid) - return str(this_uuid) - - @staticmethod - def guid_to_hex_id(guid: bytes) -> str: - """Return a hexadecimal string instead of bytes - - Example: - IN b'\x01\xee\x1d)\xa4\x19\x1d\xb6\xa9\xc0\x8d\xf1\xfe\xbaB\xdd' - OUT '01ee1d29-a419-1db6-a9c0-8df1feba42dd' - - If conversion to hexadecimal fails, the original bytes are returned - """ - - this_uuid: Union[bytes, uuid.UUID] - - try: - this_uuid = uuid.UUID(bytes=guid) - except Exception as e: - logger.debug(f"Unable to convert bytes to UUID: {bytes} -- {str(e)}") - this_uuid = guid - return str(this_uuid) + logger.debug("ThriftBackend.close_command(command_id=%s)", command_id) + req = ttypes.TCloseOperationReq(operationHandle=thrift_handle) + self.make_request(self._client.CloseOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py new file mode 100644 index 000000000..ddeac474a --- /dev/null +++ b/src/databricks/sql/backend/types.py @@ -0,0 +1,391 @@ +from enum import Enum +from typing import Dict, Optional, Any +import logging + +from databricks.sql.backend.utils.guid_utils import guid_to_hex_id +from databricks.sql.thrift_api.TCLIService import ttypes + +logger = logging.getLogger(__name__) + + +class CommandState(Enum): + """ + Enum representing the execution state of a command in Databricks SQL. + + This enum maps Thrift operation states to normalized command states, + providing a consistent interface for tracking command execution status + across different backend implementations. + + Attributes: + PENDING: Command is queued or initialized but not yet running + RUNNING: Command is currently executing + SUCCEEDED: Command completed successfully + FAILED: Command failed due to error, timeout, or unknown state + CLOSED: Command has been closed + CANCELLED: Command was cancelled before completion + """ + + PENDING = "PENDING" + RUNNING = "RUNNING" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + CLOSED = "CLOSED" + CANCELLED = "CANCELLED" + + @classmethod + def from_thrift_state( + cls, state: ttypes.TOperationState + ) -> Optional["CommandState"]: + """ + Convert a Thrift TOperationState to a normalized CommandState. + + Args: + state: A TOperationState from the Thrift API representing the current + state of an operation + + Returns: + CommandState: The corresponding normalized command state + + Raises: + ValueError: If the provided state is not a recognized TOperationState + + State Mappings: + - INITIALIZED_STATE, PENDING_STATE -> PENDING + - RUNNING_STATE -> RUNNING + - FINISHED_STATE -> SUCCEEDED + - ERROR_STATE, TIMEDOUT_STATE, UKNOWN_STATE -> FAILED + - CLOSED_STATE -> CLOSED + - CANCELED_STATE -> CANCELLED + """ + + if state in ( + ttypes.TOperationState.INITIALIZED_STATE, + ttypes.TOperationState.PENDING_STATE, + ): + return cls.PENDING + elif state == ttypes.TOperationState.RUNNING_STATE: + return cls.RUNNING + elif state == ttypes.TOperationState.FINISHED_STATE: + return cls.SUCCEEDED + elif state in ( + ttypes.TOperationState.ERROR_STATE, + ttypes.TOperationState.TIMEDOUT_STATE, + ttypes.TOperationState.UKNOWN_STATE, + ): + return cls.FAILED + elif state == ttypes.TOperationState.CLOSED_STATE: + return cls.CLOSED + elif state == ttypes.TOperationState.CANCELED_STATE: + return cls.CANCELLED + else: + return None + + +class BackendType(Enum): + """ + Enum representing the type of backend + """ + + THRIFT = "thrift" + SEA = "sea" + + +class SessionId: + """ + A normalized session identifier that works with both Thrift and SEA backends. + + This class abstracts away the differences between Thrift's TSessionHandle and + SEA's session ID string, providing a consistent interface for the connector. + """ + + def __init__( + self, + backend_type: BackendType, + guid: Any, + secret: Optional[Any] = None, + properties: Optional[Dict[str, Any]] = None, + ): + """ + Initialize a SessionId. + + Args: + backend_type: The type of backend (THRIFT or SEA) + guid: The primary identifier for the session + secret: The secret part of the identifier (only used for Thrift) + properties: Additional information about the session + """ + + self.backend_type = backend_type + self.guid = guid + self.secret = secret + self.properties = properties or {} + + def __str__(self) -> str: + """ + Return a string representation of the SessionId. + + For SEA backend, returns the guid. + For Thrift backend, returns a format like "guid|secret". + + Returns: + A string representation of the session ID + """ + + if self.backend_type == BackendType.SEA: + return str(self.guid) + elif self.backend_type == BackendType.THRIFT: + secret_hex = ( + guid_to_hex_id(self.secret) + if isinstance(self.secret, bytes) + else str(self.secret) + ) + return f"{self.hex_guid}|{secret_hex}" + return str(self.guid) + + @classmethod + def from_thrift_handle( + cls, session_handle, properties: Optional[Dict[str, Any]] = None + ): + """ + Create a SessionId from a Thrift session handle. + + Args: + session_handle: A TSessionHandle object from the Thrift API + + Returns: + A SessionId instance + """ + + if session_handle is None: + return None + + guid_bytes = session_handle.sessionId.guid + secret_bytes = session_handle.sessionId.secret + + if session_handle.serverProtocolVersion is not None: + if properties is None: + properties = {} + properties["serverProtocolVersion"] = session_handle.serverProtocolVersion + + return cls(BackendType.THRIFT, guid_bytes, secret_bytes, properties) + + @classmethod + def from_sea_session_id( + cls, session_id: str, properties: Optional[Dict[str, Any]] = None + ): + """ + Create a SessionId from a SEA session ID. + + Args: + session_id: The SEA session ID string + + Returns: + A SessionId instance + """ + + return cls(BackendType.SEA, session_id, properties=properties) + + def to_thrift_handle(self): + """ + Convert this SessionId to a Thrift TSessionHandle. + + Returns: + A TSessionHandle object or None if this is not a Thrift session ID + """ + + if self.backend_type != BackendType.THRIFT: + return None + + from databricks.sql.thrift_api.TCLIService import ttypes + + handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret) + server_protocol_version = self.properties.get("serverProtocolVersion") + return ttypes.TSessionHandle( + sessionId=handle_identifier, serverProtocolVersion=server_protocol_version + ) + + def to_sea_session_id(self): + """ + Get the SEA session ID string. + + Returns: + The session ID string or None if this is not a SEA session ID + """ + + if self.backend_type != BackendType.SEA: + return None + + return self.guid + + @property + def hex_guid(self) -> str: + """ + Get a hexadecimal string representation of the session ID. + + Returns: + A hexadecimal string representation + """ + + if isinstance(self.guid, bytes): + return guid_to_hex_id(self.guid) + else: + return str(self.guid) + + @property + def protocol_version(self): + """ + Get the server protocol version for this session. + + Returns: + The server protocol version or None if it does not exist + It is not expected to exist for SEA sessions. + """ + + return self.properties.get("serverProtocolVersion") + + +class CommandId: + """ + A normalized command identifier that works with both Thrift and SEA backends. + + This class abstracts away the differences between Thrift's TOperationHandle and + SEA's statement ID string, providing a consistent interface for the connector. + """ + + def __init__( + self, + backend_type: BackendType, + guid: Any, + secret: Optional[Any] = None, + operation_type: Optional[int] = None, + has_result_set: bool = False, + modified_row_count: Optional[int] = None, + ): + """ + Initialize a CommandId. + + Args: + backend_type: The type of backend (THRIFT or SEA) + guid: The primary identifier for the command + secret: The secret part of the identifier (only used for Thrift) + operation_type: The operation type (only used for Thrift) + has_result_set: Whether the command has a result set + modified_row_count: The number of rows modified by the command + """ + + self.backend_type = backend_type + self.guid = guid + self.secret = secret + self.operation_type = operation_type + self.has_result_set = has_result_set + self.modified_row_count = modified_row_count + + def __str__(self) -> str: + """ + Return a string representation of the CommandId. + + For SEA backend, returns the guid. + For Thrift backend, returns a format like "guid|secret". + + Returns: + A string representation of the command ID + """ + + if self.backend_type == BackendType.SEA: + return str(self.guid) + elif self.backend_type == BackendType.THRIFT: + secret_hex = ( + guid_to_hex_id(self.secret) + if isinstance(self.secret, bytes) + else str(self.secret) + ) + return f"{self.to_hex_guid()}|{secret_hex}" + return str(self.guid) + + @classmethod + def from_thrift_handle(cls, operation_handle): + """ + Create a CommandId from a Thrift operation handle. + + Args: + operation_handle: A TOperationHandle object from the Thrift API + + Returns: + A CommandId instance + """ + + if operation_handle is None: + return None + + guid_bytes = operation_handle.operationId.guid + secret_bytes = operation_handle.operationId.secret + + return cls( + BackendType.THRIFT, + guid_bytes, + secret_bytes, + operation_handle.operationType, + operation_handle.hasResultSet, + operation_handle.modifiedRowCount, + ) + + @classmethod + def from_sea_statement_id(cls, statement_id: str): + """ + Create a CommandId from a SEA statement ID. + + Args: + statement_id: The SEA statement ID string + + Returns: + A CommandId instance + """ + + return cls(BackendType.SEA, statement_id) + + def to_thrift_handle(self): + """ + Convert this CommandId to a Thrift TOperationHandle. + + Returns: + A TOperationHandle object or None if this is not a Thrift command ID + """ + + if self.backend_type != BackendType.THRIFT: + return None + + from databricks.sql.thrift_api.TCLIService import ttypes + + handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret) + return ttypes.TOperationHandle( + operationId=handle_identifier, + operationType=self.operation_type, + hasResultSet=self.has_result_set, + modifiedRowCount=self.modified_row_count, + ) + + def to_sea_statement_id(self): + """ + Get the SEA statement ID string. + + Returns: + The statement ID string or None if this is not a SEA statement ID + """ + + if self.backend_type != BackendType.SEA: + return None + + return self.guid + + def to_hex_guid(self) -> str: + """ + Get a hexadecimal string representation of the command ID. + + Returns: + A hexadecimal string representation + """ + + if isinstance(self.guid, bytes): + return guid_to_hex_id(self.guid) + else: + return str(self.guid) diff --git a/src/databricks/sql/backend/utils/__init__.py b/src/databricks/sql/backend/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/databricks/sql/backend/utils/guid_utils.py b/src/databricks/sql/backend/utils/guid_utils.py new file mode 100644 index 000000000..a6cb0e0db --- /dev/null +++ b/src/databricks/sql/backend/utils/guid_utils.py @@ -0,0 +1,23 @@ +import uuid +import logging + +logger = logging.getLogger(__name__) + + +def guid_to_hex_id(guid: bytes) -> str: + """Return a hexadecimal string instead of bytes + + Example: + IN b'\x01\xee\x1d)\xa4\x19\x1d\xb6\xa9\xc0\x8d\xf1\xfe\xbaB\xdd' + OUT '01ee1d29-a419-1db6-a9c0-8df1feba42dd' + + If conversion to hexadecimal fails, a string representation of the original + bytes is returned + """ + + try: + this_uuid = uuid.UUID(bytes=guid) + except Exception as e: + logger.debug("Unable to convert bytes to UUID: %r -- %s", guid, str(e)) + return str(guid) + return str(this_uuid) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index b81416e15..7886c2f6f 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -21,7 +21,8 @@ CursorAlreadyClosedError, ) from databricks.sql.thrift_api.TCLIService import ttypes -from databricks.sql.thrift_backend import ThriftBackend +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.utils import ( ExecuteResponse, ParamEscaper, @@ -41,12 +42,15 @@ ParameterApproach, ) - +from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.types import Row, SSLOptions from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import OAuthPersistence +from databricks.sql.session import Session +from databricks.sql.backend.types import CommandId, BackendType, CommandState, SessionId from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, TSparkParameter, TOperationState, ) @@ -224,70 +228,27 @@ def read(self) -> Optional[OAuthToken]: access_token_kv = {"access_token": access_token} kwargs = {**kwargs, **access_token_kv} - self.open = False - self.host = server_hostname - self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) + self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) + self._cursors = [] # type: List[Cursor] - auth_provider = get_python_sql_connector_auth_provider( - server_hostname, **kwargs - ) - - user_agent_entry = kwargs.get("user_agent_entry") - if user_agent_entry is None: - user_agent_entry = kwargs.get("_user_agent_entry") - if user_agent_entry is not None: - logger.warning( - "[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. " - "This parameter will be removed in the upcoming releases." - ) - - if user_agent_entry: - useragent_header = "{}/{} ({})".format( - USER_AGENT_NAME, __version__, user_agent_entry - ) - else: - useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) - - base_headers = [("User-Agent", useragent_header)] - - self._ssl_options = SSLOptions( - # Double negation is generally a bad thing, but we have to keep backward compatibility - tls_verify=not kwargs.get( - "_tls_no_verify", False - ), # by default - verify cert and host - tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), - tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), - tls_client_cert_file=kwargs.get("_tls_client_cert_file"), - tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), - tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), - ) - - self.thrift_backend = ThriftBackend( - self.host, - self.port, + self.session = Session( + server_hostname, http_path, - (http_headers or []) + base_headers, - auth_provider, - ssl_options=self._ssl_options, - _use_arrow_native_complex_types=_use_arrow_native_complex_types, + http_headers, + session_configuration, + catalog, + schema, + _use_arrow_native_complex_types, **kwargs, ) - - self._open_session_resp = self.thrift_backend.open_session( - session_configuration, catalog, schema - ) - self._session_handle = self._open_session_resp.sessionHandle - self.protocol_version = self.get_protocol_version(self._open_session_resp) - self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) - self.open = True - logger.info("Successfully opened session " + str(self.get_session_id_hex())) - self._cursors = [] # type: List[Cursor] + self.session.open() self.use_inline_params = self._set_use_inline_params_with_warning( kwargs.get("use_inline_params", False) ) + self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None) def _set_use_inline_params_with_warning(self, value: Union[bool, str]): """Valid values are True, False, and "silent" @@ -336,34 +297,40 @@ def __del__(self): logger.debug("Couldn't close unclosed connection: {}".format(e.message)) def get_session_id(self): - return self.thrift_backend.handle_to_id(self._session_handle) + """Get the raw session ID (backend-specific)""" + return self.session.guid - @staticmethod - def get_protocol_version(openSessionResp): - """ - Since the sessionHandle will sometimes have a serverProtocolVersion, it takes - precedence over the serverProtocolVersion defined in the OpenSessionResponse. - """ - if ( - openSessionResp.sessionHandle - and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion") - and openSessionResp.sessionHandle.serverProtocolVersion - ): - return openSessionResp.sessionHandle.serverProtocolVersion - return openSessionResp.serverProtocolVersion + def get_session_id_hex(self): + """Get the session ID in hex format""" + return self.session.guid_hex @staticmethod def server_parameterized_queries_enabled(protocolVersion): - if ( - protocolVersion - and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 - ): - return True - else: - return False + """Delegate to Session class static method""" + return Session.server_parameterized_queries_enabled(protocolVersion) - def get_session_id_hex(self): - return self.thrift_backend.handle_to_hex_id(self._session_handle) + @property + def protocol_version(self): + """Get the protocol version from the Session object""" + return self.session.protocol_version + + @staticmethod + def get_protocol_version(openSessionResp: TOpenSessionResp): + """Delegate to Session class static method""" + properties = ( + {"serverProtocolVersion": openSessionResp.serverProtocolVersion} + if openSessionResp.serverProtocolVersion + else {} + ) + session_id = SessionId.from_thrift_handle( + openSessionResp.sessionHandle, properties + ) + return Session.get_protocol_version(session_id) + + @property + def open(self) -> bool: + """Return whether the connection is open by checking if the session is open.""" + return self.session.is_open def cursor( self, @@ -380,7 +347,7 @@ def cursor( cursor = Cursor( self, - self.thrift_backend, + self.session.backend, arraysize=arraysize, result_buffer_size_bytes=buffer_size_bytes, ) @@ -396,28 +363,10 @@ def _close(self, close_cursors=True) -> None: for cursor in self._cursors: cursor.close() - logger.info(f"Closing session {self.get_session_id_hex()}") - if not self.open: - logger.debug("Session appears to have been closed already") - try: - self.thrift_backend.close_session(self._session_handle) - except RequestError as e: - if isinstance(e.args[1], SessionAlreadyClosedError): - logger.info("Session was closed by a prior request") - except DatabaseError as e: - if "Invalid SessionHandle" in str(e): - logger.warning( - f"Attempted to close session that was already closed: {e}" - ) - else: - logger.warning( - f"Attempt to close session raised an exception at the server: {e}" - ) + self.session.close() except Exception as e: - logger.error(f"Attempt to close session raised a local exception: {e}") - - self.open = False + logger.error(f"Attempt to close session raised an exception: {e}") def commit(self): """No-op because Databricks does not support transactions""" @@ -431,7 +380,7 @@ class Cursor: def __init__( self, connection: Connection, - thrift_backend: ThriftBackend, + backend: DatabricksClient, result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, arraysize: int = DEFAULT_ARRAY_SIZE, ) -> None: @@ -442,6 +391,7 @@ def __init__( Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately visible by other cursors or connections. """ + self.connection = connection self.rowcount = -1 # Return -1 as this is not supported self.buffer_size_bytes = result_buffer_size_bytes @@ -450,8 +400,8 @@ def __init__( # Note that Cursor closed => active result set closed, but not vice versa self.open = True self.executing_command_id = None - self.thrift_backend = thrift_backend - self.active_op_handle = None + self.backend = backend + self.active_command_id = None self.escaper = ParamEscaper() self.lastrowid = None @@ -793,6 +743,7 @@ def execute( :returns self """ + logger.debug( "Cursor.execute(operation=%s, parameters=%s)", operation, parameters ) @@ -818,9 +769,9 @@ def execute( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.execute_command( + self.active_result_set = self.backend.execute_command( operation=prepared_operation, - session_handle=self.connection._session_handle, + session_id=self.connection.session.session_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -830,18 +781,10 @@ def execute( async_op=False, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.thrift_backend, - self.buffer_size_bytes, - self.arraysize, - self.connection.use_cloud_fetch, - ) - if execute_response.is_staging_operation: + if self.active_result_set and self.active_result_set.is_staging_operation: self._handle_staging_operation( - staging_allowed_local_path=self.thrift_backend.staging_allowed_local_path + staging_allowed_local_path=self.connection.staging_allowed_local_path ) return self @@ -860,6 +803,7 @@ def execute_async( :param parameters: :return: """ + param_approach = self._determine_parameter_approach(parameters) if param_approach == ParameterApproach.NONE: prepared_params = NO_NATIVE_PARAMS @@ -881,9 +825,9 @@ def execute_async( self._check_not_closed() self._close_and_clear_active_result_set() - self.thrift_backend.execute_command( + self.backend.execute_command( operation=prepared_operation, - session_handle=self.connection._session_handle, + session_id=self.connection.session.session_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -896,14 +840,16 @@ def execute_async( return self - def get_query_state(self) -> "TOperationState": + def get_query_state(self) -> CommandState: """ Get the state of the async executing query or basically poll the status of the query :return: """ self._check_not_closed() - return self.thrift_backend.get_query_state(self.active_op_handle) + if self.active_command_id is None: + raise Error("No active command to get state for") + return self.backend.get_query_state(self.active_command_id) def is_query_pending(self): """ @@ -912,11 +858,7 @@ def is_query_pending(self): :return: """ operation_state = self.get_query_state() - - return not operation_state or operation_state in [ - ttypes.TOperationState.RUNNING_STATE, - ttypes.TOperationState.PENDING_STATE, - ] + return operation_state in [CommandState.PENDING, CommandState.RUNNING] def get_async_execution_result(self): """ @@ -932,21 +874,14 @@ def get_async_execution_result(self): time.sleep(self.ASYNC_DEFAULT_POLLING_INTERVAL) operation_state = self.get_query_state() - if operation_state == ttypes.TOperationState.FINISHED_STATE: - execute_response = self.thrift_backend.get_execution_result( - self.active_op_handle, self - ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.thrift_backend, - self.buffer_size_bytes, - self.arraysize, + if operation_state == CommandState.SUCCEEDED: + self.active_result_set = self.backend.get_execution_result( + self.active_command_id, self ) - if execute_response.is_staging_operation: + if self.active_result_set and self.active_result_set.is_staging_operation: self._handle_staging_operation( - staging_allowed_local_path=self.thrift_backend.staging_allowed_local_path + staging_allowed_local_path=self.connection.staging_allowed_local_path ) return self @@ -978,19 +913,12 @@ def catalogs(self) -> "Cursor": """ self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.get_catalogs( - session_handle=self.connection._session_handle, + self.active_result_set = self.backend.get_catalogs( + session_id=self.connection.session.session_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.thrift_backend, - self.buffer_size_bytes, - self.arraysize, - ) return self def schemas( @@ -1004,21 +932,14 @@ def schemas( """ self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.get_schemas( - session_handle=self.connection._session_handle, + self.active_result_set = self.backend.get_schemas( + session_id=self.connection.session.session_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, catalog_name=catalog_name, schema_name=schema_name, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.thrift_backend, - self.buffer_size_bytes, - self.arraysize, - ) return self def tables( @@ -1037,8 +958,8 @@ def tables( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.get_tables( - session_handle=self.connection._session_handle, + self.active_result_set = self.backend.get_tables( + session_id=self.connection.session.session_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1047,13 +968,6 @@ def tables( table_name=table_name, table_types=table_types, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.thrift_backend, - self.buffer_size_bytes, - self.arraysize, - ) return self def columns( @@ -1072,8 +986,8 @@ def columns( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.get_columns( - session_handle=self.connection._session_handle, + self.active_result_set = self.backend.get_columns( + session_id=self.connection.session.session_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1082,13 +996,6 @@ def columns( table_name=table_name, column_name=column_name, ) - self.active_result_set = ResultSet( - self.connection, - execute_response, - self.thrift_backend, - self.buffer_size_bytes, - self.arraysize, - ) return self def fetchall(self) -> List[Row]: @@ -1161,8 +1068,8 @@ def cancel(self) -> None: The command should be closed to free resources from the server. This method can be called from another thread. """ - if self.active_op_handle is not None: - self.thrift_backend.cancel_command(self.active_op_handle) + if self.active_command_id is not None: + self.backend.cancel_command(self.active_command_id) else: logger.warning( "Attempting to cancel a command, but there is no " @@ -1172,7 +1079,7 @@ def cancel(self) -> None: def close(self) -> None: """Close cursor""" self.open = False - self.active_op_handle = None + self.active_command_id = None if self.active_result_set: self._close_and_clear_active_result_set() @@ -1184,8 +1091,8 @@ def query_id(self) -> Optional[str]: This attribute will be ``None`` if the cursor has not had an operation invoked via the execute method yet, or if cursor was closed. """ - if self.active_op_handle is not None: - return str(UUID(bytes=self.active_op_handle.operationId.guid)) + if self.active_command_id is not None: + return self.active_command_id.to_hex_guid() return None @property @@ -1230,301 +1137,3 @@ def setinputsizes(self, sizes): def setoutputsize(self, size, column=None): """Does nothing by default""" pass - - -class ResultSet: - def __init__( - self, - connection: Connection, - execute_response: ExecuteResponse, - thrift_backend: ThriftBackend, - result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, - arraysize: int = 10000, - use_cloud_fetch: bool = True, - ): - """ - A ResultSet manages the results of a single command. - - :param connection: The parent connection that was used to execute this command - :param execute_response: A `ExecuteResponse` class returned by a command execution - :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - amount :param arraysize: The max number of rows to fetch at a time (PEP-249) - """ - self.connection = connection - self.command_id = execute_response.command_handle - self.op_state = execute_response.status - self.has_been_closed_server_side = execute_response.has_been_closed_server_side - self.has_more_rows = execute_response.has_more_rows - self.buffer_size_bytes = result_buffer_size_bytes - self.lz4_compressed = execute_response.lz4_compressed - self.arraysize = arraysize - self.thrift_backend = thrift_backend - self.description = execute_response.description - self._arrow_schema_bytes = execute_response.arrow_schema_bytes - self._next_row_index = 0 - self._use_cloud_fetch = use_cloud_fetch - - if execute_response.arrow_queue: - # In this case the server has taken the fast path and returned an initial batch of - # results - self.results = execute_response.arrow_queue - else: - # In this case, there are results waiting on the server so we fetch now for simplicity - self._fill_results_buffer() - - def __iter__(self): - while True: - row = self.fetchone() - if row: - yield row - else: - break - - def _fill_results_buffer(self): - # At initialization or if the server does not have cloud fetch result links available - results, has_more_rows = self.thrift_backend.fetch_results( - op_handle=self.command_id, - max_rows=self.arraysize, - max_bytes=self.buffer_size_bytes, - expected_row_start_offset=self._next_row_index, - lz4_compressed=self.lz4_compressed, - arrow_schema_bytes=self._arrow_schema_bytes, - description=self.description, - use_cloud_fetch=self._use_cloud_fetch, - ) - self.results = results - self.has_more_rows = has_more_rows - - def _convert_columnar_table(self, table): - column_names = [c[0] for c in self.description] - ResultRow = Row(*column_names) - result = [] - for row_index in range(table.num_rows): - curr_row = [] - for col_index in range(table.num_columns): - curr_row.append(table.get_item(col_index, row_index)) - result.append(ResultRow(*curr_row)) - - return result - - def _convert_arrow_table(self, table): - column_names = [c[0] for c in self.description] - ResultRow = Row(*column_names) - - if self.connection.disable_pandas is True: - return [ - ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) - ] - - # Need to use nullable types, as otherwise type can change when there are missing values. - # See https://arrow.apache.org/docs/python/pandas.html#nullable-types - # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - dtype_mapping = { - pyarrow.int8(): pandas.Int8Dtype(), - pyarrow.int16(): pandas.Int16Dtype(), - pyarrow.int32(): pandas.Int32Dtype(), - pyarrow.int64(): pandas.Int64Dtype(), - pyarrow.uint8(): pandas.UInt8Dtype(), - pyarrow.uint16(): pandas.UInt16Dtype(), - pyarrow.uint32(): pandas.UInt32Dtype(), - pyarrow.uint64(): pandas.UInt64Dtype(), - pyarrow.bool_(): pandas.BooleanDtype(), - pyarrow.float32(): pandas.Float32Dtype(), - pyarrow.float64(): pandas.Float64Dtype(), - pyarrow.string(): pandas.StringDtype(), - } - - # Need to rename columns, as the to_pandas function cannot handle duplicate column names - table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) - df = table_renamed.to_pandas( - types_mapper=dtype_mapping.get, - date_as_object=True, - timestamp_as_object=True, - ) - - res = df.to_numpy(na_value=None, dtype="object") - return [ResultRow(*v) for v in res] - - @property - def rownumber(self): - return self._next_row_index - - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": - """ - Fetch the next set of rows of a query result, returning a PyArrow table. - - An empty sequence is returned when no more rows are available. - """ - if size < 0: - raise ValueError("size argument for fetchmany is %s but must be >= 0", size) - results = self.results.next_n_rows(size) - n_remaining_rows = size - results.num_rows - self._next_row_index += results.num_rows - - while ( - n_remaining_rows > 0 - and not self.has_been_closed_server_side - and self.has_more_rows - ): - self._fill_results_buffer() - partial_results = self.results.next_n_rows(n_remaining_rows) - results = pyarrow.concat_tables([results, partial_results]) - n_remaining_rows -= partial_results.num_rows - self._next_row_index += partial_results.num_rows - - return results - - def merge_columnar(self, result1, result2): - """ - Function to merge / combining the columnar results into a single result - :param result1: - :param result2: - :return: - """ - - if result1.column_names != result2.column_names: - raise ValueError("The columns in the results don't match") - - merged_result = [ - result1.column_table[i] + result2.column_table[i] - for i in range(result1.num_columns) - ] - return ColumnTable(merged_result, result1.column_names) - - def fetchmany_columnar(self, size: int): - """ - Fetch the next set of rows of a query result, returning a Columnar Table. - An empty sequence is returned when no more rows are available. - """ - if size < 0: - raise ValueError("size argument for fetchmany is %s but must be >= 0", size) - - results = self.results.next_n_rows(size) - n_remaining_rows = size - results.num_rows - self._next_row_index += results.num_rows - - while ( - n_remaining_rows > 0 - and not self.has_been_closed_server_side - and self.has_more_rows - ): - self._fill_results_buffer() - partial_results = self.results.next_n_rows(n_remaining_rows) - results = self.merge_columnar(results, partial_results) - n_remaining_rows -= partial_results.num_rows - self._next_row_index += partial_results.num_rows - - return results - - def fetchall_arrow(self) -> "pyarrow.Table": - """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" - results = self.results.remaining_rows() - self._next_row_index += results.num_rows - - while not self.has_been_closed_server_side and self.has_more_rows: - self._fill_results_buffer() - partial_results = self.results.remaining_rows() - if isinstance(results, ColumnTable) and isinstance( - partial_results, ColumnTable - ): - results = self.merge_columnar(results, partial_results) - else: - results = pyarrow.concat_tables([results, partial_results]) - self._next_row_index += partial_results.num_rows - - # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table - # Valid only for metadata commands result set - if isinstance(results, ColumnTable) and pyarrow: - data = { - name: col - for name, col in zip(results.column_names, results.column_table) - } - return pyarrow.Table.from_pydict(data) - return results - - def fetchall_columnar(self): - """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" - results = self.results.remaining_rows() - self._next_row_index += results.num_rows - - while not self.has_been_closed_server_side and self.has_more_rows: - self._fill_results_buffer() - partial_results = self.results.remaining_rows() - results = self.merge_columnar(results, partial_results) - self._next_row_index += partial_results.num_rows - - return results - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - if isinstance(self.results, ColumnQueue): - res = self._convert_columnar_table(self.fetchmany_columnar(1)) - else: - res = self._convert_arrow_table(self.fetchmany_arrow(1)) - - if len(res) > 0: - return res[0] - else: - return None - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - if isinstance(self.results, ColumnQueue): - return self._convert_columnar_table(self.fetchall_columnar()) - else: - return self._convert_arrow_table(self.fetchall_arrow()) - - def fetchmany(self, size: int) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - if isinstance(self.results, ColumnQueue): - return self._convert_columnar_table(self.fetchmany_columnar(size)) - else: - return self._convert_arrow_table(self.fetchmany_arrow(size)) - - def close(self) -> None: - """ - Close the cursor. - - If the connection has not been closed, and the cursor has not already - been closed on the server for some other reason, issue a request to the server to close it. - """ - try: - if ( - self.op_state != self.thrift_backend.CLOSED_OP_STATE - and not self.has_been_closed_server_side - and self.connection.open - ): - self.thrift_backend.close_command(self.command_id) - except RequestError as e: - if isinstance(e.args[1], CursorAlreadyClosedError): - logger.info("Operation was canceled by a prior request") - finally: - self.has_been_closed_server_side = True - self.op_state = self.thrift_backend.CLOSED_OP_STATE - - @staticmethod - def _get_schema_description(table_schema_message): - """ - Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 - """ - - def map_col_type(type_): - if type_.startswith("decimal"): - return "decimal" - else: - return type_ - - return [ - (column.name, map_col_type(column.datatype), None, None, None, None, None) - for column in table_schema_message.columns - ] diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py new file mode 100644 index 000000000..2ffc3f257 --- /dev/null +++ b/src/databricks/sql/result_set.py @@ -0,0 +1,415 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import List, Optional, TYPE_CHECKING + +import logging +import pandas + +from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.backend.types import CommandId, CommandState + +try: + import pyarrow +except ImportError: + pyarrow = None + +if TYPE_CHECKING: + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + from databricks.sql.client import Connection + +from databricks.sql.types import Row +from databricks.sql.exc import RequestError, CursorAlreadyClosedError +from databricks.sql.utils import ExecuteResponse, ColumnTable, ColumnQueue + +logger = logging.getLogger(__name__) + + +class ResultSet(ABC): + """ + Abstract base class for result sets returned by different backend implementations. + + This class defines the interface that all concrete result set implementations must follow. + """ + + def __init__( + self, + connection: Connection, + backend: DatabricksClient, + command_id: CommandId, + op_state: Optional[CommandState], + has_been_closed_server_side: bool, + arraysize: int, + buffer_size_bytes: int, + ): + """ + A ResultSet manages the results of a single command. + + :param connection: The parent connection that was used to execute this command + :param backend: The specialised backend client to be invoked in the fetch phase + :param execute_response: A `ExecuteResponse` class returned by a command execution + :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + amount :param arraysize: The max number of rows to fetch at a time (PEP-249) + """ + + self.command_id = command_id + self.op_state = op_state + self.has_been_closed_server_side = has_been_closed_server_side + self.connection = connection + self.backend = backend + self.arraysize = arraysize + self.buffer_size_bytes = buffer_size_bytes + self._next_row_index = 0 + self.description = None + + def __iter__(self): + while True: + row = self.fetchone() + if row: + yield row + else: + break + + @property + def rownumber(self): + return self._next_row_index + + @property + @abstractmethod + def is_staging_operation(self) -> bool: + """Whether this result set represents a staging operation.""" + pass + + # Define abstract methods that concrete implementations must implement + @abstractmethod + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + pass + + @abstractmethod + def fetchone(self) -> Optional[Row]: + """Fetch the next row of a query result set.""" + pass + + @abstractmethod + def fetchmany(self, size: int) -> List[Row]: + """Fetch the next set of rows of a query result.""" + pass + + @abstractmethod + def fetchall(self) -> List[Row]: + """Fetch all remaining rows of a query result.""" + pass + + @abstractmethod + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + """Fetch the next set of rows as an Arrow table.""" + pass + + @abstractmethod + def fetchall_arrow(self) -> "pyarrow.Table": + """Fetch all remaining rows as an Arrow table.""" + pass + + def close(self) -> None: + """ + Close the result set. + + If the connection has not been closed, and the result set has not already + been closed on the server for some other reason, issue a request to the server to close it. + """ + + try: + if ( + self.op_state != CommandState.CLOSED + and not self.has_been_closed_server_side + and self.connection.open + ): + self.backend.close_command(self.command_id) + except RequestError as e: + if isinstance(e.args[1], CursorAlreadyClosedError): + logger.info("Operation was canceled by a prior request") + finally: + self.has_been_closed_server_side = True + self.op_state = CommandState.CLOSED + + +class ThriftResultSet(ResultSet): + """ResultSet implementation for the Thrift backend.""" + + def __init__( + self, + connection: Connection, + execute_response: ExecuteResponse, + thrift_client: ThriftDatabricksClient, + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + use_cloud_fetch: bool = True, + ): + """ + Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. + + Args: + connection: The parent connection + execute_response: Response from the execute command + thrift_client: The ThriftDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + use_cloud_fetch: Whether to use cloud fetch for retrieving results + """ + + super().__init__( + connection, + thrift_client, + execute_response.command_id, + execute_response.status, + execute_response.has_been_closed_server_side, + arraysize, + buffer_size_bytes, + ) + + # Initialize ThriftResultSet-specific attributes + self.has_been_closed_server_side = execute_response.has_been_closed_server_side + self.has_more_rows = execute_response.has_more_rows + self.lz4_compressed = execute_response.lz4_compressed + self.description = execute_response.description + self._arrow_schema_bytes = execute_response.arrow_schema_bytes + self._use_cloud_fetch = use_cloud_fetch + self._is_staging_operation = execute_response.is_staging_operation + + # Initialize results queue + if execute_response.arrow_queue: + # In this case the server has taken the fast path and returned an initial batch of + # results + self.results = execute_response.arrow_queue + else: + # In this case, there are results waiting on the server so we fetch now for simplicity + self._fill_results_buffer() + + def _fill_results_buffer(self): + # At initialization or if the server does not have cloud fetch result links available + results, has_more_rows = self.backend.fetch_results( + command_id=self.command_id, + max_rows=self.arraysize, + max_bytes=self.buffer_size_bytes, + expected_row_start_offset=self._next_row_index, + lz4_compressed=self.lz4_compressed, + arrow_schema_bytes=self._arrow_schema_bytes, + description=self.description, + use_cloud_fetch=self._use_cloud_fetch, + ) + self.results = results + self.has_more_rows = has_more_rows + + def _convert_columnar_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + result = [] + for row_index in range(table.num_rows): + curr_row = [] + for col_index in range(table.num_columns): + curr_row.append(table.get_item(col_index, row_index)) + result.append(ResultRow(*curr_row)) + + return result + + def _convert_arrow_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + + if self.connection.disable_pandas is True: + return [ + ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) + ] + + # Need to use nullable types, as otherwise type can change when there are missing values. + # See https://arrow.apache.org/docs/python/pandas.html#nullable-types + # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html + dtype_mapping = { + pyarrow.int8(): pandas.Int8Dtype(), + pyarrow.int16(): pandas.Int16Dtype(), + pyarrow.int32(): pandas.Int32Dtype(), + pyarrow.int64(): pandas.Int64Dtype(), + pyarrow.uint8(): pandas.UInt8Dtype(), + pyarrow.uint16(): pandas.UInt16Dtype(), + pyarrow.uint32(): pandas.UInt32Dtype(), + pyarrow.uint64(): pandas.UInt64Dtype(), + pyarrow.bool_(): pandas.BooleanDtype(), + pyarrow.float32(): pandas.Float32Dtype(), + pyarrow.float64(): pandas.Float64Dtype(), + pyarrow.string(): pandas.StringDtype(), + } + + # Need to rename columns, as the to_pandas function cannot handle duplicate column names + table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) + df = table_renamed.to_pandas( + types_mapper=dtype_mapping.get, + date_as_object=True, + timestamp_as_object=True, + ) + + res = df.to_numpy(na_value=None, dtype="object") + return [ResultRow(*v) for v in res] + + def merge_columnar(self, result1, result2) -> "ColumnTable": + """ + Function to merge / combining the columnar results into a single result + :param result1: + :param result2: + :return: + """ + + if result1.column_names != result2.column_names: + raise ValueError("The columns in the results don't match") + + merged_result = [ + result1.column_table[i] + result2.column_table[i] + for i in range(result1.num_columns) + ] + return ColumnTable(merged_result, result1.column_names) + + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + """ + Fetch the next set of rows of a query result, returning a PyArrow table. + + An empty sequence is returned when no more rows are available. + """ + if size < 0: + raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + results = self.results.next_n_rows(size) + n_remaining_rows = size - results.num_rows + self._next_row_index += results.num_rows + + while ( + n_remaining_rows > 0 + and not self.has_been_closed_server_side + and self.has_more_rows + ): + self._fill_results_buffer() + partial_results = self.results.next_n_rows(n_remaining_rows) + results = pyarrow.concat_tables([results, partial_results]) + n_remaining_rows -= partial_results.num_rows + self._next_row_index += partial_results.num_rows + + return results + + def fetchmany_columnar(self, size: int): + """ + Fetch the next set of rows of a query result, returning a Columnar Table. + An empty sequence is returned when no more rows are available. + """ + if size < 0: + raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + + results = self.results.next_n_rows(size) + n_remaining_rows = size - results.num_rows + self._next_row_index += results.num_rows + + while ( + n_remaining_rows > 0 + and not self.has_been_closed_server_side + and self.has_more_rows + ): + self._fill_results_buffer() + partial_results = self.results.next_n_rows(n_remaining_rows) + results = self.merge_columnar(results, partial_results) + n_remaining_rows -= partial_results.num_rows + self._next_row_index += partial_results.num_rows + + return results + + def fetchall_arrow(self) -> "pyarrow.Table": + """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" + results = self.results.remaining_rows() + self._next_row_index += results.num_rows + + while not self.has_been_closed_server_side and self.has_more_rows: + self._fill_results_buffer() + partial_results = self.results.remaining_rows() + if isinstance(results, ColumnTable) and isinstance( + partial_results, ColumnTable + ): + results = self.merge_columnar(results, partial_results) + else: + results = pyarrow.concat_tables([results, partial_results]) + self._next_row_index += partial_results.num_rows + + # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table + # Valid only for metadata commands result set + if isinstance(results, ColumnTable) and pyarrow: + data = { + name: col + for name, col in zip(results.column_names, results.column_table) + } + return pyarrow.Table.from_pydict(data) + return results + + def fetchall_columnar(self): + """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" + results = self.results.remaining_rows() + self._next_row_index += results.num_rows + + while not self.has_been_closed_server_side and self.has_more_rows: + self._fill_results_buffer() + partial_results = self.results.remaining_rows() + results = self.merge_columnar(results, partial_results) + self._next_row_index += partial_results.num_rows + + return results + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + if isinstance(self.results, ColumnQueue): + res = self._convert_columnar_table(self.fetchmany_columnar(1)) + else: + res = self._convert_arrow_table(self.fetchmany_arrow(1)) + + if len(res) > 0: + return res[0] + else: + return None + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + if isinstance(self.results, ColumnQueue): + return self._convert_columnar_table(self.fetchall_columnar()) + else: + return self._convert_arrow_table(self.fetchall_arrow()) + + def fetchmany(self, size: int) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + if isinstance(self.results, ColumnQueue): + return self._convert_columnar_table(self.fetchmany_columnar(size)) + else: + return self._convert_arrow_table(self.fetchmany_arrow(size)) + + @property + def is_staging_operation(self) -> bool: + """Whether this result set represents a staging operation.""" + return self._is_staging_operation + + @staticmethod + def _get_schema_description(table_schema_message): + """ + Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + """ + + def map_col_type(type_): + if type_.startswith("decimal"): + return "decimal" + else: + return type_ + + return [ + (column.name, map_col_type(column.datatype), None, None, None, None, None) + for column in table_schema_message.columns + ] diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py new file mode 100644 index 000000000..3bf0532dc --- /dev/null +++ b/src/databricks/sql/session.py @@ -0,0 +1,153 @@ +import logging +from typing import Dict, Tuple, List, Optional, Any + +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.types import SSLOptions +from databricks.sql.auth.auth import get_python_sql_connector_auth_provider +from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError +from databricks.sql import __version__ +from databricks.sql import USER_AGENT_NAME +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.backend.types import SessionId + +logger = logging.getLogger(__name__) + + +class Session: + def __init__( + self, + server_hostname: str, + http_path: str, + http_headers: Optional[List[Tuple[str, str]]] = None, + session_configuration: Optional[Dict[str, Any]] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, + _use_arrow_native_complex_types: Optional[bool] = True, + **kwargs, + ) -> None: + """ + Create a session to a Databricks SQL endpoint or a Databricks cluster. + + This class handles all session-related behavior and communication with the backend. + """ + + self.is_open = False + self.host = server_hostname + self.port = kwargs.get("_port", 443) + + self.session_configuration = session_configuration + self.catalog = catalog + self.schema = schema + + auth_provider = get_python_sql_connector_auth_provider( + server_hostname, **kwargs + ) + + user_agent_entry = kwargs.get("user_agent_entry") + if user_agent_entry is None: + user_agent_entry = kwargs.get("_user_agent_entry") + if user_agent_entry is not None: + logger.warning( + "[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. " + "This parameter will be removed in the upcoming releases." + ) + + if user_agent_entry: + useragent_header = "{}/{} ({})".format( + USER_AGENT_NAME, __version__, user_agent_entry + ) + else: + useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) + + base_headers = [("User-Agent", useragent_header)] + + self._ssl_options = SSLOptions( + # Double negation is generally a bad thing, but we have to keep backward compatibility + tls_verify=not kwargs.get( + "_tls_no_verify", False + ), # by default - verify cert and host + tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), + tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), + tls_client_cert_file=kwargs.get("_tls_client_cert_file"), + tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), + tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), + ) + + self.backend: DatabricksClient = ThriftDatabricksClient( + self.host, + self.port, + http_path, + (http_headers or []) + base_headers, + auth_provider, + ssl_options=self._ssl_options, + _use_arrow_native_complex_types=_use_arrow_native_complex_types, + **kwargs, + ) + + self.protocol_version = None + + def open(self): + self._session_id = self.backend.open_session( + session_configuration=self.session_configuration, + catalog=self.catalog, + schema=self.schema, + ) + self.protocol_version = self.get_protocol_version(self._session_id) + self.is_open = True + logger.info("Successfully opened session %s", str(self.guid_hex)) + + @staticmethod + def get_protocol_version(session_id: SessionId): + return session_id.protocol_version + + @staticmethod + def server_parameterized_queries_enabled(protocolVersion): + if ( + protocolVersion + and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 + ): + return True + else: + return False + + @property + def session_id(self) -> SessionId: + """Get the normalized session ID""" + return self._session_id + + @property + def guid(self) -> Any: + """Get the raw session ID (backend-specific)""" + return self._session_id.guid + + @property + def guid_hex(self) -> str: + """Get the session ID in hex format""" + return self._session_id.hex_guid + + def close(self) -> None: + """Close the underlying session.""" + logger.info("Closing session %s", self.guid_hex) + if not self.is_open: + logger.debug("Session appears to have been closed already") + return + + try: + self.backend.close_session(self._session_id) + except RequestError as e: + if isinstance(e.args[1], SessionAlreadyClosedError): + logger.info("Session was closed by a prior request") + except DatabaseError as e: + if "Invalid SessionHandle" in str(e): + logger.warning( + "Attempted to close session that was already closed: %s", e + ) + else: + logger.warning( + "Attempt to close session raised an exception at the server: %s", e + ) + except Exception as e: + logger.error("Attempt to close session raised a local exception: %s", e) + + self.is_open = False diff --git a/src/databricks/sql/types.py b/src/databricks/sql/types.py index fef22cd9f..e188ef577 100644 --- a/src/databricks/sql/types.py +++ b/src/databricks/sql/types.py @@ -158,6 +158,7 @@ def asDict(self, recursive: bool = False) -> Dict[str, Any]: >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}} True """ + if not hasattr(self, "__fields__"): raise TypeError("Cannot convert a Row class into dict") diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 0ce2fa169..8b25eccc6 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -26,6 +26,7 @@ TSparkRowSetType, ) from databricks.sql.types import SSLOptions +from databricks.sql.backend.types import CommandId from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter @@ -73,6 +74,7 @@ def build_queue( Returns: ResultSetQueue """ + if row_set_type == TSparkRowSetType.ARROW_BASED_SET: arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table( t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes @@ -172,6 +174,7 @@ def __init__( :param n_valid_rows: The index of the last valid row in the table :param start_row_index: The first row in the table we should start fetching from """ + self.cur_row_index = start_row_index self.arrow_table = arrow_table self.n_valid_rows = n_valid_rows @@ -215,6 +218,7 @@ def __init__( lz4_compressed (bool): Whether the files are lz4 compressed. description (List[List[Any]]): Hive table schema description. """ + self.schema_bytes = schema_bytes self.max_download_threads = max_download_threads self.start_row_index = start_row_offset @@ -255,6 +259,7 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table": Returns: pyarrow.Table """ + if not self.table: logger.debug("CloudFetchQueue: no more rows available") # Return empty pyarrow table to cause retry of fetch @@ -284,6 +289,7 @@ def remaining_rows(self) -> "pyarrow.Table": Returns: pyarrow.Table """ + if not self.table: # Return empty pyarrow table to cause retry of fetch return self._create_empty_table() @@ -345,7 +351,7 @@ def _create_empty_table(self) -> "pyarrow.Table": ExecuteResponse = namedtuple( "ExecuteResponse", "status has_been_closed_server_side has_more_rows description lz4_compressed is_staging_operation " - "command_handle arrow_queue arrow_schema_bytes", + "command_id arrow_queue arrow_schema_bytes", ) @@ -576,6 +582,7 @@ def transform_paramstyle( Returns: str """ + output = operation if ( param_structure == ParameterStructure.POSITIONAL diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 042fcc10a..a3f9b1af8 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -30,6 +30,7 @@ OperationalError, RequestError, ) +from databricks.sql.backend.types import CommandState from tests.e2e.common.predicates import ( pysql_has_version, pysql_supports_arrow, diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 91e426c64..a5db003e7 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -15,13 +15,18 @@ THandleIdentifier, TOperationState, TOperationType, + TOperationState, ) -from databricks.sql.thrift_backend import ThriftBackend +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient import databricks.sql import databricks.sql.client as client from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError from databricks.sql.types import Row +from databricks.sql.result_set import ResultSet, ThriftResultSet +from databricks.sql.backend.types import CommandId, CommandState +from databricks.sql.utils import ExecuteResponse from databricks.sql.utils import ExecuteResponse from tests.unit.test_fetches import FetchTests @@ -29,28 +34,27 @@ from tests.unit.test_arrow_queue import ArrowQueueSuite -class ThriftBackendMockFactory: +class ThriftDatabricksClientMockFactory: @classmethod def new(cls): - ThriftBackendMock = Mock(spec=ThriftBackend) + ThriftBackendMock = Mock(spec=ThriftDatabricksClient) ThriftBackendMock.return_value = ThriftBackendMock cls.apply_property_to_mock(ThriftBackendMock, staging_allowed_local_path=None) - MockTExecuteStatementResp = MagicMock(spec=TExecuteStatementResp()) + mock_result_set = Mock(spec=ThriftResultSet) cls.apply_property_to_mock( - MockTExecuteStatementResp, + mock_result_set, description=None, - arrow_queue=None, is_staging_operation=False, - command_handle=b"\x22", + command_id=None, has_been_closed_server_side=True, has_more_rows=True, lz4_compressed=True, arrow_schema_bytes=b"schema", ) - ThriftBackendMock.execute_command.return_value = MockTExecuteStatementResp + ThriftBackendMock.execute_command.return_value = mock_result_set return ThriftBackendMock @@ -82,94 +86,7 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_close_uses_the_correct_session_id(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - connection.close() - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_auth_args(self, mock_client_class): - # Test that the following auth args work: - # token = foo, - # token = None, _tls_client_cert_file = something, _use_cert_as_auth = True - connection_args = [ - { - "server_hostname": "foo", - "http_path": None, - "access_token": "tok", - }, - { - "server_hostname": "foo", - "http_path": None, - "_tls_client_cert_file": "something", - "_use_cert_as_auth": True, - "access_token": None, - }, - ] - - for args in connection_args: - connection = databricks.sql.connect(**args) - host, port, http_path, *_ = mock_client_class.call_args[0] - self.assertEqual(args["server_hostname"], host) - self.assertEqual(args["http_path"], http_path) - connection.close() - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_http_header_passthrough(self, mock_client_class): - http_headers = [("foo", "bar")] - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) - - call_args = mock_client_class.call_args[0][3] - self.assertIn(("foo", "bar"), call_args) - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_tls_arg_passthrough(self, mock_client_class): - databricks.sql.connect( - **self.DUMMY_CONNECTION_ARGS, - _tls_verify_hostname="hostname", - _tls_trusted_ca_file="trusted ca file", - _tls_client_cert_key_file="trusted client cert", - _tls_client_cert_key_password="key password", - ) - - kwargs = mock_client_class.call_args[1] - self.assertEqual(kwargs["_tls_verify_hostname"], "hostname") - self.assertEqual(kwargs["_tls_trusted_ca_file"], "trusted ca file") - self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") - self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_useragent_header(self, mock_client_class): - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - - http_headers = mock_client_class.call_args[0][3] - user_agent_header = ( - "User-Agent", - "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), - ) - self.assertIn(user_agent_header, http_headers) - - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar") - user_agent_header_with_entry = ( - "User-Agent", - "{}/{} ({})".format( - databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar" - ), - ) - http_headers = mock_client_class.call_args[0][3] - self.assertIn(user_agent_header_with_entry, http_headers) - - @patch("databricks.sql.client.ThriftBackend") + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_closing_connection_closes_commands(self, mock_thrift_client_class): """Test that closing a connection properly closes commands. @@ -181,13 +98,12 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): Args: mock_thrift_client_class: Mock for ThriftBackend class """ + for closed in (True, False): with self.subTest(closed=closed): # Set initial state based on whether the command is already closed initial_state = ( - TOperationState.FINISHED_STATE - if not closed - else TOperationState.CLOSED_STATE + CommandState.CLOSED if closed else CommandState.SUCCEEDED ) # Mock the execute response with controlled state @@ -195,54 +111,50 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): mock_execute_response.status = initial_state mock_execute_response.has_been_closed_server_side = closed mock_execute_response.is_staging_operation = False + mock_execute_response.command_id = Mock(spec=CommandId) # Mock the backend that will be used - mock_backend = Mock(spec=ThriftBackend) + mock_backend = Mock(spec=ThriftDatabricksClient) + mock_backend.staging_allowed_local_path = None mock_thrift_client_class.return_value = mock_backend # Create connection and cursor - connection = databricks.sql.connect( - server_hostname="foo", - http_path="dummy_path", - access_token="tok", - ) + connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) cursor = connection.cursor() - # Mock execute_command to return our execute response - cursor.thrift_backend.execute_command = Mock( - return_value=mock_execute_response + real_result_set = ThriftResultSet( + connection=connection, + execute_response=mock_execute_response, + thrift_client=mock_backend, ) + # Mock execute_command to return our real result set + cursor.backend.execute_command = Mock(return_value=real_result_set) + # Execute a command cursor.execute("SELECT 1") - # Get the active result set for later assertions - active_result_set = cursor.active_result_set - # Close the connection connection.close() # Verify the close logic worked: # 1. has_been_closed_server_side should always be True after close() - assert active_result_set.has_been_closed_server_side is True + assert real_result_set.has_been_closed_server_side is True # 2. op_state should always be CLOSED after close() - assert ( - active_result_set.op_state - == connection.thrift_backend.CLOSED_OP_STATE - ) + assert real_result_set.op_state == CommandState.CLOSED # 3. Backend close_command should be called appropriately if not closed: # Should have called backend.close_command during the close chain mock_backend.close_command.assert_called_once_with( - mock_execute_response.command_handle + mock_execute_response.command_id ) else: # Should NOT have called backend.close_command (already closed) mock_backend.close_command.assert_not_called() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) self.assertTrue(connection.open) @@ -252,7 +164,7 @@ def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection.cursor() self.assertIn("closed", str(cm.exception)) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) @patch("%s.client.Cursor" % PACKAGE_NAME) def test_arraysize_buffer_size_passthrough( self, mock_cursor_class, mock_client_class @@ -267,12 +179,16 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() - result_set = client.ResultSet( + + result_set = ThriftResultSet( connection=mock_connection, - thrift_backend=mock_backend, execute_response=Mock(), + thrift_client=mock_backend, ) - mock_connection.open = False + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.open = False + type(mock_connection).session = PropertyMock(return_value=mock_session) result_set.close() @@ -284,28 +200,31 @@ def test_closing_result_set_hard_closes_commands(self): mock_results_response.has_been_closed_server_side = False mock_connection = Mock() mock_thrift_backend = Mock() - mock_connection.open = True - result_set = client.ResultSet( + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.open = True + type(mock_connection).session = PropertyMock(return_value=mock_session) + + result_set = ThriftResultSet( mock_connection, mock_results_response, mock_thrift_backend ) result_set.close() mock_thrift_backend.close_command.assert_called_once_with( - mock_results_response.command_handle + mock_results_response.command_id ) - @patch("%s.client.ResultSet" % PACKAGE_NAME) - def test_executing_multiple_commands_uses_the_most_recent_command( - self, mock_result_set_class - ): - + def test_executing_multiple_commands_uses_the_most_recent_command(self): mock_result_sets = [Mock(), Mock()] - mock_result_set_class.side_effect = mock_result_sets + # Set is_staging_operation to False to avoid _handle_staging_operation being called + for mock_rs in mock_result_sets: + mock_rs.is_staging_operation = False - cursor = client.Cursor( - connection=Mock(), thrift_backend=ThriftBackendMockFactory.new() - ) + mock_backend = ThriftDatabricksClientMockFactory.new() + mock_backend.execute_command.side_effect = mock_result_sets + + cursor = client.Cursor(connection=Mock(), backend=mock_backend) cursor.execute("SELECT 1;") cursor.execute("SELECT 1;") @@ -330,7 +249,7 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - result_set = client.ResultSet(Mock(), Mock(), Mock()) + result_set = ThriftResultSet(Mock(), Mock(), Mock()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) @@ -341,21 +260,6 @@ def test_context_manager_closes_cursor(self): cursor.close = mock_close mock_close.assert_called_once_with() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_context_manager_closes_connection(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: - pass - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - def dict_product(self, dicts): """ Generate cartesion product of values in input dictionary, outputting a dictionary @@ -368,7 +272,7 @@ def dict_product(self, dicts): """ return (dict(zip(dicts.keys(), x)) for x in itertools.product(*dicts.values())) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backend): req_args_combinations = self.dict_product( dict( @@ -389,7 +293,7 @@ def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backe for k, v in req_args.items(): self.assertEqual(v, call_args[k]) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backend): req_args_combinations = self.dict_product( dict( @@ -412,7 +316,7 @@ def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backen for k, v in req_args.items(): self.assertEqual(v, call_args[k]) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backend): req_args_combinations = self.dict_product( dict( @@ -438,10 +342,10 @@ def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backe def test_cancel_command_calls_the_backend(self): mock_thrift_backend = Mock() cursor = client.Cursor(Mock(), mock_thrift_backend) - mock_op_handle = Mock() - cursor.active_op_handle = mock_op_handle + mock_command_id = Mock() + cursor.active_command_id = mock_command_id cursor.cancel() - mock_thrift_backend.cancel_command.assert_called_with(mock_op_handle) + mock_thrift_backend.cancel_command.assert_called_with(mock_command_id) @patch("databricks.sql.client.logger") def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( @@ -454,21 +358,6 @@ def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( self.assertTrue(logger_instance.warning.called) self.assertFalse(mock_thrift_backend.cancel_command.called) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_max_number_of_retries_passthrough(self, mock_client_class): - databricks.sql.connect( - _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS - ) - - self.assertEqual( - mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 - ) - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_socket_timeout_passthrough(self, mock_client_class): - databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) - self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) - def test_version_is_canonical(self): version = databricks.sql.__version__ canonical_version_re = ( @@ -477,35 +366,8 @@ def test_version_is_canonical(self): ) self.assertIsNotNone(re.match(canonical_version_re, version)) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_configuration_passthrough(self, mock_client_class): - mock_session_config = Mock() - databricks.sql.connect( - session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS - ) - - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][0], - mock_session_config, - ) - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_initial_namespace_passthrough(self, mock_client_class): - mock_cat = Mock() - mock_schem = Mock() - - databricks.sql.connect( - **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem - ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][1], mock_cat - ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][2], mock_schem - ) - def test_execute_parameter_passthrough(self): - mock_thrift_backend = ThriftBackendMockFactory.new() + mock_thrift_backend = ThriftDatabricksClientMockFactory.new() cursor = client.Cursor(Mock(), mock_thrift_backend) tests = [ @@ -529,16 +391,17 @@ def test_execute_parameter_passthrough(self): expected_query, ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - @patch("%s.client.ResultSet" % PACKAGE_NAME) - def test_executemany_parameter_passhthrough_and_uses_last_result_set( - self, mock_result_set_class, mock_thrift_backend - ): + def test_executemany_parameter_passhthrough_and_uses_last_result_set(self): # Create a new mock result set each time the class is instantiated mock_result_set_instances = [Mock(), Mock(), Mock()] - mock_result_set_class.side_effect = mock_result_set_instances - mock_thrift_backend = ThriftBackendMockFactory.new() - cursor = client.Cursor(Mock(), mock_thrift_backend()) + # Set is_staging_operation to False to avoid _handle_staging_operation being called + for mock_rs in mock_result_set_instances: + mock_rs.is_staging_operation = False + + mock_backend = ThriftDatabricksClientMockFactory.new() + mock_backend.execute_command.side_effect = mock_result_set_instances + + cursor = client.Cursor(Mock(), mock_backend) params = [{"x": None}, {"x": "foo1"}, {"x": "bar2"}] expected_queries = ["SELECT NULL", "SELECT 'foo1'", "SELECT 'bar2'"] @@ -546,13 +409,13 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set( cursor.executemany("SELECT %(x)s", seq_of_parameters=params) self.assertEqual( - len(mock_thrift_backend.execute_command.call_args_list), + len(mock_backend.execute_command.call_args_list), len(expected_queries), "Expected execute_command to be called the same number of times as params were passed", ) for expected_query, call_args in zip( - expected_queries, mock_thrift_backend.execute_command.call_args_list + expected_queries, mock_backend.execute_command.call_args_list ): self.assertEqual(call_args[1]["operation"], expected_query) @@ -563,7 +426,7 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set( "last operation", ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_commit_a_noop(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) c.commit() @@ -576,14 +439,14 @@ def test_setoutputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setoutputsize(1) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_rollback_not_supported(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) with self.assertRaises(NotSupportedError): c.rollback() @unittest.skip("JDW: skipping winter 2024 as we're about to rewrite this interface") - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_row_number_respected(self, mock_thrift_backend_class): def make_fake_row_slice(n_rows): mock_slice = Mock() @@ -608,7 +471,7 @@ def make_fake_row_slice(n_rows): self.assertEqual(cursor.rownumber, 29) @unittest.skip("JDW: skipping winter 2024 as we're about to rewrite this interface") - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_disable_pandas_respected(self, mock_thrift_backend_class): mock_thrift_backend = mock_thrift_backend_class.return_value mock_table = Mock() @@ -661,24 +524,7 @@ def test_column_name_api(self): }, ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_finalizer_closes_abandoned_connection(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - - # not strictly necessary as the refcount is 0, but just to be sure - gc.collect() - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_cursor_keeps_connection_alive(self, mock_client_class): instance = mock_client_class.return_value @@ -697,17 +543,18 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_staging_operation_response_is_handled( self, mock_client_class, mock_handle_staging_operation, mock_execute_response ): # If server sets ExecuteResponse.is_staging_operation True then _handle_staging_operation should be called - ThriftBackendMockFactory.apply_property_to_mock( + ThriftDatabricksClientMockFactory.apply_property_to_mock( mock_execute_response, is_staging_operation=True ) - mock_client_class.execute_command.return_value = mock_execute_response - mock_client_class.return_value = mock_client_class + mock_client = mock_client_class.return_value + mock_client.execute_command.return_value = Mock(is_staging_operation=True) + mock_client_class.return_value = mock_client connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) cursor = connection.cursor() @@ -716,7 +563,10 @@ def test_staging_operation_response_is_handled( mock_handle_staging_operation.call_count == 1 - @patch("%s.client.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) + @patch( + "%s.session.ThriftDatabricksClient" % PACKAGE_NAME, + ThriftDatabricksClientMockFactory.new(), + ) def test_access_current_query_id(self): operation_id = "EE6A8778-21FC-438B-92D8-96AC51EE3821" @@ -725,9 +575,13 @@ def test_access_current_query_id(self): self.assertIsNone(cursor.query_id) - cursor.active_op_handle = TOperationHandle( - operationId=THandleIdentifier(guid=UUID(operation_id).bytes, secret=0x00), - operationType=TOperationType.EXECUTE_STATEMENT, + cursor.active_command_id = CommandId.from_thrift_handle( + TOperationHandle( + operationId=THandleIdentifier( + guid=UUID(operation_id).bytes, secret=0x00 + ), + operationType=TOperationType.EXECUTE_STATEMENT, + ) ) self.assertEqual(cursor.query_id.upper(), operation_id.upper()) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 71766f2cb..030510a64 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -9,6 +9,8 @@ import databricks.sql.client as client from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.result_set import ThriftResultSet @pytest.mark.skipif(pa is None, reason="PyArrow is not installed") @@ -37,20 +39,20 @@ def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more schema, arrow_table = FetchTests.make_arrow_table(initial_results) arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) - rs = client.ResultSet( + rs = ThriftResultSet( connection=Mock(), - thrift_backend=None, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, has_more_rows=False, description=Mock(), lz4_compressed=Mock(), - command_handle=None, + command_id=None, arrow_queue=arrow_queue, arrow_schema_bytes=schema.serialize().to_pybytes(), is_staging_operation=False, ), + thrift_client=None, ) num_cols = len(initial_results[0]) if initial_results else 0 rs.description = [ @@ -64,7 +66,7 @@ def make_dummy_result_set_from_batch_list(batch_list): batch_index = 0 def fetch_results( - op_handle, + command_id, max_rows, max_bytes, expected_row_start_offset, @@ -79,13 +81,12 @@ def fetch_results( return results, batch_index < len(batch_list) - mock_thrift_backend = Mock() + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 - rs = client.ResultSet( + rs = ThriftResultSet( connection=Mock(), - thrift_backend=mock_thrift_backend, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=False, @@ -95,11 +96,12 @@ def fetch_results( for col_id in range(num_cols) ], lz4_compressed=Mock(), - command_handle=None, + command_id=None, arrow_queue=None, arrow_schema_bytes=None, is_staging_operation=False, ), + thrift_client=mock_thrift_backend, ) return rs diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index 552872221..b302c00da 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -31,13 +31,13 @@ def make_dummy_result_set_from_initial_results(arrow_table): arrow_queue = ArrowQueue(arrow_table, arrow_table.num_rows, 0) rs = client.ResultSet( connection=None, - thrift_backend=None, + backend=None, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, has_more_rows=False, description=Mock(), - command_handle=None, + command_id=None, arrow_queue=arrow_queue, arrow_schema=arrow_table.schema, ), diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index 249730789..cf2e24951 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -24,6 +24,7 @@ MapParameter, ArrayParameter, ) +from databricks.sql.backend.types import SessionId from databricks.sql.parameters.native import ( TDbsqlParameter, TSparkParameter, @@ -46,7 +47,10 @@ class TestSessionHandleChecks(object): ( TOpenSessionResp( serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, - sessionHandle=TSessionHandle(1, None), + sessionHandle=TSessionHandle( + sessionId=ttypes.THandleIdentifier(guid=0x36, secret=0x37), + serverProtocolVersion=None, + ), ), ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, ), @@ -55,7 +59,8 @@ class TestSessionHandleChecks(object): TOpenSessionResp( serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, sessionHandle=TSessionHandle( - 1, ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 + sessionId=ttypes.THandleIdentifier(guid=0x36, secret=0x37), + serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, ), ), ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py new file mode 100644 index 000000000..a5c751782 --- /dev/null +++ b/tests/unit/test_session.py @@ -0,0 +1,190 @@ +import pytest +from unittest.mock import patch, MagicMock, Mock, PropertyMock +import gc + +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, + TSessionHandle, + THandleIdentifier, +) +from databricks.sql.backend.types import SessionId, BackendType + +import databricks.sql + + +class TestSession: + """ + Unit tests for Session functionality + """ + + PACKAGE_NAME = "databricks.sql" + DUMMY_CONNECTION_ARGS = { + "server_hostname": "foo", + "http_path": "dummy_path", + "access_token": "tok", + } + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_close_uses_the_correct_session_id(self, mock_client_class): + instance = mock_client_class.return_value + + # Create a mock SessionId that will be returned by open_session + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + instance.open_session.return_value = mock_session_id + + connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection.close() + + # Check that close_session was called with the correct SessionId + close_session_call_args = instance.close_session.call_args[0][0] + assert close_session_call_args.guid == b"\x22" + assert close_session_call_args.secret == b"\x33" + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_auth_args(self, mock_client_class): + # Test that the following auth args work: + # token = foo, + # token = None, _tls_client_cert_file = something, _use_cert_as_auth = True + connection_args = [ + { + "server_hostname": "foo", + "http_path": None, + "access_token": "tok", + }, + { + "server_hostname": "foo", + "http_path": None, + "_tls_client_cert_file": "something", + "_use_cert_as_auth": True, + "access_token": None, + }, + ] + + for args in connection_args: + connection = databricks.sql.connect(**args) + host, port, http_path, *_ = mock_client_class.call_args[0] + assert args["server_hostname"] == host + assert args["http_path"] == http_path + connection.close() + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_http_header_passthrough(self, mock_client_class): + http_headers = [("foo", "bar")] + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) + + call_args = mock_client_class.call_args[0][3] + assert ("foo", "bar") in call_args + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_tls_arg_passthrough(self, mock_client_class): + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, + _tls_verify_hostname="hostname", + _tls_trusted_ca_file="trusted ca file", + _tls_client_cert_key_file="trusted client cert", + _tls_client_cert_key_password="key password", + ) + + kwargs = mock_client_class.call_args[1] + assert kwargs["_tls_verify_hostname"] == "hostname" + assert kwargs["_tls_trusted_ca_file"] == "trusted ca file" + assert kwargs["_tls_client_cert_key_file"] == "trusted client cert" + assert kwargs["_tls_client_cert_key_password"] == "key password" + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_useragent_header(self, mock_client_class): + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + + http_headers = mock_client_class.call_args[0][3] + user_agent_header = ( + "User-Agent", + "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), + ) + assert user_agent_header in http_headers + + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar") + user_agent_header_with_entry = ( + "User-Agent", + "{}/{} ({})".format( + databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar" + ), + ) + http_headers = mock_client_class.call_args[0][3] + assert user_agent_header_with_entry in http_headers + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_context_manager_closes_connection(self, mock_client_class): + instance = mock_client_class.return_value + + # Create a mock SessionId that will be returned by open_session + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + instance.open_session.return_value = mock_session_id + + with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: + pass + + # Check that close_session was called with the correct SessionId + close_session_call_args = instance.close_session.call_args[0][0] + assert close_session_call_args.guid == b"\x22" + assert close_session_call_args.secret == b"\x33" + + connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection.close = Mock() + try: + with pytest.raises(KeyboardInterrupt): + with connection: + raise KeyboardInterrupt("Simulated interrupt") + finally: + connection.close.assert_called() + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_max_number_of_retries_passthrough(self, mock_client_class): + databricks.sql.connect( + _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS + ) + + assert mock_client_class.call_args[1]["_retry_stop_after_attempts_count"] == 54 + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_socket_timeout_passthrough(self, mock_client_class): + databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) + assert mock_client_class.call_args[1]["_socket_timeout"] == 234 + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_configuration_passthrough(self, mock_client_class): + mock_session_config = Mock() + databricks.sql.connect( + session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS + ) + + call_kwargs = mock_client_class.return_value.open_session.call_args[1] + assert call_kwargs["session_configuration"] == mock_session_config + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_initial_namespace_passthrough(self, mock_client_class): + mock_cat = Mock() + mock_schem = Mock() + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem + ) + + call_kwargs = mock_client_class.return_value.open_session.call_args[1] + assert call_kwargs["catalog"] == mock_cat + assert call_kwargs["schema"] == mock_schem + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_finalizer_closes_abandoned_connection(self, mock_client_class): + instance = mock_client_class.return_value + + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + instance.open_session.return_value = mock_session_id + + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + + # not strictly necessary as the refcount is 0, but just to be sure + gc.collect() + + # Check that close_session was called with the correct SessionId + close_session_call_args = instance.close_session.call_args[0][0] + assert close_session_call_args.guid == b"\x22" + assert close_session_call_args.secret == b"\x33" diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 458ea9a82..2cfad7bf4 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -17,7 +17,9 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql import * from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.thrift_backend import ThriftBackend +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.result_set import ResultSet, ThriftResultSet +from databricks.sql.backend.types import CommandId, CommandState, SessionId, BackendType def retry_policy_factory(): @@ -51,6 +53,7 @@ class ThriftBackendTestSuite(unittest.TestCase): open_session_resp = ttypes.TOpenSessionResp( status=okay_status, serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4, + sessionHandle=session_handle, ) metadata_resp = ttypes.TGetResultSetMetadataResp( @@ -73,7 +76,7 @@ def test_make_request_checks_thrift_status_code(self): mock_method = Mock() mock_method.__name__ = "method name" mock_method.return_value = mock_response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -92,7 +95,7 @@ def _make_type_desc(self, type): ) def _make_fake_thrift_backend(self): - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -126,14 +129,16 @@ def test_hive_schema_to_arrow_schema_preserves_column_names(self): ] t_table_schema = ttypes.TTableSchema(columns) - arrow_schema = ThriftBackend._hive_schema_to_arrow_schema(t_table_schema) + arrow_schema = ThriftDatabricksClient._hive_schema_to_arrow_schema( + t_table_schema + ) self.assertEqual(arrow_schema.field(0).name, "column 1") self.assertEqual(arrow_schema.field(1).name, "column 2") self.assertEqual(arrow_schema.field(2).name, "column 2") self.assertEqual(arrow_schema.field(3).name, "") - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): t_http_client_instance = tcli_service_client_cass.return_value bad_protocol_versions = [ @@ -163,7 +168,7 @@ def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): "expected server to use a protocol version", str(cm.exception) ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): t_http_client_instance = tcli_service_client_cass.return_value good_protocol_versions = [ @@ -174,7 +179,9 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): for protocol_version in good_protocol_versions: t_http_client_instance.OpenSession.return_value = ttypes.TOpenSessionResp( - status=self.okay_status, serverProtocolVersion=protocol_version + status=self.okay_status, + serverProtocolVersion=protocol_version, + sessionHandle=self.session_handle, ) thrift_backend = self._make_fake_thrift_backend() @@ -182,7 +189,7 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_headers_are_set(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "foo", 123, "bar", @@ -229,7 +236,7 @@ def test_tls_cert_args_are_propagated( mock_ssl_context = mock_ssl_options.create_ssl_context() mock_create_default_context.assert_called_once_with(cafile=mock_trusted_ca_file) - ThriftBackend( + ThriftDatabricksClient( "foo", 123, "bar", @@ -315,7 +322,7 @@ def test_tls_no_verify_is_respected( mock_ssl_context = mock_ssl_options.create_ssl_context() mock_create_default_context.assert_called() - ThriftBackend( + ThriftDatabricksClient( "foo", 123, "bar", @@ -339,7 +346,7 @@ def test_tls_verify_hostname_is_respected( mock_ssl_context = mock_ssl_options.create_ssl_context() mock_create_default_context.assert_called() - ThriftBackend( + ThriftDatabricksClient( "foo", 123, "bar", @@ -356,7 +363,7 @@ def test_tls_verify_hostname_is_respected( @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_port_and_host_are_respected(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -371,7 +378,7 @@ def test_port_and_host_are_respected(self, t_http_client_class): @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_https_does_not_duplicate(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "https://hostname", 123, "path_value", @@ -386,7 +393,7 @@ def test_host_with_https_does_not_duplicate(self, t_http_client_class): @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "https://hostname/", 123, "path_value", @@ -401,7 +408,7 @@ def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_cla @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_socket_timeout_is_propagated(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -413,7 +420,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): self.assertEqual( t_http_client_class.return_value.setTimeout.call_args[0][0], 129 * 1000 ) - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -423,7 +430,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): _socket_timeout=0, ) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 0) - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -434,7 +441,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): self.assertEqual( t_http_client_class.return_value.setTimeout.call_args[0][0], 900 * 1000 ) - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -467,9 +474,9 @@ def test_non_primitive_types_raise_error(self): t_table_schema = ttypes.TTableSchema(columns) with self.assertRaises(OperationalError): - ThriftBackend._hive_schema_to_arrow_schema(t_table_schema) + ThriftDatabricksClient._hive_schema_to_arrow_schema(t_table_schema) with self.assertRaises(OperationalError): - ThriftBackend._hive_schema_to_description(t_table_schema) + ThriftDatabricksClient._hive_schema_to_description(t_table_schema) def test_hive_schema_to_description_preserves_column_names_and_types(self): # Full coverage of all types is done in integration tests, this is just a @@ -493,7 +500,7 @@ def test_hive_schema_to_description_preserves_column_names_and_types(self): ] t_table_schema = ttypes.TTableSchema(columns) - description = ThriftBackend._hive_schema_to_description(t_table_schema) + description = ThriftDatabricksClient._hive_schema_to_description(t_table_schema) self.assertEqual( description, @@ -532,7 +539,7 @@ def test_hive_schema_to_description_preserves_scale_and_precision(self): ] t_table_schema = ttypes.TTableSchema(columns) - description = ThriftBackend._hive_schema_to_description(t_table_schema) + description = ThriftDatabricksClient._hive_schema_to_description(t_table_schema) self.assertEqual( description, [ @@ -545,7 +552,7 @@ def test_make_request_checks_status_code(self): ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ] - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -588,8 +595,9 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): resultSet=None, closeOperation=None, ), + operationHandle=self.operation_handle, ) - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -628,7 +636,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( closeOperation=None, ), ) - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -642,7 +650,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_checks_operation_state_in_polls( self, tcli_service_class ): @@ -672,7 +680,7 @@ def test_handle_execute_response_checks_operation_state_in_polls( ) tcli_service_instance.GetOperationStatus.return_value = op_state_resp - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -686,7 +694,7 @@ def test_handle_execute_response_checks_operation_state_in_polls( if op_state_resp.errorMessage: self.assertIn(op_state_resp.errorMessage, str(cm.exception)) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_status_uses_display_message_if_available(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value @@ -710,7 +718,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): ) tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -724,7 +732,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): self.assertEqual(display_message, str(cm.exception)) self.assertIn(diagnostic_info, str(cm.exception.message_with_context())) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_direct_results_uses_display_message_if_available(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value @@ -746,11 +754,12 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla resultSet=None, closeOperation=None, ), + operationHandle=self.operation_handle, ) tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -776,6 +785,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): resultSet=None, closeOperation=None, ), + operationHandle=self.operation_handle, ) resp_2 = resp_type( @@ -788,6 +798,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): resultSet=None, closeOperation=None, ), + operationHandle=self.operation_handle, ) resp_3 = resp_type( @@ -798,6 +809,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): resultSet=ttypes.TFetchResultsResp(status=self.bad_status), closeOperation=None, ), + operationHandle=self.operation_handle, ) resp_4 = resp_type( @@ -808,11 +820,12 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): resultSet=None, closeOperation=ttypes.TCloseOperationResp(status=self.bad_status), ), + operationHandle=self.operation_handle, ) for error_resp in [resp_1, resp_2, resp_3, resp_4]: with self.subTest(error_resp=error_resp): - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -825,7 +838,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): thrift_backend._handle_execute_response(error_resp, Mock()) self.assertIn("this is a bad error", str(cm.exception)) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_can_handle_without_direct_results( self, tcli_service_class ): @@ -863,7 +876,7 @@ def test_handle_execute_response_can_handle_without_direct_results( op_state_2, op_state_3, ] - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -876,7 +889,7 @@ def test_handle_execute_response_can_handle_without_direct_results( ) self.assertEqual( results_message_response.status, - ttypes.TOperationState.FINISHED_STATE, + CommandState.SUCCEEDED, ) def test_handle_execute_response_can_handle_with_direct_results(self): @@ -900,7 +913,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): operationHandle=self.operation_handle, ) - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -917,7 +930,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): ttypes.TOperationState.FINISHED_STATE, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_use_arrow_schema_if_available(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value arrow_schema_mock = MagicMock(name="Arrow schema mock") @@ -946,7 +959,7 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value hive_schema_mock = MagicMock(name="Hive schema mock") @@ -976,7 +989,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): @@ -1020,7 +1033,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue ): @@ -1064,7 +1077,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( thrift_backend._handle_execute_response(execute_resp, Mock()) _, has_more_rows_resp = thrift_backend.fetch_results( - op_handle=Mock(), + command_id=Mock(), max_rows=1, max_bytes=1, expected_row_start_offset=0, @@ -1075,7 +1088,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( self.assertEqual(has_more_rows, has_more_rows_resp) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): # make some semi-real arrow batches and check the number of rows is correct in the queue tcli_service_instance = tcli_service_class.return_value @@ -1108,7 +1121,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): .to_pybytes() ) - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1117,7 +1130,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): ssl_options=SSLOptions(), ) arrow_queue, has_more_results = thrift_backend.fetch_results( - op_handle=Mock(), + command_id=Mock(), max_rows=1, max_bytes=1, expected_row_start_offset=0, @@ -1128,14 +1141,14 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_execute_statement_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.ExecuteStatement.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1146,7 +1159,12 @@ def test_execute_statement_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.execute_command("foo", Mock(), 100, 200, Mock(), cursor_mock) + result = thrift_backend.execute_command( + "foo", Mock(), 100, 200, Mock(), cursor_mock + ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.ExecuteStatement.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1157,14 +1175,14 @@ def test_execute_statement_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_catalogs_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetCatalogs.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1175,7 +1193,10 @@ def test_get_catalogs_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) + result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetCatalogs.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1185,14 +1206,14 @@ def test_get_catalogs_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_schemas_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetSchemas.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1203,7 +1224,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_schemas( + result = thrift_backend.get_schemas( Mock(), 100, 200, @@ -1211,6 +1232,9 @@ def test_get_schemas_calls_client_and_handle_execute_response( catalog_name="catalog_pattern", schema_name="schema_pattern", ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetSchemas.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1222,14 +1246,14 @@ def test_get_schemas_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_tables_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetTables.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1240,7 +1264,7 @@ def test_get_tables_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_tables( + result = thrift_backend.get_tables( Mock(), 100, 200, @@ -1250,6 +1274,9 @@ def test_get_tables_calls_client_and_handle_execute_response( table_name="table_pattern", table_types=["type1", "type2"], ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetTables.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1263,14 +1290,14 @@ def test_get_tables_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_columns_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetColumns.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1281,7 +1308,7 @@ def test_get_columns_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() - thrift_backend.get_columns( + result = thrift_backend.get_columns( Mock(), 100, 200, @@ -1291,6 +1318,9 @@ def test_get_columns_calls_client_and_handle_execute_response( table_name="table_pattern", column_name="column_pattern", ) + # Verify the result is a ResultSet + self.assertIsInstance(result, ResultSet) + # Check call to client req = tcli_service_instance.GetColumns.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -1304,12 +1334,12 @@ def test_get_columns_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_open_session_user_provided_session_id_optional(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1320,10 +1350,10 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class thrift_backend.open_session({}, None, None) self.assertEqual(len(tcli_service_instance.OpenSession.call_args_list), 1) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1331,16 +1361,17 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend.close_command(self.operation_handle) + command_id = CommandId.from_thrift_handle(self.operation_handle) + thrift_backend.close_command(command_id) self.assertEqual( tcli_service_instance.CloseOperation.call_args[0][0].operationHandle, self.operation_handle, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1348,13 +1379,14 @@ def test_session_handle_respected_in_close_session(self, tcli_service_class): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend.close_session(self.session_handle) + session_id = SessionId.from_thrift_handle(self.session_handle) + thrift_backend.close_session(session_id) self.assertEqual( tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, self.session_handle, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_non_arrow_non_column_based_set_triggers_exception( self, tcli_service_class ): @@ -1392,7 +1424,7 @@ def test_non_arrow_non_column_based_set_triggers_exception( def test_create_arrow_table_raises_error_for_unsupported_type(self): t_row_set = ttypes.TRowSet() - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1403,12 +1435,16 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) - @patch("databricks.sql.thrift_backend.convert_arrow_based_set_to_arrow_table") - @patch("databricks.sql.thrift_backend.convert_column_based_set_to_arrow_table") + @patch( + "databricks.sql.backend.thrift_backend.convert_arrow_based_set_to_arrow_table" + ) + @patch( + "databricks.sql.backend.thrift_backend.convert_column_based_set_to_arrow_table" + ) def test_create_arrow_table_calls_correct_conversion_method( self, convert_col_mock, convert_arrow_mock ): - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1443,7 +1479,7 @@ def test_create_arrow_table_calls_correct_conversion_method( def test_convert_arrow_based_set_to_arrow_table( self, open_stream_mock, lz4_decompress_mock ): - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1597,17 +1633,18 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self): self.assertEqual(arrow_table.column(2).to_pylist(), [1.15, 2.2, 3.3]) self.assertEqual(arrow_table.column(3).to_pylist(), [b"\x11", b"\x22", b"\x33"]) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_cancel_command_uses_active_op_handle(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value thrift_backend = self._make_fake_thrift_backend() - active_op_handle_mock = Mock() - thrift_backend.cancel_command(active_op_handle_mock) + # Create a proper CommandId from the existing operation_handle + command_id = CommandId.from_thrift_handle(self.operation_handle) + thrift_backend.cancel_command(command_id) self.assertEqual( tcli_service_instance.CancelOperation.call_args[0][0].operationHandle, - active_op_handle_mock, + self.operation_handle, ) def test_handle_execute_response_sets_active_op_handle(self): @@ -1615,19 +1652,27 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() thrift_backend._results_message_to_execute_response = Mock() + + # Create a mock response with a real operation handle mock_resp = Mock() + mock_resp.operationHandle = ( + self.operation_handle + ) # Use the real operation handle from the test class mock_cursor = Mock() thrift_backend._handle_execute_response(mock_resp, mock_cursor) - self.assertEqual(mock_resp.operationHandle, mock_cursor.active_op_handle) + self.assertEqual( + mock_resp.operationHandle, mock_cursor.active_command_id.to_thrift_handle() + ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch( "databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus" ) @patch( - "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + "databricks.sql.backend.thrift_backend._retry_policy", + new_callable=retry_policy_factory, ) def test_make_request_will_retry_GetOperationStatus( self, mock_retry_policy, mock_GetOperationStatus, t_transport_class @@ -1654,7 +1699,7 @@ def test_make_request_will_retry_GetOperationStatus( EXPECTED_RETRIES = 2 - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1681,7 +1726,7 @@ def test_make_request_will_retry_GetOperationStatus( ) with self.assertLogs( - "databricks.sql.thrift_backend", level=logging.WARNING + "databricks.sql.backend.thrift_backend", level=logging.WARNING ) as cm: with self.assertRaises(RequestError): thrift_backend.make_request(client.GetOperationStatus, req) @@ -1702,7 +1747,8 @@ def test_make_request_will_retry_GetOperationStatus( "databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus" ) @patch( - "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + "databricks.sql.backend.thrift_backend._retry_policy", + new_callable=retry_policy_factory, ) def test_make_request_will_retry_GetOperationStatus_for_http_error( self, mock_retry_policy, mock_gos @@ -1731,7 +1777,7 @@ def test_make_request_will_retry_GetOperationStatus_for_http_error( EXPECTED_RETRIES = 2 - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1763,7 +1809,7 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503( mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1779,7 +1825,8 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503( @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch( - "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + "databricks.sql.backend.thrift_backend._retry_policy", + new_callable=retry_policy_factory, ) def test_make_request_will_retry_stop_after_attempts_count_if_retryable( self, mock_retry_policy, t_transport_class @@ -1791,7 +1838,7 @@ def test_make_request_will_retry_stop_after_attempts_count_if_retryable( mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1820,7 +1867,7 @@ def test_make_request_will_read_error_message_headers_if_set( mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1944,7 +1991,7 @@ def test_retry_args_passthrough(self, mock_http_client): "_retry_stop_after_attempts_count": 1, "_retry_stop_after_attempts_duration": 100, } - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1959,7 +2006,12 @@ def test_retry_args_passthrough(self, mock_http_client): @patch("thrift.transport.THttpClient.THttpClient") def test_retry_args_bounding(self, mock_http_client): retry_delay_test_args_and_expected_values = {} - for k, (_, _, min, max) in databricks.sql.thrift_backend._retry_policy.items(): + for k, ( + _, + _, + min, + max, + ) in databricks.sql.backend.thrift_backend._retry_policy.items(): retry_delay_test_args_and_expected_values[k] = ( (min - 1, min), (max + 1, max), @@ -1970,7 +2022,7 @@ def test_retry_args_bounding(self, mock_http_client): k: v[i][0] for (k, v) in retry_delay_test_args_and_expected_values.items() } - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1986,7 +2038,7 @@ def test_retry_args_bounding(self, mock_http_client): for arg, val in retry_delay_expected_vals.items(): self.assertEqual(getattr(backend, arg), val) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_configuration_passthrough(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp @@ -1998,7 +2050,7 @@ def test_configuration_passthrough(self, tcli_client_class): "42": "42", } - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2011,12 +2063,12 @@ def test_configuration_passthrough(self, tcli_client_class): open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] self.assertEqual(open_session_req.configuration, expected_config) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_cant_set_timestamp_as_string_to_true(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp mock_config = {"spark.thriftserver.arrowBasedRowSet.timestampAsString": True} - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2036,13 +2088,14 @@ def _construct_open_session_with_namespace(self, can_use_multiple_cats, cat, sch serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4, canUseMultipleCatalogs=can_use_multiple_cats, initialNamespace=ttypes.TNamespace(catalogName=cat, schemaName=schem), + sessionHandle=self.session_handle, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2066,14 +2119,14 @@ def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): self.assertEqual(open_session_req.initialNamespace.catalogName, cat) self.assertEqual(open_session_req.initialNamespace.schemaName, schem) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_can_use_multiple_catalogs_is_set_in_open_session_req( self, tcli_client_class ): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2086,13 +2139,13 @@ def test_can_use_multiple_catalogs_is_set_in_open_session_req( open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] self.assertTrue(open_session_req.canUseMultipleCatalogs) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog( self, tcli_client_class ): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2126,7 +2179,7 @@ def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog( ) backend.open_session({}, cat, schem) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value @@ -2135,9 +2188,10 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V3, canUseMultipleCatalogs=True, initialNamespace=ttypes.TNamespace(catalogName="cat", schemaName="schem"), + sessionHandle=self.session_handle, ) - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2154,8 +2208,10 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): str(cm.exception), ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - @patch("databricks.sql.thrift_backend.ThriftBackend._handle_execute_response") + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + @patch( + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" + ) def test_execute_command_sets_complex_type_fields_correctly( self, mock_handle_execute_response, tcli_service_class ): @@ -2172,7 +2228,7 @@ def test_execute_command_sets_complex_type_fields_correctly( if decimals is not None: complex_arg_types["_use_arrow_native_decimals"] = decimals - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path",