From 138c2aebab99659d1c970fa70e4a431fec78aae2 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:24:22 +0000 Subject: [PATCH 01/33] [squash from exec-sea] bring over execution phase changes Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 124 ++- .../sql/backend/databricks_client.py | 30 - src/databricks/sql/backend/filters.py | 143 ++++ src/databricks/sql/backend/sea/backend.py | 360 ++++++++- .../sql/backend/sea/models/__init__.py | 30 + src/databricks/sql/backend/sea/models/base.py | 68 ++ .../sql/backend/sea/models/requests.py | 110 ++- .../sql/backend/sea/models/responses.py | 95 ++- src/databricks/sql/backend/thrift_backend.py | 99 ++- src/databricks/sql/backend/types.py | 64 +- src/databricks/sql/client.py | 1 - src/databricks/sql/result_set.py | 234 ++++-- src/databricks/sql/session.py | 2 +- src/databricks/sql/utils.py | 7 - tests/unit/test_client.py | 22 +- tests/unit/test_fetches.py | 13 +- tests/unit/test_fetches_bench.py | 3 +- tests/unit/test_result_set_filter.py | 246 ++++++ tests/unit/test_sea_backend.py | 755 +++++++++++++++--- tests/unit/test_sea_result_set.py | 275 +++++++ tests/unit/test_session.py | 5 + tests/unit/test_thrift_backend.py | 55 +- 22 files changed, 2375 insertions(+), 366 deletions(-) create mode 100644 src/databricks/sql/backend/filters.py create mode 100644 src/databricks/sql/backend/sea/models/base.py create mode 100644 tests/unit/test_result_set_filter.py create mode 100644 tests/unit/test_sea_result_set.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index abe6bd1ab..87b62efea 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -6,34 +6,122 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) + +def test_sea_query_exec(): + """ + Test executing a query using the SEA backend with result compression. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with result compression enabled and disabled, + and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + sys.exit(1) + + try: + # Test with compression enabled + logger.info("Creating connection with LZ4 compression enabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, # Enable cloud fetch to use compression + enable_query_result_lz4_compression=True, # Enable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"backend type: {type(connection.session.backend)}") + + # Execute a simple query with compression enabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query with LZ4 compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query with compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression enabled") + + # Test with compression disabled + logger.info("Creating connection with LZ4 compression disabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, # Enable cloud fetch + enable_query_result_lz4_compression=False, # Disable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query with compression disabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query without compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query without compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression disabled") + + except Exception as e: + logger.error(f"Error during SEA query execution test: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + sys.exit(1) + + logger.info("SEA query execution test with compression completed successfully") + + def test_sea_session(): """ Test opening and closing a SEA session using the connector. - + This function connects to a Databricks SQL endpoint using the SEA backend, opens a session, and then closes it. - + Required environment variables: - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") catalog = os.environ.get("DATABRICKS_CATALOG") - + if not all([server_hostname, http_path, access_token]): logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) sys.exit(1) - + logger.info(f"Connecting to {server_hostname}") logger.info(f"HTTP Path: {http_path}") if catalog: logger.info(f"Using catalog: {catalog}") - + try: logger.info("Creating connection with SEA backend...") connection = Connection( @@ -42,25 +130,33 @@ def test_sea_session(): access_token=access_token, catalog=catalog, schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent + use_sea=True, + user_agent_entry="SEA-Test-Client", # add custom user agent + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") logger.info(f"backend type: {type(connection.session.backend)}") - + # Close the connection logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback + logger.error(traceback.format_exc()) sys.exit(1) - + logger.info("SEA session test completed successfully") + if __name__ == "__main__": + # Test session management test_sea_session() + + # Test query execution with compression + test_sea_query_exec() diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 20b059fa7..bbca4c502 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -16,8 +16,6 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.types import SessionId, CommandId, CommandState -from databricks.sql.utils import ExecuteResponse -from databricks.sql.types import SSLOptions # Forward reference for type hints from typing import TYPE_CHECKING @@ -88,34 +86,6 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: - """ - Executes a SQL command or query within the specified session. - - This method sends a SQL command to the server for execution and handles - the response. It can operate in both synchronous and asynchronous modes. - - Args: - operation: The SQL command or query to execute - session_id: The session identifier in which to execute the command - max_rows: Maximum number of rows to fetch in a single fetch batch - max_bytes: Maximum number of bytes to fetch in a single fetch batch - lz4_compression: Whether to use LZ4 compression for result data - cursor: The cursor object that will handle the results - use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets - parameters: List of parameters to bind to the query - async_op: Whether to execute the command asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - If async_op is False, returns a ResultSet object containing the - query results and metadata. If async_op is True, returns None and the - results must be fetched later using get_execution_result(). - - Raises: - ValueError: If the session ID is invalid - OperationalError: If there's an error executing the command - ServerOperationError: If the server encounters an error during execution - """ pass @abstractmethod diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py new file mode 100644 index 000000000..7f48b6179 --- /dev/null +++ b/src/databricks/sql/backend/filters.py @@ -0,0 +1,143 @@ +""" +Client-side filtering utilities for Databricks SQL connector. + +This module provides filtering capabilities for result sets returned by different backends. +""" + +import logging +from typing import ( + List, + Optional, + Any, + Callable, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet + +from databricks.sql.result_set import SeaResultSet + +logger = logging.getLogger(__name__) + + +class ResultSetFilter: + """ + A general-purpose filter for result sets that can be applied to any backend. + + This class provides methods to filter result sets based on various criteria, + similar to the client-side filtering in the JDBC connector. + """ + + @staticmethod + def _filter_sea_result_set( + result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] + ) -> "SeaResultSet": + """ + Filter a SEA result set using the provided filter function. + + Args: + result_set: The SEA result set to filter + filter_func: Function that takes a row and returns True if the row should be included + + Returns: + A filtered SEA result set + """ + # Create a filtered version of the result set + filtered_response = result_set._response.copy() + + # If there's a result with rows, filter them + if ( + "result" in filtered_response + and "data_array" in filtered_response["result"] + ): + rows = filtered_response["result"]["data_array"] + filtered_rows = [row for row in rows if filter_func(row)] + filtered_response["result"]["data_array"] = filtered_rows + + # Update row count if present + if "row_count" in filtered_response["result"]: + filtered_response["result"]["row_count"] = len(filtered_rows) + + # Create a new result set with the filtered data + return SeaResultSet( + connection=result_set.connection, + sea_response=filtered_response, + sea_client=result_set.backend, + buffer_size_bytes=result_set.buffer_size_bytes, + arraysize=result_set.arraysize, + ) + + @staticmethod + def filter_by_column_values( + result_set: "ResultSet", + column_index: int, + allowed_values: List[str], + case_sensitive: bool = False, + ) -> "ResultSet": + """ + Filter a result set by values in a specific column. + + Args: + result_set: The result set to filter + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered result set + """ + # Convert to uppercase for case-insensitive comparison if needed + if not case_sensitive: + allowed_values = [v.upper() for v in allowed_values] + + # Determine the type of result set and apply appropriate filtering + if isinstance(result_set, SeaResultSet): + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and isinstance(row[column_index], str) + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), + ) + + # For other result set types, return the original (should be handled by specific implementations) + logger.warning( + f"Filtering not implemented for result set type: {type(result_set).__name__}" + ) + return result_set + + @staticmethod + def filter_tables_by_type( + result_set: "ResultSet", table_types: Optional[List[str]] = None + ) -> "ResultSet": + """ + Filter a result set of tables by the specified table types. + + This is a client-side filter that processes the result set after it has been + retrieved from the server. It filters out tables whose type does not match + any of the types in the table_types list. + + Args: + result_set: The original result set containing tables + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) + + Returns: + A filtered result set containing only tables of the specified types + """ + # Default table types if none specified + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] + valid_types = ( + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES + ) + + # Table type is typically in the 6th column (index 5) + return ResultSetFilter.filter_by_column_values( + result_set, 5, valid_types, case_sensitive=False + ) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 97d25a058..c7a4ed1b1 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,23 +1,34 @@ import logging import re -from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set +import uuid +import time +from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING + +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import ServerOperationError +from databricks.sql.exc import Error, NotSupportedError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, + StatementParameter, + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) @@ -55,6 +66,9 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. """ # SEA API paths @@ -274,41 +288,222 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: List, async_op: bool, enforce_embedded_schema_correctness: bool, - ): - """Not implemented yet.""" - raise NotImplementedError( - "execute_command is not yet implemented for SEA backend" + ) -> Union["ResultSet", None]: + """ + Execute a SQL command using the SEA backend. + + Args: + operation: SQL command to execute + session_id: Session identifier + max_rows: Maximum number of rows to fetch + max_bytes: Maximum number of bytes to fetch + lz4_compression: Whether to use LZ4 compression + cursor: Cursor executing the command + use_cloud_fetch: Whether to use cloud fetch + parameters: SQL parameters + async_op: Whether to execute asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + ResultSet: A SeaResultSet instance for the executed command + """ + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + + sea_session_id = session_id.to_sea_session_id() + + # Convert parameters to StatementParameter objects + sea_parameters = [] + if parameters: + for param in parameters: + sea_parameters.append( + StatementParameter( + name=param.name, + value=param.value, + type=param.type if hasattr(param, "type") else None, + ) + ) + + format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" + disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" + result_compression = "LZ4_FRAME" if lz4_compression else "NONE" + + request = ExecuteStatementRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + statement=operation, + disposition=disposition, + format=format, + wait_timeout="0s" if async_op else "10s", + on_wait_timeout="CONTINUE", + row_limit=max_rows if max_rows > 0 else None, + byte_limit=max_bytes if max_bytes > 0 else None, + parameters=sea_parameters if sea_parameters else None, + result_compression=result_compression, ) + response_data = self.http_client._make_request( + method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ) + response = ExecuteStatementResponse.from_dict(response_data) + statement_id = response.statement_id + if not statement_id: + raise ServerOperationError( + "Failed to execute command: No statement ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) + + command_id = CommandId.from_sea_statement_id(statement_id) + + # Store the command ID in the cursor + cursor.active_command_id = command_id + + # If async operation, return None and let the client poll for results + if async_op: + return None + + # For synchronous operation, wait for the statement to complete + # Poll until the statement is done + status = response.status + state = status.state + + # Keep polling until we reach a terminal state + while state in [CommandState.PENDING, CommandState.RUNNING]: + time.sleep(0.5) # add a small delay to avoid excessive API calls + state = self.get_query_state(command_id) + + if state != CommandState.SUCCEEDED: + raise ServerOperationError( + f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", + { + "operation-id": command_id.to_sea_statement_id(), + "diagnostic-info": None, + }, + ) + + return self.get_execution_result(command_id, cursor) + def cancel_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "cancel_command is not yet implemented for SEA backend" + """ + Cancel a running command. + + Args: + command_id: Command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CancelStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="POST", + path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def close_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "close_command is not yet implemented for SEA backend" + """ + Close a command and release resources. + + Args: + command_id: Command identifier to close + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CloseStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="DELETE", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def get_query_state(self, command_id: CommandId) -> CommandState: - """Not implemented yet.""" - raise NotImplementedError( - "get_query_state is not yet implemented for SEA backend" + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = GetStatementRequest(statement_id=sea_statement_id) + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) + # Parse the response + response = GetStatementResponse.from_dict(response_data) + return response.status.state + def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError( - "get_execution_result is not yet implemented for SEA backend" + ) -> "ResultSet": + """ + Get the result of a command execution. + + Args: + command_id: Command identifier + cursor: Cursor executing the command + + Returns: + ResultSet: A SeaResultSet instance with the execution results + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + # Create the request model + request = GetStatementRequest(statement_id=sea_statement_id) + + # Get the statement result + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + + # Create and return a SeaResultSet + from databricks.sql.result_set import SeaResultSet + + return SeaResultSet( + connection=cursor.connection, + sea_response=response_data, + sea_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, ) # == Metadata Operations == @@ -319,9 +514,22 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get available catalogs by executing 'SHOW CATALOGS'.""" + result = self.execute_command( + operation="SHOW CATALOGS", + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_schemas( self, @@ -331,9 +539,30 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_schemas is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_schemas") + + operation = f"SHOW SCHEMAS IN `{catalog_name}`" + + if schema_name: + operation += f" LIKE '{schema_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_tables( self, @@ -345,9 +574,43 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_tables is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_tables") + + operation = "SHOW TABLES IN " + ( + "ALL CATALOGS" + if catalog_name in [None, "*", "%"] + else f"CATALOG `{catalog_name}`" + ) + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" LIKE '{table_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + + # Apply client-side filtering by table_types if specified + from databricks.sql.backend.filters import ResultSetFilter + + result = ResultSetFilter.filter_tables_by_type(result, table_types) + + return result def get_columns( self, @@ -359,6 +622,33 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_columns is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_columns") + + operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" TABLE LIKE '{table_name}'" + + if column_name: + operation += f" LIKE '{column_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index c9310d367..b7c8bd399 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,19 +4,49 @@ This package contains data models for SEA API requests and responses. """ +from databricks.sql.backend.sea.models.base import ( + ServiceError, + StatementStatus, + ExternalLink, + ResultData, + ColumnInfo, + ResultManifest, +) + from databricks.sql.backend.sea.models.requests import ( + StatementParameter, + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) __all__ = [ + # Base models + "ServiceError", + "StatementStatus", + "ExternalLink", + "ResultData", + "ColumnInfo", + "ResultManifest", # Request models + "StatementParameter", + "ExecuteStatementRequest", + "GetStatementRequest", + "CancelStatementRequest", + "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models + "ExecuteStatementResponse", + "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py new file mode 100644 index 000000000..671f7be13 --- /dev/null +++ b/src/databricks/sql/backend/sea/models/base.py @@ -0,0 +1,68 @@ +""" +Base models for the SEA (Statement Execution API) backend. + +These models define the common structures used in SEA API requests and responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState + + +@dataclass +class ServiceError: + """Error information returned by the SEA API.""" + + message: str + error_code: Optional[str] = None + + +@dataclass +class StatementStatus: + """Status information for a statement execution.""" + + state: CommandState + error: Optional[ServiceError] = None + sql_state: Optional[str] = None + + +@dataclass +class ExternalLink: + """External link information for result data.""" + + external_link: str + expiration: str + chunk_index: int + + +@dataclass +class ResultData: + """Result data from a statement execution.""" + + data: Optional[List[List[Any]]] = None + external_links: Optional[List[ExternalLink]] = None + + +@dataclass +class ColumnInfo: + """Information about a column in the result set.""" + + name: str + type_name: str + type_text: str + nullable: bool = True + precision: Optional[int] = None + scale: Optional[int] = None + ordinal_position: Optional[int] = None + + +@dataclass +class ResultManifest: + """Manifest information for a result set.""" + + schema: List[ColumnInfo] + total_row_count: int + total_byte_count: int + truncated: bool = False + chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 7966cb502..1c519d931 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,5 +1,111 @@ -from typing import Dict, Any, Optional -from dataclasses import dataclass +""" +Request models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API requests. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + + +@dataclass +class StatementParameter: + """Parameter for a SQL statement.""" + + name: str + value: Optional[str] = None + type: Optional[str] = None + + +@dataclass +class ExecuteStatementRequest: + """Request to execute a SQL statement.""" + + warehouse_id: str + statement: str + session_id: str + disposition: str = "EXTERNAL_LINKS" + format: str = "JSON_ARRAY" + wait_timeout: str = "10s" + on_wait_timeout: str = "CONTINUE" + row_limit: Optional[int] = None + byte_limit: Optional[int] = None + parameters: Optional[List[StatementParameter]] = None + catalog: Optional[str] = None + schema: Optional[str] = None + result_compression: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + result: Dict[str, Any] = { + "warehouse_id": self.warehouse_id, + "session_id": self.session_id, + "statement": self.statement, + "disposition": self.disposition, + "format": self.format, + "wait_timeout": self.wait_timeout, + "on_wait_timeout": self.on_wait_timeout, + } + + if self.row_limit is not None and self.row_limit > 0: + result["row_limit"] = self.row_limit + + if self.byte_limit is not None and self.byte_limit > 0: + result["byte_limit"] = self.byte_limit + + if self.catalog: + result["catalog"] = self.catalog + + if self.schema: + result["schema"] = self.schema + + if self.result_compression: + result["result_compression"] = self.result_compression + + if self.parameters: + result["parameters"] = [ + { + "name": param.name, + **({"value": param.value} if param.value is not None else {}), + **({"type": param.type} if param.type is not None else {}), + } + for param in self.parameters + ] + + return result + + +@dataclass +class GetStatementRequest: + """Request to get information about a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CancelStatementRequest: + """Request to cancel a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CloseStatementRequest: + """Request to close a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 1bb54590f..d70459b9f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,5 +1,96 @@ -from typing import Dict, Any -from dataclasses import dataclass +""" +Response models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState +from databricks.sql.backend.sea.models.base import ( + StatementStatus, + ResultManifest, + ResultData, + ServiceError, +) + + +@dataclass +class ExecuteStatementResponse: + """Response from executing a SQL statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": + """Create an ExecuteStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) + + +@dataclass +class GetStatementResponse: + """Response from getting information about a statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": + """Create a GetStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) @dataclass diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index de388f1d4..e03d6f235 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -5,11 +5,10 @@ import time import uuid import threading -from typing import List, Optional, Union, Any, TYPE_CHECKING +from typing import List, Union, Any, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( @@ -17,8 +16,9 @@ SessionId, CommandId, BackendType, + guid_to_hex_id, + ExecuteResponse, ) -from databricks.sql.backend.utils import guid_to_hex_id try: import pyarrow @@ -42,7 +42,7 @@ ) from databricks.sql.utils import ( - ExecuteResponse, + ResultSetQueueFactory, _bound, RequestErrorInfo, NoRetryReason, @@ -53,6 +53,7 @@ ) from databricks.sql.types import SSLOptions from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.result_set import ResultSet, ThriftResultSet logger = logging.getLogger(__name__) @@ -351,7 +352,6 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ - # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -797,23 +797,27 @@ def _results_message_to_execute_response(self, resp, operation_state): command_id = CommandId.from_thrift_handle(resp.operationHandle) - return ExecuteResponse( - arrow_queue=arrow_queue_opt, - status=CommandState.from_thrift_state(operation_state), - has_been_closed_server_side=has_been_closed_server_side, - has_more_rows=has_more_rows, - lz4_compressed=lz4_compressed, - is_staging_operation=is_staging_operation, - command_id=command_id, - description=description, - arrow_schema_bytes=schema_bytes, + status = CommandState.from_thrift_state(operation_state) + if status is None: + raise ValueError(f"Invalid operation state: {operation_state}") + + return ( + ExecuteResponse( + command_id=command_id, + status=status, + description=description, + has_more_rows=has_more_rows, + results_queue=arrow_queue_opt, + has_been_closed_server_side=has_been_closed_server_side, + lz4_compressed=lz4_compressed, + is_staging_operation=is_staging_operation, + ), + schema_bytes, ) def get_execution_result( self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = command_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift command ID") @@ -863,15 +867,14 @@ def get_execution_result( ) execute_response = ExecuteResponse( - arrow_queue=queue, - status=CommandState.from_thrift_state(resp.status), - has_been_closed_server_side=False, + command_id=command_id, + status=resp.status, + description=description, has_more_rows=has_more_rows, + results_queue=queue, + has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - command_id=command_id, - description=description, - arrow_schema_bytes=schema_bytes, ) return ThriftResultSet( @@ -881,6 +884,7 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=schema_bytes, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -909,10 +913,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: poll_resp = self._poll_for_status(thrift_handle) operation_state = poll_resp.operationState self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) - state = CommandState.from_thrift_state(operation_state) - if state is None: - raise ValueError(f"Unknown command state: {operation_state}") - return state + return CommandState.from_thrift_state(operation_state) @staticmethod def _check_direct_results_for_error(t_spark_direct_results): @@ -947,8 +948,6 @@ def execute_command( async_op=False, enforce_embedded_schema_correctness=False, ) -> Union["ResultSet", None]: - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -995,7 +994,9 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1004,6 +1005,7 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_catalogs( @@ -1013,8 +1015,6 @@ def get_catalogs( max_bytes: int, cursor: "Cursor", ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1027,7 +1027,9 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1036,6 +1038,7 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_schemas( @@ -1047,8 +1050,6 @@ def get_schemas( catalog_name=None, schema_name=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1063,7 +1064,9 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1072,6 +1075,7 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_tables( @@ -1085,8 +1089,6 @@ def get_tables( table_name=None, table_types=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1103,7 +1105,9 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1112,6 +1116,7 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_columns( @@ -1125,8 +1130,6 @@ def get_columns( table_name=None, column_name=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1143,7 +1146,9 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1152,6 +1157,7 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def _handle_execute_response(self, resp, cursor): @@ -1165,7 +1171,12 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - return self._results_message_to_execute_response(resp, final_operation_state) + ( + execute_response, + arrow_schema_bytes, + ) = self._results_message_to_execute_response(resp, final_operation_state) + execute_response.command_id = command_id + return execute_response, arrow_schema_bytes def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) @@ -1226,7 +1237,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) + logger.debug("Cancelling command {}".format(command_id.guid)) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 9cd21b5e6..5bf02e0ea 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from enum import Enum -from typing import Dict, Optional, Any +from typing import Dict, List, Optional, Any, Tuple import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -80,6 +81,28 @@ def from_thrift_state( else: return None + @classmethod + def from_sea_state(cls, state: str) -> Optional["CommandState"]: + """ + Map SEA state string to CommandState enum. + + Args: + state: SEA state string + + Returns: + CommandState: The corresponding CommandState enum value + """ + state_mapping = { + "PENDING": cls.PENDING, + "RUNNING": cls.RUNNING, + "SUCCEEDED": cls.SUCCEEDED, + "FAILED": cls.FAILED, + "CLOSED": cls.CLOSED, + "CANCELED": cls.CANCELLED, + } + + return state_mapping.get(state, None) + class BackendType(Enum): """ @@ -285,28 +308,6 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count - def __str__(self) -> str: - """ - Return a string representation of the CommandId. - - For SEA backend, returns the guid. - For Thrift backend, returns a format like "guid|secret". - - Returns: - A string representation of the command ID - """ - - if self.backend_type == BackendType.SEA: - return str(self.guid) - elif self.backend_type == BackendType.THRIFT: - secret_hex = ( - guid_to_hex_id(self.secret) - if isinstance(self.secret, bytes) - else str(self.secret) - ) - return f"{self.to_hex_guid()}|{secret_hex}" - return str(self.guid) - @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -318,7 +319,6 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ - if operation_handle is None: return None @@ -394,3 +394,19 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) + + +@dataclass +class ExecuteResponse: + """Response from executing a SQL command.""" + + command_id: CommandId + status: CommandState + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None + has_been_closed_server_side: bool = False + lz4_compressed: bool = True + is_staging_operation: bool = False diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 9f7c060a7..e145e4e58 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -24,7 +24,6 @@ from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.utils import ( - ExecuteResponse, ParamEscaper, inject_parameters, transform_paramstyle, diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a0d8d3579..fc8595839 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -1,26 +1,23 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Any, Union, TYPE_CHECKING +from typing import List, Optional, Any, Union, Tuple, TYPE_CHECKING import logging import time import pandas -from databricks.sql.backend.types import CommandId, CommandState - try: import pyarrow except ImportError: pyarrow = None if TYPE_CHECKING: - from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.client import Connection - from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ExecuteResponse, ColumnTable, ColumnQueue +from databricks.sql.utils import ColumnTable, ColumnQueue +from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) @@ -34,32 +31,31 @@ class ResultSet(ABC): def __init__( self, - connection: "Connection", - backend: "DatabricksClient", - command_id: CommandId, - op_state: Optional[CommandState], - has_been_closed_server_side: bool, + connection, + backend, arraysize: int, buffer_size_bytes: int, + command_id=None, + status=None, + has_been_closed_server_side: bool = False, + has_more_rows: bool = False, + results_queue=None, + description=None, + is_staging_operation: bool = False, ): - """ - A ResultSet manages the results of a single command. - - :param connection: The parent connection that was used to execute this command - :param backend: The specialised backend client to be invoked in the fetch phase - :param execute_response: A `ExecuteResponse` class returned by a command execution - :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - amount :param arraysize: The max number of rows to fetch at a time (PEP-249) - """ - self.command_id = command_id - self.op_state = op_state - self.has_been_closed_server_side = has_been_closed_server_side + """Initialize the base ResultSet with common properties.""" self.connection = connection - self.backend = backend + self.backend = backend # Store the backend client directly self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 - self.description = None + self.description = description + self.command_id = command_id + self.status = status + self.has_been_closed_server_side = has_been_closed_server_side + self._has_more_rows = has_more_rows + self.results = results_queue + self._is_staging_operation = is_staging_operation def __iter__(self): while True: @@ -74,10 +70,9 @@ def rownumber(self): return self._next_row_index @property - @abstractmethod def is_staging_operation(self) -> bool: """Whether this result set represents a staging operation.""" - pass + return self._is_staging_operation # Define abstract methods that concrete implementations must implement @abstractmethod @@ -101,12 +96,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + def fetchmany_arrow(self, size: int) -> Any: """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> "pyarrow.Table": + def fetchall_arrow(self) -> Any: """Fetch all remaining rows as an Arrow table.""" pass @@ -119,7 +114,7 @@ def close(self) -> None: """ try: if ( - self.op_state != CommandState.CLOSED + self.status != CommandState.CLOSED and not self.has_been_closed_server_side and self.connection.open ): @@ -129,7 +124,7 @@ def close(self) -> None: logger.info("Operation was canceled by a prior request") finally: self.has_been_closed_server_side = True - self.op_state = CommandState.CLOSED + self.status = CommandState.CLOSED class ThriftResultSet(ResultSet): @@ -138,11 +133,12 @@ class ThriftResultSet(ResultSet): def __init__( self, connection: "Connection", - execute_response: ExecuteResponse, + execute_response: "ExecuteResponse", thrift_client: "ThriftDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, + arrow_schema_bytes: Optional[bytes] = None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -154,37 +150,33 @@ def __init__( buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch use_cloud_fetch: Whether to use cloud fetch for retrieving results + arrow_schema_bytes: Arrow schema bytes for the result set """ - super().__init__( - connection, - thrift_client, - execute_response.command_id, - execute_response.status, - execute_response.has_been_closed_server_side, - arraysize, - buffer_size_bytes, - ) - # Initialize ThriftResultSet-specific attributes - self.has_been_closed_server_side = execute_response.has_been_closed_server_side - self.has_more_rows = execute_response.has_more_rows - self.lz4_compressed = execute_response.lz4_compressed - self.description = execute_response.description - self._arrow_schema_bytes = execute_response.arrow_schema_bytes + self._arrow_schema_bytes = arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch - self._is_staging_operation = execute_response.is_staging_operation + self.lz4_compressed = execute_response.lz4_compressed - # Initialize results queue - if execute_response.arrow_queue: - # In this case the server has taken the fast path and returned an initial batch of - # results - self.results = execute_response.arrow_queue - else: - # In this case, there are results waiting on the server so we fetch now for simplicity + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=thrift_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, + ) + + # Initialize results queue if not provided + if not self.results: self._fill_results_buffer() def _fill_results_buffer(self): - # At initialization or if the server does not have cloud fetch result links available results, has_more_rows = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, @@ -196,7 +188,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self._has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -248,7 +240,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": + def merge_columnar(self, result1, result2): """ Function to merge / combining the columnar results into a single result :param result1: @@ -280,7 +272,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -305,7 +297,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -320,7 +312,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -346,7 +338,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -389,24 +381,110 @@ def fetchmany(self, size: int) -> List[Row]: else: return self._convert_arrow_table(self.fetchmany_arrow(size)) - @property - def is_staging_operation(self) -> bool: - """Whether this result set represents a staging operation.""" - return self._is_staging_operation - @staticmethod - def _get_schema_description(table_schema_message): +class SeaResultSet(ResultSet): + """ResultSet implementation for SEA backend.""" + + def __init__( + self, + connection, + sea_client, + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + execute_response=None, + sea_response=None, + ): """ - Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + execute_response: Response from the execute command (new style) + sea_response: Direct SEA response (legacy style) """ + # Handle both initialization styles + if execute_response is not None: + # New style with ExecuteResponse + command_id = execute_response.command_id + status = execute_response.status + has_been_closed_server_side = execute_response.has_been_closed_server_side + has_more_rows = execute_response.has_more_rows + results_queue = execute_response.results_queue + description = execute_response.description + is_staging_operation = execute_response.is_staging_operation + self._response = getattr(execute_response, "sea_response", {}) + self.statement_id = command_id.to_sea_statement_id() if command_id else None + elif sea_response is not None: + # Legacy style with direct sea_response + self._response = sea_response + # Extract values from sea_response + command_id = CommandId.from_sea_statement_id( + sea_response.get("statement_id", "") + ) + self.statement_id = sea_response.get("statement_id", "") + + # Extract status + status_data = sea_response.get("status", {}) + status = CommandState.from_sea_state(status_data.get("state", "PENDING")) + + # Set defaults for other fields + has_been_closed_server_side = False + has_more_rows = False + results_queue = None + description = None + is_staging_operation = False + else: + raise ValueError("Either execute_response or sea_response must be provided") - def map_col_type(type_): - if type_.startswith("decimal"): - return "decimal" - else: - return type_ + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=command_id, + status=status, + has_been_closed_server_side=has_been_closed_server_side, + has_more_rows=has_more_rows, + results_queue=results_queue, + description=description, + is_staging_operation=is_staging_operation, + ) - return [ - (column.name, map_col_type(column.datatype), None, None, None, None, None) - for column in table_schema_message.columns - ] + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchmany(self, size: Optional[int] = None) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + + raise NotImplementedError("fetchmany is not implemented for SEA backend") + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + raise NotImplementedError("fetchall is not implemented for SEA backend") + + def fetchmany_arrow(self, size: int) -> Any: + """Fetch the next set of rows as an Arrow table.""" + raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") + + def fetchall_arrow(self) -> Any: + """Fetch all remaining rows as an Arrow table.""" + raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 7c33d9b2d..76aec4675 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -10,7 +10,7 @@ from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import SessionId +from databricks.sql.backend.types import SessionId, BackendType logger = logging.getLogger(__name__) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 2622b1172..edb13ef6d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -349,13 +349,6 @@ def _create_empty_table(self) -> "pyarrow.Table": return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) -ExecuteResponse = namedtuple( - "ExecuteResponse", - "status has_been_closed_server_side has_more_rows description lz4_compressed is_staging_operation " - "command_id arrow_queue arrow_schema_bytes", -) - - def _bound(min_x, max_x, x): """Bound x by [min_x, max_x] diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 1a7950870..090ec255e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -26,7 +26,7 @@ from databricks.sql.types import Row from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.backend.types import CommandId, CommandState -from databricks.sql.utils import ExecuteResponse +from databricks.sql.backend.types import ExecuteResponse from tests.unit.test_fetches import FetchTests from tests.unit.test_thrift_backend import ThriftBackendTestSuite @@ -121,10 +121,10 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Verify initial state self.assertEqual(real_result_set.has_been_closed_server_side, closed) - expected_op_state = ( + expected_status = ( CommandState.CLOSED if closed else CommandState.SUCCEEDED ) - self.assertEqual(real_result_set.op_state, expected_op_state) + self.assertEqual(real_result_set.status, expected_status) # Mock execute_command to return our real result set cursor.backend.execute_command = Mock(return_value=real_result_set) @@ -146,8 +146,8 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # 1. has_been_closed_server_side should always be True after close() self.assertTrue(real_result_set.has_been_closed_server_side) - # 2. op_state should always be CLOSED after close() - self.assertEqual(real_result_set.op_state, CommandState.CLOSED) + # 2. status should always be CLOSED after close() + self.assertEqual(real_result_set.status, CommandState.CLOSED) # 3. Backend close_command should be called appropriately if not closed: @@ -556,7 +556,7 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): self.assertEqual(instance.close_session.call_count, 0) cursor.close() - @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) + @patch("%s.backend.types.ExecuteResponse" % PACKAGE_NAME) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_staging_operation_response_is_handled( @@ -678,10 +678,10 @@ def test_resultset_close_handles_cursor_already_closed_error(self): """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" result_set = client.ThriftResultSet.__new__(client.ThriftResultSet) result_set.backend = Mock() - result_set.backend.CLOSED_OP_STATE = "CLOSED" + result_set.backend.CLOSED_OP_STATE = CommandState.CLOSED result_set.connection = Mock() result_set.connection.open = True - result_set.op_state = "RUNNING" + result_set.status = CommandState.RUNNING result_set.has_been_closed_server_side = False result_set.command_id = Mock() @@ -695,7 +695,7 @@ def __init__(self): try: try: if ( - result_set.op_state != result_set.backend.CLOSED_OP_STATE + result_set.status != result_set.backend.CLOSED_OP_STATE and not result_set.has_been_closed_server_side and result_set.connection.open ): @@ -705,7 +705,7 @@ def __init__(self): pass finally: result_set.has_been_closed_server_side = True - result_set.op_state = result_set.backend.CLOSED_OP_STATE + result_set.status = result_set.backend.CLOSED_OP_STATE result_set.backend.close_command.assert_called_once_with( result_set.command_id @@ -713,7 +713,7 @@ def __init__(self): assert result_set.has_been_closed_server_side is True - assert result_set.op_state == result_set.backend.CLOSED_OP_STATE + assert result_set.status == result_set.backend.CLOSED_OP_STATE finally: pass diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 030510a64..7249a59e6 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -8,7 +8,8 @@ pa = None import databricks.sql.client as client -from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.utils import ArrowQueue from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.result_set import ThriftResultSet @@ -42,14 +43,13 @@ def make_dummy_result_set_from_initial_results(initial_results): rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( + command_id=None, status=None, has_been_closed_server_side=True, has_more_rows=False, description=Mock(), lz4_compressed=Mock(), - command_id=None, - arrow_queue=arrow_queue, - arrow_schema_bytes=schema.serialize().to_pybytes(), + results_queue=arrow_queue, is_staging_operation=False, ), thrift_client=None, @@ -88,6 +88,7 @@ def fetch_results( rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( + command_id=None, status=None, has_been_closed_server_side=False, has_more_rows=True, @@ -96,9 +97,7 @@ def fetch_results( for col_id in range(num_cols) ], lz4_compressed=Mock(), - command_id=None, - arrow_queue=None, - arrow_schema_bytes=None, + results_queue=None, is_staging_operation=False, ), thrift_client=mock_thrift_backend, diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index b302c00da..7e025cf82 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -10,7 +10,8 @@ import pytest import databricks.sql.client as client -from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.utils import ArrowQueue @pytest.mark.skipif(pa is None, reason="PyArrow is not installed") diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py new file mode 100644 index 000000000..e8eb2a757 --- /dev/null +++ b/tests/unit/test_result_set_filter.py @@ -0,0 +1,246 @@ +""" +Tests for the ResultSetFilter class. + +This module contains tests for the ResultSetFilter class, which provides +filtering capabilities for result sets returned by different backends. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.backend.filters import ResultSetFilter +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestResultSetFilter: + """Test suite for the ResultSetFilter class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response_with_tables(self): + """Create a sample SEA response with table data based on the server schema.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 7, + "columns": [ + {"name": "namespace", "type_text": "STRING", "position": 0}, + {"name": "tableName", "type_text": "STRING", "position": 1}, + {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, + {"name": "information", "type_text": "STRING", "position": 3}, + {"name": "catalogName", "type_text": "STRING", "position": 4}, + {"name": "tableType", "type_text": "STRING", "position": 5}, + {"name": "remarks", "type_text": "STRING", "position": 6}, + ], + }, + "total_row_count": 4, + }, + "result": { + "data_array": [ + ["schema1", "table1", False, None, "catalog1", "TABLE", None], + ["schema1", "table2", False, None, "catalog1", "VIEW", None], + [ + "schema1", + "table3", + False, + None, + "catalog1", + "SYSTEM TABLE", + None, + ], + ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], + ] + }, + } + + @pytest.fixture + def sea_result_set_with_tables( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Create a SeaResultSet with table data.""" + # Create a deep copy of the response to avoid test interference + import copy + + sea_response_copy = copy.deepcopy(sea_response_with_tables) + + return SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy, + buffer_size_bytes=1000, + arraysize=100, + ) + + def test_filter_tables_by_type_default(self, sea_result_set_with_tables): + """Test filtering tables by type with default types.""" + # Default types are TABLE, VIEW, SYSTEM TABLE + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables + ) + + # Verify that only the default types are included + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): + """Test filtering tables by type with custom types.""" + # Filter for only TABLE and EXTERNAL + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] + ) + + # Verify that only the specified types are included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "EXTERNAL" in table_types + assert "VIEW" not in table_types + assert "SYSTEM TABLE" not in table_types + + def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): + """Test that table type filtering is case-insensitive.""" + # Filter for lowercase "table" and "view" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["table", "view"] + ) + + # Verify that the matching types are included despite case differences + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" not in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): + """Test filtering tables with an empty type list (should use defaults).""" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=[] + ) + + # Verify that default types are used + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_by_column_values(self, sea_result_set_with_tables): + """Test filtering by values in a specific column.""" + # Filter by namespace in column index 0 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] + ) + + # All rows have schema1 in namespace, so all should be included + assert len(filtered_result._response["result"]["data_array"]) == 4 + + # Filter by table name in column index 1 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, + column_index=1, + allowed_values=["table1", "table3"], + ) + + # Only rows with table1 or table3 should be included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_names = [ + row[1] for row in filtered_result._response["result"]["data_array"] + ] + assert "table1" in table_names + assert "table3" in table_names + assert "table2" not in table_names + assert "table4" not in table_names + + def test_filter_by_column_values_case_sensitive( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Test case-sensitive filtering by column values.""" + import copy + + # Create a fresh result set for the first test + sea_response_copy1 = copy.deepcopy(sea_response_with_tables) + result_set1 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy1, + buffer_size_bytes=1000, + arraysize=100, + ) + + # First test: Case-sensitive filtering with lowercase values (should find no matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set1, + column_index=5, # tableType column + allowed_values=["table", "view"], # lowercase + case_sensitive=True, + ) + + # Verify no matches with lowercase values + assert len(filtered_result._response["result"]["data_array"]) == 0 + + # Create a fresh result set for the second test + sea_response_copy2 = copy.deepcopy(sea_response_with_tables) + result_set2 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy2, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Second test: Case-sensitive filtering with correct case (should find matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set2, + column_index=5, # tableType column + allowed_values=["TABLE", "VIEW"], # correct case + case_sensitive=True, + ) + + # Verify matches with correct case + assert len(filtered_result._response["result"]["data_array"]) == 2 + + # Extract the table types from the filtered results + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + + def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): + """Test filtering with a column index that's out of bounds.""" + # Filter by column index 10 (out of bounds) + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=10, allowed_values=["value"] + ) + + # No rows should match + assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index bc2688a68..c2078f731 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,11 +1,20 @@ +""" +Tests for the SEA (Statement Execution API) backend implementation. + +This module contains tests for the SeaDatabricksClient class, which implements +the Databricks SQL connector's SEA backend functionality. +""" + +import json import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error +from databricks.sql.exc import Error, NotSupportedError class TestSeaBackend: @@ -41,6 +50,23 @@ def sea_client(self, mock_http_client): return client + @pytest.fixture + def sea_session_id(self): + """Create a SEA session ID.""" + return SessionId.from_sea_session_id("test-session-123") + + @pytest.fixture + def sea_command_id(self): + """Create a SEA command ID.""" + return CommandId.from_sea_statement_id("test-statement-123") + + @pytest.fixture + def mock_cursor(self): + """Create a mock cursor.""" + cursor = Mock() + cursor.active_command_id = None + return cursor + def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -175,109 +201,650 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - def test_session_configuration_helpers(self): - """Test the session configuration helper methods.""" - # Test getting default value for a supported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "ANSI_MODE" - ) - assert default_value == "true" + # Tests for command execution and management + + def test_execute_command_sync( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command synchronously.""" + # Set up mock responses + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + mock_http_client._make_request.return_value = execute_response + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "warehouse_id" in kwargs["data"] + assert "session_id" in kwargs["data"] + assert "statement" in kwargs["data"] + assert kwargs["data"]["statement"] == "SELECT 1" + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-123" + + def test_execute_command_async( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command asynchronously.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-456", + "status": {"state": "PENDING"}, + } + mock_http_client._make_request.return_value = execute_response - # Test getting default value for an unsupported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "UNSUPPORTED_PARAM" + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, # Async mode + enforce_embedded_schema_correctness=False, ) - assert default_value is None - - # Test getting the list of allowed configurations - allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() - - expected_keys = { - "ANSI_MODE", - "ENABLE_PHOTON", - "LEGACY_TIME_PARSER_POLICY", - "MAX_FILE_PARTITION_BYTES", - "READ_ONLY_EXTERNAL_METASTORE", - "STATEMENT_TIMEOUT", - "TIMEZONE", - "USE_CACHED_RESULT", + + # Verify the result is None for async operation + assert result is None + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "wait_timeout" in kwargs["data"] + assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout + + # Verify the command ID was stored in the cursor + assert hasattr(mock_cursor, "active_command_id") + assert isinstance(mock_cursor.active_command_id, CommandId) + assert mock_cursor.active_command_id.guid == "test-statement-456" + + def test_execute_command_with_polling( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that requires polling.""" + # Set up mock responses for initial request and polling + initial_response = { + "statement_id": "test-statement-789", + "status": {"state": "RUNNING"}, + } + poll_response = { + "statement_id": "test-statement-789", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, + } + + # Configure mock to return different responses on subsequent calls + mock_http_client._make_request.side_effect = [initial_response, poll_response] + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method + result = sea_client.execute_command( + operation="SELECT * FROM large_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP requests (initial and poll) + assert mock_http_client._make_request.call_count == 2 + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-789" + + def test_execute_command_with_parameters( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command with parameters.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, } - assert set(allowed_configs) == expected_keys + mock_http_client._make_request.return_value = execute_response - def test_unimplemented_methods(self, sea_client): - """Test that unimplemented methods raise NotImplementedError.""" - # Create dummy parameters for testing - session_id = SessionId.from_sea_session_id("test-session") - command_id = MagicMock() - cursor = MagicMock() + # Create parameter mock + param = MagicMock() + param.name = "param1" + param.value = "value1" + param.type = "STRING" - # Test execute_command - with pytest.raises(NotImplementedError) as excinfo: + # Mock the get_execution_result method + with patch.object(sea_client, "get_execution_result") as mock_get_result: + # Call the method with parameters sea_client.execute_command( - operation="SELECT 1", - session_id=session_id, + operation="SELECT * FROM table WHERE col = :param1", + session_id=sea_session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=cursor, + cursor=mock_cursor, use_cloud_fetch=False, - parameters=[], + parameters=[param], async_op=False, enforce_embedded_schema_correctness=False, ) - assert "execute_command is not yet implemented" in str(excinfo.value) - - # Test cancel_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.cancel_command(command_id) - assert "cancel_command is not yet implemented" in str(excinfo.value) - - # Test close_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.close_command(command_id) - assert "close_command is not yet implemented" in str(excinfo.value) - - # Test get_query_state - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_query_state(command_id) - assert "get_query_state is not yet implemented" in str(excinfo.value) - - # Test get_execution_result - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_execution_result(command_id, cursor) - assert "get_execution_result is not yet implemented" in str(excinfo.value) - - # Test metadata operations - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_catalogs(session_id, 100, 1000, cursor) - assert "get_catalogs is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_schemas(session_id, 100, 1000, cursor) - assert "get_schemas is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_tables(session_id, 100, 1000, cursor) - assert "get_tables is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_columns(session_id, 100, 1000, cursor) - assert "get_columns is not yet implemented" in str(excinfo.value) - - def test_max_download_threads_property(self, sea_client): - """Test the max_download_threads property.""" - assert sea_client.max_download_threads == 10 - - # Create a client with a custom value - custom_client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - max_download_threads=20, + + # Verify the HTTP request contains parameters + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert "parameters" in kwargs["data"] + assert len(kwargs["data"]["parameters"]) == 1 + assert kwargs["data"]["parameters"][0]["name"] == "param1" + assert kwargs["data"]["parameters"][0]["value"] == "value1" + assert kwargs["data"]["parameters"][0]["type"] == "STRING" + + def test_execute_command_failure( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that fails.""" + # Set up mock response for a failed execution + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": "Syntax error in SQL", + "error_code": "SYNTAX_ERROR", + }, + }, + } + + # Configure the mock to return the error response for the initial request + # and then raise an exception when trying to poll (to simulate immediate failure) + mock_http_client._make_request.side_effect = [ + error_response, # Initial response + Error( + "Statement execution did not succeed: Syntax error in SQL" + ), # Will be raised during polling + ] + + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method and expect an error + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Statement execution did not succeed" in str(excinfo.value) + + def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): + """Test canceling a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.cancel_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_close_command(self, sea_client, mock_http_client, sea_command_id): + """Test closing a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.close_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "DELETE" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): + """Test getting the state of a query.""" + # Set up mock response + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-123", + "status": {"state": "RUNNING"}, + } + + # Call the method + state = sea_client.get_query_state(sea_command_id) + + # Verify the result + assert state == CommandState.RUNNING + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" ) - # Verify the custom value is returned - assert custom_client.max_download_threads == 20 + def test_get_execution_result( + self, sea_client, mock_http_client, mock_cursor, sea_command_id + ): + """Test getting the result of a command execution.""" + # Set up mock response + sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + mock_http_client._make_request.return_value = sea_response + + # Create a real result set to verify the implementation + result = sea_client.get_execution_result(sea_command_id, mock_cursor) + print(result) + + # Verify basic properties of the result + assert result.statement_id == "test-statement-123" + assert result.status == CommandState.SUCCEEDED + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + # Tests for metadata operations + + def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): + """Test getting catalogs.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["lz4_compression"] is False + assert kwargs["use_cloud_fetch"] is False + assert kwargs["parameters"] == [] + assert kwargs["async_op"] is False + assert kwargs["enforce_embedded_schema_correctness"] is False + + def test_get_schemas_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_with_catalog_and_schema_pattern( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with catalog and schema pattern.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog and schema pattern + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting schemas without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + ) + + assert "Catalog name is required for get_schemas" in str(excinfo.value) + + def test_get_tables_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_all_catalogs( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables from all catalogs using wildcard.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with wildcard catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_schema_and_table_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with schema and table patterns.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog, schema, and table patterns + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting tables without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_tables" in str(excinfo.value) + + def test_get_columns_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_with_all_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with all patterns specified.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with all patterns + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="test_column", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting columns without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_columns" in str(excinfo.value) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py new file mode 100644 index 000000000..f666fd613 --- /dev/null +++ b/tests/unit/test_sea_result_set.py @@ -0,0 +1,275 @@ +""" +Tests for the SeaResultSet class. + +This module contains tests for the SeaResultSet class, which implements +the result set functionality for the SEA (Statement Execution API) backend. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestSeaResultSet: + """Test suite for the SeaResultSet class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response(self): + """Create a sample SEA response.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + + @pytest.fixture + def execute_response(self): + """Create a sample execute response.""" + mock_response = Mock() + mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") + mock_response.status = CommandState.SUCCEEDED + mock_response.has_been_closed_server_side = False + mock_response.has_more_rows = False + mock_response.results_queue = None + mock_response.description = [ + ("test_value", "INT", None, None, None, None, None) + ] + mock_response.is_staging_operation = False + mock_response.sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "result": {"data_array": [["1"]]}, + } + return mock_response + + def test_init_with_sea_response( + self, mock_connection, mock_sea_client, sea_response + ): + """Test initializing SeaResultSet with a SEA response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == sea_response + + def test_init_with_execute_response( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet with an execute response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + execute_response=execute_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == execute_response.sea_response + + def test_init_with_no_response(self, mock_connection, mock_sea_client): + """Test that initialization fails when neither response type is provided.""" + with pytest.raises(ValueError) as excinfo: + SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + assert "Either execute_response or sea_response must be provided" in str( + excinfo.value + ) + + def test_close(self, mock_connection, mock_sea_client, sea_response): + """Test closing a result set.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was called + mock_sea_client.close_command.assert_called_once_with(result_set.command_id) + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_already_closed_server_side( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set that has already been closed server-side.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.has_been_closed_server_side = True + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_connection_closed( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set when the connection is closed.""" + mock_connection.open = False + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_unimplemented_methods( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that unimplemented methods raise NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Test each unimplemented method individually with specific error messages + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set.fetchone() + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + result_set.fetchmany(10) + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + # Test with default parameter value + result_set.fetchmany() + + with pytest.raises( + NotImplementedError, match="fetchall is not implemented for SEA backend" + ): + result_set.fetchall() + + with pytest.raises( + NotImplementedError, + match="fetchmany_arrow is not implemented for SEA backend", + ): + result_set.fetchmany_arrow(10) + + with pytest.raises( + NotImplementedError, + match="fetchall_arrow is not implemented for SEA backend", + ): + result_set.fetchall_arrow() + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test iteration protocol (calls fetchone internally) + next(iter(result_set)) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test using the result set in a for loop + for row in result_set: + pass + + def test_fill_results_buffer_not_implemented( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that _fill_results_buffer raises NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set._fill_results_buffer() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 92de8d8fd..fef070362 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,6 +2,11 @@ from unittest.mock import patch, Mock import gc +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, + TSessionHandle, + THandleIdentifier, +) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 57a2a61e3..b8de970db 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -619,11 +619,18 @@ def test_handle_execute_response_sets_compression_in_direct_results( lz4Compressed = Mock() resultSet = MagicMock() resultSet.results.startRowOffset = 0 + + # Create a valid operation status + op_status = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus=Mock(), + operationStatus=op_status, resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, @@ -644,7 +651,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ssl_options=SSLOptions(), ) - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -878,11 +885,12 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - results_message_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( execute_resp, Mock() ) + self.assertEqual( - results_message_response.status, + execute_response.status, CommandState.SUCCEEDED, ) @@ -915,7 +923,9 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock( + return_value=(Mock(), Mock()) + ) thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -943,15 +953,21 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) + # Mock the operation status response + op_state = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( + execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) - self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -971,6 +987,12 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) + # Mock the operation status response + op_state = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req thrift_backend = self._make_fake_thrift_backend() thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -1018,7 +1040,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( execute_resp, Mock() ) @@ -1150,7 +1172,7 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.execute_command( @@ -1184,7 +1206,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) @@ -1215,7 +1237,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1255,7 +1277,7 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1299,7 +1321,7 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1645,7 +1667,9 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock( + return_value=(Mock(), Mock()) + ) # Create a mock response with a real operation handle mock_resp = Mock() @@ -2204,7 +2228,8 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( - "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", + return_value=(Mock(), Mock()), ) def test_execute_command_sets_complex_type_fields_correctly( self, mock_handle_execute_response, tcli_service_class From 3e3ab94e8fa3dd02e4b05b5fc35939aef57793a2 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:31:37 +0000 Subject: [PATCH 02/33] remove excess test Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 124 +++----------------- 1 file changed, 14 insertions(+), 110 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 87b62efea..abe6bd1ab 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -6,122 +6,34 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) - -def test_sea_query_exec(): - """ - Test executing a query using the SEA backend with result compression. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query with result compression enabled and disabled, - and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - sys.exit(1) - - try: - # Test with compression enabled - logger.info("Creating connection with LZ4 compression enabled") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=True, # Enable cloud fetch to use compression - enable_query_result_lz4_compression=True, # Enable LZ4 compression - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - logger.info(f"backend type: {type(connection.session.backend)}") - - # Execute a simple query with compression enabled - cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) - logger.info("Executing query with LZ4 compression: SELECT 1 as test_value") - cursor.execute("SELECT 1 as test_value") - logger.info("Query with compression executed successfully") - cursor.close() - connection.close() - logger.info("Successfully closed SEA session with compression enabled") - - # Test with compression disabled - logger.info("Creating connection with LZ4 compression disabled") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=False, # Enable cloud fetch - enable_query_result_lz4_compression=False, # Disable LZ4 compression - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query with compression disabled - cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) - logger.info("Executing query without compression: SELECT 1 as test_value") - cursor.execute("SELECT 1 as test_value") - logger.info("Query without compression executed successfully") - cursor.close() - connection.close() - logger.info("Successfully closed SEA session with compression disabled") - - except Exception as e: - logger.error(f"Error during SEA query execution test: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - sys.exit(1) - - logger.info("SEA query execution test with compression completed successfully") - - def test_sea_session(): """ Test opening and closing a SEA session using the connector. - + This function connects to a Databricks SQL endpoint using the SEA backend, opens a session, and then closes it. - + Required environment variables: - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") catalog = os.environ.get("DATABRICKS_CATALOG") - + if not all([server_hostname, http_path, access_token]): logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) + logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") sys.exit(1) - + logger.info(f"Connecting to {server_hostname}") logger.info(f"HTTP Path: {http_path}") if catalog: logger.info(f"Using catalog: {catalog}") - + try: logger.info("Creating connection with SEA backend...") connection = Connection( @@ -130,33 +42,25 @@ def test_sea_session(): access_token=access_token, catalog=catalog, schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", # add custom user agent - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + use_sea=True, + user_agent_entry="SEA-Test-Client" # add custom user agent ) + + logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") logger.info(f"backend type: {type(connection.session.backend)}") - + # Close the connection logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback - logger.error(traceback.format_exc()) sys.exit(1) - + logger.info("SEA session test completed successfully") - if __name__ == "__main__": - # Test session management test_sea_session() - - # Test query execution with compression - test_sea_query_exec() From 4a781653375d8f06dd7d9ad745446e49a355c680 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:33:02 +0000 Subject: [PATCH 03/33] add docstring Signed-off-by: varun-edachali-dbx --- .../sql/backend/databricks_client.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index bbca4c502..cd347d9ab 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -86,6 +86,33 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: + """ + Executes a SQL command or query within the specified session. + This method sends a SQL command to the server for execution and handles + the response. It can operate in both synchronous and asynchronous modes. + + Args: + operation: The SQL command or query to execute + session_id: The session identifier in which to execute the command + max_rows: Maximum number of rows to fetch in a single fetch batch + max_bytes: Maximum number of bytes to fetch in a single fetch batch + lz4_compression: Whether to use LZ4 compression for result data + cursor: The cursor object that will handle the results + use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets + parameters: List of parameters to bind to the query + async_op: Whether to execute the command asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + If async_op is False, returns a ResultSet object containing the + query results and metadata. If async_op is True, returns None and the + results must be fetched later using get_execution_result(). + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error executing the command + ServerOperationError: If the server encounters an error during execution + """ pass @abstractmethod From 0dac4aaf90dba50151dd7565adee270a794e8330 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:34:49 +0000 Subject: [PATCH 04/33] remvoe exec func in sea backend Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 360 +++------------------- 1 file changed, 35 insertions(+), 325 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index c7a4ed1b1..97d25a058 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,34 +1,23 @@ import logging import re -import uuid -import time -from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING - -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) +from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import Error, NotSupportedError, ServerOperationError +from databricks.sql.exc import ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, - StatementParameter, - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) @@ -66,9 +55,6 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. - - This implementation provides session management functionality for SEA, - while other operations raise NotImplementedError. """ # SEA API paths @@ -288,222 +274,41 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List, + parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, - ) -> Union["ResultSet", None]: - """ - Execute a SQL command using the SEA backend. - - Args: - operation: SQL command to execute - session_id: Session identifier - max_rows: Maximum number of rows to fetch - max_bytes: Maximum number of bytes to fetch - lz4_compression: Whether to use LZ4 compression - cursor: Cursor executing the command - use_cloud_fetch: Whether to use cloud fetch - parameters: SQL parameters - async_op: Whether to execute asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - ResultSet: A SeaResultSet instance for the executed command - """ - if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") - - sea_session_id = session_id.to_sea_session_id() - - # Convert parameters to StatementParameter objects - sea_parameters = [] - if parameters: - for param in parameters: - sea_parameters.append( - StatementParameter( - name=param.name, - value=param.value, - type=param.type if hasattr(param, "type") else None, - ) - ) - - format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" - disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" - result_compression = "LZ4_FRAME" if lz4_compression else "NONE" - - request = ExecuteStatementRequest( - warehouse_id=self.warehouse_id, - session_id=sea_session_id, - statement=operation, - disposition=disposition, - format=format, - wait_timeout="0s" if async_op else "10s", - on_wait_timeout="CONTINUE", - row_limit=max_rows if max_rows > 0 else None, - byte_limit=max_bytes if max_bytes > 0 else None, - parameters=sea_parameters if sea_parameters else None, - result_compression=result_compression, - ) - - response_data = self.http_client._make_request( - method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ): + """Not implemented yet.""" + raise NotImplementedError( + "execute_command is not yet implemented for SEA backend" ) - response = ExecuteStatementResponse.from_dict(response_data) - statement_id = response.statement_id - if not statement_id: - raise ServerOperationError( - "Failed to execute command: No statement ID returned", - { - "operation-id": None, - "diagnostic-info": None, - }, - ) - - command_id = CommandId.from_sea_statement_id(statement_id) - - # Store the command ID in the cursor - cursor.active_command_id = command_id - - # If async operation, return None and let the client poll for results - if async_op: - return None - - # For synchronous operation, wait for the statement to complete - # Poll until the statement is done - status = response.status - state = status.state - - # Keep polling until we reach a terminal state - while state in [CommandState.PENDING, CommandState.RUNNING]: - time.sleep(0.5) # add a small delay to avoid excessive API calls - state = self.get_query_state(command_id) - - if state != CommandState.SUCCEEDED: - raise ServerOperationError( - f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", - { - "operation-id": command_id.to_sea_statement_id(), - "diagnostic-info": None, - }, - ) - - return self.get_execution_result(command_id, cursor) def cancel_command(self, command_id: CommandId) -> None: - """ - Cancel a running command. - - Args: - command_id: Command identifier to cancel - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CancelStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="POST", - path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "cancel_command is not yet implemented for SEA backend" ) def close_command(self, command_id: CommandId) -> None: - """ - Close a command and release resources. - - Args: - command_id: Command identifier to close - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CloseStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="DELETE", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "close_command is not yet implemented for SEA backend" ) def get_query_state(self, command_id: CommandId) -> CommandState: - """ - Get the state of a running query. - - Args: - command_id: Command identifier - - Returns: - CommandState: The current state of the command - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = GetStatementRequest(statement_id=sea_statement_id) - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "get_query_state is not yet implemented for SEA backend" ) - # Parse the response - response = GetStatementResponse.from_dict(response_data) - return response.status.state - def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ) -> "ResultSet": - """ - Get the result of a command execution. - - Args: - command_id: Command identifier - cursor: Cursor executing the command - - Returns: - ResultSet: A SeaResultSet instance with the execution results - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - # Create the request model - request = GetStatementRequest(statement_id=sea_statement_id) - - # Get the statement result - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), - ) - - # Create and return a SeaResultSet - from databricks.sql.result_set import SeaResultSet - - return SeaResultSet( - connection=cursor.connection, - sea_response=response_data, - sea_client=self, - buffer_size_bytes=cursor.buffer_size_bytes, - arraysize=cursor.arraysize, + ): + """Not implemented yet.""" + raise NotImplementedError( + "get_execution_result is not yet implemented for SEA backend" ) # == Metadata Operations == @@ -514,22 +319,9 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> "ResultSet": - """Get available catalogs by executing 'SHOW CATALOGS'.""" - result = self.execute_command( - operation="SHOW CATALOGS", - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") def get_schemas( self, @@ -539,30 +331,9 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> "ResultSet": - """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_schemas") - - operation = f"SHOW SCHEMAS IN `{catalog_name}`" - - if schema_name: - operation += f" LIKE '{schema_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_schemas is not yet implemented for SEA backend") def get_tables( self, @@ -574,43 +345,9 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> "ResultSet": - """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_tables") - - operation = "SHOW TABLES IN " + ( - "ALL CATALOGS" - if catalog_name in [None, "*", "%"] - else f"CATALOG `{catalog_name}`" - ) - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" LIKE '{table_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - - # Apply client-side filtering by table_types if specified - from databricks.sql.backend.filters import ResultSetFilter - - result = ResultSetFilter.filter_tables_by_type(result, table_types) - - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_tables is not yet implemented for SEA backend") def get_columns( self, @@ -622,33 +359,6 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> "ResultSet": - """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_columns") - - operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" TABLE LIKE '{table_name}'" - - if column_name: - operation += f" LIKE '{column_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_columns is not yet implemented for SEA backend") From 1b794c7df6f5e414ef793a5da0f2b8ba19c9bc61 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:35:40 +0000 Subject: [PATCH 05/33] remove excess files Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 143 -------------- tests/unit/test_result_set_filter.py | 246 ----------------------- tests/unit/test_sea_result_set.py | 275 -------------------------- 3 files changed, 664 deletions(-) delete mode 100644 src/databricks/sql/backend/filters.py delete mode 100644 tests/unit/test_result_set_filter.py delete mode 100644 tests/unit/test_sea_result_set.py diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py deleted file mode 100644 index 7f48b6179..000000000 --- a/src/databricks/sql/backend/filters.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Client-side filtering utilities for Databricks SQL connector. - -This module provides filtering capabilities for result sets returned by different backends. -""" - -import logging -from typing import ( - List, - Optional, - Any, - Callable, - TYPE_CHECKING, -) - -if TYPE_CHECKING: - from databricks.sql.result_set import ResultSet - -from databricks.sql.result_set import SeaResultSet - -logger = logging.getLogger(__name__) - - -class ResultSetFilter: - """ - A general-purpose filter for result sets that can be applied to any backend. - - This class provides methods to filter result sets based on various criteria, - similar to the client-side filtering in the JDBC connector. - """ - - @staticmethod - def _filter_sea_result_set( - result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] - ) -> "SeaResultSet": - """ - Filter a SEA result set using the provided filter function. - - Args: - result_set: The SEA result set to filter - filter_func: Function that takes a row and returns True if the row should be included - - Returns: - A filtered SEA result set - """ - # Create a filtered version of the result set - filtered_response = result_set._response.copy() - - # If there's a result with rows, filter them - if ( - "result" in filtered_response - and "data_array" in filtered_response["result"] - ): - rows = filtered_response["result"]["data_array"] - filtered_rows = [row for row in rows if filter_func(row)] - filtered_response["result"]["data_array"] = filtered_rows - - # Update row count if present - if "row_count" in filtered_response["result"]: - filtered_response["result"]["row_count"] = len(filtered_rows) - - # Create a new result set with the filtered data - return SeaResultSet( - connection=result_set.connection, - sea_response=filtered_response, - sea_client=result_set.backend, - buffer_size_bytes=result_set.buffer_size_bytes, - arraysize=result_set.arraysize, - ) - - @staticmethod - def filter_by_column_values( - result_set: "ResultSet", - column_index: int, - allowed_values: List[str], - case_sensitive: bool = False, - ) -> "ResultSet": - """ - Filter a result set by values in a specific column. - - Args: - result_set: The result set to filter - column_index: The index of the column to filter on - allowed_values: List of allowed values for the column - case_sensitive: Whether to perform case-sensitive comparison - - Returns: - A filtered result set - """ - # Convert to uppercase for case-insensitive comparison if needed - if not case_sensitive: - allowed_values = [v.upper() for v in allowed_values] - - # Determine the type of result set and apply appropriate filtering - if isinstance(result_set, SeaResultSet): - return ResultSetFilter._filter_sea_result_set( - result_set, - lambda row: ( - len(row) > column_index - and isinstance(row[column_index], str) - and ( - row[column_index].upper() - if not case_sensitive - else row[column_index] - ) - in allowed_values - ), - ) - - # For other result set types, return the original (should be handled by specific implementations) - logger.warning( - f"Filtering not implemented for result set type: {type(result_set).__name__}" - ) - return result_set - - @staticmethod - def filter_tables_by_type( - result_set: "ResultSet", table_types: Optional[List[str]] = None - ) -> "ResultSet": - """ - Filter a result set of tables by the specified table types. - - This is a client-side filter that processes the result set after it has been - retrieved from the server. It filters out tables whose type does not match - any of the types in the table_types list. - - Args: - result_set: The original result set containing tables - table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) - - Returns: - A filtered result set containing only tables of the specified types - """ - # Default table types if none specified - DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] - valid_types = ( - table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES - ) - - # Table type is typically in the 6th column (index 5) - return ResultSetFilter.filter_by_column_values( - result_set, 5, valid_types, case_sensitive=False - ) diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py deleted file mode 100644 index e8eb2a757..000000000 --- a/tests/unit/test_result_set_filter.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -Tests for the ResultSetFilter class. - -This module contains tests for the ResultSetFilter class, which provides -filtering capabilities for result sets returned by different backends. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.backend.filters import ResultSetFilter -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestResultSetFilter: - """Test suite for the ResultSetFilter class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def sea_response_with_tables(self): - """Create a sample SEA response with table data based on the server schema.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 7, - "columns": [ - {"name": "namespace", "type_text": "STRING", "position": 0}, - {"name": "tableName", "type_text": "STRING", "position": 1}, - {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, - {"name": "information", "type_text": "STRING", "position": 3}, - {"name": "catalogName", "type_text": "STRING", "position": 4}, - {"name": "tableType", "type_text": "STRING", "position": 5}, - {"name": "remarks", "type_text": "STRING", "position": 6}, - ], - }, - "total_row_count": 4, - }, - "result": { - "data_array": [ - ["schema1", "table1", False, None, "catalog1", "TABLE", None], - ["schema1", "table2", False, None, "catalog1", "VIEW", None], - [ - "schema1", - "table3", - False, - None, - "catalog1", - "SYSTEM TABLE", - None, - ], - ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], - ] - }, - } - - @pytest.fixture - def sea_result_set_with_tables( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Create a SeaResultSet with table data.""" - # Create a deep copy of the response to avoid test interference - import copy - - sea_response_copy = copy.deepcopy(sea_response_with_tables) - - return SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy, - buffer_size_bytes=1000, - arraysize=100, - ) - - def test_filter_tables_by_type_default(self, sea_result_set_with_tables): - """Test filtering tables by type with default types.""" - # Default types are TABLE, VIEW, SYSTEM TABLE - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables - ) - - # Verify that only the default types are included - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): - """Test filtering tables by type with custom types.""" - # Filter for only TABLE and EXTERNAL - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] - ) - - # Verify that only the specified types are included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "EXTERNAL" in table_types - assert "VIEW" not in table_types - assert "SYSTEM TABLE" not in table_types - - def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): - """Test that table type filtering is case-insensitive.""" - # Filter for lowercase "table" and "view" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["table", "view"] - ) - - # Verify that the matching types are included despite case differences - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" not in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): - """Test filtering tables with an empty type list (should use defaults).""" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=[] - ) - - # Verify that default types are used - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_by_column_values(self, sea_result_set_with_tables): - """Test filtering by values in a specific column.""" - # Filter by namespace in column index 0 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] - ) - - # All rows have schema1 in namespace, so all should be included - assert len(filtered_result._response["result"]["data_array"]) == 4 - - # Filter by table name in column index 1 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, - column_index=1, - allowed_values=["table1", "table3"], - ) - - # Only rows with table1 or table3 should be included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_names = [ - row[1] for row in filtered_result._response["result"]["data_array"] - ] - assert "table1" in table_names - assert "table3" in table_names - assert "table2" not in table_names - assert "table4" not in table_names - - def test_filter_by_column_values_case_sensitive( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Test case-sensitive filtering by column values.""" - import copy - - # Create a fresh result set for the first test - sea_response_copy1 = copy.deepcopy(sea_response_with_tables) - result_set1 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy1, - buffer_size_bytes=1000, - arraysize=100, - ) - - # First test: Case-sensitive filtering with lowercase values (should find no matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set1, - column_index=5, # tableType column - allowed_values=["table", "view"], # lowercase - case_sensitive=True, - ) - - # Verify no matches with lowercase values - assert len(filtered_result._response["result"]["data_array"]) == 0 - - # Create a fresh result set for the second test - sea_response_copy2 = copy.deepcopy(sea_response_with_tables) - result_set2 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy2, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Second test: Case-sensitive filtering with correct case (should find matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set2, - column_index=5, # tableType column - allowed_values=["TABLE", "VIEW"], # correct case - case_sensitive=True, - ) - - # Verify matches with correct case - assert len(filtered_result._response["result"]["data_array"]) == 2 - - # Extract the table types from the filtered results - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - - def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): - """Test filtering with a column index that's out of bounds.""" - # Filter by column index 10 (out of bounds) - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=10, allowed_values=["value"] - ) - - # No rows should match - assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py deleted file mode 100644 index f666fd613..000000000 --- a/tests/unit/test_sea_result_set.py +++ /dev/null @@ -1,275 +0,0 @@ -""" -Tests for the SeaResultSet class. - -This module contains tests for the SeaResultSet class, which implements -the result set functionality for the SEA (Statement Execution API) backend. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestSeaResultSet: - """Test suite for the SeaResultSet class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def sea_response(self): - """Create a sample SEA response.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - - @pytest.fixture - def execute_response(self): - """Create a sample execute response.""" - mock_response = Mock() - mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") - mock_response.status = CommandState.SUCCEEDED - mock_response.has_been_closed_server_side = False - mock_response.has_more_rows = False - mock_response.results_queue = None - mock_response.description = [ - ("test_value", "INT", None, None, None, None, None) - ] - mock_response.is_staging_operation = False - mock_response.sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "result": {"data_array": [["1"]]}, - } - return mock_response - - def test_init_with_sea_response( - self, mock_connection, mock_sea_client, sea_response - ): - """Test initializing SeaResultSet with a SEA response.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.statement_id == "test-statement-123" - assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set._response == sea_response - - def test_init_with_execute_response( - self, mock_connection, mock_sea_client, execute_response - ): - """Test initializing SeaResultSet with an execute response.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - execute_response=execute_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.statement_id == "test-statement-123" - assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set._response == execute_response.sea_response - - def test_init_with_no_response(self, mock_connection, mock_sea_client): - """Test that initialization fails when neither response type is provided.""" - with pytest.raises(ValueError) as excinfo: - SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - assert "Either execute_response or sea_response must be provided" in str( - excinfo.value - ) - - def test_close(self, mock_connection, mock_sea_client, sea_response): - """Test closing a result set.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was called - mock_sea_client.close_command.assert_called_once_with(result_set.command_id) - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, sea_response - ): - """Test closing a result set that has already been closed server-side.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.has_been_closed_server_side = True - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, sea_response - ): - """Test closing a result set when the connection is closed.""" - mock_connection.open = False - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_unimplemented_methods( - self, mock_connection, mock_sea_client, sea_response - ): - """Test that unimplemented methods raise NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() - - with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() - - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", - ): - result_set.fetchmany_arrow(10) - - with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", - ): - result_set.fetchall_arrow() - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test using the result set in a for loop - for row in result_set: - pass - - def test_fill_results_buffer_not_implemented( - self, mock_connection, mock_sea_client, sea_response - ): - """Test that _fill_results_buffer raises NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set._fill_results_buffer() From da5a6fe7511e927c511d61adb222b8a6a0da14d3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:39:11 +0000 Subject: [PATCH 06/33] remove excess models Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/models/__init__.py | 30 ----- src/databricks/sql/backend/sea/models/base.py | 68 ----------- .../sql/backend/sea/models/requests.py | 110 +----------------- .../sql/backend/sea/models/responses.py | 95 +-------------- 4 files changed, 4 insertions(+), 299 deletions(-) delete mode 100644 src/databricks/sql/backend/sea/models/base.py diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index b7c8bd399..c9310d367 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,49 +4,19 @@ This package contains data models for SEA API requests and responses. """ -from databricks.sql.backend.sea.models.base import ( - ServiceError, - StatementStatus, - ExternalLink, - ResultData, - ColumnInfo, - ResultManifest, -) - from databricks.sql.backend.sea.models.requests import ( - StatementParameter, - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) __all__ = [ - # Base models - "ServiceError", - "StatementStatus", - "ExternalLink", - "ResultData", - "ColumnInfo", - "ResultManifest", # Request models - "StatementParameter", - "ExecuteStatementRequest", - "GetStatementRequest", - "CancelStatementRequest", - "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models - "ExecuteStatementResponse", - "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py deleted file mode 100644 index 671f7be13..000000000 --- a/src/databricks/sql/backend/sea/models/base.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -Base models for the SEA (Statement Execution API) backend. - -These models define the common structures used in SEA API requests and responses. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - -from databricks.sql.backend.types import CommandState - - -@dataclass -class ServiceError: - """Error information returned by the SEA API.""" - - message: str - error_code: Optional[str] = None - - -@dataclass -class StatementStatus: - """Status information for a statement execution.""" - - state: CommandState - error: Optional[ServiceError] = None - sql_state: Optional[str] = None - - -@dataclass -class ExternalLink: - """External link information for result data.""" - - external_link: str - expiration: str - chunk_index: int - - -@dataclass -class ResultData: - """Result data from a statement execution.""" - - data: Optional[List[List[Any]]] = None - external_links: Optional[List[ExternalLink]] = None - - -@dataclass -class ColumnInfo: - """Information about a column in the result set.""" - - name: str - type_name: str - type_text: str - nullable: bool = True - precision: Optional[int] = None - scale: Optional[int] = None - ordinal_position: Optional[int] = None - - -@dataclass -class ResultManifest: - """Manifest information for a result set.""" - - schema: List[ColumnInfo] - total_row_count: int - total_byte_count: int - truncated: bool = False - chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 1c519d931..7966cb502 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,111 +1,5 @@ -""" -Request models for the SEA (Statement Execution API) backend. - -These models define the structures used in SEA API requests. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - - -@dataclass -class StatementParameter: - """Parameter for a SQL statement.""" - - name: str - value: Optional[str] = None - type: Optional[str] = None - - -@dataclass -class ExecuteStatementRequest: - """Request to execute a SQL statement.""" - - warehouse_id: str - statement: str - session_id: str - disposition: str = "EXTERNAL_LINKS" - format: str = "JSON_ARRAY" - wait_timeout: str = "10s" - on_wait_timeout: str = "CONTINUE" - row_limit: Optional[int] = None - byte_limit: Optional[int] = None - parameters: Optional[List[StatementParameter]] = None - catalog: Optional[str] = None - schema: Optional[str] = None - result_compression: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - result: Dict[str, Any] = { - "warehouse_id": self.warehouse_id, - "session_id": self.session_id, - "statement": self.statement, - "disposition": self.disposition, - "format": self.format, - "wait_timeout": self.wait_timeout, - "on_wait_timeout": self.on_wait_timeout, - } - - if self.row_limit is not None and self.row_limit > 0: - result["row_limit"] = self.row_limit - - if self.byte_limit is not None and self.byte_limit > 0: - result["byte_limit"] = self.byte_limit - - if self.catalog: - result["catalog"] = self.catalog - - if self.schema: - result["schema"] = self.schema - - if self.result_compression: - result["result_compression"] = self.result_compression - - if self.parameters: - result["parameters"] = [ - { - "name": param.name, - **({"value": param.value} if param.value is not None else {}), - **({"type": param.type} if param.type is not None else {}), - } - for param in self.parameters - ] - - return result - - -@dataclass -class GetStatementRequest: - """Request to get information about a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} - - -@dataclass -class CancelStatementRequest: - """Request to cancel a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} - - -@dataclass -class CloseStatementRequest: - """Request to close a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} +from typing import Dict, Any, Optional +from dataclasses import dataclass @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index d70459b9f..1bb54590f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,96 +1,5 @@ -""" -Response models for the SEA (Statement Execution API) backend. - -These models define the structures used in SEA API responses. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - -from databricks.sql.backend.types import CommandState -from databricks.sql.backend.sea.models.base import ( - StatementStatus, - ResultManifest, - ResultData, - ServiceError, -) - - -@dataclass -class ExecuteStatementResponse: - """Response from executing a SQL statement.""" - - statement_id: str - status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": - """Create an ExecuteStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - return cls( - statement_id=data.get("statement_id", ""), - status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed - ) - - -@dataclass -class GetStatementResponse: - """Response from getting information about a statement.""" - - statement_id: str - status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": - """Create a GetStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - return cls( - statement_id=data.get("statement_id", ""), - status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed - ) +from typing import Dict, Any +from dataclasses import dataclass @dataclass From 686ade4fbf8e43a053b61f27220066852682167e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:40:50 +0000 Subject: [PATCH 07/33] remove excess sea backend tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 755 ++++----------------------------- 1 file changed, 94 insertions(+), 661 deletions(-) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index c2078f731..bc2688a68 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,20 +1,11 @@ -""" -Tests for the SEA (Statement Execution API) backend implementation. - -This module contains tests for the SeaDatabricksClient class, which implements -the Databricks SQL connector's SEA backend functionality. -""" - -import json import pytest -from unittest.mock import patch, MagicMock, Mock +from unittest.mock import patch, MagicMock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error, NotSupportedError +from databricks.sql.exc import Error class TestSeaBackend: @@ -50,23 +41,6 @@ def sea_client(self, mock_http_client): return client - @pytest.fixture - def sea_session_id(self): - """Create a SEA session ID.""" - return SessionId.from_sea_session_id("test-session-123") - - @pytest.fixture - def sea_command_id(self): - """Create a SEA command ID.""" - return CommandId.from_sea_statement_id("test-statement-123") - - @pytest.fixture - def mock_cursor(self): - """Create a mock cursor.""" - cursor = Mock() - cursor.active_command_id = None - return cursor - def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -201,650 +175,109 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - # Tests for command execution and management - - def test_execute_command_sync( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command synchronously.""" - # Set up mock responses - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "schema": [ - { - "name": "col1", - "type_name": "STRING", - "type_text": "string", - "nullable": True, - } - ], - "total_row_count": 1, - "total_byte_count": 100, - }, - "result": {"data": [["value1"]]}, - } - mock_http_client._make_request.return_value = execute_response - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "warehouse_id" in kwargs["data"] - assert "session_id" in kwargs["data"] - assert "statement" in kwargs["data"] - assert kwargs["data"]["statement"] == "SELECT 1" - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-123" - - def test_execute_command_async( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command asynchronously.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-456", - "status": {"state": "PENDING"}, - } - mock_http_client._make_request.return_value = execute_response - - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=True, # Async mode - enforce_embedded_schema_correctness=False, + def test_session_configuration_helpers(self): + """Test the session configuration helper methods.""" + # Test getting default value for a supported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "ANSI_MODE" ) + assert default_value == "true" - # Verify the result is None for async operation - assert result is None - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "wait_timeout" in kwargs["data"] - assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout - - # Verify the command ID was stored in the cursor - assert hasattr(mock_cursor, "active_command_id") - assert isinstance(mock_cursor.active_command_id, CommandId) - assert mock_cursor.active_command_id.guid == "test-statement-456" - - def test_execute_command_with_polling( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that requires polling.""" - # Set up mock responses for initial request and polling - initial_response = { - "statement_id": "test-statement-789", - "status": {"state": "RUNNING"}, - } - poll_response = { - "statement_id": "test-statement-789", - "status": {"state": "SUCCEEDED"}, - "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, - "result": {"data": []}, - } - - # Configure mock to return different responses on subsequent calls - mock_http_client._make_request.side_effect = [initial_response, poll_response] - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method - result = sea_client.execute_command( - operation="SELECT * FROM large_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP requests (initial and poll) - assert mock_http_client._make_request.call_count == 2 - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-789" - - def test_execute_command_with_parameters( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command with parameters.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, + # Test getting default value for an unsupported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "UNSUPPORTED_PARAM" + ) + assert default_value is None + + # Test getting the list of allowed configurations + allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() + + expected_keys = { + "ANSI_MODE", + "ENABLE_PHOTON", + "LEGACY_TIME_PARSER_POLICY", + "MAX_FILE_PARTITION_BYTES", + "READ_ONLY_EXTERNAL_METASTORE", + "STATEMENT_TIMEOUT", + "TIMEZONE", + "USE_CACHED_RESULT", } - mock_http_client._make_request.return_value = execute_response + assert set(allowed_configs) == expected_keys - # Create parameter mock - param = MagicMock() - param.name = "param1" - param.value = "value1" - param.type = "STRING" + def test_unimplemented_methods(self, sea_client): + """Test that unimplemented methods raise NotImplementedError.""" + # Create dummy parameters for testing + session_id = SessionId.from_sea_session_id("test-session") + command_id = MagicMock() + cursor = MagicMock() - # Mock the get_execution_result method - with patch.object(sea_client, "get_execution_result") as mock_get_result: - # Call the method with parameters + # Test execute_command + with pytest.raises(NotImplementedError) as excinfo: sea_client.execute_command( - operation="SELECT * FROM table WHERE col = :param1", - session_id=sea_session_id, + operation="SELECT 1", + session_id=session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=mock_cursor, + cursor=cursor, use_cloud_fetch=False, - parameters=[param], + parameters=[], async_op=False, enforce_embedded_schema_correctness=False, ) - - # Verify the HTTP request contains parameters - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert "parameters" in kwargs["data"] - assert len(kwargs["data"]["parameters"]) == 1 - assert kwargs["data"]["parameters"][0]["name"] == "param1" - assert kwargs["data"]["parameters"][0]["value"] == "value1" - assert kwargs["data"]["parameters"][0]["type"] == "STRING" - - def test_execute_command_failure( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that fails.""" - # Set up mock response for a failed execution - error_response = { - "statement_id": "test-statement-123", - "status": { - "state": "FAILED", - "error": { - "message": "Syntax error in SQL", - "error_code": "SYNTAX_ERROR", - }, - }, - } - - # Configure the mock to return the error response for the initial request - # and then raise an exception when trying to poll (to simulate immediate failure) - mock_http_client._make_request.side_effect = [ - error_response, # Initial response - Error( - "Statement execution did not succeed: Syntax error in SQL" - ), # Will be raised during polling - ] - - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method and expect an error - with pytest.raises(Error) as excinfo: - sea_client.execute_command( - operation="SELECT * FROM nonexistent_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - assert "Statement execution did not succeed" in str(excinfo.value) - - def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): - """Test canceling a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.cancel_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_close_command(self, sea_client, mock_http_client, sea_command_id): - """Test closing a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.close_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "DELETE" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): - """Test getting the state of a query.""" - # Set up mock response - mock_http_client._make_request.return_value = { - "statement_id": "test-statement-123", - "status": {"state": "RUNNING"}, - } - - # Call the method - state = sea_client.get_query_state(sea_command_id) - - # Verify the result - assert state == CommandState.RUNNING - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_execution_result( - self, sea_client, mock_http_client, mock_cursor, sea_command_id - ): - """Test getting the result of a command execution.""" - # Set up mock response - sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - mock_http_client._make_request.return_value = sea_response - - # Create a real result set to verify the implementation - result = sea_client.get_execution_result(sea_command_id, mock_cursor) - print(result) - - # Verify basic properties of the result - assert result.statement_id == "test-statement-123" - assert result.status == CommandState.SUCCEEDED - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" + assert "execute_command is not yet implemented" in str(excinfo.value) + + # Test cancel_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.cancel_command(command_id) + assert "cancel_command is not yet implemented" in str(excinfo.value) + + # Test close_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.close_command(command_id) + assert "close_command is not yet implemented" in str(excinfo.value) + + # Test get_query_state + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_query_state(command_id) + assert "get_query_state is not yet implemented" in str(excinfo.value) + + # Test get_execution_result + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_execution_result(command_id, cursor) + assert "get_execution_result is not yet implemented" in str(excinfo.value) + + # Test metadata operations + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_catalogs(session_id, 100, 1000, cursor) + assert "get_catalogs is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_schemas(session_id, 100, 1000, cursor) + assert "get_schemas is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_tables(session_id, 100, 1000, cursor) + assert "get_tables is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_columns(session_id, 100, 1000, cursor) + assert "get_columns is not yet implemented" in str(excinfo.value) + + def test_max_download_threads_property(self, sea_client): + """Test the max_download_threads property.""" + assert sea_client.max_download_threads == 10 + + # Create a client with a custom value + custom_client = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=20, ) - # Tests for metadata operations - - def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): - """Test getting catalogs.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method - result = sea_client.get_catalogs( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["lz4_compression"] is False - assert kwargs["use_cloud_fetch"] is False - assert kwargs["parameters"] == [] - assert kwargs["async_op"] is False - assert kwargs["enforce_embedded_schema_correctness"] is False - - def test_get_schemas_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_with_catalog_and_schema_pattern( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with catalog and schema pattern.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog and schema pattern - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting schemas without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - ) - - assert "Catalog name is required for get_schemas" in str(excinfo.value) - - def test_get_tables_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_all_catalogs( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables from all catalogs using wildcard.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with wildcard catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="*", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_schema_and_table_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with schema and table patterns.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog, schema, and table patterns - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting tables without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_tables" in str(excinfo.value) - - def test_get_columns_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_with_all_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with all patterns specified.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with all patterns - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - column_name="test_column", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting columns without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_columns" in str(excinfo.value) + # Verify the custom value is returned + assert custom_client.max_download_threads == 20 From 31e6c8305154e9c6384b422be35ac17b6f851e0c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:54:05 +0000 Subject: [PATCH 08/33] cleanup Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 8 +- src/databricks/sql/backend/types.py | 38 ++++---- src/databricks/sql/result_set.py | 91 ++++++++------------ 3 files changed, 65 insertions(+), 72 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index e03d6f235..21a6befbe 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -913,7 +913,10 @@ def get_query_state(self, command_id: CommandId) -> CommandState: poll_resp = self._poll_for_status(thrift_handle) operation_state = poll_resp.operationState self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) - return CommandState.from_thrift_state(operation_state) + state = CommandState.from_thrift_state(operation_state) + if state is None: + raise ValueError(f"Invalid operation state: {operation_state}") + return state @staticmethod def _check_direct_results_for_error(t_spark_direct_results): @@ -1175,7 +1178,6 @@ def _handle_execute_response(self, resp, cursor): execute_response, arrow_schema_bytes, ) = self._results_message_to_execute_response(resp, final_operation_state) - execute_response.command_id = command_id return execute_response, arrow_schema_bytes def _handle_execute_response_async(self, resp, cursor): @@ -1237,7 +1239,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(command_id.guid)) + logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 5bf02e0ea..3107083fb 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -285,9 +285,6 @@ def __init__( backend_type: BackendType, guid: Any, secret: Optional[Any] = None, - operation_type: Optional[int] = None, - has_result_set: bool = False, - modified_row_count: Optional[int] = None, ): """ Initialize a CommandId. @@ -296,17 +293,34 @@ def __init__( backend_type: The type of backend (THRIFT or SEA) guid: The primary identifier for the command secret: The secret part of the identifier (only used for Thrift) - operation_type: The operation type (only used for Thrift) - has_result_set: Whether the command has a result set - modified_row_count: The number of rows modified by the command """ self.backend_type = backend_type self.guid = guid self.secret = secret - self.operation_type = operation_type - self.has_result_set = has_result_set - self.modified_row_count = modified_row_count + + + def __str__(self) -> str: + """ + Return a string representation of the CommandId. + + For SEA backend, returns the guid. + For Thrift backend, returns a format like "guid|secret". + + Returns: + A string representation of the command ID + """ + + if self.backend_type == BackendType.SEA: + return str(self.guid) + elif self.backend_type == BackendType.THRIFT: + secret_hex = ( + guid_to_hex_id(self.secret) + if isinstance(self.secret, bytes) + else str(self.secret) + ) + return f"{self.to_hex_guid()}|{secret_hex}" + return str(self.guid) @classmethod def from_thrift_handle(cls, operation_handle): @@ -329,9 +343,6 @@ def from_thrift_handle(cls, operation_handle): BackendType.THRIFT, guid_bytes, secret_bytes, - operation_handle.operationType, - operation_handle.hasResultSet, - operation_handle.modifiedRowCount, ) @classmethod @@ -364,9 +375,6 @@ def to_thrift_handle(self): handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret) return ttypes.TOperationHandle( operationId=handle_identifier, - operationType=self.operation_type, - hasResultSet=self.has_result_set, - modifiedRowCount=self.modified_row_count, ) def to_sea_statement_id(self): diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index fc8595839..12ee129cf 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -5,6 +5,8 @@ import time import pandas +from databricks.sql.backend.sea.backend import SeaDatabricksClient + try: import pyarrow except ImportError: @@ -13,6 +15,7 @@ if TYPE_CHECKING: from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.client import Connection +from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError @@ -31,21 +34,37 @@ class ResultSet(ABC): def __init__( self, - connection, - backend, + connection: "Connection", + backend: "DatabricksClient", arraysize: int, buffer_size_bytes: int, - command_id=None, - status=None, + command_id: CommandId, + status: CommandState, has_been_closed_server_side: bool = False, has_more_rows: bool = False, results_queue=None, description=None, is_staging_operation: bool = False, ): - """Initialize the base ResultSet with common properties.""" + """ + A ResultSet manages the results of a single command. + + Args: + connection: The parent connection + backend: The backend client + arraysize: The max number of rows to fetch at a time (PEP-249) + buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + command_id: The command ID + status: The command status + has_been_closed_server_side: Whether the command has been closed on the server + has_more_rows: Whether the command has more rows + results_queue: The results queue + description: column description of the results + is_staging_operation: Whether the command is a staging operation + """ + self.connection = connection - self.backend = backend # Store the backend client directly + self.backend = backend self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -240,7 +259,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2): + def merge_columnar(self, result1, result2) -> "ColumnTable": """ Function to merge / combining the columnar results into a single result :param result1: @@ -387,12 +406,11 @@ class SeaResultSet(ResultSet): def __init__( self, - connection, - sea_client, + connection: "Connection", + execute_response: "ExecuteResponse", + sea_client: "SeaDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, - execute_response=None, - sea_response=None, ): """ Initialize a SeaResultSet with the response from a SEA query execution. @@ -402,56 +420,21 @@ def __init__( sea_client: The SeaDatabricksClient instance for direct access buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch - execute_response: Response from the execute command (new style) - sea_response: Direct SEA response (legacy style) + execute_response: Response from the execute command """ - # Handle both initialization styles - if execute_response is not None: - # New style with ExecuteResponse - command_id = execute_response.command_id - status = execute_response.status - has_been_closed_server_side = execute_response.has_been_closed_server_side - has_more_rows = execute_response.has_more_rows - results_queue = execute_response.results_queue - description = execute_response.description - is_staging_operation = execute_response.is_staging_operation - self._response = getattr(execute_response, "sea_response", {}) - self.statement_id = command_id.to_sea_statement_id() if command_id else None - elif sea_response is not None: - # Legacy style with direct sea_response - self._response = sea_response - # Extract values from sea_response - command_id = CommandId.from_sea_statement_id( - sea_response.get("statement_id", "") - ) - self.statement_id = sea_response.get("statement_id", "") - - # Extract status - status_data = sea_response.get("status", {}) - status = CommandState.from_sea_state(status_data.get("state", "PENDING")) - - # Set defaults for other fields - has_been_closed_server_side = False - has_more_rows = False - results_queue = None - description = None - is_staging_operation = False - else: - raise ValueError("Either execute_response or sea_response must be provided") - # Call parent constructor with common attributes super().__init__( connection=connection, backend=sea_client, arraysize=arraysize, buffer_size_bytes=buffer_size_bytes, - command_id=command_id, - status=status, - has_been_closed_server_side=has_been_closed_server_side, - has_more_rows=has_more_rows, - results_queue=results_queue, - description=description, - is_staging_operation=is_staging_operation, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, ) def _fill_results_buffer(self): From 69ea23811e03705998baba569bcda259a0646de5 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:56:09 +0000 Subject: [PATCH 09/33] re-introduce get_schema_desc Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 1 - src/databricks/sql/result_set.py | 21 +++++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 3107083fb..7a276c102 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -299,7 +299,6 @@ def __init__( self.guid = guid self.secret = secret - def __str__(self) -> str: """ Return a string representation of the CommandId. diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 12ee129cf..1fee995e5 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -59,12 +59,12 @@ def __init__( has_been_closed_server_side: Whether the command has been closed on the server has_more_rows: Whether the command has more rows results_queue: The results queue - description: column description of the results + description: column description of the results is_staging_operation: Whether the command is a staging operation """ self.connection = connection - self.backend = backend + self.backend = backend self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -400,6 +400,23 @@ def fetchmany(self, size: int) -> List[Row]: else: return self._convert_arrow_table(self.fetchmany_arrow(size)) + @staticmethod + def _get_schema_description(table_schema_message): + """ + Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + """ + + def map_col_type(type_): + if type_.startswith("decimal"): + return "decimal" + else: + return type_ + + return [ + (column.name, map_col_type(column.datatype), None, None, None, None, None) + for column in table_schema_message.columns + ] + class SeaResultSet(ResultSet): """ResultSet implementation for SEA backend.""" From 66d75171991f9fcc98d541729a3127aea0d37a81 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:57:53 +0000 Subject: [PATCH 10/33] remove SeaResultSet Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 72 -------------------------------- 1 file changed, 72 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 1fee995e5..eaabcc186 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -416,75 +416,3 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] - - -class SeaResultSet(ResultSet): - """ResultSet implementation for SEA backend.""" - - def __init__( - self, - connection: "Connection", - execute_response: "ExecuteResponse", - sea_client: "SeaDatabricksClient", - buffer_size_bytes: int = 104857600, - arraysize: int = 10000, - ): - """ - Initialize a SeaResultSet with the response from a SEA query execution. - - Args: - connection: The parent connection - sea_client: The SeaDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - execute_response: Response from the execute command - """ - - super().__init__( - connection=connection, - backend=sea_client, - arraysize=arraysize, - buffer_size_bytes=buffer_size_bytes, - command_id=execute_response.command_id, - status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, - description=execute_response.description, - is_staging_operation=execute_response.is_staging_operation, - ) - - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchmany(self, size: Optional[int] = None) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - - raise NotImplementedError("fetchmany is not implemented for SEA backend") - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - raise NotImplementedError("fetchall is not implemented for SEA backend") - - def fetchmany_arrow(self, size: int) -> Any: - """Fetch the next set of rows as an Arrow table.""" - raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") - - def fetchall_arrow(self) -> Any: - """Fetch all remaining rows as an Arrow table.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") From 71feef96b3a41889a5cd9313fc81910cebd7a084 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:01:22 +0000 Subject: [PATCH 11/33] clean imports and attributes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/databricks_client.py | 1 + src/databricks/sql/backend/sea/utils/http_client.py | 2 +- src/databricks/sql/result_set.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index cd347d9ab..8fda71e1e 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -88,6 +88,7 @@ def execute_command( ) -> Union["ResultSet", None]: """ Executes a SQL command or query within the specified session. + This method sends a SQL command to the server for execution and handles the response. It can operate in both synchronous and asynchronous modes. diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index f0b931ee4..fe292919c 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -1,7 +1,7 @@ import json import logging import requests -from typing import Callable, Dict, Any, Optional, Union, List, Tuple +from typing import Callable, Dict, Any, Optional, List, Tuple from urllib.parse import urljoin from databricks.sql.auth.authenticators import AuthProvider diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index eaabcc186..a33fc977d 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -72,7 +72,7 @@ def __init__( self.command_id = command_id self.status = status self.has_been_closed_server_side = has_been_closed_server_side - self._has_more_rows = has_more_rows + self.has_more_rows = has_more_rows self.results = results_queue self._is_staging_operation = is_staging_operation From ae9862f90e7cf0a4949d6b1c7e04fdbae222c2d8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:05:53 +0000 Subject: [PATCH 12/33] pass CommandId to ExecResp Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 7 ++++++- src/databricks/sql/result_set.py | 10 +++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 21a6befbe..316cf24a0 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -352,6 +352,7 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ + # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -866,9 +867,13 @@ def get_execution_result( ssl_options=self._ssl_options, ) + status = CommandState.from_thrift_state(resp.status) + if status is None: + raise ValueError(f"Invalid operation state: {resp.status}") + execute_response = ExecuteResponse( command_id=command_id, - status=resp.status, + status=status, description=description, has_more_rows=has_more_rows, results_queue=queue, diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a33fc977d..a0cb73732 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -207,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self._has_more_rows = has_more_rows + self.has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -291,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) From d8aa69e40438c33014e0d5afaec6a4175e64bea8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:08:04 +0000 Subject: [PATCH 13/33] remove changes in types Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 57 +++++++++-------------------- 1 file changed, 17 insertions(+), 40 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 7a276c102..9cd21b5e6 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,6 +1,5 @@ -from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Any, Tuple +from typing import Dict, Optional, Any import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -81,28 +80,6 @@ def from_thrift_state( else: return None - @classmethod - def from_sea_state(cls, state: str) -> Optional["CommandState"]: - """ - Map SEA state string to CommandState enum. - - Args: - state: SEA state string - - Returns: - CommandState: The corresponding CommandState enum value - """ - state_mapping = { - "PENDING": cls.PENDING, - "RUNNING": cls.RUNNING, - "SUCCEEDED": cls.SUCCEEDED, - "FAILED": cls.FAILED, - "CLOSED": cls.CLOSED, - "CANCELED": cls.CANCELLED, - } - - return state_mapping.get(state, None) - class BackendType(Enum): """ @@ -285,6 +262,9 @@ def __init__( backend_type: BackendType, guid: Any, secret: Optional[Any] = None, + operation_type: Optional[int] = None, + has_result_set: bool = False, + modified_row_count: Optional[int] = None, ): """ Initialize a CommandId. @@ -293,11 +273,17 @@ def __init__( backend_type: The type of backend (THRIFT or SEA) guid: The primary identifier for the command secret: The secret part of the identifier (only used for Thrift) + operation_type: The operation type (only used for Thrift) + has_result_set: Whether the command has a result set + modified_row_count: The number of rows modified by the command """ self.backend_type = backend_type self.guid = guid self.secret = secret + self.operation_type = operation_type + self.has_result_set = has_result_set + self.modified_row_count = modified_row_count def __str__(self) -> str: """ @@ -332,6 +318,7 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ + if operation_handle is None: return None @@ -342,6 +329,9 @@ def from_thrift_handle(cls, operation_handle): BackendType.THRIFT, guid_bytes, secret_bytes, + operation_handle.operationType, + operation_handle.hasResultSet, + operation_handle.modifiedRowCount, ) @classmethod @@ -374,6 +364,9 @@ def to_thrift_handle(self): handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret) return ttypes.TOperationHandle( operationId=handle_identifier, + operationType=self.operation_type, + hasResultSet=self.has_result_set, + modifiedRowCount=self.modified_row_count, ) def to_sea_statement_id(self): @@ -401,19 +394,3 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) - - -@dataclass -class ExecuteResponse: - """Response from executing a SQL command.""" - - command_id: CommandId - status: CommandState - description: Optional[ - List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] - ] = None - has_more_rows: bool = False - results_queue: Optional[Any] = None - has_been_closed_server_side: bool = False - lz4_compressed: bool = True - is_staging_operation: bool = False From db139bc1179bb7cab6ec6f283cdfa0646b04b01b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:09:35 +0000 Subject: [PATCH 14/33] add back essential types (ExecResponse, from_sea_state) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 39 ++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 9cd21b5e6..958eaa289 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from enum import Enum -from typing import Dict, Optional, Any +from typing import Dict, List, Optional, Any, Tuple import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -80,6 +81,27 @@ def from_thrift_state( else: return None + @classmethod + def from_sea_state(cls, state: str) -> Optional["CommandState"]: + """ + Map SEA state string to CommandState enum. + Args: + state: SEA state string + Returns: + CommandState: The corresponding CommandState enum value + """ + state_mapping = { + "PENDING": cls.PENDING, + "RUNNING": cls.RUNNING, + "SUCCEEDED": cls.SUCCEEDED, + "FAILED": cls.FAILED, + "CLOSED": cls.CLOSED, + "CANCELED": cls.CANCELLED, + } + + return state_mapping.get(state, None) + + class BackendType(Enum): """ @@ -394,3 +416,18 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) + +@dataclass +class ExecuteResponse: + """Response from executing a SQL command.""" + + command_id: CommandId + status: CommandState + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None + has_been_closed_server_side: bool = False + lz4_compressed: bool = True + is_staging_operation: bool = False \ No newline at end of file From b977b1210a5d39543b8a3734128ba820e597337f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:11:23 +0000 Subject: [PATCH 15/33] fix fetch types Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 4 ++-- src/databricks/sql/result_set.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 958eaa289..3d24a09a1 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -102,7 +102,6 @@ def from_sea_state(cls, state: str) -> Optional["CommandState"]: return state_mapping.get(state, None) - class BackendType(Enum): """ Enum representing the type of backend @@ -417,6 +416,7 @@ def to_hex_guid(self) -> str: else: return str(self.guid) + @dataclass class ExecuteResponse: """Response from executing a SQL command.""" @@ -430,4 +430,4 @@ class ExecuteResponse: results_queue: Optional[Any] = None has_been_closed_server_side: bool = False lz4_compressed: bool = True - is_staging_operation: bool = False \ No newline at end of file + is_staging_operation: bool = False diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a0cb73732..e177d495f 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -115,12 +115,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> Any: + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> Any: + def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all remaining rows as an Arrow table.""" pass From da615c0db8ba2037c106b533331cf1ca1c9f49f9 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:12:45 +0000 Subject: [PATCH 16/33] excess imports Signed-off-by: varun-edachali-dbx --- tests/unit/test_session.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index fef070362..92de8d8fd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,11 +2,6 @@ from unittest.mock import patch, Mock import gc -from databricks.sql.thrift_api.TCLIService.ttypes import ( - TOpenSessionResp, - TSessionHandle, - THandleIdentifier, -) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql From 0da04a6f1086998927a28759fc67da4e2c8c71c6 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:15:59 +0000 Subject: [PATCH 17/33] reduce diff by maintaining logs Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 316cf24a0..821559ad3 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -800,7 +800,7 @@ def _results_message_to_execute_response(self, resp, operation_state): status = CommandState.from_thrift_state(operation_state) if status is None: - raise ValueError(f"Invalid operation state: {operation_state}") + raise ValueError(f"Unknown command state: {operation_state}") return ( ExecuteResponse( From ea9d456ee9ca47434618a079698fa166b6c8a308 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 08:47:54 +0000 Subject: [PATCH 18/33] fix int test types Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 4 +--- tests/e2e/common/retry_test_mixins.py | 2 +- tests/e2e/test_driver.py | 6 +++--- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 821559ad3..f41f4b6d8 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -867,9 +867,7 @@ def get_execution_result( ssl_options=self._ssl_options, ) - status = CommandState.from_thrift_state(resp.status) - if status is None: - raise ValueError(f"Invalid operation state: {resp.status}") + status = self.get_query_state(command_id) execute_response = ExecuteResponse( command_id=command_id, diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index b5d01a45d..dd509c062 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -326,7 +326,7 @@ def test_retry_abort_close_operation_on_404(self, caplog): with self.connection(extra_params={**self._retry_policy}) as conn: with conn.cursor() as curs: with patch( - "databricks.sql.utils.ExecuteResponse.has_been_closed_server_side", + "databricks.sql.backend.types.ExecuteResponse.has_been_closed_server_side", new_callable=PropertyMock, return_value=False, ): diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 22897644f..8cfed7c28 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -933,12 +933,12 @@ def test_result_set_close(self): result_set = cursor.active_result_set assert result_set is not None - initial_op_state = result_set.op_state + initial_op_state = result_set.status result_set.close() - assert result_set.op_state == CommandState.CLOSED - assert result_set.op_state != initial_op_state + assert result_set.status == CommandState.CLOSED + assert result_set.status != initial_op_state # Closing the result set again should be a no-op and not raise exceptions result_set.close() From 8985c624bcdbb7e0abfa73b7a1a2dbad15b4e1ec Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 08:55:24 +0000 Subject: [PATCH 19/33] [squashed from exec-sea] init execution func Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 124 ++- .../sql/backend/databricks_client.py | 28 - src/databricks/sql/backend/filters.py | 143 ++++ src/databricks/sql/backend/sea/backend.py | 360 ++++++++- .../sql/backend/sea/models/__init__.py | 30 + src/databricks/sql/backend/sea/models/base.py | 68 ++ .../sql/backend/sea/models/requests.py | 110 ++- .../sql/backend/sea/models/responses.py | 95 ++- src/databricks/sql/backend/thrift_backend.py | 3 +- src/databricks/sql/backend/types.py | 25 +- src/databricks/sql/result_set.py | 118 ++- tests/unit/test_result_set_filter.py | 246 ++++++ tests/unit/test_sea_backend.py | 755 +++++++++++++++--- tests/unit/test_sea_result_set.py | 275 +++++++ tests/unit/test_session.py | 5 + 15 files changed, 2166 insertions(+), 219 deletions(-) create mode 100644 src/databricks/sql/backend/filters.py create mode 100644 src/databricks/sql/backend/sea/models/base.py create mode 100644 tests/unit/test_result_set_filter.py create mode 100644 tests/unit/test_sea_result_set.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index abe6bd1ab..87b62efea 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -6,34 +6,122 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) + +def test_sea_query_exec(): + """ + Test executing a query using the SEA backend with result compression. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with result compression enabled and disabled, + and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + sys.exit(1) + + try: + # Test with compression enabled + logger.info("Creating connection with LZ4 compression enabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, # Enable cloud fetch to use compression + enable_query_result_lz4_compression=True, # Enable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"backend type: {type(connection.session.backend)}") + + # Execute a simple query with compression enabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query with LZ4 compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query with compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression enabled") + + # Test with compression disabled + logger.info("Creating connection with LZ4 compression disabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, # Enable cloud fetch + enable_query_result_lz4_compression=False, # Disable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query with compression disabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query without compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query without compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression disabled") + + except Exception as e: + logger.error(f"Error during SEA query execution test: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + sys.exit(1) + + logger.info("SEA query execution test with compression completed successfully") + + def test_sea_session(): """ Test opening and closing a SEA session using the connector. - + This function connects to a Databricks SQL endpoint using the SEA backend, opens a session, and then closes it. - + Required environment variables: - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") catalog = os.environ.get("DATABRICKS_CATALOG") - + if not all([server_hostname, http_path, access_token]): logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) sys.exit(1) - + logger.info(f"Connecting to {server_hostname}") logger.info(f"HTTP Path: {http_path}") if catalog: logger.info(f"Using catalog: {catalog}") - + try: logger.info("Creating connection with SEA backend...") connection = Connection( @@ -42,25 +130,33 @@ def test_sea_session(): access_token=access_token, catalog=catalog, schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent + use_sea=True, + user_agent_entry="SEA-Test-Client", # add custom user agent + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") logger.info(f"backend type: {type(connection.session.backend)}") - + # Close the connection logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback + logger.error(traceback.format_exc()) sys.exit(1) - + logger.info("SEA session test completed successfully") + if __name__ == "__main__": + # Test session management test_sea_session() + + # Test query execution with compression + test_sea_query_exec() diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 8fda71e1e..bbca4c502 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -86,34 +86,6 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: - """ - Executes a SQL command or query within the specified session. - - This method sends a SQL command to the server for execution and handles - the response. It can operate in both synchronous and asynchronous modes. - - Args: - operation: The SQL command or query to execute - session_id: The session identifier in which to execute the command - max_rows: Maximum number of rows to fetch in a single fetch batch - max_bytes: Maximum number of bytes to fetch in a single fetch batch - lz4_compression: Whether to use LZ4 compression for result data - cursor: The cursor object that will handle the results - use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets - parameters: List of parameters to bind to the query - async_op: Whether to execute the command asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - If async_op is False, returns a ResultSet object containing the - query results and metadata. If async_op is True, returns None and the - results must be fetched later using get_execution_result(). - - Raises: - ValueError: If the session ID is invalid - OperationalError: If there's an error executing the command - ServerOperationError: If the server encounters an error during execution - """ pass @abstractmethod diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py new file mode 100644 index 000000000..7f48b6179 --- /dev/null +++ b/src/databricks/sql/backend/filters.py @@ -0,0 +1,143 @@ +""" +Client-side filtering utilities for Databricks SQL connector. + +This module provides filtering capabilities for result sets returned by different backends. +""" + +import logging +from typing import ( + List, + Optional, + Any, + Callable, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet + +from databricks.sql.result_set import SeaResultSet + +logger = logging.getLogger(__name__) + + +class ResultSetFilter: + """ + A general-purpose filter for result sets that can be applied to any backend. + + This class provides methods to filter result sets based on various criteria, + similar to the client-side filtering in the JDBC connector. + """ + + @staticmethod + def _filter_sea_result_set( + result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] + ) -> "SeaResultSet": + """ + Filter a SEA result set using the provided filter function. + + Args: + result_set: The SEA result set to filter + filter_func: Function that takes a row and returns True if the row should be included + + Returns: + A filtered SEA result set + """ + # Create a filtered version of the result set + filtered_response = result_set._response.copy() + + # If there's a result with rows, filter them + if ( + "result" in filtered_response + and "data_array" in filtered_response["result"] + ): + rows = filtered_response["result"]["data_array"] + filtered_rows = [row for row in rows if filter_func(row)] + filtered_response["result"]["data_array"] = filtered_rows + + # Update row count if present + if "row_count" in filtered_response["result"]: + filtered_response["result"]["row_count"] = len(filtered_rows) + + # Create a new result set with the filtered data + return SeaResultSet( + connection=result_set.connection, + sea_response=filtered_response, + sea_client=result_set.backend, + buffer_size_bytes=result_set.buffer_size_bytes, + arraysize=result_set.arraysize, + ) + + @staticmethod + def filter_by_column_values( + result_set: "ResultSet", + column_index: int, + allowed_values: List[str], + case_sensitive: bool = False, + ) -> "ResultSet": + """ + Filter a result set by values in a specific column. + + Args: + result_set: The result set to filter + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered result set + """ + # Convert to uppercase for case-insensitive comparison if needed + if not case_sensitive: + allowed_values = [v.upper() for v in allowed_values] + + # Determine the type of result set and apply appropriate filtering + if isinstance(result_set, SeaResultSet): + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and isinstance(row[column_index], str) + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), + ) + + # For other result set types, return the original (should be handled by specific implementations) + logger.warning( + f"Filtering not implemented for result set type: {type(result_set).__name__}" + ) + return result_set + + @staticmethod + def filter_tables_by_type( + result_set: "ResultSet", table_types: Optional[List[str]] = None + ) -> "ResultSet": + """ + Filter a result set of tables by the specified table types. + + This is a client-side filter that processes the result set after it has been + retrieved from the server. It filters out tables whose type does not match + any of the types in the table_types list. + + Args: + result_set: The original result set containing tables + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) + + Returns: + A filtered result set containing only tables of the specified types + """ + # Default table types if none specified + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] + valid_types = ( + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES + ) + + # Table type is typically in the 6th column (index 5) + return ResultSetFilter.filter_by_column_values( + result_set, 5, valid_types, case_sensitive=False + ) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 97d25a058..c7a4ed1b1 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,23 +1,34 @@ import logging import re -from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set +import uuid +import time +from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING + +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import ServerOperationError +from databricks.sql.exc import Error, NotSupportedError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, + StatementParameter, + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) @@ -55,6 +66,9 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. """ # SEA API paths @@ -274,41 +288,222 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: List, async_op: bool, enforce_embedded_schema_correctness: bool, - ): - """Not implemented yet.""" - raise NotImplementedError( - "execute_command is not yet implemented for SEA backend" + ) -> Union["ResultSet", None]: + """ + Execute a SQL command using the SEA backend. + + Args: + operation: SQL command to execute + session_id: Session identifier + max_rows: Maximum number of rows to fetch + max_bytes: Maximum number of bytes to fetch + lz4_compression: Whether to use LZ4 compression + cursor: Cursor executing the command + use_cloud_fetch: Whether to use cloud fetch + parameters: SQL parameters + async_op: Whether to execute asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + ResultSet: A SeaResultSet instance for the executed command + """ + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + + sea_session_id = session_id.to_sea_session_id() + + # Convert parameters to StatementParameter objects + sea_parameters = [] + if parameters: + for param in parameters: + sea_parameters.append( + StatementParameter( + name=param.name, + value=param.value, + type=param.type if hasattr(param, "type") else None, + ) + ) + + format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" + disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" + result_compression = "LZ4_FRAME" if lz4_compression else "NONE" + + request = ExecuteStatementRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + statement=operation, + disposition=disposition, + format=format, + wait_timeout="0s" if async_op else "10s", + on_wait_timeout="CONTINUE", + row_limit=max_rows if max_rows > 0 else None, + byte_limit=max_bytes if max_bytes > 0 else None, + parameters=sea_parameters if sea_parameters else None, + result_compression=result_compression, ) + response_data = self.http_client._make_request( + method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ) + response = ExecuteStatementResponse.from_dict(response_data) + statement_id = response.statement_id + if not statement_id: + raise ServerOperationError( + "Failed to execute command: No statement ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) + + command_id = CommandId.from_sea_statement_id(statement_id) + + # Store the command ID in the cursor + cursor.active_command_id = command_id + + # If async operation, return None and let the client poll for results + if async_op: + return None + + # For synchronous operation, wait for the statement to complete + # Poll until the statement is done + status = response.status + state = status.state + + # Keep polling until we reach a terminal state + while state in [CommandState.PENDING, CommandState.RUNNING]: + time.sleep(0.5) # add a small delay to avoid excessive API calls + state = self.get_query_state(command_id) + + if state != CommandState.SUCCEEDED: + raise ServerOperationError( + f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", + { + "operation-id": command_id.to_sea_statement_id(), + "diagnostic-info": None, + }, + ) + + return self.get_execution_result(command_id, cursor) + def cancel_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "cancel_command is not yet implemented for SEA backend" + """ + Cancel a running command. + + Args: + command_id: Command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CancelStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="POST", + path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def close_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "close_command is not yet implemented for SEA backend" + """ + Close a command and release resources. + + Args: + command_id: Command identifier to close + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CloseStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="DELETE", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def get_query_state(self, command_id: CommandId) -> CommandState: - """Not implemented yet.""" - raise NotImplementedError( - "get_query_state is not yet implemented for SEA backend" + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = GetStatementRequest(statement_id=sea_statement_id) + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) + # Parse the response + response = GetStatementResponse.from_dict(response_data) + return response.status.state + def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError( - "get_execution_result is not yet implemented for SEA backend" + ) -> "ResultSet": + """ + Get the result of a command execution. + + Args: + command_id: Command identifier + cursor: Cursor executing the command + + Returns: + ResultSet: A SeaResultSet instance with the execution results + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + # Create the request model + request = GetStatementRequest(statement_id=sea_statement_id) + + # Get the statement result + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + + # Create and return a SeaResultSet + from databricks.sql.result_set import SeaResultSet + + return SeaResultSet( + connection=cursor.connection, + sea_response=response_data, + sea_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, ) # == Metadata Operations == @@ -319,9 +514,22 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get available catalogs by executing 'SHOW CATALOGS'.""" + result = self.execute_command( + operation="SHOW CATALOGS", + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_schemas( self, @@ -331,9 +539,30 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_schemas is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_schemas") + + operation = f"SHOW SCHEMAS IN `{catalog_name}`" + + if schema_name: + operation += f" LIKE '{schema_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_tables( self, @@ -345,9 +574,43 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_tables is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_tables") + + operation = "SHOW TABLES IN " + ( + "ALL CATALOGS" + if catalog_name in [None, "*", "%"] + else f"CATALOG `{catalog_name}`" + ) + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" LIKE '{table_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + + # Apply client-side filtering by table_types if specified + from databricks.sql.backend.filters import ResultSetFilter + + result = ResultSetFilter.filter_tables_by_type(result, table_types) + + return result def get_columns( self, @@ -359,6 +622,33 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_columns is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_columns") + + operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" TABLE LIKE '{table_name}'" + + if column_name: + operation += f" LIKE '{column_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index c9310d367..b7c8bd399 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,19 +4,49 @@ This package contains data models for SEA API requests and responses. """ +from databricks.sql.backend.sea.models.base import ( + ServiceError, + StatementStatus, + ExternalLink, + ResultData, + ColumnInfo, + ResultManifest, +) + from databricks.sql.backend.sea.models.requests import ( + StatementParameter, + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) __all__ = [ + # Base models + "ServiceError", + "StatementStatus", + "ExternalLink", + "ResultData", + "ColumnInfo", + "ResultManifest", # Request models + "StatementParameter", + "ExecuteStatementRequest", + "GetStatementRequest", + "CancelStatementRequest", + "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models + "ExecuteStatementResponse", + "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py new file mode 100644 index 000000000..671f7be13 --- /dev/null +++ b/src/databricks/sql/backend/sea/models/base.py @@ -0,0 +1,68 @@ +""" +Base models for the SEA (Statement Execution API) backend. + +These models define the common structures used in SEA API requests and responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState + + +@dataclass +class ServiceError: + """Error information returned by the SEA API.""" + + message: str + error_code: Optional[str] = None + + +@dataclass +class StatementStatus: + """Status information for a statement execution.""" + + state: CommandState + error: Optional[ServiceError] = None + sql_state: Optional[str] = None + + +@dataclass +class ExternalLink: + """External link information for result data.""" + + external_link: str + expiration: str + chunk_index: int + + +@dataclass +class ResultData: + """Result data from a statement execution.""" + + data: Optional[List[List[Any]]] = None + external_links: Optional[List[ExternalLink]] = None + + +@dataclass +class ColumnInfo: + """Information about a column in the result set.""" + + name: str + type_name: str + type_text: str + nullable: bool = True + precision: Optional[int] = None + scale: Optional[int] = None + ordinal_position: Optional[int] = None + + +@dataclass +class ResultManifest: + """Manifest information for a result set.""" + + schema: List[ColumnInfo] + total_row_count: int + total_byte_count: int + truncated: bool = False + chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 7966cb502..1c519d931 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,5 +1,111 @@ -from typing import Dict, Any, Optional -from dataclasses import dataclass +""" +Request models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API requests. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + + +@dataclass +class StatementParameter: + """Parameter for a SQL statement.""" + + name: str + value: Optional[str] = None + type: Optional[str] = None + + +@dataclass +class ExecuteStatementRequest: + """Request to execute a SQL statement.""" + + warehouse_id: str + statement: str + session_id: str + disposition: str = "EXTERNAL_LINKS" + format: str = "JSON_ARRAY" + wait_timeout: str = "10s" + on_wait_timeout: str = "CONTINUE" + row_limit: Optional[int] = None + byte_limit: Optional[int] = None + parameters: Optional[List[StatementParameter]] = None + catalog: Optional[str] = None + schema: Optional[str] = None + result_compression: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + result: Dict[str, Any] = { + "warehouse_id": self.warehouse_id, + "session_id": self.session_id, + "statement": self.statement, + "disposition": self.disposition, + "format": self.format, + "wait_timeout": self.wait_timeout, + "on_wait_timeout": self.on_wait_timeout, + } + + if self.row_limit is not None and self.row_limit > 0: + result["row_limit"] = self.row_limit + + if self.byte_limit is not None and self.byte_limit > 0: + result["byte_limit"] = self.byte_limit + + if self.catalog: + result["catalog"] = self.catalog + + if self.schema: + result["schema"] = self.schema + + if self.result_compression: + result["result_compression"] = self.result_compression + + if self.parameters: + result["parameters"] = [ + { + "name": param.name, + **({"value": param.value} if param.value is not None else {}), + **({"type": param.type} if param.type is not None else {}), + } + for param in self.parameters + ] + + return result + + +@dataclass +class GetStatementRequest: + """Request to get information about a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CancelStatementRequest: + """Request to cancel a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CloseStatementRequest: + """Request to close a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 1bb54590f..d70459b9f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,5 +1,96 @@ -from typing import Dict, Any -from dataclasses import dataclass +""" +Response models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState +from databricks.sql.backend.sea.models.base import ( + StatementStatus, + ResultManifest, + ResultData, + ServiceError, +) + + +@dataclass +class ExecuteStatementResponse: + """Response from executing a SQL statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": + """Create an ExecuteStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) + + +@dataclass +class GetStatementResponse: + """Response from getting information about a statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": + """Create a GetStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) @dataclass diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index f41f4b6d8..810c2e7a1 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -352,7 +352,6 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ - # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -1242,7 +1241,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) + logger.debug("Cancelling command {}".format(command_id.guid)) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 3d24a09a1..5bf02e0ea 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -85,8 +85,10 @@ def from_thrift_state( def from_sea_state(cls, state: str) -> Optional["CommandState"]: """ Map SEA state string to CommandState enum. + Args: state: SEA state string + Returns: CommandState: The corresponding CommandState enum value """ @@ -306,28 +308,6 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count - def __str__(self) -> str: - """ - Return a string representation of the CommandId. - - For SEA backend, returns the guid. - For Thrift backend, returns a format like "guid|secret". - - Returns: - A string representation of the command ID - """ - - if self.backend_type == BackendType.SEA: - return str(self.guid) - elif self.backend_type == BackendType.THRIFT: - secret_hex = ( - guid_to_hex_id(self.secret) - if isinstance(self.secret, bytes) - else str(self.secret) - ) - return f"{self.to_hex_guid()}|{secret_hex}" - return str(self.guid) - @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -339,7 +319,6 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ - if operation_handle is None: return None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index e177d495f..a4beda629 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -64,7 +64,7 @@ def __init__( """ self.connection = connection - self.backend = backend + self.backend = backend # Store the backend client directly self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -115,12 +115,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + def fetchmany_arrow(self, size: int) -> Any: """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> "pyarrow.Table": + def fetchall_arrow(self) -> Any: """Fetch all remaining rows as an Arrow table.""" pass @@ -207,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self._has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -259,7 +259,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": + def merge_columnar(self, result1, result2): """ Function to merge / combining the columnar results into a single result :param result1: @@ -291,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -403,16 +403,96 @@ def fetchmany(self, size: int) -> List[Row]: @staticmethod def _get_schema_description(table_schema_message): """ - Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + execute_response: Response from the execute command (new style) + sea_response: Direct SEA response (legacy style) """ + # Handle both initialization styles + if execute_response is not None: + # New style with ExecuteResponse + command_id = execute_response.command_id + status = execute_response.status + has_been_closed_server_side = execute_response.has_been_closed_server_side + has_more_rows = execute_response.has_more_rows + results_queue = execute_response.results_queue + description = execute_response.description + is_staging_operation = execute_response.is_staging_operation + self._response = getattr(execute_response, "sea_response", {}) + self.statement_id = command_id.to_sea_statement_id() if command_id else None + elif sea_response is not None: + # Legacy style with direct sea_response + self._response = sea_response + # Extract values from sea_response + command_id = CommandId.from_sea_statement_id( + sea_response.get("statement_id", "") + ) + self.statement_id = sea_response.get("statement_id", "") + + # Extract status + status_data = sea_response.get("status", {}) + status = CommandState.from_sea_state(status_data.get("state", "PENDING")) + + # Set defaults for other fields + has_been_closed_server_side = False + has_more_rows = False + results_queue = None + description = None + is_staging_operation = False + else: + raise ValueError("Either execute_response or sea_response must be provided") - def map_col_type(type_): - if type_.startswith("decimal"): - return "decimal" - else: - return type_ + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=command_id, + status=status, + has_been_closed_server_side=has_been_closed_server_side, + has_more_rows=has_more_rows, + results_queue=results_queue, + description=description, + is_staging_operation=is_staging_operation, + ) - return [ - (column.name, map_col_type(column.datatype), None, None, None, None, None) - for column in table_schema_message.columns - ] + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchmany(self, size: Optional[int] = None) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + + raise NotImplementedError("fetchmany is not implemented for SEA backend") + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + raise NotImplementedError("fetchall is not implemented for SEA backend") + + def fetchmany_arrow(self, size: int) -> Any: + """Fetch the next set of rows as an Arrow table.""" + raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") + + def fetchall_arrow(self) -> Any: + """Fetch all remaining rows as an Arrow table.""" + raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py new file mode 100644 index 000000000..e8eb2a757 --- /dev/null +++ b/tests/unit/test_result_set_filter.py @@ -0,0 +1,246 @@ +""" +Tests for the ResultSetFilter class. + +This module contains tests for the ResultSetFilter class, which provides +filtering capabilities for result sets returned by different backends. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.backend.filters import ResultSetFilter +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestResultSetFilter: + """Test suite for the ResultSetFilter class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response_with_tables(self): + """Create a sample SEA response with table data based on the server schema.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 7, + "columns": [ + {"name": "namespace", "type_text": "STRING", "position": 0}, + {"name": "tableName", "type_text": "STRING", "position": 1}, + {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, + {"name": "information", "type_text": "STRING", "position": 3}, + {"name": "catalogName", "type_text": "STRING", "position": 4}, + {"name": "tableType", "type_text": "STRING", "position": 5}, + {"name": "remarks", "type_text": "STRING", "position": 6}, + ], + }, + "total_row_count": 4, + }, + "result": { + "data_array": [ + ["schema1", "table1", False, None, "catalog1", "TABLE", None], + ["schema1", "table2", False, None, "catalog1", "VIEW", None], + [ + "schema1", + "table3", + False, + None, + "catalog1", + "SYSTEM TABLE", + None, + ], + ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], + ] + }, + } + + @pytest.fixture + def sea_result_set_with_tables( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Create a SeaResultSet with table data.""" + # Create a deep copy of the response to avoid test interference + import copy + + sea_response_copy = copy.deepcopy(sea_response_with_tables) + + return SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy, + buffer_size_bytes=1000, + arraysize=100, + ) + + def test_filter_tables_by_type_default(self, sea_result_set_with_tables): + """Test filtering tables by type with default types.""" + # Default types are TABLE, VIEW, SYSTEM TABLE + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables + ) + + # Verify that only the default types are included + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): + """Test filtering tables by type with custom types.""" + # Filter for only TABLE and EXTERNAL + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] + ) + + # Verify that only the specified types are included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "EXTERNAL" in table_types + assert "VIEW" not in table_types + assert "SYSTEM TABLE" not in table_types + + def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): + """Test that table type filtering is case-insensitive.""" + # Filter for lowercase "table" and "view" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["table", "view"] + ) + + # Verify that the matching types are included despite case differences + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" not in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): + """Test filtering tables with an empty type list (should use defaults).""" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=[] + ) + + # Verify that default types are used + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_by_column_values(self, sea_result_set_with_tables): + """Test filtering by values in a specific column.""" + # Filter by namespace in column index 0 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] + ) + + # All rows have schema1 in namespace, so all should be included + assert len(filtered_result._response["result"]["data_array"]) == 4 + + # Filter by table name in column index 1 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, + column_index=1, + allowed_values=["table1", "table3"], + ) + + # Only rows with table1 or table3 should be included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_names = [ + row[1] for row in filtered_result._response["result"]["data_array"] + ] + assert "table1" in table_names + assert "table3" in table_names + assert "table2" not in table_names + assert "table4" not in table_names + + def test_filter_by_column_values_case_sensitive( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Test case-sensitive filtering by column values.""" + import copy + + # Create a fresh result set for the first test + sea_response_copy1 = copy.deepcopy(sea_response_with_tables) + result_set1 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy1, + buffer_size_bytes=1000, + arraysize=100, + ) + + # First test: Case-sensitive filtering with lowercase values (should find no matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set1, + column_index=5, # tableType column + allowed_values=["table", "view"], # lowercase + case_sensitive=True, + ) + + # Verify no matches with lowercase values + assert len(filtered_result._response["result"]["data_array"]) == 0 + + # Create a fresh result set for the second test + sea_response_copy2 = copy.deepcopy(sea_response_with_tables) + result_set2 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy2, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Second test: Case-sensitive filtering with correct case (should find matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set2, + column_index=5, # tableType column + allowed_values=["TABLE", "VIEW"], # correct case + case_sensitive=True, + ) + + # Verify matches with correct case + assert len(filtered_result._response["result"]["data_array"]) == 2 + + # Extract the table types from the filtered results + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + + def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): + """Test filtering with a column index that's out of bounds.""" + # Filter by column index 10 (out of bounds) + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=10, allowed_values=["value"] + ) + + # No rows should match + assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index bc2688a68..c2078f731 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,11 +1,20 @@ +""" +Tests for the SEA (Statement Execution API) backend implementation. + +This module contains tests for the SeaDatabricksClient class, which implements +the Databricks SQL connector's SEA backend functionality. +""" + +import json import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error +from databricks.sql.exc import Error, NotSupportedError class TestSeaBackend: @@ -41,6 +50,23 @@ def sea_client(self, mock_http_client): return client + @pytest.fixture + def sea_session_id(self): + """Create a SEA session ID.""" + return SessionId.from_sea_session_id("test-session-123") + + @pytest.fixture + def sea_command_id(self): + """Create a SEA command ID.""" + return CommandId.from_sea_statement_id("test-statement-123") + + @pytest.fixture + def mock_cursor(self): + """Create a mock cursor.""" + cursor = Mock() + cursor.active_command_id = None + return cursor + def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -175,109 +201,650 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - def test_session_configuration_helpers(self): - """Test the session configuration helper methods.""" - # Test getting default value for a supported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "ANSI_MODE" - ) - assert default_value == "true" + # Tests for command execution and management + + def test_execute_command_sync( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command synchronously.""" + # Set up mock responses + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + mock_http_client._make_request.return_value = execute_response + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "warehouse_id" in kwargs["data"] + assert "session_id" in kwargs["data"] + assert "statement" in kwargs["data"] + assert kwargs["data"]["statement"] == "SELECT 1" + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-123" + + def test_execute_command_async( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command asynchronously.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-456", + "status": {"state": "PENDING"}, + } + mock_http_client._make_request.return_value = execute_response - # Test getting default value for an unsupported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "UNSUPPORTED_PARAM" + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, # Async mode + enforce_embedded_schema_correctness=False, ) - assert default_value is None - - # Test getting the list of allowed configurations - allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() - - expected_keys = { - "ANSI_MODE", - "ENABLE_PHOTON", - "LEGACY_TIME_PARSER_POLICY", - "MAX_FILE_PARTITION_BYTES", - "READ_ONLY_EXTERNAL_METASTORE", - "STATEMENT_TIMEOUT", - "TIMEZONE", - "USE_CACHED_RESULT", + + # Verify the result is None for async operation + assert result is None + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "wait_timeout" in kwargs["data"] + assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout + + # Verify the command ID was stored in the cursor + assert hasattr(mock_cursor, "active_command_id") + assert isinstance(mock_cursor.active_command_id, CommandId) + assert mock_cursor.active_command_id.guid == "test-statement-456" + + def test_execute_command_with_polling( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that requires polling.""" + # Set up mock responses for initial request and polling + initial_response = { + "statement_id": "test-statement-789", + "status": {"state": "RUNNING"}, + } + poll_response = { + "statement_id": "test-statement-789", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, + } + + # Configure mock to return different responses on subsequent calls + mock_http_client._make_request.side_effect = [initial_response, poll_response] + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method + result = sea_client.execute_command( + operation="SELECT * FROM large_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP requests (initial and poll) + assert mock_http_client._make_request.call_count == 2 + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-789" + + def test_execute_command_with_parameters( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command with parameters.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, } - assert set(allowed_configs) == expected_keys + mock_http_client._make_request.return_value = execute_response - def test_unimplemented_methods(self, sea_client): - """Test that unimplemented methods raise NotImplementedError.""" - # Create dummy parameters for testing - session_id = SessionId.from_sea_session_id("test-session") - command_id = MagicMock() - cursor = MagicMock() + # Create parameter mock + param = MagicMock() + param.name = "param1" + param.value = "value1" + param.type = "STRING" - # Test execute_command - with pytest.raises(NotImplementedError) as excinfo: + # Mock the get_execution_result method + with patch.object(sea_client, "get_execution_result") as mock_get_result: + # Call the method with parameters sea_client.execute_command( - operation="SELECT 1", - session_id=session_id, + operation="SELECT * FROM table WHERE col = :param1", + session_id=sea_session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=cursor, + cursor=mock_cursor, use_cloud_fetch=False, - parameters=[], + parameters=[param], async_op=False, enforce_embedded_schema_correctness=False, ) - assert "execute_command is not yet implemented" in str(excinfo.value) - - # Test cancel_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.cancel_command(command_id) - assert "cancel_command is not yet implemented" in str(excinfo.value) - - # Test close_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.close_command(command_id) - assert "close_command is not yet implemented" in str(excinfo.value) - - # Test get_query_state - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_query_state(command_id) - assert "get_query_state is not yet implemented" in str(excinfo.value) - - # Test get_execution_result - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_execution_result(command_id, cursor) - assert "get_execution_result is not yet implemented" in str(excinfo.value) - - # Test metadata operations - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_catalogs(session_id, 100, 1000, cursor) - assert "get_catalogs is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_schemas(session_id, 100, 1000, cursor) - assert "get_schemas is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_tables(session_id, 100, 1000, cursor) - assert "get_tables is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_columns(session_id, 100, 1000, cursor) - assert "get_columns is not yet implemented" in str(excinfo.value) - - def test_max_download_threads_property(self, sea_client): - """Test the max_download_threads property.""" - assert sea_client.max_download_threads == 10 - - # Create a client with a custom value - custom_client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - max_download_threads=20, + + # Verify the HTTP request contains parameters + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert "parameters" in kwargs["data"] + assert len(kwargs["data"]["parameters"]) == 1 + assert kwargs["data"]["parameters"][0]["name"] == "param1" + assert kwargs["data"]["parameters"][0]["value"] == "value1" + assert kwargs["data"]["parameters"][0]["type"] == "STRING" + + def test_execute_command_failure( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that fails.""" + # Set up mock response for a failed execution + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": "Syntax error in SQL", + "error_code": "SYNTAX_ERROR", + }, + }, + } + + # Configure the mock to return the error response for the initial request + # and then raise an exception when trying to poll (to simulate immediate failure) + mock_http_client._make_request.side_effect = [ + error_response, # Initial response + Error( + "Statement execution did not succeed: Syntax error in SQL" + ), # Will be raised during polling + ] + + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method and expect an error + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Statement execution did not succeed" in str(excinfo.value) + + def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): + """Test canceling a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.cancel_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_close_command(self, sea_client, mock_http_client, sea_command_id): + """Test closing a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.close_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "DELETE" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): + """Test getting the state of a query.""" + # Set up mock response + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-123", + "status": {"state": "RUNNING"}, + } + + # Call the method + state = sea_client.get_query_state(sea_command_id) + + # Verify the result + assert state == CommandState.RUNNING + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" ) - # Verify the custom value is returned - assert custom_client.max_download_threads == 20 + def test_get_execution_result( + self, sea_client, mock_http_client, mock_cursor, sea_command_id + ): + """Test getting the result of a command execution.""" + # Set up mock response + sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + mock_http_client._make_request.return_value = sea_response + + # Create a real result set to verify the implementation + result = sea_client.get_execution_result(sea_command_id, mock_cursor) + print(result) + + # Verify basic properties of the result + assert result.statement_id == "test-statement-123" + assert result.status == CommandState.SUCCEEDED + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + # Tests for metadata operations + + def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): + """Test getting catalogs.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["lz4_compression"] is False + assert kwargs["use_cloud_fetch"] is False + assert kwargs["parameters"] == [] + assert kwargs["async_op"] is False + assert kwargs["enforce_embedded_schema_correctness"] is False + + def test_get_schemas_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_with_catalog_and_schema_pattern( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with catalog and schema pattern.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog and schema pattern + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting schemas without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + ) + + assert "Catalog name is required for get_schemas" in str(excinfo.value) + + def test_get_tables_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_all_catalogs( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables from all catalogs using wildcard.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with wildcard catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_schema_and_table_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with schema and table patterns.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog, schema, and table patterns + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting tables without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_tables" in str(excinfo.value) + + def test_get_columns_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_with_all_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with all patterns specified.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with all patterns + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="test_column", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting columns without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_columns" in str(excinfo.value) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py new file mode 100644 index 000000000..f666fd613 --- /dev/null +++ b/tests/unit/test_sea_result_set.py @@ -0,0 +1,275 @@ +""" +Tests for the SeaResultSet class. + +This module contains tests for the SeaResultSet class, which implements +the result set functionality for the SEA (Statement Execution API) backend. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestSeaResultSet: + """Test suite for the SeaResultSet class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response(self): + """Create a sample SEA response.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + + @pytest.fixture + def execute_response(self): + """Create a sample execute response.""" + mock_response = Mock() + mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") + mock_response.status = CommandState.SUCCEEDED + mock_response.has_been_closed_server_side = False + mock_response.has_more_rows = False + mock_response.results_queue = None + mock_response.description = [ + ("test_value", "INT", None, None, None, None, None) + ] + mock_response.is_staging_operation = False + mock_response.sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "result": {"data_array": [["1"]]}, + } + return mock_response + + def test_init_with_sea_response( + self, mock_connection, mock_sea_client, sea_response + ): + """Test initializing SeaResultSet with a SEA response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == sea_response + + def test_init_with_execute_response( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet with an execute response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + execute_response=execute_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == execute_response.sea_response + + def test_init_with_no_response(self, mock_connection, mock_sea_client): + """Test that initialization fails when neither response type is provided.""" + with pytest.raises(ValueError) as excinfo: + SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + assert "Either execute_response or sea_response must be provided" in str( + excinfo.value + ) + + def test_close(self, mock_connection, mock_sea_client, sea_response): + """Test closing a result set.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was called + mock_sea_client.close_command.assert_called_once_with(result_set.command_id) + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_already_closed_server_side( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set that has already been closed server-side.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.has_been_closed_server_side = True + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_connection_closed( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set when the connection is closed.""" + mock_connection.open = False + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_unimplemented_methods( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that unimplemented methods raise NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Test each unimplemented method individually with specific error messages + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set.fetchone() + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + result_set.fetchmany(10) + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + # Test with default parameter value + result_set.fetchmany() + + with pytest.raises( + NotImplementedError, match="fetchall is not implemented for SEA backend" + ): + result_set.fetchall() + + with pytest.raises( + NotImplementedError, + match="fetchmany_arrow is not implemented for SEA backend", + ): + result_set.fetchmany_arrow(10) + + with pytest.raises( + NotImplementedError, + match="fetchall_arrow is not implemented for SEA backend", + ): + result_set.fetchall_arrow() + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test iteration protocol (calls fetchone internally) + next(iter(result_set)) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test using the result set in a for loop + for row in result_set: + pass + + def test_fill_results_buffer_not_implemented( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that _fill_results_buffer raises NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set._fill_results_buffer() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 92de8d8fd..fef070362 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,6 +2,11 @@ from unittest.mock import patch, Mock import gc +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, + TSessionHandle, + THandleIdentifier, +) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql From d9bcdbef396433e01b298fca9a27b1bce2b1414b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:03:13 +0000 Subject: [PATCH 20/33] remove irrelevant changes Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 124 +----- .../sql/backend/databricks_client.py | 30 ++ src/databricks/sql/backend/sea/backend.py | 360 ++---------------- .../sql/backend/sea/models/__init__.py | 30 -- src/databricks/sql/backend/sea/models/base.py | 68 ---- .../sql/backend/sea/models/requests.py | 110 +----- .../sql/backend/sea/models/responses.py | 95 +---- src/databricks/sql/backend/types.py | 64 ++-- 8 files changed, 107 insertions(+), 774 deletions(-) delete mode 100644 src/databricks/sql/backend/sea/models/base.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 87b62efea..abe6bd1ab 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -6,122 +6,34 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) - -def test_sea_query_exec(): - """ - Test executing a query using the SEA backend with result compression. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query with result compression enabled and disabled, - and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - sys.exit(1) - - try: - # Test with compression enabled - logger.info("Creating connection with LZ4 compression enabled") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=True, # Enable cloud fetch to use compression - enable_query_result_lz4_compression=True, # Enable LZ4 compression - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - logger.info(f"backend type: {type(connection.session.backend)}") - - # Execute a simple query with compression enabled - cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) - logger.info("Executing query with LZ4 compression: SELECT 1 as test_value") - cursor.execute("SELECT 1 as test_value") - logger.info("Query with compression executed successfully") - cursor.close() - connection.close() - logger.info("Successfully closed SEA session with compression enabled") - - # Test with compression disabled - logger.info("Creating connection with LZ4 compression disabled") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=False, # Enable cloud fetch - enable_query_result_lz4_compression=False, # Disable LZ4 compression - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query with compression disabled - cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) - logger.info("Executing query without compression: SELECT 1 as test_value") - cursor.execute("SELECT 1 as test_value") - logger.info("Query without compression executed successfully") - cursor.close() - connection.close() - logger.info("Successfully closed SEA session with compression disabled") - - except Exception as e: - logger.error(f"Error during SEA query execution test: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - sys.exit(1) - - logger.info("SEA query execution test with compression completed successfully") - - def test_sea_session(): """ Test opening and closing a SEA session using the connector. - + This function connects to a Databricks SQL endpoint using the SEA backend, opens a session, and then closes it. - + Required environment variables: - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") catalog = os.environ.get("DATABRICKS_CATALOG") - + if not all([server_hostname, http_path, access_token]): logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) + logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") sys.exit(1) - + logger.info(f"Connecting to {server_hostname}") logger.info(f"HTTP Path: {http_path}") if catalog: logger.info(f"Using catalog: {catalog}") - + try: logger.info("Creating connection with SEA backend...") connection = Connection( @@ -130,33 +42,25 @@ def test_sea_session(): access_token=access_token, catalog=catalog, schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", # add custom user agent - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + use_sea=True, + user_agent_entry="SEA-Test-Client" # add custom user agent ) + + logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") logger.info(f"backend type: {type(connection.session.backend)}") - + # Close the connection logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback - logger.error(traceback.format_exc()) sys.exit(1) - + logger.info("SEA session test completed successfully") - if __name__ == "__main__": - # Test session management test_sea_session() - - # Test query execution with compression - test_sea_query_exec() diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index bbca4c502..20b059fa7 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -16,6 +16,8 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.types import SessionId, CommandId, CommandState +from databricks.sql.utils import ExecuteResponse +from databricks.sql.types import SSLOptions # Forward reference for type hints from typing import TYPE_CHECKING @@ -86,6 +88,34 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: + """ + Executes a SQL command or query within the specified session. + + This method sends a SQL command to the server for execution and handles + the response. It can operate in both synchronous and asynchronous modes. + + Args: + operation: The SQL command or query to execute + session_id: The session identifier in which to execute the command + max_rows: Maximum number of rows to fetch in a single fetch batch + max_bytes: Maximum number of bytes to fetch in a single fetch batch + lz4_compression: Whether to use LZ4 compression for result data + cursor: The cursor object that will handle the results + use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets + parameters: List of parameters to bind to the query + async_op: Whether to execute the command asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + If async_op is False, returns a ResultSet object containing the + query results and metadata. If async_op is True, returns None and the + results must be fetched later using get_execution_result(). + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error executing the command + ServerOperationError: If the server encounters an error during execution + """ pass @abstractmethod diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index c7a4ed1b1..97d25a058 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,34 +1,23 @@ import logging import re -import uuid -import time -from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING - -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) +from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import Error, NotSupportedError, ServerOperationError +from databricks.sql.exc import ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, - StatementParameter, - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) @@ -66,9 +55,6 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. - - This implementation provides session management functionality for SEA, - while other operations raise NotImplementedError. """ # SEA API paths @@ -288,222 +274,41 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List, + parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, - ) -> Union["ResultSet", None]: - """ - Execute a SQL command using the SEA backend. - - Args: - operation: SQL command to execute - session_id: Session identifier - max_rows: Maximum number of rows to fetch - max_bytes: Maximum number of bytes to fetch - lz4_compression: Whether to use LZ4 compression - cursor: Cursor executing the command - use_cloud_fetch: Whether to use cloud fetch - parameters: SQL parameters - async_op: Whether to execute asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - ResultSet: A SeaResultSet instance for the executed command - """ - if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") - - sea_session_id = session_id.to_sea_session_id() - - # Convert parameters to StatementParameter objects - sea_parameters = [] - if parameters: - for param in parameters: - sea_parameters.append( - StatementParameter( - name=param.name, - value=param.value, - type=param.type if hasattr(param, "type") else None, - ) - ) - - format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" - disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" - result_compression = "LZ4_FRAME" if lz4_compression else "NONE" - - request = ExecuteStatementRequest( - warehouse_id=self.warehouse_id, - session_id=sea_session_id, - statement=operation, - disposition=disposition, - format=format, - wait_timeout="0s" if async_op else "10s", - on_wait_timeout="CONTINUE", - row_limit=max_rows if max_rows > 0 else None, - byte_limit=max_bytes if max_bytes > 0 else None, - parameters=sea_parameters if sea_parameters else None, - result_compression=result_compression, - ) - - response_data = self.http_client._make_request( - method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ): + """Not implemented yet.""" + raise NotImplementedError( + "execute_command is not yet implemented for SEA backend" ) - response = ExecuteStatementResponse.from_dict(response_data) - statement_id = response.statement_id - if not statement_id: - raise ServerOperationError( - "Failed to execute command: No statement ID returned", - { - "operation-id": None, - "diagnostic-info": None, - }, - ) - - command_id = CommandId.from_sea_statement_id(statement_id) - - # Store the command ID in the cursor - cursor.active_command_id = command_id - - # If async operation, return None and let the client poll for results - if async_op: - return None - - # For synchronous operation, wait for the statement to complete - # Poll until the statement is done - status = response.status - state = status.state - - # Keep polling until we reach a terminal state - while state in [CommandState.PENDING, CommandState.RUNNING]: - time.sleep(0.5) # add a small delay to avoid excessive API calls - state = self.get_query_state(command_id) - - if state != CommandState.SUCCEEDED: - raise ServerOperationError( - f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", - { - "operation-id": command_id.to_sea_statement_id(), - "diagnostic-info": None, - }, - ) - - return self.get_execution_result(command_id, cursor) def cancel_command(self, command_id: CommandId) -> None: - """ - Cancel a running command. - - Args: - command_id: Command identifier to cancel - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CancelStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="POST", - path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "cancel_command is not yet implemented for SEA backend" ) def close_command(self, command_id: CommandId) -> None: - """ - Close a command and release resources. - - Args: - command_id: Command identifier to close - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CloseStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="DELETE", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "close_command is not yet implemented for SEA backend" ) def get_query_state(self, command_id: CommandId) -> CommandState: - """ - Get the state of a running query. - - Args: - command_id: Command identifier - - Returns: - CommandState: The current state of the command - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = GetStatementRequest(statement_id=sea_statement_id) - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "get_query_state is not yet implemented for SEA backend" ) - # Parse the response - response = GetStatementResponse.from_dict(response_data) - return response.status.state - def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ) -> "ResultSet": - """ - Get the result of a command execution. - - Args: - command_id: Command identifier - cursor: Cursor executing the command - - Returns: - ResultSet: A SeaResultSet instance with the execution results - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - # Create the request model - request = GetStatementRequest(statement_id=sea_statement_id) - - # Get the statement result - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), - ) - - # Create and return a SeaResultSet - from databricks.sql.result_set import SeaResultSet - - return SeaResultSet( - connection=cursor.connection, - sea_response=response_data, - sea_client=self, - buffer_size_bytes=cursor.buffer_size_bytes, - arraysize=cursor.arraysize, + ): + """Not implemented yet.""" + raise NotImplementedError( + "get_execution_result is not yet implemented for SEA backend" ) # == Metadata Operations == @@ -514,22 +319,9 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> "ResultSet": - """Get available catalogs by executing 'SHOW CATALOGS'.""" - result = self.execute_command( - operation="SHOW CATALOGS", - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") def get_schemas( self, @@ -539,30 +331,9 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> "ResultSet": - """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_schemas") - - operation = f"SHOW SCHEMAS IN `{catalog_name}`" - - if schema_name: - operation += f" LIKE '{schema_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_schemas is not yet implemented for SEA backend") def get_tables( self, @@ -574,43 +345,9 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> "ResultSet": - """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_tables") - - operation = "SHOW TABLES IN " + ( - "ALL CATALOGS" - if catalog_name in [None, "*", "%"] - else f"CATALOG `{catalog_name}`" - ) - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" LIKE '{table_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - - # Apply client-side filtering by table_types if specified - from databricks.sql.backend.filters import ResultSetFilter - - result = ResultSetFilter.filter_tables_by_type(result, table_types) - - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_tables is not yet implemented for SEA backend") def get_columns( self, @@ -622,33 +359,6 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> "ResultSet": - """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_columns") - - operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" TABLE LIKE '{table_name}'" - - if column_name: - operation += f" LIKE '{column_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_columns is not yet implemented for SEA backend") diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index b7c8bd399..c9310d367 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,49 +4,19 @@ This package contains data models for SEA API requests and responses. """ -from databricks.sql.backend.sea.models.base import ( - ServiceError, - StatementStatus, - ExternalLink, - ResultData, - ColumnInfo, - ResultManifest, -) - from databricks.sql.backend.sea.models.requests import ( - StatementParameter, - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) __all__ = [ - # Base models - "ServiceError", - "StatementStatus", - "ExternalLink", - "ResultData", - "ColumnInfo", - "ResultManifest", # Request models - "StatementParameter", - "ExecuteStatementRequest", - "GetStatementRequest", - "CancelStatementRequest", - "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models - "ExecuteStatementResponse", - "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py deleted file mode 100644 index 671f7be13..000000000 --- a/src/databricks/sql/backend/sea/models/base.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -Base models for the SEA (Statement Execution API) backend. - -These models define the common structures used in SEA API requests and responses. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - -from databricks.sql.backend.types import CommandState - - -@dataclass -class ServiceError: - """Error information returned by the SEA API.""" - - message: str - error_code: Optional[str] = None - - -@dataclass -class StatementStatus: - """Status information for a statement execution.""" - - state: CommandState - error: Optional[ServiceError] = None - sql_state: Optional[str] = None - - -@dataclass -class ExternalLink: - """External link information for result data.""" - - external_link: str - expiration: str - chunk_index: int - - -@dataclass -class ResultData: - """Result data from a statement execution.""" - - data: Optional[List[List[Any]]] = None - external_links: Optional[List[ExternalLink]] = None - - -@dataclass -class ColumnInfo: - """Information about a column in the result set.""" - - name: str - type_name: str - type_text: str - nullable: bool = True - precision: Optional[int] = None - scale: Optional[int] = None - ordinal_position: Optional[int] = None - - -@dataclass -class ResultManifest: - """Manifest information for a result set.""" - - schema: List[ColumnInfo] - total_row_count: int - total_byte_count: int - truncated: bool = False - chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 1c519d931..7966cb502 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,111 +1,5 @@ -""" -Request models for the SEA (Statement Execution API) backend. - -These models define the structures used in SEA API requests. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - - -@dataclass -class StatementParameter: - """Parameter for a SQL statement.""" - - name: str - value: Optional[str] = None - type: Optional[str] = None - - -@dataclass -class ExecuteStatementRequest: - """Request to execute a SQL statement.""" - - warehouse_id: str - statement: str - session_id: str - disposition: str = "EXTERNAL_LINKS" - format: str = "JSON_ARRAY" - wait_timeout: str = "10s" - on_wait_timeout: str = "CONTINUE" - row_limit: Optional[int] = None - byte_limit: Optional[int] = None - parameters: Optional[List[StatementParameter]] = None - catalog: Optional[str] = None - schema: Optional[str] = None - result_compression: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - result: Dict[str, Any] = { - "warehouse_id": self.warehouse_id, - "session_id": self.session_id, - "statement": self.statement, - "disposition": self.disposition, - "format": self.format, - "wait_timeout": self.wait_timeout, - "on_wait_timeout": self.on_wait_timeout, - } - - if self.row_limit is not None and self.row_limit > 0: - result["row_limit"] = self.row_limit - - if self.byte_limit is not None and self.byte_limit > 0: - result["byte_limit"] = self.byte_limit - - if self.catalog: - result["catalog"] = self.catalog - - if self.schema: - result["schema"] = self.schema - - if self.result_compression: - result["result_compression"] = self.result_compression - - if self.parameters: - result["parameters"] = [ - { - "name": param.name, - **({"value": param.value} if param.value is not None else {}), - **({"type": param.type} if param.type is not None else {}), - } - for param in self.parameters - ] - - return result - - -@dataclass -class GetStatementRequest: - """Request to get information about a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} - - -@dataclass -class CancelStatementRequest: - """Request to cancel a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} - - -@dataclass -class CloseStatementRequest: - """Request to close a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} +from typing import Dict, Any, Optional +from dataclasses import dataclass @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index d70459b9f..1bb54590f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,96 +1,5 @@ -""" -Response models for the SEA (Statement Execution API) backend. - -These models define the structures used in SEA API responses. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - -from databricks.sql.backend.types import CommandState -from databricks.sql.backend.sea.models.base import ( - StatementStatus, - ResultManifest, - ResultData, - ServiceError, -) - - -@dataclass -class ExecuteStatementResponse: - """Response from executing a SQL statement.""" - - statement_id: str - status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": - """Create an ExecuteStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - return cls( - statement_id=data.get("statement_id", ""), - status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed - ) - - -@dataclass -class GetStatementResponse: - """Response from getting information about a statement.""" - - statement_id: str - status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": - """Create a GetStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - return cls( - statement_id=data.get("statement_id", ""), - status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed - ) +from typing import Dict, Any +from dataclasses import dataclass @dataclass diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 5bf02e0ea..9cd21b5e6 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,6 +1,5 @@ -from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Any, Tuple +from typing import Dict, Optional, Any import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -81,28 +80,6 @@ def from_thrift_state( else: return None - @classmethod - def from_sea_state(cls, state: str) -> Optional["CommandState"]: - """ - Map SEA state string to CommandState enum. - - Args: - state: SEA state string - - Returns: - CommandState: The corresponding CommandState enum value - """ - state_mapping = { - "PENDING": cls.PENDING, - "RUNNING": cls.RUNNING, - "SUCCEEDED": cls.SUCCEEDED, - "FAILED": cls.FAILED, - "CLOSED": cls.CLOSED, - "CANCELED": cls.CANCELLED, - } - - return state_mapping.get(state, None) - class BackendType(Enum): """ @@ -308,6 +285,28 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count + def __str__(self) -> str: + """ + Return a string representation of the CommandId. + + For SEA backend, returns the guid. + For Thrift backend, returns a format like "guid|secret". + + Returns: + A string representation of the command ID + """ + + if self.backend_type == BackendType.SEA: + return str(self.guid) + elif self.backend_type == BackendType.THRIFT: + secret_hex = ( + guid_to_hex_id(self.secret) + if isinstance(self.secret, bytes) + else str(self.secret) + ) + return f"{self.to_hex_guid()}|{secret_hex}" + return str(self.guid) + @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -319,6 +318,7 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ + if operation_handle is None: return None @@ -394,19 +394,3 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) - - -@dataclass -class ExecuteResponse: - """Response from executing a SQL command.""" - - command_id: CommandId - status: CommandState - description: Optional[ - List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] - ] = None - has_more_rows: bool = False - results_queue: Optional[Any] = None - has_been_closed_server_side: bool = False - lz4_compressed: bool = True - is_staging_operation: bool = False From ee9fa1c972bad75557ac0671d5eef96c0a0cff21 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:03:59 +0000 Subject: [PATCH 21/33] remove ResultSetFilter functionality Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 143 --------------- tests/unit/test_result_set_filter.py | 246 -------------------------- 2 files changed, 389 deletions(-) delete mode 100644 src/databricks/sql/backend/filters.py delete mode 100644 tests/unit/test_result_set_filter.py diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py deleted file mode 100644 index 7f48b6179..000000000 --- a/src/databricks/sql/backend/filters.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Client-side filtering utilities for Databricks SQL connector. - -This module provides filtering capabilities for result sets returned by different backends. -""" - -import logging -from typing import ( - List, - Optional, - Any, - Callable, - TYPE_CHECKING, -) - -if TYPE_CHECKING: - from databricks.sql.result_set import ResultSet - -from databricks.sql.result_set import SeaResultSet - -logger = logging.getLogger(__name__) - - -class ResultSetFilter: - """ - A general-purpose filter for result sets that can be applied to any backend. - - This class provides methods to filter result sets based on various criteria, - similar to the client-side filtering in the JDBC connector. - """ - - @staticmethod - def _filter_sea_result_set( - result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] - ) -> "SeaResultSet": - """ - Filter a SEA result set using the provided filter function. - - Args: - result_set: The SEA result set to filter - filter_func: Function that takes a row and returns True if the row should be included - - Returns: - A filtered SEA result set - """ - # Create a filtered version of the result set - filtered_response = result_set._response.copy() - - # If there's a result with rows, filter them - if ( - "result" in filtered_response - and "data_array" in filtered_response["result"] - ): - rows = filtered_response["result"]["data_array"] - filtered_rows = [row for row in rows if filter_func(row)] - filtered_response["result"]["data_array"] = filtered_rows - - # Update row count if present - if "row_count" in filtered_response["result"]: - filtered_response["result"]["row_count"] = len(filtered_rows) - - # Create a new result set with the filtered data - return SeaResultSet( - connection=result_set.connection, - sea_response=filtered_response, - sea_client=result_set.backend, - buffer_size_bytes=result_set.buffer_size_bytes, - arraysize=result_set.arraysize, - ) - - @staticmethod - def filter_by_column_values( - result_set: "ResultSet", - column_index: int, - allowed_values: List[str], - case_sensitive: bool = False, - ) -> "ResultSet": - """ - Filter a result set by values in a specific column. - - Args: - result_set: The result set to filter - column_index: The index of the column to filter on - allowed_values: List of allowed values for the column - case_sensitive: Whether to perform case-sensitive comparison - - Returns: - A filtered result set - """ - # Convert to uppercase for case-insensitive comparison if needed - if not case_sensitive: - allowed_values = [v.upper() for v in allowed_values] - - # Determine the type of result set and apply appropriate filtering - if isinstance(result_set, SeaResultSet): - return ResultSetFilter._filter_sea_result_set( - result_set, - lambda row: ( - len(row) > column_index - and isinstance(row[column_index], str) - and ( - row[column_index].upper() - if not case_sensitive - else row[column_index] - ) - in allowed_values - ), - ) - - # For other result set types, return the original (should be handled by specific implementations) - logger.warning( - f"Filtering not implemented for result set type: {type(result_set).__name__}" - ) - return result_set - - @staticmethod - def filter_tables_by_type( - result_set: "ResultSet", table_types: Optional[List[str]] = None - ) -> "ResultSet": - """ - Filter a result set of tables by the specified table types. - - This is a client-side filter that processes the result set after it has been - retrieved from the server. It filters out tables whose type does not match - any of the types in the table_types list. - - Args: - result_set: The original result set containing tables - table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) - - Returns: - A filtered result set containing only tables of the specified types - """ - # Default table types if none specified - DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] - valid_types = ( - table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES - ) - - # Table type is typically in the 6th column (index 5) - return ResultSetFilter.filter_by_column_values( - result_set, 5, valid_types, case_sensitive=False - ) diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py deleted file mode 100644 index e8eb2a757..000000000 --- a/tests/unit/test_result_set_filter.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -Tests for the ResultSetFilter class. - -This module contains tests for the ResultSetFilter class, which provides -filtering capabilities for result sets returned by different backends. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.backend.filters import ResultSetFilter -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestResultSetFilter: - """Test suite for the ResultSetFilter class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def sea_response_with_tables(self): - """Create a sample SEA response with table data based on the server schema.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 7, - "columns": [ - {"name": "namespace", "type_text": "STRING", "position": 0}, - {"name": "tableName", "type_text": "STRING", "position": 1}, - {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, - {"name": "information", "type_text": "STRING", "position": 3}, - {"name": "catalogName", "type_text": "STRING", "position": 4}, - {"name": "tableType", "type_text": "STRING", "position": 5}, - {"name": "remarks", "type_text": "STRING", "position": 6}, - ], - }, - "total_row_count": 4, - }, - "result": { - "data_array": [ - ["schema1", "table1", False, None, "catalog1", "TABLE", None], - ["schema1", "table2", False, None, "catalog1", "VIEW", None], - [ - "schema1", - "table3", - False, - None, - "catalog1", - "SYSTEM TABLE", - None, - ], - ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], - ] - }, - } - - @pytest.fixture - def sea_result_set_with_tables( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Create a SeaResultSet with table data.""" - # Create a deep copy of the response to avoid test interference - import copy - - sea_response_copy = copy.deepcopy(sea_response_with_tables) - - return SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy, - buffer_size_bytes=1000, - arraysize=100, - ) - - def test_filter_tables_by_type_default(self, sea_result_set_with_tables): - """Test filtering tables by type with default types.""" - # Default types are TABLE, VIEW, SYSTEM TABLE - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables - ) - - # Verify that only the default types are included - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): - """Test filtering tables by type with custom types.""" - # Filter for only TABLE and EXTERNAL - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] - ) - - # Verify that only the specified types are included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "EXTERNAL" in table_types - assert "VIEW" not in table_types - assert "SYSTEM TABLE" not in table_types - - def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): - """Test that table type filtering is case-insensitive.""" - # Filter for lowercase "table" and "view" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["table", "view"] - ) - - # Verify that the matching types are included despite case differences - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" not in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): - """Test filtering tables with an empty type list (should use defaults).""" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=[] - ) - - # Verify that default types are used - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_by_column_values(self, sea_result_set_with_tables): - """Test filtering by values in a specific column.""" - # Filter by namespace in column index 0 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] - ) - - # All rows have schema1 in namespace, so all should be included - assert len(filtered_result._response["result"]["data_array"]) == 4 - - # Filter by table name in column index 1 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, - column_index=1, - allowed_values=["table1", "table3"], - ) - - # Only rows with table1 or table3 should be included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_names = [ - row[1] for row in filtered_result._response["result"]["data_array"] - ] - assert "table1" in table_names - assert "table3" in table_names - assert "table2" not in table_names - assert "table4" not in table_names - - def test_filter_by_column_values_case_sensitive( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Test case-sensitive filtering by column values.""" - import copy - - # Create a fresh result set for the first test - sea_response_copy1 = copy.deepcopy(sea_response_with_tables) - result_set1 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy1, - buffer_size_bytes=1000, - arraysize=100, - ) - - # First test: Case-sensitive filtering with lowercase values (should find no matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set1, - column_index=5, # tableType column - allowed_values=["table", "view"], # lowercase - case_sensitive=True, - ) - - # Verify no matches with lowercase values - assert len(filtered_result._response["result"]["data_array"]) == 0 - - # Create a fresh result set for the second test - sea_response_copy2 = copy.deepcopy(sea_response_with_tables) - result_set2 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy2, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Second test: Case-sensitive filtering with correct case (should find matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set2, - column_index=5, # tableType column - allowed_values=["TABLE", "VIEW"], # correct case - case_sensitive=True, - ) - - # Verify matches with correct case - assert len(filtered_result._response["result"]["data_array"]) == 2 - - # Extract the table types from the filtered results - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - - def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): - """Test filtering with a column index that's out of bounds.""" - # Filter by column index 10 (out of bounds) - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=10, allowed_values=["value"] - ) - - # No rows should match - assert len(filtered_result._response["result"]["data_array"]) == 0 From 24c6152e9c2c003aa3074057c3d7d6e98d8d1916 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:06:23 +0000 Subject: [PATCH 22/33] remove more irrelevant changes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 39 +- tests/unit/test_sea_backend.py | 755 ++++------------------------ 2 files changed, 132 insertions(+), 662 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 9cd21b5e6..3d24a09a1 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from enum import Enum -from typing import Dict, Optional, Any +from typing import Dict, List, Optional, Any, Tuple import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -80,6 +81,26 @@ def from_thrift_state( else: return None + @classmethod + def from_sea_state(cls, state: str) -> Optional["CommandState"]: + """ + Map SEA state string to CommandState enum. + Args: + state: SEA state string + Returns: + CommandState: The corresponding CommandState enum value + """ + state_mapping = { + "PENDING": cls.PENDING, + "RUNNING": cls.RUNNING, + "SUCCEEDED": cls.SUCCEEDED, + "FAILED": cls.FAILED, + "CLOSED": cls.CLOSED, + "CANCELED": cls.CANCELLED, + } + + return state_mapping.get(state, None) + class BackendType(Enum): """ @@ -394,3 +415,19 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) + + +@dataclass +class ExecuteResponse: + """Response from executing a SQL command.""" + + command_id: CommandId + status: CommandState + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None + has_been_closed_server_side: bool = False + lz4_compressed: bool = True + is_staging_operation: bool = False diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index c2078f731..bc2688a68 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,20 +1,11 @@ -""" -Tests for the SEA (Statement Execution API) backend implementation. - -This module contains tests for the SeaDatabricksClient class, which implements -the Databricks SQL connector's SEA backend functionality. -""" - -import json import pytest -from unittest.mock import patch, MagicMock, Mock +from unittest.mock import patch, MagicMock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error, NotSupportedError +from databricks.sql.exc import Error class TestSeaBackend: @@ -50,23 +41,6 @@ def sea_client(self, mock_http_client): return client - @pytest.fixture - def sea_session_id(self): - """Create a SEA session ID.""" - return SessionId.from_sea_session_id("test-session-123") - - @pytest.fixture - def sea_command_id(self): - """Create a SEA command ID.""" - return CommandId.from_sea_statement_id("test-statement-123") - - @pytest.fixture - def mock_cursor(self): - """Create a mock cursor.""" - cursor = Mock() - cursor.active_command_id = None - return cursor - def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -201,650 +175,109 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - # Tests for command execution and management - - def test_execute_command_sync( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command synchronously.""" - # Set up mock responses - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "schema": [ - { - "name": "col1", - "type_name": "STRING", - "type_text": "string", - "nullable": True, - } - ], - "total_row_count": 1, - "total_byte_count": 100, - }, - "result": {"data": [["value1"]]}, - } - mock_http_client._make_request.return_value = execute_response - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "warehouse_id" in kwargs["data"] - assert "session_id" in kwargs["data"] - assert "statement" in kwargs["data"] - assert kwargs["data"]["statement"] == "SELECT 1" - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-123" - - def test_execute_command_async( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command asynchronously.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-456", - "status": {"state": "PENDING"}, - } - mock_http_client._make_request.return_value = execute_response - - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=True, # Async mode - enforce_embedded_schema_correctness=False, + def test_session_configuration_helpers(self): + """Test the session configuration helper methods.""" + # Test getting default value for a supported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "ANSI_MODE" ) + assert default_value == "true" - # Verify the result is None for async operation - assert result is None - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "wait_timeout" in kwargs["data"] - assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout - - # Verify the command ID was stored in the cursor - assert hasattr(mock_cursor, "active_command_id") - assert isinstance(mock_cursor.active_command_id, CommandId) - assert mock_cursor.active_command_id.guid == "test-statement-456" - - def test_execute_command_with_polling( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that requires polling.""" - # Set up mock responses for initial request and polling - initial_response = { - "statement_id": "test-statement-789", - "status": {"state": "RUNNING"}, - } - poll_response = { - "statement_id": "test-statement-789", - "status": {"state": "SUCCEEDED"}, - "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, - "result": {"data": []}, - } - - # Configure mock to return different responses on subsequent calls - mock_http_client._make_request.side_effect = [initial_response, poll_response] - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method - result = sea_client.execute_command( - operation="SELECT * FROM large_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP requests (initial and poll) - assert mock_http_client._make_request.call_count == 2 - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-789" - - def test_execute_command_with_parameters( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command with parameters.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, + # Test getting default value for an unsupported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "UNSUPPORTED_PARAM" + ) + assert default_value is None + + # Test getting the list of allowed configurations + allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() + + expected_keys = { + "ANSI_MODE", + "ENABLE_PHOTON", + "LEGACY_TIME_PARSER_POLICY", + "MAX_FILE_PARTITION_BYTES", + "READ_ONLY_EXTERNAL_METASTORE", + "STATEMENT_TIMEOUT", + "TIMEZONE", + "USE_CACHED_RESULT", } - mock_http_client._make_request.return_value = execute_response + assert set(allowed_configs) == expected_keys - # Create parameter mock - param = MagicMock() - param.name = "param1" - param.value = "value1" - param.type = "STRING" + def test_unimplemented_methods(self, sea_client): + """Test that unimplemented methods raise NotImplementedError.""" + # Create dummy parameters for testing + session_id = SessionId.from_sea_session_id("test-session") + command_id = MagicMock() + cursor = MagicMock() - # Mock the get_execution_result method - with patch.object(sea_client, "get_execution_result") as mock_get_result: - # Call the method with parameters + # Test execute_command + with pytest.raises(NotImplementedError) as excinfo: sea_client.execute_command( - operation="SELECT * FROM table WHERE col = :param1", - session_id=sea_session_id, + operation="SELECT 1", + session_id=session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=mock_cursor, + cursor=cursor, use_cloud_fetch=False, - parameters=[param], + parameters=[], async_op=False, enforce_embedded_schema_correctness=False, ) - - # Verify the HTTP request contains parameters - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert "parameters" in kwargs["data"] - assert len(kwargs["data"]["parameters"]) == 1 - assert kwargs["data"]["parameters"][0]["name"] == "param1" - assert kwargs["data"]["parameters"][0]["value"] == "value1" - assert kwargs["data"]["parameters"][0]["type"] == "STRING" - - def test_execute_command_failure( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that fails.""" - # Set up mock response for a failed execution - error_response = { - "statement_id": "test-statement-123", - "status": { - "state": "FAILED", - "error": { - "message": "Syntax error in SQL", - "error_code": "SYNTAX_ERROR", - }, - }, - } - - # Configure the mock to return the error response for the initial request - # and then raise an exception when trying to poll (to simulate immediate failure) - mock_http_client._make_request.side_effect = [ - error_response, # Initial response - Error( - "Statement execution did not succeed: Syntax error in SQL" - ), # Will be raised during polling - ] - - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method and expect an error - with pytest.raises(Error) as excinfo: - sea_client.execute_command( - operation="SELECT * FROM nonexistent_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - assert "Statement execution did not succeed" in str(excinfo.value) - - def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): - """Test canceling a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.cancel_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_close_command(self, sea_client, mock_http_client, sea_command_id): - """Test closing a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.close_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "DELETE" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): - """Test getting the state of a query.""" - # Set up mock response - mock_http_client._make_request.return_value = { - "statement_id": "test-statement-123", - "status": {"state": "RUNNING"}, - } - - # Call the method - state = sea_client.get_query_state(sea_command_id) - - # Verify the result - assert state == CommandState.RUNNING - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_execution_result( - self, sea_client, mock_http_client, mock_cursor, sea_command_id - ): - """Test getting the result of a command execution.""" - # Set up mock response - sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - mock_http_client._make_request.return_value = sea_response - - # Create a real result set to verify the implementation - result = sea_client.get_execution_result(sea_command_id, mock_cursor) - print(result) - - # Verify basic properties of the result - assert result.statement_id == "test-statement-123" - assert result.status == CommandState.SUCCEEDED - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" + assert "execute_command is not yet implemented" in str(excinfo.value) + + # Test cancel_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.cancel_command(command_id) + assert "cancel_command is not yet implemented" in str(excinfo.value) + + # Test close_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.close_command(command_id) + assert "close_command is not yet implemented" in str(excinfo.value) + + # Test get_query_state + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_query_state(command_id) + assert "get_query_state is not yet implemented" in str(excinfo.value) + + # Test get_execution_result + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_execution_result(command_id, cursor) + assert "get_execution_result is not yet implemented" in str(excinfo.value) + + # Test metadata operations + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_catalogs(session_id, 100, 1000, cursor) + assert "get_catalogs is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_schemas(session_id, 100, 1000, cursor) + assert "get_schemas is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_tables(session_id, 100, 1000, cursor) + assert "get_tables is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_columns(session_id, 100, 1000, cursor) + assert "get_columns is not yet implemented" in str(excinfo.value) + + def test_max_download_threads_property(self, sea_client): + """Test the max_download_threads property.""" + assert sea_client.max_download_threads == 10 + + # Create a client with a custom value + custom_client = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=20, ) - # Tests for metadata operations - - def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): - """Test getting catalogs.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method - result = sea_client.get_catalogs( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["lz4_compression"] is False - assert kwargs["use_cloud_fetch"] is False - assert kwargs["parameters"] == [] - assert kwargs["async_op"] is False - assert kwargs["enforce_embedded_schema_correctness"] is False - - def test_get_schemas_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_with_catalog_and_schema_pattern( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with catalog and schema pattern.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog and schema pattern - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting schemas without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - ) - - assert "Catalog name is required for get_schemas" in str(excinfo.value) - - def test_get_tables_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_all_catalogs( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables from all catalogs using wildcard.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with wildcard catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="*", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_schema_and_table_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with schema and table patterns.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog, schema, and table patterns - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting tables without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_tables" in str(excinfo.value) - - def test_get_columns_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_with_all_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with all patterns specified.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with all patterns - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - column_name="test_column", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting columns without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_columns" in str(excinfo.value) + # Verify the custom value is returned + assert custom_client.max_download_threads == 20 From 67fd1012f9496724aa05183f82d9c92f0c40f1ed Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:10:48 +0000 Subject: [PATCH 23/33] remove more irrelevant changes Signed-off-by: varun-edachali-dbx --- .../sql/backend/databricks_client.py | 2 - src/databricks/sql/backend/thrift_backend.py | 3 +- src/databricks/sql/result_set.py | 91 +++++++++---------- 3 files changed, 44 insertions(+), 52 deletions(-) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 20b059fa7..8fda71e1e 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -16,8 +16,6 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.types import SessionId, CommandId, CommandState -from databricks.sql.utils import ExecuteResponse -from databricks.sql.types import SSLOptions # Forward reference for type hints from typing import TYPE_CHECKING diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 810c2e7a1..f41f4b6d8 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -352,6 +352,7 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ + # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -1241,7 +1242,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(command_id.guid)) + logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a4beda629..dd61408db 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -64,7 +64,7 @@ def __init__( """ self.connection = connection - self.backend = backend # Store the backend client directly + self.backend = backend self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -115,12 +115,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> Any: + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> Any: + def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all remaining rows as an Arrow table.""" pass @@ -207,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self._has_more_rows = has_more_rows + self.has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -291,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -402,6 +402,33 @@ def fetchmany(self, size: int) -> List[Row]: @staticmethod def _get_schema_description(table_schema_message): + """ + Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + """ + + def map_col_type(type_): + if type_.startswith("decimal"): + return "decimal" + else: + return type_ + + return [ + (column.name, map_col_type(column.datatype), None, None, None, None, None) + for column in table_schema_message.columns + ] + + +class SeaResultSet(ResultSet): + """ResultSet implementation for the SEA backend.""" + + def __init__( + self, + connection: "Connection", + execute_response: "ExecuteResponse", + sea_client: "SeaDatabricksClient", + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + ): """ Initialize a SeaResultSet with the response from a SEA query execution. @@ -413,53 +440,19 @@ def _get_schema_description(table_schema_message): execute_response: Response from the execute command (new style) sea_response: Direct SEA response (legacy style) """ - # Handle both initialization styles - if execute_response is not None: - # New style with ExecuteResponse - command_id = execute_response.command_id - status = execute_response.status - has_been_closed_server_side = execute_response.has_been_closed_server_side - has_more_rows = execute_response.has_more_rows - results_queue = execute_response.results_queue - description = execute_response.description - is_staging_operation = execute_response.is_staging_operation - self._response = getattr(execute_response, "sea_response", {}) - self.statement_id = command_id.to_sea_statement_id() if command_id else None - elif sea_response is not None: - # Legacy style with direct sea_response - self._response = sea_response - # Extract values from sea_response - command_id = CommandId.from_sea_statement_id( - sea_response.get("statement_id", "") - ) - self.statement_id = sea_response.get("statement_id", "") - - # Extract status - status_data = sea_response.get("status", {}) - status = CommandState.from_sea_state(status_data.get("state", "PENDING")) - - # Set defaults for other fields - has_been_closed_server_side = False - has_more_rows = False - results_queue = None - description = None - is_staging_operation = False - else: - raise ValueError("Either execute_response or sea_response must be provided") - # Call parent constructor with common attributes super().__init__( connection=connection, backend=sea_client, arraysize=arraysize, buffer_size_bytes=buffer_size_bytes, - command_id=command_id, - status=status, - has_been_closed_server_side=has_been_closed_server_side, - has_more_rows=has_more_rows, - results_queue=results_queue, - description=description, - is_staging_operation=is_staging_operation, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, ) def _fill_results_buffer(self): From 271fcafbb04e7c5e08423b7536dac57f9595c5b6 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:12:13 +0000 Subject: [PATCH 24/33] even more irrelevant changes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 2 +- tests/unit/test_session.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index dd61408db..7ea4976c1 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -259,7 +259,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2): + def merge_columnar(self, result1, result2) -> "ColumnTable": """ Function to merge / combining the columnar results into a single result :param result1: diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index fef070362..92de8d8fd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,11 +2,6 @@ from unittest.mock import patch, Mock import gc -from databricks.sql.thrift_api.TCLIService.ttypes import ( - TOpenSessionResp, - TSessionHandle, - THandleIdentifier, -) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql From bf26ea3e4dae441d0e82d1f55c3da36ee2282568 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:19:46 +0000 Subject: [PATCH 25/33] remove sea response as init option Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_result_set.py | 103 ++++-------------------------- 1 file changed, 14 insertions(+), 89 deletions(-) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index f666fd613..02421a915 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -27,38 +27,6 @@ def mock_sea_client(self): """Create a mock SEA client.""" return Mock() - @pytest.fixture - def sea_response(self): - """Create a sample SEA response.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - @pytest.fixture def execute_response(self): """Create a sample execute response.""" @@ -72,78 +40,35 @@ def execute_response(self): ("test_value", "INT", None, None, None, None, None) ] mock_response.is_staging_operation = False - mock_response.sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "result": {"data_array": [["1"]]}, - } return mock_response - def test_init_with_sea_response( - self, mock_connection, mock_sea_client, sea_response - ): - """Test initializing SeaResultSet with a SEA response.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.statement_id == "test-statement-123" - assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set._response == sea_response - def test_init_with_execute_response( self, mock_connection, mock_sea_client, execute_response ): """Test initializing SeaResultSet with an execute response.""" result_set = SeaResultSet( connection=mock_connection, - sea_client=mock_sea_client, execute_response=execute_response, + sea_client=mock_sea_client, buffer_size_bytes=1000, arraysize=100, ) # Verify basic properties - assert result_set.statement_id == "test-statement-123" + assert result_set.command_id == execute_response.command_id assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA assert result_set.connection == mock_connection assert result_set.backend == mock_sea_client assert result_set.buffer_size_bytes == 1000 assert result_set.arraysize == 100 - assert result_set._response == execute_response.sea_response + assert result_set.description == execute_response.description - def test_init_with_no_response(self, mock_connection, mock_sea_client): - """Test that initialization fails when neither response type is provided.""" - with pytest.raises(ValueError) as excinfo: - SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - assert "Either execute_response or sea_response must be provided" in str( - excinfo.value - ) - - def test_close(self, mock_connection, mock_sea_client, sea_response): + def test_close(self, mock_connection, mock_sea_client, execute_response): """Test closing a result set.""" result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -157,13 +82,13 @@ def test_close(self, mock_connection, mock_sea_client, sea_response): assert result_set.status == CommandState.CLOSED def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, sea_response + self, mock_connection, mock_sea_client, execute_response ): """Test closing a result set that has already been closed server-side.""" result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -178,14 +103,14 @@ def test_close_when_already_closed_server_side( assert result_set.status == CommandState.CLOSED def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, sea_response + self, mock_connection, mock_sea_client, execute_response ): """Test closing a result set when the connection is closed.""" mock_connection.open = False result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -199,13 +124,13 @@ def test_close_when_connection_closed( assert result_set.status == CommandState.CLOSED def test_unimplemented_methods( - self, mock_connection, mock_sea_client, sea_response + self, mock_connection, mock_sea_client, execute_response ): """Test that unimplemented methods raise NotImplementedError.""" result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -258,13 +183,13 @@ def test_unimplemented_methods( pass def test_fill_results_buffer_not_implemented( - self, mock_connection, mock_sea_client, sea_response + self, mock_connection, mock_sea_client, execute_response ): """Test that _fill_results_buffer raises NotImplementedError.""" result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -272,4 +197,4 @@ def test_fill_results_buffer_not_implemented( with pytest.raises( NotImplementedError, match="fetchone is not implemented for SEA backend" ): - result_set._fill_results_buffer() + result_set._fill_results_buffer() \ No newline at end of file From ed7cf9138e937774546fa0f3e793a6eb8768060a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 10:06:36 +0000 Subject: [PATCH 26/33] exec test example scripts Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 147 ++++++++++------ examples/experimental/tests/__init__.py | 1 + .../tests/test_sea_async_query.py | 165 ++++++++++++++++++ .../experimental/tests/test_sea_metadata.py | 91 ++++++++++ .../experimental/tests/test_sea_session.py | 70 ++++++++ .../experimental/tests/test_sea_sync_query.py | 143 +++++++++++++++ 6 files changed, 566 insertions(+), 51 deletions(-) create mode 100644 examples/experimental/tests/__init__.py create mode 100644 examples/experimental/tests/test_sea_async_query.py create mode 100644 examples/experimental/tests/test_sea_metadata.py create mode 100644 examples/experimental/tests/test_sea_session.py create mode 100644 examples/experimental/tests/test_sea_sync_query.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index abe6bd1ab..33b5af334 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -1,66 +1,111 @@ +""" +Main script to run all SEA connector tests. + +This script imports and runs all the individual test modules and displays +a summary of test results with visual indicators. +""" import os import sys import logging -from databricks.sql.client import Connection +import importlib.util +from typing import Dict, Callable, List, Tuple -logging.basicConfig(level=logging.DEBUG) +# Configure logging +logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -def test_sea_session(): - """ - Test opening and closing a SEA session using the connector. +# Define test modules and their main test functions +TEST_MODULES = [ + "test_sea_session", + "test_sea_sync_query", + "test_sea_async_query", + "test_sea_metadata", +] + +def load_test_function(module_name: str) -> Callable: + """Load a test function from a module.""" + module_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "tests", + f"{module_name}.py" + ) + + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Get the main test function (assuming it starts with "test_") + for name in dir(module): + if name.startswith("test_") and callable(getattr(module, name)): + # For sync and async query modules, we want the main function that runs both tests + if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec": + return getattr(module, name) - This function connects to a Databricks SQL endpoint using the SEA backend, - opens a session, and then closes it. + # Fallback to the first test function found + for name in dir(module): + if name.startswith("test_") and callable(getattr(module, name)): + return getattr(module, name) - Required environment variables: - - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - - DATABRICKS_TOKEN: Personal access token for authentication - """ + raise ValueError(f"No test function found in module {module_name}") - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") +def run_tests() -> List[Tuple[str, bool]]: + """Run all tests and return results.""" + results = [] - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") - sys.exit(1) + for module_name in TEST_MODULES: + try: + test_func = load_test_function(module_name) + logger.info(f"\n{'=' * 50}") + logger.info(f"Running test: {module_name}") + logger.info(f"{'-' * 50}") + + success = test_func() + results.append((module_name, success)) + + status = "✅ PASSED" if success else "❌ FAILED" + logger.info(f"Test {module_name}: {status}") + + except Exception as e: + logger.error(f"Error loading or running test {module_name}: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + results.append((module_name, False)) - logger.info(f"Connecting to {server_hostname}") - logger.info(f"HTTP Path: {http_path}") - if catalog: - logger.info(f"Using catalog: {catalog}") + return results + +def print_summary(results: List[Tuple[str, bool]]) -> None: + """Print a summary of test results.""" + logger.info(f"\n{'=' * 50}") + logger.info("TEST SUMMARY") + logger.info(f"{'-' * 50}") - try: - logger.info("Creating connection with SEA backend...") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent - ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") - logger.info(f"backend type: {type(connection.session.backend)}") - - # Close the connection - logger.info("Closing the SEA session...") - connection.close() - logger.info("Successfully closed SEA session") - - except Exception as e: - logger.error(f"Error testing SEA session: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - sys.exit(1) + passed = sum(1 for _, success in results if success) + total = len(results) - logger.info("SEA session test completed successfully") + for module_name, success in results: + status = "✅ PASSED" if success else "❌ FAILED" + logger.info(f"{status} - {module_name}") + + logger.info(f"{'-' * 50}") + logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") + logger.info(f"{'=' * 50}") if __name__ == "__main__": - test_sea_session() + # Check if required environment variables are set + required_vars = ["DATABRICKS_SERVER_HOSTNAME", "DATABRICKS_HTTP_PATH", "DATABRICKS_TOKEN"] + missing_vars = [var for var in required_vars if not os.environ.get(var)] + + if missing_vars: + logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") + logger.error("Please set these variables before running the tests.") + sys.exit(1) + + # Run all tests + results = run_tests() + + # Print summary + print_summary(results) + + # Exit with appropriate status code + all_passed = all(success for _, success in results) + sys.exit(0 if all_passed else 1) \ No newline at end of file diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py new file mode 100644 index 000000000..5e1a8a58b --- /dev/null +++ b/examples/experimental/tests/__init__.py @@ -0,0 +1 @@ +# This file makes the tests directory a Python package \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py new file mode 100644 index 000000000..a4f3702f9 --- /dev/null +++ b/examples/experimental/tests/test_sea_async_query.py @@ -0,0 +1,165 @@ +""" +Test for SEA asynchronous query execution functionality. +""" +import os +import sys +import logging +import time +from databricks.sql.client import Connection +from databricks.sql.backend.types import CommandState + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_async_query_with_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query asynchronously with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info("Creating connection for asynchronous query execution with cloud fetch enabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query asynchronously + cursor = connection.cursor() + logger.info("Executing asynchronous query with cloud fetch: SELECT 1 as test_value") + cursor.execute_async("SELECT 1 as test_value") + logger.info("Asynchronous query submitted successfully with cloud fetch enabled") + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + logger.info("Successfully retrieved asynchronous query results with cloud fetch enabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_without_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query asynchronously with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info("Creating connection for asynchronous query execution with cloud fetch disabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query asynchronously + cursor = connection.cursor() + logger.info("Executing asynchronous query without cloud fetch: SELECT 1 as test_value") + cursor.execute_async("SELECT 1 as test_value") + logger.info("Asynchronous query submitted successfully with cloud fetch disabled") + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + logger.info("Successfully retrieved asynchronous query results with cloud fetch disabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_exec(): + """ + Run both asynchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_async_query_with_cloud_fetch() + logger.info(f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}") + + without_cloud_fetch_success = test_sea_async_query_without_cloud_fetch() + logger.info(f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}") + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_async_query_exec() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py new file mode 100644 index 000000000..ba760b61a --- /dev/null +++ b/examples/experimental/tests/test_sea_metadata.py @@ -0,0 +1,91 @@ +""" +Test for SEA metadata functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_metadata(): + """ + Test metadata operations using the SEA backend. + + This function connects to a Databricks SQL endpoint using the SEA backend, + and executes metadata operations like catalogs(), schemas(), tables(), and columns(). + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + if not catalog: + logger.error("DATABRICKS_CATALOG environment variable is required for metadata tests.") + return False + + try: + # Create connection + logger.info("Creating connection for metadata operations") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Test catalogs + cursor = connection.cursor() + logger.info("Fetching catalogs...") + cursor.catalogs() + logger.info("Successfully fetched catalogs") + + # Test schemas + logger.info(f"Fetching schemas for catalog '{catalog}'...") + cursor.schemas(catalog_name=catalog) + logger.info("Successfully fetched schemas") + + # Test tables + logger.info(f"Fetching tables for catalog '{catalog}', schema 'default'...") + cursor.tables(catalog_name=catalog, schema_name="default") + logger.info("Successfully fetched tables") + + # Test columns for a specific table + # Using a common table that should exist in most environments + logger.info(f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'...") + cursor.columns(catalog_name=catalog, schema_name="default", table_name="information_schema") + logger.info("Successfully fetched columns") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA metadata test: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_metadata() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_session.py b/examples/experimental/tests/test_sea_session.py new file mode 100644 index 000000000..c0f6817da --- /dev/null +++ b/examples/experimental/tests/test_sea_session.py @@ -0,0 +1,70 @@ +""" +Test for SEA session management functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_session(): + """ + Test opening and closing a SEA session using the connector. + + This function connects to a Databricks SQL endpoint using the SEA backend, + opens a session, and then closes it. + + Required environment variables: + - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname + - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint + - DATABRICKS_TOKEN: Personal access token for authentication + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + logger.info("Creating connection with SEA backend...") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"Backend type: {type(connection.session.backend)}") + + # Close the connection + logger.info("Closing the SEA session...") + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error testing SEA session: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_session() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py new file mode 100644 index 000000000..4879e587a --- /dev/null +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -0,0 +1,143 @@ +""" +Test for SEA synchronous query execution functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_sync_query_with_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info("Creating connection for synchronous query execution with cloud fetch enabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query + cursor = connection.cursor() + logger.info("Executing synchronous query with cloud fetch: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query executed successfully with cloud fetch enabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_without_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info("Creating connection for synchronous query execution with cloud fetch disabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query + cursor = connection.cursor() + logger.info("Executing synchronous query without cloud fetch: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query executed successfully with cloud fetch disabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_exec(): + """ + Run both synchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_sync_query_with_cloud_fetch() + logger.info(f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}") + + without_cloud_fetch_success = test_sea_sync_query_without_cloud_fetch() + logger.info(f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}") + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_sync_query_exec() + sys.exit(0 if success else 1) \ No newline at end of file From dae15e37b6161740481084c405aeff84278c73cd Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 10:10:23 +0000 Subject: [PATCH 27/33] formatting (black) Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 53 ++++++++------ examples/experimental/tests/__init__.py | 1 - .../tests/test_sea_async_query.py | 72 +++++++++++++------ .../experimental/tests/test_sea_metadata.py | 27 ++++--- .../experimental/tests/test_sea_session.py | 5 +- .../experimental/tests/test_sea_sync_query.py | 48 +++++++++---- 6 files changed, 133 insertions(+), 73 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 33b5af334..b03f8ff64 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -22,90 +22,99 @@ "test_sea_metadata", ] + def load_test_function(module_name: str) -> Callable: """Load a test function from a module.""" module_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "tests", - f"{module_name}.py" + os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" ) - + spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - + # Get the main test function (assuming it starts with "test_") for name in dir(module): if name.startswith("test_") and callable(getattr(module, name)): # For sync and async query modules, we want the main function that runs both tests if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec": return getattr(module, name) - + # Fallback to the first test function found for name in dir(module): if name.startswith("test_") and callable(getattr(module, name)): return getattr(module, name) - + raise ValueError(f"No test function found in module {module_name}") + def run_tests() -> List[Tuple[str, bool]]: """Run all tests and return results.""" results = [] - + for module_name in TEST_MODULES: try: test_func = load_test_function(module_name) logger.info(f"\n{'=' * 50}") logger.info(f"Running test: {module_name}") logger.info(f"{'-' * 50}") - + success = test_func() results.append((module_name, success)) - + status = "✅ PASSED" if success else "❌ FAILED" logger.info(f"Test {module_name}: {status}") - + except Exception as e: logger.error(f"Error loading or running test {module_name}: {str(e)}") import traceback + logger.error(traceback.format_exc()) results.append((module_name, False)) - + return results + def print_summary(results: List[Tuple[str, bool]]) -> None: """Print a summary of test results.""" logger.info(f"\n{'=' * 50}") logger.info("TEST SUMMARY") logger.info(f"{'-' * 50}") - + passed = sum(1 for _, success in results if success) total = len(results) - + for module_name, success in results: status = "✅ PASSED" if success else "❌ FAILED" logger.info(f"{status} - {module_name}") - + logger.info(f"{'-' * 50}") logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") logger.info(f"{'=' * 50}") + if __name__ == "__main__": # Check if required environment variables are set - required_vars = ["DATABRICKS_SERVER_HOSTNAME", "DATABRICKS_HTTP_PATH", "DATABRICKS_TOKEN"] + required_vars = [ + "DATABRICKS_SERVER_HOSTNAME", + "DATABRICKS_HTTP_PATH", + "DATABRICKS_TOKEN", + ] missing_vars = [var for var in required_vars if not os.environ.get(var)] - + if missing_vars: - logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") + logger.error( + f"Missing required environment variables: {', '.join(missing_vars)}" + ) logger.error("Please set these variables before running the tests.") sys.exit(1) - + # Run all tests results = run_tests() - + # Print summary print_summary(results) - + # Exit with appropriate status code all_passed = all(success for _, success in results) - sys.exit(0 if all_passed else 1) \ No newline at end of file + sys.exit(0 if all_passed else 1) diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py index 5e1a8a58b..e69de29bb 100644 --- a/examples/experimental/tests/__init__.py +++ b/examples/experimental/tests/__init__.py @@ -1 +0,0 @@ -# This file makes the tests directory a Python package \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py index a4f3702f9..a776377c3 100644 --- a/examples/experimental/tests/test_sea_async_query.py +++ b/examples/experimental/tests/test_sea_async_query.py @@ -33,7 +33,9 @@ def test_sea_async_query_with_cloud_fetch(): try: # Create connection with cloud fetch enabled - logger.info("Creating connection for asynchronous query execution with cloud fetch enabled") + logger.info( + "Creating connection for asynchronous query execution with cloud fetch enabled" + ) connection = Connection( server_hostname=server_hostname, http_path=http_path, @@ -51,30 +53,39 @@ def test_sea_async_query_with_cloud_fetch(): # Execute a simple query asynchronously cursor = connection.cursor() - logger.info("Executing asynchronous query with cloud fetch: SELECT 1 as test_value") + logger.info( + "Executing asynchronous query with cloud fetch: SELECT 1 as test_value" + ) cursor.execute_async("SELECT 1 as test_value") - logger.info("Asynchronous query submitted successfully with cloud fetch enabled") - + logger.info( + "Asynchronous query submitted successfully with cloud fetch enabled" + ) + # Check query state logger.info("Checking query state...") while cursor.is_query_pending(): logger.info("Query is still pending, waiting...") time.sleep(1) - + logger.info("Query is no longer pending, getting results...") cursor.get_async_execution_result() - logger.info("Successfully retrieved asynchronous query results with cloud fetch enabled") - + logger.info( + "Successfully retrieved asynchronous query results with cloud fetch enabled" + ) + # Close resources cursor.close() connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: - logger.error(f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}") + logger.error( + f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}" + ) import traceback + logger.error(traceback.format_exc()) return False @@ -100,7 +111,9 @@ def test_sea_async_query_without_cloud_fetch(): try: # Create connection with cloud fetch disabled - logger.info("Creating connection for asynchronous query execution with cloud fetch disabled") + logger.info( + "Creating connection for asynchronous query execution with cloud fetch disabled" + ) connection = Connection( server_hostname=server_hostname, http_path=http_path, @@ -119,30 +132,39 @@ def test_sea_async_query_without_cloud_fetch(): # Execute a simple query asynchronously cursor = connection.cursor() - logger.info("Executing asynchronous query without cloud fetch: SELECT 1 as test_value") + logger.info( + "Executing asynchronous query without cloud fetch: SELECT 1 as test_value" + ) cursor.execute_async("SELECT 1 as test_value") - logger.info("Asynchronous query submitted successfully with cloud fetch disabled") - + logger.info( + "Asynchronous query submitted successfully with cloud fetch disabled" + ) + # Check query state logger.info("Checking query state...") while cursor.is_query_pending(): logger.info("Query is still pending, waiting...") time.sleep(1) - + logger.info("Query is no longer pending, getting results...") cursor.get_async_execution_result() - logger.info("Successfully retrieved asynchronous query results with cloud fetch disabled") - + logger.info( + "Successfully retrieved asynchronous query results with cloud fetch disabled" + ) + # Close resources cursor.close() connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: - logger.error(f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}") + logger.error( + f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}" + ) import traceback + logger.error(traceback.format_exc()) return False @@ -152,14 +174,18 @@ def test_sea_async_query_exec(): Run both asynchronous query tests and return overall success. """ with_cloud_fetch_success = test_sea_async_query_with_cloud_fetch() - logger.info(f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}") - + logger.info( + f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" + ) + without_cloud_fetch_success = test_sea_async_query_without_cloud_fetch() - logger.info(f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}") - + logger.info( + f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" + ) + return with_cloud_fetch_success and without_cloud_fetch_success if __name__ == "__main__": success = test_sea_async_query_exec() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py index ba760b61a..c715e5984 100644 --- a/examples/experimental/tests/test_sea_metadata.py +++ b/examples/experimental/tests/test_sea_metadata.py @@ -28,9 +28,11 @@ def test_sea_metadata(): "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." ) return False - + if not catalog: - logger.error("DATABRICKS_CATALOG environment variable is required for metadata tests.") + logger.error( + "DATABRICKS_CATALOG environment variable is required for metadata tests." + ) return False try: @@ -55,37 +57,42 @@ def test_sea_metadata(): logger.info("Fetching catalogs...") cursor.catalogs() logger.info("Successfully fetched catalogs") - + # Test schemas logger.info(f"Fetching schemas for catalog '{catalog}'...") cursor.schemas(catalog_name=catalog) logger.info("Successfully fetched schemas") - + # Test tables logger.info(f"Fetching tables for catalog '{catalog}', schema 'default'...") cursor.tables(catalog_name=catalog, schema_name="default") logger.info("Successfully fetched tables") - + # Test columns for a specific table # Using a common table that should exist in most environments - logger.info(f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'...") - cursor.columns(catalog_name=catalog, schema_name="default", table_name="information_schema") + logger.info( + f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'..." + ) + cursor.columns( + catalog_name=catalog, schema_name="default", table_name="information_schema" + ) logger.info("Successfully fetched columns") - + # Close resources cursor.close() connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: logger.error(f"Error during SEA metadata test: {str(e)}") import traceback + logger.error(traceback.format_exc()) return False if __name__ == "__main__": success = test_sea_metadata() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_session.py b/examples/experimental/tests/test_sea_session.py index c0f6817da..516c1bbb8 100644 --- a/examples/experimental/tests/test_sea_session.py +++ b/examples/experimental/tests/test_sea_session.py @@ -55,16 +55,17 @@ def test_sea_session(): logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback + logger.error(traceback.format_exc()) return False if __name__ == "__main__": success = test_sea_session() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py index 4879e587a..07be8aafc 100644 --- a/examples/experimental/tests/test_sea_sync_query.py +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -31,7 +31,9 @@ def test_sea_sync_query_with_cloud_fetch(): try: # Create connection with cloud fetch enabled - logger.info("Creating connection for synchronous query execution with cloud fetch enabled") + logger.info( + "Creating connection for synchronous query execution with cloud fetch enabled" + ) connection = Connection( server_hostname=server_hostname, http_path=http_path, @@ -49,20 +51,25 @@ def test_sea_sync_query_with_cloud_fetch(): # Execute a simple query cursor = connection.cursor() - logger.info("Executing synchronous query with cloud fetch: SELECT 1 as test_value") + logger.info( + "Executing synchronous query with cloud fetch: SELECT 1 as test_value" + ) cursor.execute("SELECT 1 as test_value") logger.info("Query executed successfully with cloud fetch enabled") - + # Close resources cursor.close() connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: - logger.error(f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}") + logger.error( + f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}" + ) import traceback + logger.error(traceback.format_exc()) return False @@ -88,7 +95,9 @@ def test_sea_sync_query_without_cloud_fetch(): try: # Create connection with cloud fetch disabled - logger.info("Creating connection for synchronous query execution with cloud fetch disabled") + logger.info( + "Creating connection for synchronous query execution with cloud fetch disabled" + ) connection = Connection( server_hostname=server_hostname, http_path=http_path, @@ -107,20 +116,25 @@ def test_sea_sync_query_without_cloud_fetch(): # Execute a simple query cursor = connection.cursor() - logger.info("Executing synchronous query without cloud fetch: SELECT 1 as test_value") + logger.info( + "Executing synchronous query without cloud fetch: SELECT 1 as test_value" + ) cursor.execute("SELECT 1 as test_value") logger.info("Query executed successfully with cloud fetch disabled") - + # Close resources cursor.close() connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: - logger.error(f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}") + logger.error( + f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}" + ) import traceback + logger.error(traceback.format_exc()) return False @@ -130,14 +144,18 @@ def test_sea_sync_query_exec(): Run both synchronous query tests and return overall success. """ with_cloud_fetch_success = test_sea_sync_query_with_cloud_fetch() - logger.info(f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}") - + logger.info( + f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" + ) + without_cloud_fetch_success = test_sea_sync_query_without_cloud_fetch() - logger.info(f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}") - + logger.info( + f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" + ) + return with_cloud_fetch_success and without_cloud_fetch_success if __name__ == "__main__": success = test_sea_sync_query_exec() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) From 3e22c6c4f297a3c83dbebba7c57e3bc8c0c5fe9a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 06:34:34 +0000 Subject: [PATCH 28/33] change to valid table name Signed-off-by: varun-edachali-dbx --- examples/experimental/tests/test_sea_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py index c715e5984..394c48b24 100644 --- a/examples/experimental/tests/test_sea_metadata.py +++ b/examples/experimental/tests/test_sea_metadata.py @@ -74,7 +74,7 @@ def test_sea_metadata(): f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'..." ) cursor.columns( - catalog_name=catalog, schema_name="default", table_name="information_schema" + catalog_name=catalog, schema_name="default", table_name="customer" ) logger.info("Successfully fetched columns") From 165c4f35ce69f282b03e6522c6ea72c6d0a8f5fc Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:18:39 +0000 Subject: [PATCH 29/33] remove un-necessary changes covered by #588 Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 11 +- src/databricks/sql/result_set.py | 73 ------- tests/unit/test_sea_result_set.py | 200 ------------------- tests/unit/test_thrift_backend.py | 32 +-- 4 files changed, 7 insertions(+), 309 deletions(-) delete mode 100644 tests/unit/test_sea_result_set.py diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index d28a2c6fd..e824de1c2 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -15,6 +15,7 @@ CommandId, ExecuteResponse, ) +from databricks.sql.backend.utils import guid_to_hex_id try: @@ -841,8 +842,6 @@ def get_execution_result( status = self.get_query_state(command_id) - status = self.get_query_state(command_id) - execute_response = ExecuteResponse( command_id=command_id, status=status, @@ -895,7 +894,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) state = CommandState.from_thrift_state(operation_state) if state is None: - raise ValueError(f"Invalid operation state: {operation_state}") + raise ValueError(f"Unknown command state: {operation_state}") return state @staticmethod @@ -1189,11 +1188,7 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - ( - execute_response, - arrow_schema_bytes, - ) = self._results_message_to_execute_response(resp, final_operation_state) - return execute_response, arrow_schema_bytes + return self._results_message_to_execute_response(resp, final_operation_state) def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 97b10cbbe..cf6940bb2 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -438,76 +438,3 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] - - -class SeaResultSet(ResultSet): - """ResultSet implementation for the SEA backend.""" - - def __init__( - self, - connection: "Connection", - execute_response: "ExecuteResponse", - sea_client: "SeaDatabricksClient", - buffer_size_bytes: int = 104857600, - arraysize: int = 10000, - ): - """ - Initialize a SeaResultSet with the response from a SEA query execution. - - Args: - connection: The parent connection - sea_client: The SeaDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - execute_response: Response from the execute command (new style) - sea_response: Direct SEA response (legacy style) - """ - - super().__init__( - connection=connection, - backend=sea_client, - arraysize=arraysize, - buffer_size_bytes=buffer_size_bytes, - command_id=execute_response.command_id, - status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, - description=execute_response.description, - is_staging_operation=execute_response.is_staging_operation, - ) - - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchmany(self, size: Optional[int] = None) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - - raise NotImplementedError("fetchmany is not implemented for SEA backend") - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - raise NotImplementedError("fetchall is not implemented for SEA backend") - - def fetchmany_arrow(self, size: int) -> Any: - """Fetch the next set of rows as an Arrow table.""" - raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") - - def fetchall_arrow(self) -> Any: - """Fetch all remaining rows as an Arrow table.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py deleted file mode 100644 index 02421a915..000000000 --- a/tests/unit/test_sea_result_set.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -Tests for the SeaResultSet class. - -This module contains tests for the SeaResultSet class, which implements -the result set functionality for the SEA (Statement Execution API) backend. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestSeaResultSet: - """Test suite for the SeaResultSet class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def execute_response(self): - """Create a sample execute response.""" - mock_response = Mock() - mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") - mock_response.status = CommandState.SUCCEEDED - mock_response.has_been_closed_server_side = False - mock_response.has_more_rows = False - mock_response.results_queue = None - mock_response.description = [ - ("test_value", "INT", None, None, None, None, None) - ] - mock_response.is_staging_operation = False - return mock_response - - def test_init_with_execute_response( - self, mock_connection, mock_sea_client, execute_response - ): - """Test initializing SeaResultSet with an execute response.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.command_id == execute_response.command_id - assert result_set.status == CommandState.SUCCEEDED - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set.description == execute_response.description - - def test_close(self, mock_connection, mock_sea_client, execute_response): - """Test closing a result set.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was called - mock_sea_client.close_command.assert_called_once_with(result_set.command_id) - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, execute_response - ): - """Test closing a result set that has already been closed server-side.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.has_been_closed_server_side = True - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, execute_response - ): - """Test closing a result set when the connection is closed.""" - mock_connection.open = False - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_unimplemented_methods( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that unimplemented methods raise NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() - - with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() - - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", - ): - result_set.fetchmany_arrow(10) - - with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", - ): - result_set.fetchall_arrow() - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test using the result set in a for loop - for row in result_set: - pass - - def test_fill_results_buffer_not_implemented( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that _fill_results_buffer raises NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set._fill_results_buffer() \ No newline at end of file diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 88adcd3e9..57b5e9b58 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -619,13 +619,6 @@ def test_handle_execute_response_sets_compression_in_direct_results( lz4Compressed = Mock() resultSet = MagicMock() resultSet.results.startRowOffset = 0 - - # Create a valid operation status - op_status = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), @@ -927,9 +920,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -957,12 +948,6 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) @@ -977,7 +962,7 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): t_execute_resp, Mock() ) - self.assertEqual(arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -997,12 +982,6 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req tcli_service_instance.GetOperationStatus.return_value = ( ttypes.TGetOperationStatusResp( @@ -1694,9 +1673,7 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() # Create a mock response with a real operation handle mock_resp = Mock() @@ -2256,8 +2233,7 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( - "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", - return_value=(Mock(), Mock()), + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" ) def test_execute_command_sets_complex_type_fields_correctly( self, mock_handle_execute_response, tcli_service_class, mock_result_set From a6e40d0dce9acd43c29e2de76f7d64ce96f775a8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:25:51 +0000 Subject: [PATCH 30/33] simplify test module Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 41 +++++++++------------ 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index b03f8ff64..3a8b163f5 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -1,20 +1,18 @@ """ Main script to run all SEA connector tests. -This script imports and runs all the individual test modules and displays +This script runs all the individual test modules and displays a summary of test results with visual indicators. """ import os import sys import logging -import importlib.util -from typing import Dict, Callable, List, Tuple +import subprocess +from typing import List, Tuple -# Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Define test modules and their main test functions TEST_MODULES = [ "test_sea_session", "test_sea_sync_query", @@ -23,29 +21,27 @@ ] -def load_test_function(module_name: str) -> Callable: - """Load a test function from a module.""" +def run_test_module(module_name: str) -> bool: + """Run a test module and return success status.""" module_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" ) - spec = importlib.util.spec_from_file_location(module_name, module_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + # Simply run the module as a script - each module handles its own test execution + result = subprocess.run( + [sys.executable, module_path], capture_output=True, text=True + ) - # Get the main test function (assuming it starts with "test_") - for name in dir(module): - if name.startswith("test_") and callable(getattr(module, name)): - # For sync and async query modules, we want the main function that runs both tests - if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec": - return getattr(module, name) + # Log the output from the test module + if result.stdout: + for line in result.stdout.strip().split("\n"): + logger.info(line) - # Fallback to the first test function found - for name in dir(module): - if name.startswith("test_") and callable(getattr(module, name)): - return getattr(module, name) + if result.stderr: + for line in result.stderr.strip().split("\n"): + logger.error(line) - raise ValueError(f"No test function found in module {module_name}") + return result.returncode == 0 def run_tests() -> List[Tuple[str, bool]]: @@ -54,12 +50,11 @@ def run_tests() -> List[Tuple[str, bool]]: for module_name in TEST_MODULES: try: - test_func = load_test_function(module_name) logger.info(f"\n{'=' * 50}") logger.info(f"Running test: {module_name}") logger.info(f"{'-' * 50}") - success = test_func() + success = run_test_module(module_name) results.append((module_name, success)) status = "✅ PASSED" if success else "❌ FAILED" From 52e3088b31d659064e740388bd2f25df1c3b158f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:26:23 +0000 Subject: [PATCH 31/33] logging -> debug level Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 3a8b163f5..edd171b05 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -10,7 +10,7 @@ import subprocess from typing import List, Tuple -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) TEST_MODULES = [ From 641c09b0d2a5fb5c79b3b696f767f81d0b5283e4 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:28:18 +0000 Subject: [PATCH 32/33] change table name in log Signed-off-by: varun-edachali-dbx --- examples/experimental/tests/test_sea_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py index 394c48b24..a200d97d3 100644 --- a/examples/experimental/tests/test_sea_metadata.py +++ b/examples/experimental/tests/test_sea_metadata.py @@ -71,7 +71,7 @@ def test_sea_metadata(): # Test columns for a specific table # Using a common table that should exist in most environments logger.info( - f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'..." + f"Fetching columns for catalog '{catalog}', schema 'default', table 'customer'..." ) cursor.columns( catalog_name=catalog, schema_name="default", table_name="customer" From aa7b542ed6e9f719fbd43391648e0ee38294b884 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 05:19:41 +0000 Subject: [PATCH 33/33] add basic documentation on env vars to be set Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index edd171b05..712f033c6 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -3,7 +3,13 @@ This script runs all the individual test modules and displays a summary of test results with visual indicators. + +In order to run the script, the following environment variables need to be set: +- DATABRICKS_SERVER_HOSTNAME: The hostname of the Databricks server +- DATABRICKS_HTTP_PATH: The HTTP path of the Databricks server +- DATABRICKS_TOKEN: The token to use for authentication """ + import os import sys import logging