From f3e4a97047725060ba45420595aa9600e4e06100 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 3 Jun 2025 08:25:53 +0000 Subject: [PATCH 01/24] [squashed from prev branch] introduce sea client with session open and close functionality Signed-off-by: varun-edachali-dbx --- .github/workflows/code-quality-checks.yml | 14 +- .github/workflows/integration.yml | 10 +- examples/experimental/sea_connector_test.py | 65 ++++ src/databricks/sql/backend/sea_backend.py | 301 ++++++++++++++++++ src/databricks/sql/backend/thrift_backend.py | 7 +- .../sql/backend/utils/http_client.py | 172 ++++++++++ src/databricks/sql/session.py | 36 ++- tests/unit/test_parameters.py | 8 +- tests/unit/test_sea_backend.py | 168 ++++++++++ 9 files changed, 763 insertions(+), 18 deletions(-) create mode 100644 examples/experimental/sea_connector_test.py create mode 100644 src/databricks/sql/backend/sea_backend.py create mode 100644 src/databricks/sql/backend/utils/http_client.py create mode 100644 tests/unit/test_sea_backend.py diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index b6db61a3c..462d22369 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -1,7 +1,15 @@ name: Code Quality Checks - -on: [pull_request] - +on: + push: + branches: + - main + - sea-migration + - telemetry + pull_request: + branches: + - main + - sea-migration + - telemetry jobs: run-unit-tests: runs-on: ubuntu-latest diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 127c8ff4f..ccd3a580d 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -1,10 +1,14 @@ name: Integration Tests - on: - push: + push: + paths-ignore: + - "**.MD" + - "**.md" + pull_request: branches: - main - pull_request: + - sea-migration + - telemetry jobs: run-e2e-tests: diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py new file mode 100644 index 000000000..dcfcd475f --- /dev/null +++ b/examples/experimental/sea_connector_test.py @@ -0,0 +1,65 @@ +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +def test_sea_session(): + """ + Test opening and closing a SEA session using the connector. + + This function connects to a Databricks SQL endpoint using the SEA backend, + opens a session, and then closes it. + + Required environment variables: + - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname + - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint + - DATABRICKS_TOKEN: Personal access token for authentication + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") + sys.exit(1) + + logger.info(f"Connecting to {server_hostname}") + logger.info(f"HTTP Path: {http_path}") + if catalog: + logger.info(f"Using catalog: {catalog}") + + try: + logger.info("Creating connection with SEA backend...") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client" # add custom user agent + ) + + logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") + logger.info(f"backend type: {type(connection.session.backend)}") + + # Close the connection + logger.info("Closing the SEA session...") + connection.close() + logger.info("Successfully closed SEA session") + + except Exception as e: + logger.error(f"Error testing SEA session: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + sys.exit(1) + + logger.info("SEA session test completed successfully") + +if __name__ == "__main__": + test_sea_session() \ No newline at end of file diff --git a/src/databricks/sql/backend/sea_backend.py b/src/databricks/sql/backend/sea_backend.py new file mode 100644 index 000000000..e5dc721ac --- /dev/null +++ b/src/databricks/sql/backend/sea_backend.py @@ -0,0 +1,301 @@ +import logging +import uuid +from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from databricks.sql.client import Cursor + +from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.exc import Error, NotSupportedError +from databricks.sql.backend.utils.http_client import CustomHttpClient +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.types import SSLOptions + +logger = logging.getLogger(__name__) + + +class SeaDatabricksClient(DatabricksClient): + """ + Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. + """ + + # SEA API paths + BASE_PATH = "/api/2.0/sql/" + SESSION_PATH = BASE_PATH + "sessions" + SESSION_PATH_WITH_ID = SESSION_PATH + "/{}" + STATEMENT_PATH = BASE_PATH + "statements" + STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}" + CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel" + + def __init__( + self, + server_hostname: str, + port: int, + http_path: str, + http_headers: List[Tuple[str, str]], + auth_provider, + ssl_options: SSLOptions, + staging_allowed_local_path: Union[None, str, List[str]] = None, + **kwargs, + ): + """ + Initialize the SEA backend client. + + Args: + server_hostname: Hostname of the Databricks server + port: Port number for the connection + http_path: HTTP path for the connection + http_headers: List of HTTP headers to include in requests + auth_provider: Authentication provider + ssl_options: SSL configuration options + staging_allowed_local_path: Allowed local paths for staging operations + **kwargs: Additional keyword arguments + """ + logger.debug( + "SEADatabricksClient.__init__(server_hostname=%s, port=%s, http_path=%s)", + server_hostname, + port, + http_path, + ) + + self._staging_allowed_local_path = staging_allowed_local_path + self._ssl_options = ssl_options + self._max_download_threads = kwargs.get("max_download_threads", 10) + + # Extract warehouse ID from http_path + self.warehouse_id = self._extract_warehouse_id(http_path) + + # Initialize HTTP client + self.http_client = CustomHttpClient( + server_hostname=server_hostname, + port=port, + http_path=http_path, + http_headers=http_headers, + auth_provider=auth_provider, + ssl_options=ssl_options, + **kwargs, + ) + + def _extract_warehouse_id(self, http_path: str) -> str: + """ + Extract the warehouse ID from the HTTP path. + + The warehouse ID is expected to be the last segment of the path when the + second-to-last segment is either 'warehouses' or 'endpoints'. + This matches the JDBC implementation which supports both formats. + + Args: + http_path: The HTTP path from which to extract the warehouse ID + + Returns: + The extracted warehouse ID + + Raises: + Error: If the warehouse ID cannot be extracted from the path + """ + path_parts = http_path.strip("/").split("/") + warehouse_id = None + + if len(path_parts) >= 3 and path_parts[-2] in ["warehouses", "endpoints"]: + warehouse_id = path_parts[-1] + logger.debug( + f"Extracted warehouse ID: {warehouse_id} from path: {http_path}" + ) + + if not warehouse_id: + error_message = ( + f"Could not extract warehouse ID from http_path: {http_path}. " + f"Expected format: /path/to/warehouses/{{warehouse_id}} or " + f"/path/to/endpoints/{{warehouse_id}}" + ) + logger.error(error_message) + raise ValueError(error_message) + + return warehouse_id + + @property + def staging_allowed_local_path(self) -> Union[None, str, List[str]]: + """Get the allowed local paths for staging operations.""" + return self._staging_allowed_local_path + + @property + def ssl_options(self) -> SSLOptions: + """Get the SSL options for this client.""" + return self._ssl_options + + @property + def max_download_threads(self) -> int: + """Get the maximum number of download threads for cloud fetch operations.""" + return self._max_download_threads + + def open_session( + self, + session_configuration: Optional[Dict[str, str]], + catalog: Optional[str], + schema: Optional[str], + ) -> SessionId: + """ + Opens a new session with the Databricks SQL service using SEA. + + Args: + session_configuration: Optional dictionary of configuration parameters for the session + catalog: Optional catalog name to use as the initial catalog for the session + schema: Optional schema name to use as the initial schema for the session + + Returns: + SessionId: A session identifier object that can be used for subsequent operations + + Raises: + Error: If the session configuration is invalid + OperationalError: If there's an error establishing the session + """ + logger.debug( + "SEADatabricksClient.open_session(session_configuration=%s, catalog=%s, schema=%s)", + session_configuration, + catalog, + schema, + ) + + request_data: Dict[str, Any] = {"warehouse_id": self.warehouse_id} + if session_configuration: + request_data["session_confs"] = session_configuration + if catalog: + request_data["catalog"] = catalog + if schema: + request_data["schema"] = schema + + response = self.http_client._make_request( + method="POST", path=self.SESSION_PATH, data=request_data + ) + + session_id = response.get("session_id") + if not session_id: + raise Error("Failed to create session: No session ID returned") + + return SessionId.from_sea_session_id(session_id) + + def close_session(self, session_id: SessionId) -> None: + """ + Closes an existing session with the Databricks SQL service. + + Args: + session_id: The session identifier returned by open_session() + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error closing the session + """ + logger.debug("SEADatabricksClient.close_session(session_id=%s)", session_id) + + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + sea_session_id = session_id.to_sea_session_id() + + request_data = {"warehouse_id": self.warehouse_id} + + self.http_client._make_request( + method="DELETE", + path=self.SESSION_PATH_WITH_ID.format(sea_session_id), + data=request_data, + ) + + # == Not Implemented Operations == + # These methods will be implemented in future iterations + + def execute_command( + self, + operation: str, + session_id: SessionId, + max_rows: int, + max_bytes: int, + lz4_compression: bool, + cursor: "Cursor", + use_cloud_fetch: bool, + parameters: List[ttypes.TSparkParameter], + async_op: bool, + enforce_embedded_schema_correctness: bool, + ): + """Not implemented yet.""" + raise NotSupportedError( + "execute_command is not yet implemented for SEA backend" + ) + + def cancel_command(self, command_id: CommandId) -> None: + """Not implemented yet.""" + raise NotSupportedError("cancel_command is not yet implemented for SEA backend") + + def close_command(self, command_id: CommandId) -> None: + """Not implemented yet.""" + raise NotSupportedError("close_command is not yet implemented for SEA backend") + + def get_query_state(self, command_id: CommandId) -> CommandState: + """Not implemented yet.""" + raise NotSupportedError( + "get_query_state is not yet implemented for SEA backend" + ) + + def get_execution_result( + self, + command_id: CommandId, + cursor: "Cursor", + ): + """Not implemented yet.""" + raise NotSupportedError( + "get_execution_result is not yet implemented for SEA backend" + ) + + # == Metadata Operations == + + def get_catalogs( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + ): + """Not implemented yet.""" + raise NotSupportedError("get_catalogs is not yet implemented for SEA backend") + + def get_schemas( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + ): + """Not implemented yet.""" + raise NotSupportedError("get_schemas is not yet implemented for SEA backend") + + def get_tables( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + table_types: Optional[List[str]] = None, + ): + """Not implemented yet.""" + raise NotSupportedError("get_tables is not yet implemented for SEA backend") + + def get_columns( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + column_name: Optional[str] = None, + ): + """Not implemented yet.""" + raise NotSupportedError("get_columns is not yet implemented for SEA backend") diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index de388f1d4..265e60d85 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -53,6 +53,7 @@ ) from databricks.sql.types import SSLOptions from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.result_set import ResultSet, ThriftResultSet logger = logging.getLogger(__name__) @@ -1165,7 +1166,11 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - return self._results_message_to_execute_response(resp, final_operation_state) + execute_response = self._results_message_to_execute_response( + resp, final_operation_state + ) + execute_response = execute_response._replace(command_id=command_id) + return execute_response def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) diff --git a/src/databricks/sql/backend/utils/http_client.py b/src/databricks/sql/backend/utils/http_client.py new file mode 100644 index 000000000..8cc229850 --- /dev/null +++ b/src/databricks/sql/backend/utils/http_client.py @@ -0,0 +1,172 @@ +import json +import logging +import requests +from typing import Dict, Any, Optional, Union, List +from urllib.parse import urljoin + +from databricks.sql.auth.authenticators import AuthProvider +from databricks.sql.types import SSLOptions + +logger = logging.getLogger(__name__) + + +class CustomHttpClient: + """ + HTTP client for Statement Execution API (SEA). + + This client handles the HTTP communication with the SEA endpoints, + including authentication, request formatting, and response parsing. + """ + + def __init__( + self, + server_hostname: str, + port: int, + http_path: str, + http_headers: List[tuple], + auth_provider: AuthProvider, + ssl_options: SSLOptions, + **kwargs, + ): + """ + Initialize the SEA HTTP client. + + Args: + server_hostname: Hostname of the Databricks server + port: Port number for the connection + http_path: HTTP path for the connection + http_headers: List of HTTP headers to include in requests + auth_provider: Authentication provider + ssl_options: SSL configuration options + **kwargs: Additional keyword arguments + """ + self.server_hostname = server_hostname + self.port = port + self.http_path = http_path + self.auth_provider = auth_provider + self.ssl_options = ssl_options + + self.base_url = f"https://{server_hostname}:{port}" + + self.headers = dict(http_headers) + self.headers.update({"Content-Type": "application/json"}) + + self.max_retries = kwargs.get("_retry_stop_after_attempts_count", 30) + + # Create a session for connection pooling + self.session = requests.Session() + + # Configure SSL verification + if ssl_options.tls_verify: + self.session.verify = ssl_options.tls_trusted_ca_file or True + else: + self.session.verify = False + + # Configure client certificates if provided + if ssl_options.tls_client_cert_file: + client_cert = ssl_options.tls_client_cert_file + client_key = ssl_options.tls_client_cert_key_file + client_key_password = ssl_options.tls_client_cert_key_password + + if client_key: + self.session.cert = (client_cert, client_key) + else: + self.session.cert = client_cert + + if client_key_password: + # Note: requests doesn't directly support key passwords + # This would require more complex handling with libraries like pyOpenSSL + logger.warning( + "Client key password provided but not supported by requests library" + ) + + def _get_auth_headers(self) -> Dict[str, str]: + """Get authentication headers from the auth provider.""" + headers: Dict[str, str] = {} + self.auth_provider.add_headers(headers) + return headers + + def _make_request( + self, method: str, path: str, data: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Make an HTTP request to the SEA endpoint. + + Args: + method: HTTP method (GET, POST, DELETE) + path: API endpoint path + data: Request payload data + + Returns: + Dict[str, Any]: Response data parsed from JSON + + Raises: + RequestError: If the request fails + """ + url = urljoin(self.base_url, path) + headers = {**self.headers, **self._get_auth_headers()} + + logger.debug(f"making {method} request to {url}") + + try: + if method.upper() == "GET": + response = self.session.get(url, headers=headers, params=data) + elif method.upper() == "POST": + response = self.session.post(url, headers=headers, json=data) + elif method.upper() == "DELETE": + # For DELETE requests, use params for data (query parameters) + response = self.session.delete(url, headers=headers, params=data) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + # Check for HTTP errors + response.raise_for_status() + + # Log response details + logger.debug(f"Response status: {response.status_code}") + + # Parse JSON response + if response.content: + result = response.json() + # Log response content (but limit it for large responses) + content_str = json.dumps(result) + if len(content_str) > 1000: + logger.debug( + f"Response content (truncated): {content_str[:1000]}..." + ) + else: + logger.debug(f"Response content: {content_str}") + return result + return {} + + except requests.exceptions.RequestException as e: + # Handle request errors + error_message = f"SEA HTTP request failed: {str(e)}" + logger.error(error_message) + + # Extract error details from response if available + if hasattr(e, "response") and e.response is not None: + try: + error_details = e.response.json() + error_message = ( + f"{error_message}: {error_details.get('message', '')}" + ) + logger.error( + f"Response status: {e.response.status_code}, Error details: {error_details}" + ) + except (ValueError, KeyError): + # If we can't parse the JSON, just log the raw content + content_str = ( + e.response.content.decode("utf-8", errors="replace") + if isinstance(e.response.content, bytes) + else str(e.response.content) + ) + logger.error( + f"Response status: {e.response.status_code}, Raw content: {content_str}" + ) + pass + + # Re-raise as a RequestError + from databricks.sql.exc import RequestError + + raise RequestError(error_message, e) diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 6d69b5487..98883310f 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -8,6 +8,7 @@ from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.backend.sea_backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, BackendType @@ -74,16 +75,31 @@ def __init__( tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) - self.backend: DatabricksClient = ThriftDatabricksClient( - self.host, - self.port, - http_path, - (http_headers or []) + base_headers, - auth_provider, - ssl_options=self._ssl_options, - _use_arrow_native_complex_types=_use_arrow_native_complex_types, - **kwargs, - ) + # Determine which backend to use + use_sea = kwargs.get("use_sea", False) + + if use_sea: + self.backend: DatabricksClient = SeaDatabricksClient( + self.host, + self.port, + http_path, + (http_headers or []) + base_headers, + auth_provider, + ssl_options=self._ssl_options, + _use_arrow_native_complex_types=_use_arrow_native_complex_types, + **kwargs, + ) + else: + self.backend = ThriftDatabricksClient( + self.host, + self.port, + http_path, + (http_headers or []) + base_headers, + auth_provider, + ssl_options=self._ssl_options, + _use_arrow_native_complex_types=_use_arrow_native_complex_types, + **kwargs, + ) self.protocol_version = None diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index 37e6cf1c9..949230d1e 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -64,7 +64,13 @@ class TestSessionHandleChecks(object): ], ) def test_get_protocol_version_fallback_behavior(self, test_input, expected): - assert Connection.get_protocol_version(test_input) == expected + properties = ( + {"serverProtocolVersion": test_input.serverProtocolVersion} + if test_input.serverProtocolVersion + else {} + ) + session_id = SessionId.from_thrift_handle(test_input.sessionHandle, properties) + assert Connection.get_protocol_version(session_id) == expected @pytest.mark.parametrize( "test_input,expected", diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py new file mode 100644 index 000000000..72009d6cf --- /dev/null +++ b/tests/unit/test_sea_backend.py @@ -0,0 +1,168 @@ +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.backend.sea_backend import SeaDatabricksClient +from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.types import SSLOptions +from databricks.sql.auth.authenticators import AuthProvider +from databricks.sql.exc import Error, NotSupportedError + + +class TestSeaBackend: + """Test suite for the SeaDatabricksClient class.""" + + @pytest.fixture + def mock_http_client(self): + """Create a mock HTTP client.""" + with patch( + "databricks.sql.backend.sea_backend.CustomHttpClient" + ) as mock_client_class: + mock_client = mock_client_class.return_value + yield mock_client + + @pytest.fixture + def sea_client(self, mock_http_client): + """Create a SeaDatabricksClient instance with mocked dependencies.""" + server_hostname = "test-server.databricks.com" + port = 443 + http_path = "/sql/warehouses/abc123" + http_headers = [("header1", "value1"), ("header2", "value2")] + auth_provider = AuthProvider() + ssl_options = SSLOptions() + + client = SeaDatabricksClient( + server_hostname=server_hostname, + port=port, + http_path=http_path, + http_headers=http_headers, + auth_provider=auth_provider, + ssl_options=ssl_options, + ) + + return client + + def test_init_extracts_warehouse_id(self, mock_http_client): + """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" + # Test with warehouses format + client1 = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client1.warehouse_id == "abc123" + + # Test with endpoints format + client2 = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/endpoints/def456", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client2.warehouse_id == "def456" + + def test_init_raises_error_for_invalid_http_path(self, mock_http_client): + """Test that the constructor raises an error for invalid HTTP paths.""" + with pytest.raises(ValueError) as excinfo: + SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/invalid/path", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert "Could not extract warehouse ID" in str(excinfo.value) + + def test_open_session_basic(self, sea_client, mock_http_client): + """Test the open_session method with minimal parameters.""" + # Set up mock response + mock_http_client._make_request.return_value = {"session_id": "test-session-123"} + + # Call the method + session_id = sea_client.open_session(None, None, None) + + # Verify the result + assert isinstance(session_id, SessionId) + assert session_id.backend_type == BackendType.SEA + assert session_id.guid == "test-session-123" + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once_with( + method="POST", path=sea_client.SESSION_PATH, data={"warehouse_id": "abc123"} + ) + + def test_open_session_with_all_parameters(self, sea_client, mock_http_client): + """Test the open_session method with all parameters.""" + # Set up mock response + mock_http_client._make_request.return_value = {"session_id": "test-session-456"} + + # Call the method with all parameters + session_config = {"spark.sql.shuffle.partitions": "10"} + catalog = "test_catalog" + schema = "test_schema" + + session_id = sea_client.open_session(session_config, catalog, schema) + + # Verify the result + assert isinstance(session_id, SessionId) + assert session_id.backend_type == BackendType.SEA + assert session_id.guid == "test-session-456" + + # Verify the HTTP request + expected_data = { + "warehouse_id": "abc123", + "session_confs": session_config, + "catalog": catalog, + "schema": schema, + } + mock_http_client._make_request.assert_called_once_with( + method="POST", path=sea_client.SESSION_PATH, data=expected_data + ) + + def test_open_session_error_handling(self, sea_client, mock_http_client): + """Test error handling in the open_session method.""" + # Set up mock response without session_id + mock_http_client._make_request.return_value = {} + + # Call the method and expect an error + with pytest.raises(Error) as excinfo: + sea_client.open_session(None, None, None) + + assert "Failed to create session" in str(excinfo.value) + + def test_close_session_valid_id(self, sea_client, mock_http_client): + """Test closing a session with a valid session ID.""" + # Create a valid SEA session ID + session_id = SessionId.from_sea_session_id("test-session-789") + + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.close_session(session_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once_with( + method="DELETE", + path=sea_client.SESSION_PATH_WITH_ID.format("test-session-789"), + data={"warehouse_id": "abc123"}, + ) + + def test_close_session_invalid_id_type(self, sea_client): + """Test closing a session with an invalid session ID type.""" + # Create a Thrift session ID (not SEA) + mock_thrift_handle = MagicMock() + mock_thrift_handle.sessionId.guid = b"guid" + mock_thrift_handle.sessionId.secret = b"secret" + session_id = SessionId.from_thrift_handle(mock_thrift_handle) + + # Call the method and expect an error + with pytest.raises(ValueError) as excinfo: + sea_client.close_session(session_id) + + assert "Not a valid SEA session ID" in str(excinfo.value) From 3df57529056b0882728e2c99ba0f091fc6e65184 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 3 Jun 2025 08:29:55 +0000 Subject: [PATCH 02/24] remove accidental changes to workflows (merge artifacts) Signed-off-by: varun-edachali-dbx --- .github/workflows/code-quality-checks.yml | 14 +++----------- .github/workflows/integration.yml | 10 +++------- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 462d22369..b6db61a3c 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -1,15 +1,7 @@ name: Code Quality Checks -on: - push: - branches: - - main - - sea-migration - - telemetry - pull_request: - branches: - - main - - sea-migration - - telemetry + +on: [pull_request] + jobs: run-unit-tests: runs-on: ubuntu-latest diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index ccd3a580d..127c8ff4f 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -1,14 +1,10 @@ name: Integration Tests + on: - push: - paths-ignore: - - "**.MD" - - "**.md" - pull_request: + push: branches: - main - - sea-migration - - telemetry + pull_request: jobs: run-e2e-tests: From 9146a94be8d0d6868fcc1abdb76df3a895084b43 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 3 Jun 2025 08:33:47 +0000 Subject: [PATCH 03/24] pass test_input to get_protocol_version instead of session_id to maintain previous API Signed-off-by: varun-edachali-dbx --- tests/unit/test_parameters.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index 949230d1e..37e6cf1c9 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -64,13 +64,7 @@ class TestSessionHandleChecks(object): ], ) def test_get_protocol_version_fallback_behavior(self, test_input, expected): - properties = ( - {"serverProtocolVersion": test_input.serverProtocolVersion} - if test_input.serverProtocolVersion - else {} - ) - session_id = SessionId.from_thrift_handle(test_input.sessionHandle, properties) - assert Connection.get_protocol_version(session_id) == expected + assert Connection.get_protocol_version(test_input) == expected @pytest.mark.parametrize( "test_input,expected", From 9b39e37dc074dcb3adfd94d9b3cdc6118c304476 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 3 Jun 2025 08:51:38 +0000 Subject: [PATCH 04/24] formatting (black + line gaps after multi-line pydocs) Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 1 + src/databricks/sql/backend/sea_backend.py | 10 +++++++--- src/databricks/sql/backend/thrift_backend.py | 7 +------ src/databricks/sql/backend/utils/http_client.py | 2 ++ 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index dcfcd475f..a27099da7 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -18,6 +18,7 @@ def test_sea_session(): - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") diff --git a/src/databricks/sql/backend/sea_backend.py b/src/databricks/sql/backend/sea_backend.py index e5dc721ac..10e9acc00 100644 --- a/src/databricks/sql/backend/sea_backend.py +++ b/src/databricks/sql/backend/sea_backend.py @@ -55,8 +55,9 @@ def __init__( staging_allowed_local_path: Allowed local paths for staging operations **kwargs: Additional keyword arguments """ + logger.debug( - "SEADatabricksClient.__init__(server_hostname=%s, port=%s, http_path=%s)", + "SeaDatabricksClient.__init__(server_hostname=%s, port=%s, http_path=%s)", server_hostname, port, http_path, @@ -97,6 +98,7 @@ def _extract_warehouse_id(self, http_path: str) -> str: Raises: Error: If the warehouse ID cannot be extracted from the path """ + path_parts = http_path.strip("/").split("/") warehouse_id = None @@ -153,8 +155,9 @@ def open_session( Error: If the session configuration is invalid OperationalError: If there's an error establishing the session """ + logger.debug( - "SEADatabricksClient.open_session(session_configuration=%s, catalog=%s, schema=%s)", + "SeaDatabricksClient.open_session(session_configuration=%s, catalog=%s, schema=%s)", session_configuration, catalog, schema, @@ -189,7 +192,8 @@ def close_session(self, session_id: SessionId) -> None: ValueError: If the session ID is invalid OperationalError: If there's an error closing the session """ - logger.debug("SEADatabricksClient.close_session(session_id=%s)", session_id) + + logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id) if session_id.backend_type != BackendType.SEA: raise ValueError("Not a valid SEA session ID") diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 265e60d85..de388f1d4 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -53,7 +53,6 @@ ) from databricks.sql.types import SSLOptions from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.result_set import ResultSet, ThriftResultSet logger = logging.getLogger(__name__) @@ -1166,11 +1165,7 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - execute_response = self._results_message_to_execute_response( - resp, final_operation_state - ) - execute_response = execute_response._replace(command_id=command_id) - return execute_response + return self._results_message_to_execute_response(resp, final_operation_state) def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) diff --git a/src/databricks/sql/backend/utils/http_client.py b/src/databricks/sql/backend/utils/http_client.py index 8cc229850..82980792c 100644 --- a/src/databricks/sql/backend/utils/http_client.py +++ b/src/databricks/sql/backend/utils/http_client.py @@ -40,6 +40,7 @@ def __init__( ssl_options: SSL configuration options **kwargs: Additional keyword arguments """ + self.server_hostname = server_hostname self.port = port self.http_path = http_path @@ -103,6 +104,7 @@ def _make_request( Raises: RequestError: If the request fails """ + url = urljoin(self.base_url, path) headers = {**self.headers, **self._get_auth_headers()} From 1ccbcd2ad55d3ca864cfcf10179dc759abcd20ab Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 3 Jun 2025 09:20:56 +0000 Subject: [PATCH 05/24] use factory for backend instantiation Signed-off-by: varun-edachali-dbx --- src/databricks/sql/session.py | 65 ++++++++++++++++++++++------------- tests/unit/test_session.py | 16 +++++---- 2 files changed, 50 insertions(+), 31 deletions(-) diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 98883310f..58e8919c9 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Tuple, List, Optional, Any +from typing import Dict, Tuple, List, Optional, Any, Type from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions @@ -62,6 +62,7 @@ def __init__( useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) base_headers = [("User-Agent", useragent_header)] + all_headers = (http_headers or []) + base_headers self._ssl_options = SSLOptions( # Double negation is generally a bad thing, but we have to keep backward compatibility @@ -75,33 +76,49 @@ def __init__( tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) - # Determine which backend to use + self.backend = self._create_backend( + server_hostname, + http_path, + all_headers, + auth_provider, + _use_arrow_native_complex_types, + kwargs, + ) + + self.protocol_version = None + + def _create_backend( + self, + server_hostname: str, + http_path: str, + all_headers: List[Tuple[str, str]], + auth_provider, + _use_arrow_native_complex_types: bool, + kwargs: dict, + ) -> DatabricksClient: + """Create and return the appropriate backend client.""" use_sea = kwargs.get("use_sea", False) if use_sea: - self.backend: DatabricksClient = SeaDatabricksClient( - self.host, - self.port, - http_path, - (http_headers or []) + base_headers, - auth_provider, - ssl_options=self._ssl_options, - _use_arrow_native_complex_types=_use_arrow_native_complex_types, - **kwargs, - ) + logger.debug("Creating SEA backend client") + databricks_client_class = SeaDatabricksClient else: - self.backend = ThriftDatabricksClient( - self.host, - self.port, - http_path, - (http_headers or []) + base_headers, - auth_provider, - ssl_options=self._ssl_options, - _use_arrow_native_complex_types=_use_arrow_native_complex_types, - **kwargs, - ) - - self.protocol_version = None + logger.debug("Creating Thrift backend client") + databricks_client_class = ThriftDatabricksClient + + # Prepare common arguments + common_args = { + "server_hostname": server_hostname, + "port": self.port, + "http_path": http_path, + "http_headers": all_headers, + "auth_provider": auth_provider, + "ssl_options": self._ssl_options, + "_use_arrow_native_complex_types": _use_arrow_native_complex_types, + **kwargs, + } + + return databricks_client_class(**common_args) def open(self): self._session_id = self.backend.open_session( diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 858119f92..7db4b1338 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -62,9 +62,9 @@ def test_auth_args(self, mock_client_class): for args in connection_args: connection = databricks.sql.connect(**args) - host, port, http_path, *_ = mock_client_class.call_args[0] - self.assertEqual(args["server_hostname"], host) - self.assertEqual(args["http_path"], http_path) + call_kwargs = mock_client_class.call_args[1] + self.assertEqual(args["server_hostname"], call_kwargs["server_hostname"]) + self.assertEqual(args["http_path"], call_kwargs["http_path"]) connection.close() @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) @@ -72,8 +72,8 @@ def test_http_header_passthrough(self, mock_client_class): http_headers = [("foo", "bar")] databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) - call_args = mock_client_class.call_args[0][3] - self.assertIn(("foo", "bar"), call_args) + call_kwargs = mock_client_class.call_args[1] + self.assertIn(("foo", "bar"), call_kwargs["http_headers"]) @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_tls_arg_passthrough(self, mock_client_class): @@ -95,7 +95,8 @@ def test_tls_arg_passthrough(self, mock_client_class): def test_useragent_header(self, mock_client_class): databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - http_headers = mock_client_class.call_args[0][3] + call_kwargs = mock_client_class.call_args[1] + http_headers = call_kwargs["http_headers"] user_agent_header = ( "User-Agent", "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), @@ -109,7 +110,8 @@ def test_useragent_header(self, mock_client_class): databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar" ), ) - http_headers = mock_client_class.call_args[0][3] + call_kwargs = mock_client_class.call_args[1] + http_headers = call_kwargs["http_headers"] self.assertIn(user_agent_header_with_entry, http_headers) @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) From 3528523dcc103456341052dca77835e82538808a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 3 Jun 2025 09:37:11 +0000 Subject: [PATCH 06/24] fix type issues Signed-off-by: varun-edachali-dbx --- src/databricks/sql/session.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 58e8919c9..86bf24a7e 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -93,12 +93,13 @@ def _create_backend( http_path: str, all_headers: List[Tuple[str, str]], auth_provider, - _use_arrow_native_complex_types: bool, + _use_arrow_native_complex_types: Optional[bool], kwargs: dict, ) -> DatabricksClient: """Create and return the appropriate backend client.""" use_sea = kwargs.get("use_sea", False) + databricks_client_class: Type[DatabricksClient] if use_sea: logger.debug("Creating SEA backend client") databricks_client_class = SeaDatabricksClient From b39e83ba2b8671202e954d5c379c12d8913bd5c5 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 3 Jun 2025 09:39:17 +0000 Subject: [PATCH 07/24] remove redundant comments Signed-off-by: varun-edachali-dbx --- src/databricks/sql/session.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 86bf24a7e..6804002bb 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -107,7 +107,6 @@ def _create_backend( logger.debug("Creating Thrift backend client") databricks_client_class = ThriftDatabricksClient - # Prepare common arguments common_args = { "server_hostname": server_hostname, "port": self.port, @@ -118,7 +117,6 @@ def _create_backend( "_use_arrow_native_complex_types": _use_arrow_native_complex_types, **kwargs, } - return databricks_client_class(**common_args) def open(self): From ba361261dcd03f9f5d6a3ae9d17af86724de30dd Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 3 Jun 2025 16:08:09 +0000 Subject: [PATCH 08/24] introduce models for requests and responses Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/models/__init__.py | 22 ++++++++++ src/databricks/sql/backend/models/requests.py | 39 ++++++++++++++++++ .../sql/backend/models/responses.py | 14 +++++++ src/databricks/sql/backend/sea_backend.py | 41 +++++++++++++------ tests/unit/test_sea_backend.py | 2 +- 5 files changed, 104 insertions(+), 14 deletions(-) create mode 100644 src/databricks/sql/backend/models/__init__.py create mode 100644 src/databricks/sql/backend/models/requests.py create mode 100644 src/databricks/sql/backend/models/responses.py diff --git a/src/databricks/sql/backend/models/__init__.py b/src/databricks/sql/backend/models/__init__.py new file mode 100644 index 000000000..667235cce --- /dev/null +++ b/src/databricks/sql/backend/models/__init__.py @@ -0,0 +1,22 @@ +""" +Models for the SEA (Statement Execution API) backend. + +This package contains data models for SEA API requests and responses. +""" + +from databricks.sql.backend.models.requests import ( + CreateSessionRequest, + DeleteSessionRequest, +) + +from databricks.sql.backend.models.responses import ( + CreateSessionResponse, +) + +__all__ = [ + # Request models + "CreateSessionRequest", + "DeleteSessionRequest", + # Response models + "CreateSessionResponse", +] diff --git a/src/databricks/sql/backend/models/requests.py b/src/databricks/sql/backend/models/requests.py new file mode 100644 index 000000000..7966cb502 --- /dev/null +++ b/src/databricks/sql/backend/models/requests.py @@ -0,0 +1,39 @@ +from typing import Dict, Any, Optional +from dataclasses import dataclass + + +@dataclass +class CreateSessionRequest: + """Request to create a new session.""" + + warehouse_id: str + session_confs: Optional[Dict[str, str]] = None + catalog: Optional[str] = None + schema: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + result: Dict[str, Any] = {"warehouse_id": self.warehouse_id} + + if self.session_confs: + result["session_confs"] = self.session_confs + + if self.catalog: + result["catalog"] = self.catalog + + if self.schema: + result["schema"] = self.schema + + return result + + +@dataclass +class DeleteSessionRequest: + """Request to delete a session.""" + + warehouse_id: str + session_id: str + + def to_dict(self) -> Dict[str, str]: + """Convert the request to a dictionary for JSON serialization.""" + return {"warehouse_id": self.warehouse_id, "session_id": self.session_id} diff --git a/src/databricks/sql/backend/models/responses.py b/src/databricks/sql/backend/models/responses.py new file mode 100644 index 000000000..1bb54590f --- /dev/null +++ b/src/databricks/sql/backend/models/responses.py @@ -0,0 +1,14 @@ +from typing import Dict, Any +from dataclasses import dataclass + + +@dataclass +class CreateSessionResponse: + """Response from creating a new session.""" + + session_id: str + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse": + """Create a CreateSessionResponse from a dictionary.""" + return cls(session_id=data.get("session_id", "")) diff --git a/src/databricks/sql/backend/sea_backend.py b/src/databricks/sql/backend/sea_backend.py index 10e9acc00..20288175a 100644 --- a/src/databricks/sql/backend/sea_backend.py +++ b/src/databricks/sql/backend/sea_backend.py @@ -7,11 +7,17 @@ from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import Error, NotSupportedError +from databricks.sql.exc import Error, NotSupportedError, ServerOperationError from databricks.sql.backend.utils.http_client import CustomHttpClient from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions +from databricks.sql.backend.models import ( + CreateSessionRequest, + DeleteSessionRequest, + CreateSessionResponse, +) + logger = logging.getLogger(__name__) @@ -163,21 +169,27 @@ def open_session( schema, ) - request_data: Dict[str, Any] = {"warehouse_id": self.warehouse_id} - if session_configuration: - request_data["session_confs"] = session_configuration - if catalog: - request_data["catalog"] = catalog - if schema: - request_data["schema"] = schema + request_data = CreateSessionRequest( + warehouse_id=self.warehouse_id, + session_confs=session_configuration, + catalog=catalog, + schema=schema, + ) response = self.http_client._make_request( - method="POST", path=self.SESSION_PATH, data=request_data + method="POST", path=self.SESSION_PATH, data=request_data.to_dict() ) - session_id = response.get("session_id") + session_response = CreateSessionResponse.from_dict(response) + session_id = session_response.session_id if not session_id: - raise Error("Failed to create session: No session ID returned") + raise ServerOperationError( + "Failed to create session: No session ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) return SessionId.from_sea_session_id(session_id) @@ -199,12 +211,15 @@ def close_session(self, session_id: SessionId) -> None: raise ValueError("Not a valid SEA session ID") sea_session_id = session_id.to_sea_session_id() - request_data = {"warehouse_id": self.warehouse_id} + request_data = DeleteSessionRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + ) self.http_client._make_request( method="DELETE", path=self.SESSION_PATH_WITH_ID.format(sea_session_id), - data=request_data, + data=request_data.to_dict(), ) # == Not Implemented Operations == diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 72009d6cf..76fc6d273 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -150,7 +150,7 @@ def test_close_session_valid_id(self, sea_client, mock_http_client): mock_http_client._make_request.assert_called_once_with( method="DELETE", path=sea_client.SESSION_PATH_WITH_ID.format("test-session-789"), - data={"warehouse_id": "abc123"}, + data={"session_id": "test-session-789", "warehouse_id": "abc123"}, ) def test_close_session_invalid_id_type(self, sea_client): From 059cd4ddf21b8e1cc9aabec64eddf5d9f3d77e63 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 4 Jun 2025 11:14:14 +0530 Subject: [PATCH 09/24] remove http client and test script to prevent diff from showing up post http-client merge Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 66 ------- .../sql/backend/utils/http_client.py | 174 ------------------ 2 files changed, 240 deletions(-) delete mode 100644 examples/experimental/sea_connector_test.py delete mode 100644 src/databricks/sql/backend/utils/http_client.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py deleted file mode 100644 index a27099da7..000000000 --- a/examples/experimental/sea_connector_test.py +++ /dev/null @@ -1,66 +0,0 @@ -import os -import sys -import logging -from databricks.sql.client import Connection - -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger(__name__) - -def test_sea_session(): - """ - Test opening and closing a SEA session using the connector. - - This function connects to a Databricks SQL endpoint using the SEA backend, - opens a session, and then closes it. - - Required environment variables: - - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - - DATABRICKS_TOKEN: Personal access token for authentication - """ - - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") - sys.exit(1) - - logger.info(f"Connecting to {server_hostname}") - logger.info(f"HTTP Path: {http_path}") - if catalog: - logger.info(f"Using catalog: {catalog}") - - try: - logger.info("Creating connection with SEA backend...") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent - ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") - logger.info(f"backend type: {type(connection.session.backend)}") - - # Close the connection - logger.info("Closing the SEA session...") - connection.close() - logger.info("Successfully closed SEA session") - - except Exception as e: - logger.error(f"Error testing SEA session: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - sys.exit(1) - - logger.info("SEA session test completed successfully") - -if __name__ == "__main__": - test_sea_session() \ No newline at end of file diff --git a/src/databricks/sql/backend/utils/http_client.py b/src/databricks/sql/backend/utils/http_client.py deleted file mode 100644 index 82980792c..000000000 --- a/src/databricks/sql/backend/utils/http_client.py +++ /dev/null @@ -1,174 +0,0 @@ -import json -import logging -import requests -from typing import Dict, Any, Optional, Union, List -from urllib.parse import urljoin - -from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.types import SSLOptions - -logger = logging.getLogger(__name__) - - -class CustomHttpClient: - """ - HTTP client for Statement Execution API (SEA). - - This client handles the HTTP communication with the SEA endpoints, - including authentication, request formatting, and response parsing. - """ - - def __init__( - self, - server_hostname: str, - port: int, - http_path: str, - http_headers: List[tuple], - auth_provider: AuthProvider, - ssl_options: SSLOptions, - **kwargs, - ): - """ - Initialize the SEA HTTP client. - - Args: - server_hostname: Hostname of the Databricks server - port: Port number for the connection - http_path: HTTP path for the connection - http_headers: List of HTTP headers to include in requests - auth_provider: Authentication provider - ssl_options: SSL configuration options - **kwargs: Additional keyword arguments - """ - - self.server_hostname = server_hostname - self.port = port - self.http_path = http_path - self.auth_provider = auth_provider - self.ssl_options = ssl_options - - self.base_url = f"https://{server_hostname}:{port}" - - self.headers = dict(http_headers) - self.headers.update({"Content-Type": "application/json"}) - - self.max_retries = kwargs.get("_retry_stop_after_attempts_count", 30) - - # Create a session for connection pooling - self.session = requests.Session() - - # Configure SSL verification - if ssl_options.tls_verify: - self.session.verify = ssl_options.tls_trusted_ca_file or True - else: - self.session.verify = False - - # Configure client certificates if provided - if ssl_options.tls_client_cert_file: - client_cert = ssl_options.tls_client_cert_file - client_key = ssl_options.tls_client_cert_key_file - client_key_password = ssl_options.tls_client_cert_key_password - - if client_key: - self.session.cert = (client_cert, client_key) - else: - self.session.cert = client_cert - - if client_key_password: - # Note: requests doesn't directly support key passwords - # This would require more complex handling with libraries like pyOpenSSL - logger.warning( - "Client key password provided but not supported by requests library" - ) - - def _get_auth_headers(self) -> Dict[str, str]: - """Get authentication headers from the auth provider.""" - headers: Dict[str, str] = {} - self.auth_provider.add_headers(headers) - return headers - - def _make_request( - self, method: str, path: str, data: Optional[Dict[str, Any]] = None - ) -> Dict[str, Any]: - """ - Make an HTTP request to the SEA endpoint. - - Args: - method: HTTP method (GET, POST, DELETE) - path: API endpoint path - data: Request payload data - - Returns: - Dict[str, Any]: Response data parsed from JSON - - Raises: - RequestError: If the request fails - """ - - url = urljoin(self.base_url, path) - headers = {**self.headers, **self._get_auth_headers()} - - logger.debug(f"making {method} request to {url}") - - try: - if method.upper() == "GET": - response = self.session.get(url, headers=headers, params=data) - elif method.upper() == "POST": - response = self.session.post(url, headers=headers, json=data) - elif method.upper() == "DELETE": - # For DELETE requests, use params for data (query parameters) - response = self.session.delete(url, headers=headers, params=data) - else: - raise ValueError(f"Unsupported HTTP method: {method}") - - # Check for HTTP errors - response.raise_for_status() - - # Log response details - logger.debug(f"Response status: {response.status_code}") - - # Parse JSON response - if response.content: - result = response.json() - # Log response content (but limit it for large responses) - content_str = json.dumps(result) - if len(content_str) > 1000: - logger.debug( - f"Response content (truncated): {content_str[:1000]}..." - ) - else: - logger.debug(f"Response content: {content_str}") - return result - return {} - - except requests.exceptions.RequestException as e: - # Handle request errors - error_message = f"SEA HTTP request failed: {str(e)}" - logger.error(error_message) - - # Extract error details from response if available - if hasattr(e, "response") and e.response is not None: - try: - error_details = e.response.json() - error_message = ( - f"{error_message}: {error_details.get('message', '')}" - ) - logger.error( - f"Response status: {e.response.status_code}, Error details: {error_details}" - ) - except (ValueError, KeyError): - # If we can't parse the JSON, just log the raw content - content_str = ( - e.response.content.decode("utf-8", errors="replace") - if isinstance(e.response.content, bytes) - else str(e.response.content) - ) - logger.error( - f"Response status: {e.response.status_code}, Raw content: {content_str}" - ) - pass - - # Re-raise as a RequestError - from databricks.sql.exc import RequestError - - raise RequestError(error_message, e) From 6830327f7c233a8d645845453496486b7eedd8a6 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 4 Jun 2025 11:24:49 +0530 Subject: [PATCH 10/24] Introduce Sea HTTP Client and test script (#583) * introduce http client (temp) and sea test file Signed-off-by: varun-edachali-dbx * reduce verbosity Signed-off-by: varun-edachali-dbx * redundant comment Signed-off-by: varun-edachali-dbx * reduce redundancy, params and data separate Signed-off-by: varun-edachali-dbx * rename client Signed-off-by: varun-edachali-dbx * fix type issues Signed-off-by: varun-edachali-dbx * reduce repetition in request calls Signed-off-by: varun-edachali-dbx * remove un-necessary elifs Signed-off-by: varun-edachali-dbx * add newline at EOF Signed-off-by: varun-edachali-dbx --------- Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 66 +++++++ .../sql/backend/utils/http_client.py | 186 ++++++++++++++++++ 2 files changed, 252 insertions(+) create mode 100644 examples/experimental/sea_connector_test.py create mode 100644 src/databricks/sql/backend/utils/http_client.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py new file mode 100644 index 000000000..abe6bd1ab --- /dev/null +++ b/examples/experimental/sea_connector_test.py @@ -0,0 +1,66 @@ +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +def test_sea_session(): + """ + Test opening and closing a SEA session using the connector. + + This function connects to a Databricks SQL endpoint using the SEA backend, + opens a session, and then closes it. + + Required environment variables: + - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname + - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint + - DATABRICKS_TOKEN: Personal access token for authentication + """ + + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") + sys.exit(1) + + logger.info(f"Connecting to {server_hostname}") + logger.info(f"HTTP Path: {http_path}") + if catalog: + logger.info(f"Using catalog: {catalog}") + + try: + logger.info("Creating connection with SEA backend...") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client" # add custom user agent + ) + + logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") + logger.info(f"backend type: {type(connection.session.backend)}") + + # Close the connection + logger.info("Closing the SEA session...") + connection.close() + logger.info("Successfully closed SEA session") + + except Exception as e: + logger.error(f"Error testing SEA session: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + sys.exit(1) + + logger.info("SEA session test completed successfully") + +if __name__ == "__main__": + test_sea_session() diff --git a/src/databricks/sql/backend/utils/http_client.py b/src/databricks/sql/backend/utils/http_client.py new file mode 100644 index 000000000..f0b931ee4 --- /dev/null +++ b/src/databricks/sql/backend/utils/http_client.py @@ -0,0 +1,186 @@ +import json +import logging +import requests +from typing import Callable, Dict, Any, Optional, Union, List, Tuple +from urllib.parse import urljoin + +from databricks.sql.auth.authenticators import AuthProvider +from databricks.sql.types import SSLOptions + +logger = logging.getLogger(__name__) + + +class SeaHttpClient: + """ + HTTP client for Statement Execution API (SEA). + + This client handles the HTTP communication with the SEA endpoints, + including authentication, request formatting, and response parsing. + """ + + def __init__( + self, + server_hostname: str, + port: int, + http_path: str, + http_headers: List[Tuple[str, str]], + auth_provider: AuthProvider, + ssl_options: SSLOptions, + **kwargs, + ): + """ + Initialize the SEA HTTP client. + + Args: + server_hostname: Hostname of the Databricks server + port: Port number for the connection + http_path: HTTP path for the connection + http_headers: List of HTTP headers to include in requests + auth_provider: Authentication provider + ssl_options: SSL configuration options + **kwargs: Additional keyword arguments + """ + + self.server_hostname = server_hostname + self.port = port + self.http_path = http_path + self.auth_provider = auth_provider + self.ssl_options = ssl_options + + self.base_url = f"https://{server_hostname}:{port}" + + self.headers: Dict[str, str] = dict(http_headers) + self.headers.update({"Content-Type": "application/json"}) + + self.max_retries = kwargs.get("_retry_stop_after_attempts_count", 30) + + # Create a session for connection pooling + self.session = requests.Session() + + # Configure SSL verification + if ssl_options.tls_verify: + self.session.verify = ssl_options.tls_trusted_ca_file or True + else: + self.session.verify = False + + # Configure client certificates if provided + if ssl_options.tls_client_cert_file: + client_cert = ssl_options.tls_client_cert_file + client_key = ssl_options.tls_client_cert_key_file + client_key_password = ssl_options.tls_client_cert_key_password + + if client_key: + self.session.cert = (client_cert, client_key) + else: + self.session.cert = client_cert + + if client_key_password: + # Note: requests doesn't directly support key passwords + # This would require more complex handling with libraries like pyOpenSSL + logger.warning( + "Client key password provided but not supported by requests library" + ) + + def _get_auth_headers(self) -> Dict[str, str]: + """Get authentication headers from the auth provider.""" + headers: Dict[str, str] = {} + self.auth_provider.add_headers(headers) + return headers + + def _get_call(self, method: str) -> Callable: + """Get the appropriate HTTP method function.""" + method = method.upper() + if method == "GET": + return self.session.get + if method == "POST": + return self.session.post + if method == "DELETE": + return self.session.delete + raise ValueError(f"Unsupported HTTP method: {method}") + + def _make_request( + self, + method: str, + path: str, + data: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Make an HTTP request to the SEA endpoint. + + Args: + method: HTTP method (GET, POST, DELETE) + path: API endpoint path + data: Request payload data + params: Query parameters + + Returns: + Dict[str, Any]: Response data parsed from JSON + + Raises: + RequestError: If the request fails + """ + + url = urljoin(self.base_url, path) + headers: Dict[str, str] = {**self.headers, **self._get_auth_headers()} + + logger.debug(f"making {method} request to {url}") + + try: + call = self._get_call(method) + response = call( + url=url, + headers=headers, + json=data, + params=params, + ) + + # Check for HTTP errors + response.raise_for_status() + + # Log response details + logger.debug(f"Response status: {response.status_code}") + + # Parse JSON response + if response.content: + result = response.json() + # Log response content (but limit it for large responses) + content_str = json.dumps(result) + if len(content_str) > 1000: + logger.debug( + f"Response content (truncated): {content_str[:1000]}..." + ) + else: + logger.debug(f"Response content: {content_str}") + return result + return {} + + except requests.exceptions.RequestException as e: + # Handle request errors and extract details from response if available + error_message = f"SEA HTTP request failed: {str(e)}" + + if hasattr(e, "response") and e.response is not None: + status_code = e.response.status_code + try: + error_details = e.response.json() + error_message = ( + f"{error_message}: {error_details.get('message', '')}" + ) + logger.error( + f"Request failed (status {status_code}): {error_details}" + ) + except (ValueError, KeyError): + # If we can't parse JSON, log raw content + content = ( + e.response.content.decode("utf-8", errors="replace") + if isinstance(e.response.content, bytes) + else str(e.response.content) + ) + logger.error(f"Request failed (status {status_code}): {content}") + else: + logger.error(error_message) + + # Re-raise as a RequestError + from databricks.sql.exc import RequestError + + raise RequestError(error_message, e) From ab847da53596cda474e11d15ddad0081c97c5088 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 4 Jun 2025 05:59:02 +0000 Subject: [PATCH 11/24] CustomHttpClient -> SeaHttpClient Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea_backend.py | 4 ++-- tests/unit/test_sea_backend.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/backend/sea_backend.py b/src/databricks/sql/backend/sea_backend.py index 20288175a..38839c07e 100644 --- a/src/databricks/sql/backend/sea_backend.py +++ b/src/databricks/sql/backend/sea_backend.py @@ -8,7 +8,7 @@ from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.exc import Error, NotSupportedError, ServerOperationError -from databricks.sql.backend.utils.http_client import CustomHttpClient +from databricks.sql.backend.utils.http_client import SeaHttpClient from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions @@ -77,7 +77,7 @@ def __init__( self.warehouse_id = self._extract_warehouse_id(http_path) # Initialize HTTP client - self.http_client = CustomHttpClient( + self.http_client = SeaHttpClient( server_hostname=server_hostname, port=port, http_path=http_path, diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 76fc6d273..ddec74f33 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -15,7 +15,7 @@ class TestSeaBackend: def mock_http_client(self): """Create a mock HTTP client.""" with patch( - "databricks.sql.backend.sea_backend.CustomHttpClient" + "databricks.sql.backend.sea_backend.SeaHttpClient" ) as mock_client_class: mock_client = mock_client_class.return_value yield mock_client From 1c399d57a3eae9eb1a8cf7a51acbb0d82859eee5 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 4 Jun 2025 06:02:30 +0000 Subject: [PATCH 12/24] redundant comment in backend client Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea_backend.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/databricks/sql/backend/sea_backend.py b/src/databricks/sql/backend/sea_backend.py index 38839c07e..5209ac253 100644 --- a/src/databricks/sql/backend/sea_backend.py +++ b/src/databricks/sql/backend/sea_backend.py @@ -24,9 +24,6 @@ class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. - - This implementation provides session management functionality for SEA, - while other operations raise NotImplementedError. """ # SEA API paths From 42c45812655ccc6c23b75b7f273a50720955352b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 4 Jun 2025 06:19:33 +0000 Subject: [PATCH 13/24] regex for warehouse_id instead of .split, remove excess imports and behaviour Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea_backend.py | 80 ++++++++++------------- 1 file changed, 33 insertions(+), 47 deletions(-) diff --git a/src/databricks/sql/backend/sea_backend.py b/src/databricks/sql/backend/sea_backend.py index 5209ac253..f0e7f78e0 100644 --- a/src/databricks/sql/backend/sea_backend.py +++ b/src/databricks/sql/backend/sea_backend.py @@ -1,13 +1,13 @@ import logging -import uuid -from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING +import re +from typing import Dict, Tuple, List, Optional, Union, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import Error, NotSupportedError, ServerOperationError +from databricks.sql.exc import ServerOperationError from databricks.sql.backend.utils.http_client import SeaHttpClient from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions @@ -88,10 +88,6 @@ def _extract_warehouse_id(self, http_path: str) -> str: """ Extract the warehouse ID from the HTTP path. - The warehouse ID is expected to be the last segment of the path when the - second-to-last segment is either 'warehouses' or 'endpoints'. - This matches the JDBC implementation which supports both formats. - Args: http_path: The HTTP path from which to extract the warehouse ID @@ -99,38 +95,28 @@ def _extract_warehouse_id(self, http_path: str) -> str: The extracted warehouse ID Raises: - Error: If the warehouse ID cannot be extracted from the path + ValueError: If the warehouse ID cannot be extracted from the path """ - - path_parts = http_path.strip("/").split("/") - warehouse_id = None - - if len(path_parts) >= 3 and path_parts[-2] in ["warehouses", "endpoints"]: - warehouse_id = path_parts[-1] - logger.debug( - f"Extracted warehouse ID: {warehouse_id} from path: {http_path}" - ) - - if not warehouse_id: - error_message = ( - f"Could not extract warehouse ID from http_path: {http_path}. " - f"Expected format: /path/to/warehouses/{{warehouse_id}} or " - f"/path/to/endpoints/{{warehouse_id}}" - ) - logger.error(error_message) - raise ValueError(error_message) - - return warehouse_id - - @property - def staging_allowed_local_path(self) -> Union[None, str, List[str]]: - """Get the allowed local paths for staging operations.""" - return self._staging_allowed_local_path - - @property - def ssl_options(self) -> SSLOptions: - """Get the SSL options for this client.""" - return self._ssl_options + warehouse_pattern = re.compile(r".*/warehouses/(.+)") + endpoint_pattern = re.compile(r".*/endpoints/(.+)") + + for pattern in [warehouse_pattern, endpoint_pattern]: + match = pattern.match(http_path) + if match: + warehouse_id = match.group(1) + logger.debug( + f"Extracted warehouse ID: {warehouse_id} from path: {http_path}" + ) + return warehouse_id + + # If no match found, raise error + error_message = ( + f"Could not extract warehouse ID from http_path: {http_path}. " + f"Expected format: /path/to/warehouses/{{warehouse_id}} or " + f"/path/to/endpoints/{{warehouse_id}}" + ) + logger.error(error_message) + raise ValueError(error_message) @property def max_download_threads(self) -> int: @@ -236,21 +222,21 @@ def execute_command( enforce_embedded_schema_correctness: bool, ): """Not implemented yet.""" - raise NotSupportedError( + raise NotImplementedError( "execute_command is not yet implemented for SEA backend" ) def cancel_command(self, command_id: CommandId) -> None: """Not implemented yet.""" - raise NotSupportedError("cancel_command is not yet implemented for SEA backend") + raise NotImplementedError("cancel_command is not yet implemented for SEA backend") def close_command(self, command_id: CommandId) -> None: """Not implemented yet.""" - raise NotSupportedError("close_command is not yet implemented for SEA backend") + raise NotImplementedError("close_command is not yet implemented for SEA backend") def get_query_state(self, command_id: CommandId) -> CommandState: """Not implemented yet.""" - raise NotSupportedError( + raise NotImplementedError( "get_query_state is not yet implemented for SEA backend" ) @@ -260,7 +246,7 @@ def get_execution_result( cursor: "Cursor", ): """Not implemented yet.""" - raise NotSupportedError( + raise NotImplementedError( "get_execution_result is not yet implemented for SEA backend" ) @@ -274,7 +260,7 @@ def get_catalogs( cursor: "Cursor", ): """Not implemented yet.""" - raise NotSupportedError("get_catalogs is not yet implemented for SEA backend") + raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") def get_schemas( self, @@ -286,7 +272,7 @@ def get_schemas( schema_name: Optional[str] = None, ): """Not implemented yet.""" - raise NotSupportedError("get_schemas is not yet implemented for SEA backend") + raise NotImplementedError("get_schemas is not yet implemented for SEA backend") def get_tables( self, @@ -300,7 +286,7 @@ def get_tables( table_types: Optional[List[str]] = None, ): """Not implemented yet.""" - raise NotSupportedError("get_tables is not yet implemented for SEA backend") + raise NotImplementedError("get_tables is not yet implemented for SEA backend") def get_columns( self, @@ -314,4 +300,4 @@ def get_columns( column_name: Optional[str] = None, ): """Not implemented yet.""" - raise NotSupportedError("get_columns is not yet implemented for SEA backend") + raise NotImplementedError("get_columns is not yet implemented for SEA backend") From 8bfca4541abf3f6b8c255dc7227ab9bff2aba8c9 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 4 Jun 2025 06:22:25 +0000 Subject: [PATCH 14/24] remove redundant attributes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea_backend.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/databricks/sql/backend/sea_backend.py b/src/databricks/sql/backend/sea_backend.py index f0e7f78e0..f15ea54a8 100644 --- a/src/databricks/sql/backend/sea_backend.py +++ b/src/databricks/sql/backend/sea_backend.py @@ -42,7 +42,6 @@ def __init__( http_headers: List[Tuple[str, str]], auth_provider, ssl_options: SSLOptions, - staging_allowed_local_path: Union[None, str, List[str]] = None, **kwargs, ): """ @@ -55,7 +54,6 @@ def __init__( http_headers: List of HTTP headers to include in requests auth_provider: Authentication provider ssl_options: SSL configuration options - staging_allowed_local_path: Allowed local paths for staging operations **kwargs: Additional keyword arguments """ @@ -66,8 +64,6 @@ def __init__( http_path, ) - self._staging_allowed_local_path = staging_allowed_local_path - self._ssl_options = ssl_options self._max_download_threads = kwargs.get("max_download_threads", 10) # Extract warehouse ID from http_path From 5005b136a5f066d6891e467151550ba3f7c5c738 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 4 Jun 2025 06:25:09 +0000 Subject: [PATCH 15/24] formatting (black) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea_backend.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/sea_backend.py b/src/databricks/sql/backend/sea_backend.py index f15ea54a8..01b331e09 100644 --- a/src/databricks/sql/backend/sea_backend.py +++ b/src/databricks/sql/backend/sea_backend.py @@ -95,7 +95,7 @@ def _extract_warehouse_id(self, http_path: str) -> str: """ warehouse_pattern = re.compile(r".*/warehouses/(.+)") endpoint_pattern = re.compile(r".*/endpoints/(.+)") - + for pattern in [warehouse_pattern, endpoint_pattern]: match = pattern.match(http_path) if match: @@ -104,7 +104,7 @@ def _extract_warehouse_id(self, http_path: str) -> str: f"Extracted warehouse ID: {warehouse_id} from path: {http_path}" ) return warehouse_id - + # If no match found, raise error error_message = ( f"Could not extract warehouse ID from http_path: {http_path}. " @@ -224,11 +224,15 @@ def execute_command( def cancel_command(self, command_id: CommandId) -> None: """Not implemented yet.""" - raise NotImplementedError("cancel_command is not yet implemented for SEA backend") + raise NotImplementedError( + "cancel_command is not yet implemented for SEA backend" + ) def close_command(self, command_id: CommandId) -> None: """Not implemented yet.""" - raise NotImplementedError("close_command is not yet implemented for SEA backend") + raise NotImplementedError( + "close_command is not yet implemented for SEA backend" + ) def get_query_state(self, command_id: CommandId) -> CommandState: """Not implemented yet.""" From 8efa68c50177a9a021cbf00726cd429a570ae3ff Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 4 Jun 2025 06:45:40 +0000 Subject: [PATCH 16/24] [nit] reduce nested code Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea_backend.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/databricks/sql/backend/sea_backend.py b/src/databricks/sql/backend/sea_backend.py index 01b331e09..2440db2cc 100644 --- a/src/databricks/sql/backend/sea_backend.py +++ b/src/databricks/sql/backend/sea_backend.py @@ -98,12 +98,13 @@ def _extract_warehouse_id(self, http_path: str) -> str: for pattern in [warehouse_pattern, endpoint_pattern]: match = pattern.match(http_path) - if match: - warehouse_id = match.group(1) - logger.debug( - f"Extracted warehouse ID: {warehouse_id} from path: {http_path}" - ) - return warehouse_id + if not match: + continue + warehouse_id = match.group(1) + logger.debug( + f"Extracted warehouse ID: {warehouse_id} from path: {http_path}" + ) + return warehouse_id # If no match found, raise error error_message = ( From 6e41ebf76e7fa90a60635644392ed531d2e99914 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 5 Jun 2025 10:04:26 +0530 Subject: [PATCH 17/24] line gap after multi-line pydoc Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea_backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/databricks/sql/backend/sea_backend.py b/src/databricks/sql/backend/sea_backend.py index 2440db2cc..7c0ed0e07 100644 --- a/src/databricks/sql/backend/sea_backend.py +++ b/src/databricks/sql/backend/sea_backend.py @@ -93,6 +93,7 @@ def _extract_warehouse_id(self, http_path: str) -> str: Raises: ValueError: If the warehouse ID cannot be extracted from the path """ + warehouse_pattern = re.compile(r".*/warehouses/(.+)") endpoint_pattern = re.compile(r".*/endpoints/(.+)") From ed4931e1061fce220feacff3cf5068981014e882 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 5 Jun 2025 11:16:39 +0530 Subject: [PATCH 18/24] redundant imports Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea_backend.py | 2 +- src/databricks/sql/session.py | 2 +- tests/unit/test_sea_backend.py | 4 ++-- tests/unit/test_session.py | 7 +------ 4 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/databricks/sql/backend/sea_backend.py b/src/databricks/sql/backend/sea_backend.py index 7c0ed0e07..2e3290be1 100644 --- a/src/databricks/sql/backend/sea_backend.py +++ b/src/databricks/sql/backend/sea_backend.py @@ -1,6 +1,6 @@ import logging import re -from typing import Dict, Tuple, List, Optional, Union, TYPE_CHECKING +from typing import Dict, Tuple, List, Optional, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 6804002bb..7ed387213 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -10,7 +10,7 @@ from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.backend.sea_backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.backend.types import SessionId logger = logging.getLogger(__name__) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index ddec74f33..154d286c9 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,11 +1,11 @@ import pytest -from unittest.mock import patch, MagicMock, Mock +from unittest.mock import patch, MagicMock from databricks.sql.backend.sea_backend import SeaDatabricksClient from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error, NotSupportedError +from databricks.sql.exc import Error class TestSeaBackend: diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 7db4b1338..92de8d8fd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1,12 +1,7 @@ import unittest -from unittest.mock import patch, MagicMock, Mock, PropertyMock +from unittest.mock import patch, Mock import gc -from databricks.sql.thrift_api.TCLIService.ttypes import ( - TOpenSessionResp, - TSessionHandle, - THandleIdentifier, -) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql From 4ff64edbb5ff1ac9639511064d25b4032b7c811b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Sat, 7 Jun 2025 16:41:49 +0000 Subject: [PATCH 19/24] move sea backend and models into separate sea/ dir Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/{sea_backend.py => sea/backend.py} | 2 +- src/databricks/sql/backend/{ => sea}/models/__init__.py | 4 ++-- src/databricks/sql/backend/{ => sea}/models/requests.py | 0 src/databricks/sql/backend/{ => sea}/models/responses.py | 0 src/databricks/sql/session.py | 2 +- tests/unit/test_sea_backend.py | 4 ++-- 6 files changed, 6 insertions(+), 6 deletions(-) rename src/databricks/sql/backend/{sea_backend.py => sea/backend.py} (99%) rename src/databricks/sql/backend/{ => sea}/models/__init__.py (75%) rename src/databricks/sql/backend/{ => sea}/models/requests.py (100%) rename src/databricks/sql/backend/{ => sea}/models/responses.py (100%) diff --git a/src/databricks/sql/backend/sea_backend.py b/src/databricks/sql/backend/sea/backend.py similarity index 99% rename from src/databricks/sql/backend/sea_backend.py rename to src/databricks/sql/backend/sea/backend.py index 2e3290be1..d2d5aa3b0 100644 --- a/src/databricks/sql/backend/sea_backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -12,7 +12,7 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions -from databricks.sql.backend.models import ( +from databricks.sql.backend.sea.models import ( CreateSessionRequest, DeleteSessionRequest, CreateSessionResponse, diff --git a/src/databricks/sql/backend/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py similarity index 75% rename from src/databricks/sql/backend/models/__init__.py rename to src/databricks/sql/backend/sea/models/__init__.py index 667235cce..c9310d367 100644 --- a/src/databricks/sql/backend/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,12 +4,12 @@ This package contains data models for SEA API requests and responses. """ -from databricks.sql.backend.models.requests import ( +from databricks.sql.backend.sea.models.requests import ( CreateSessionRequest, DeleteSessionRequest, ) -from databricks.sql.backend.models.responses import ( +from databricks.sql.backend.sea.models.responses import ( CreateSessionResponse, ) diff --git a/src/databricks/sql/backend/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py similarity index 100% rename from src/databricks/sql/backend/models/requests.py rename to src/databricks/sql/backend/sea/models/requests.py diff --git a/src/databricks/sql/backend/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py similarity index 100% rename from src/databricks/sql/backend/models/responses.py rename to src/databricks/sql/backend/sea/models/responses.py diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 7ed387213..7c33d9b2d 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -8,7 +8,7 @@ from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME from databricks.sql.backend.thrift_backend import ThriftDatabricksClient -from databricks.sql.backend.sea_backend import SeaDatabricksClient +from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 154d286c9..ed22bd3af 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import patch, MagicMock -from databricks.sql.backend.sea_backend import SeaDatabricksClient +from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider @@ -15,7 +15,7 @@ class TestSeaBackend: def mock_http_client(self): """Create a mock HTTP client.""" with patch( - "databricks.sql.backend.sea_backend.SeaHttpClient" + "databricks.sql.backend.sea.backend.SeaHttpClient" ) as mock_client_class: mock_client = mock_client_class.return_value yield mock_client From 9aebea29d4c6da5360a24ae420c390e5f665f454 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Sat, 7 Jun 2025 16:52:55 +0000 Subject: [PATCH 20/24] move http client into separate sea/ dir Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 2 +- src/databricks/sql/backend/{ => sea}/utils/http_client.py | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename src/databricks/sql/backend/{ => sea}/utils/http_client.py (100%) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index d2d5aa3b0..ca4a91499 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -8,7 +8,7 @@ from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.exc import ServerOperationError -from databricks.sql.backend.utils.http_client import SeaHttpClient +from databricks.sql.backend.sea.utils.http_client import SeaHttpClient from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions diff --git a/src/databricks/sql/backend/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py similarity index 100% rename from src/databricks/sql/backend/utils/http_client.py rename to src/databricks/sql/backend/sea/utils/http_client.py From a05f1fd7205e35729a1b97da8add9a3444914a5f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Sat, 7 Jun 2025 17:35:35 +0000 Subject: [PATCH 21/24] change commands to include ones in docs Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index ed22bd3af..97effe2c2 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -102,7 +102,7 @@ def test_open_session_with_all_parameters(self, sea_client, mock_http_client): mock_http_client._make_request.return_value = {"session_id": "test-session-456"} # Call the method with all parameters - session_config = {"spark.sql.shuffle.partitions": "10"} + session_config = {"ANSI_MODE": "FALSE", "STATEMENT_TIMEOUT": "3600"} catalog = "test_catalog" schema = "test_schema" From 46104e2f3d5adad82a121793a0e4beb9be4ab422 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Sat, 7 Jun 2025 17:49:15 +0000 Subject: [PATCH 22/24] add link to sql-ref-parameters for session-confs Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index ca4a91499..40e767080 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -131,7 +131,9 @@ def open_session( Opens a new session with the Databricks SQL service using SEA. Args: - session_configuration: Optional dictionary of configuration parameters for the session + session_configuration: Optional dictionary of configuration parameters for the session. + Only specific parameters are supported as documented at: + https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-parameters catalog: Optional catalog name to use as the initial catalog for the session schema: Optional schema name to use as the initial schema for the session From 390c1e76f1a8cda0d940f800aba56417e5152cd7 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Sun, 8 Jun 2025 16:03:26 +0000 Subject: [PATCH 23/24] add client side filtering for session confs, add note on warehouses over endoints Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 61 ++++++++++++++++++- .../sql/backend/sea/utils/constants.py | 17 ++++++ tests/unit/test_sea_backend.py | 45 ++++++++++++-- 3 files changed, 117 insertions(+), 6 deletions(-) create mode 100644 src/databricks/sql/backend/sea/utils/constants.py diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 40e767080..97d25a058 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,6 +1,6 @@ import logging import re -from typing import Dict, Tuple, List, Optional, TYPE_CHECKING +from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set if TYPE_CHECKING: from databricks.sql.client import Cursor @@ -9,6 +9,9 @@ from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.exc import ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions @@ -21,6 +24,34 @@ logger = logging.getLogger(__name__) +def _filter_session_configuration( + session_configuration: Optional[Dict[str, str]] +) -> Optional[Dict[str, str]]: + if not session_configuration: + return None + + filtered_session_configuration = {} + ignored_configs: Set[str] = set() + + for key, value in session_configuration.items(): + if key.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP: + filtered_session_configuration[key.lower()] = value + else: + ignored_configs.add(key) + + if ignored_configs: + logger.warning( + "Some session configurations were ignored because they are not supported: %s", + ignored_configs, + ) + logger.warning( + "Supported session configurations are: %s", + list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()), + ) + + return filtered_session_configuration + + class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. @@ -111,7 +142,8 @@ def _extract_warehouse_id(self, http_path: str) -> str: error_message = ( f"Could not extract warehouse ID from http_path: {http_path}. " f"Expected format: /path/to/warehouses/{{warehouse_id}} or " - f"/path/to/endpoints/{{warehouse_id}}" + f"/path/to/endpoints/{{warehouse_id}}." + f"Note: SEA only works for warehouses." ) logger.error(error_message) raise ValueError(error_message) @@ -152,6 +184,8 @@ def open_session( schema, ) + session_configuration = _filter_session_configuration(session_configuration) + request_data = CreateSessionRequest( warehouse_id=self.warehouse_id, session_confs=session_configuration, @@ -205,6 +239,29 @@ def close_session(self, session_id: SessionId) -> None: data=request_data.to_dict(), ) + @staticmethod + def get_default_session_configuration_value(name: str) -> Optional[str]: + """ + Get the default value for a session configuration parameter. + + Args: + name: The name of the session configuration parameter + + Returns: + The default value if the parameter is supported, None otherwise + """ + return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper()) + + @staticmethod + def get_allowed_session_configurations() -> List[str]: + """ + Get the list of allowed session configuration parameters. + + Returns: + List of allowed session configuration parameter names + """ + return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()) + # == Not Implemented Operations == # These methods will be implemented in future iterations diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py new file mode 100644 index 000000000..9160ef6ad --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -0,0 +1,17 @@ +""" +Constants for the Statement Execution API (SEA) backend. +""" + +from typing import Dict + +# from https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-parameters +ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP: Dict[str, str] = { + "ANSI_MODE": "true", + "ENABLE_PHOTON": "true", + "LEGACY_TIME_PARSER_POLICY": "Exception", + "MAX_FILE_PARTITION_BYTES": "128m", + "READ_ONLY_EXTERNAL_METASTORE": "false", + "STATEMENT_TIMEOUT": "0", + "TIMEZONE": "UTC", + "USE_CACHED_RESULT": "true", +} diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 97effe2c2..c57cd6aae 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -101,8 +101,12 @@ def test_open_session_with_all_parameters(self, sea_client, mock_http_client): # Set up mock response mock_http_client._make_request.return_value = {"session_id": "test-session-456"} - # Call the method with all parameters - session_config = {"ANSI_MODE": "FALSE", "STATEMENT_TIMEOUT": "3600"} + # Call the method with all parameters, including both supported and unsupported configurations + session_config = { + "ANSI_MODE": "FALSE", # Supported parameter + "STATEMENT_TIMEOUT": "3600", # Supported parameter + "unsupported_param": "value", # Unsupported parameter + } catalog = "test_catalog" schema = "test_schema" @@ -113,10 +117,14 @@ def test_open_session_with_all_parameters(self, sea_client, mock_http_client): assert session_id.backend_type == BackendType.SEA assert session_id.guid == "test-session-456" - # Verify the HTTP request + # Verify the HTTP request - only supported parameters should be included + # and keys should be in lowercase expected_data = { "warehouse_id": "abc123", - "session_confs": session_config, + "session_confs": { + "ansi_mode": "FALSE", + "statement_timeout": "3600", + }, "catalog": catalog, "schema": schema, } @@ -166,3 +174,32 @@ def test_close_session_invalid_id_type(self, sea_client): sea_client.close_session(session_id) assert "Not a valid SEA session ID" in str(excinfo.value) + + def test_session_configuration_helpers(self): + """Test the session configuration helper methods.""" + # Test getting default value for a supported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "ANSI_MODE" + ) + assert default_value == "true" + + # Test getting default value for an unsupported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "UNSUPPORTED_PARAM" + ) + assert default_value is None + + # Test getting the list of allowed configurations + allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() + + expected_keys = { + "ANSI_MODE", + "ENABLE_PHOTON", + "LEGACY_TIME_PARSER_POLICY", + "MAX_FILE_PARTITION_BYTES", + "READ_ONLY_EXTERNAL_METASTORE", + "STATEMENT_TIMEOUT", + "TIMEZONE", + "USE_CACHED_RESULT", + } + assert set(allowed_configs) == expected_keys From 86ee56fda021c5784a1e025ce86d60ad92bbf375 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Sun, 8 Jun 2025 16:20:03 +0000 Subject: [PATCH 24/24] test unimplemented methods and max_download_threads prop Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 78 ++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index c57cd6aae..bc2688a68 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -203,3 +203,81 @@ def test_session_configuration_helpers(self): "USE_CACHED_RESULT", } assert set(allowed_configs) == expected_keys + + def test_unimplemented_methods(self, sea_client): + """Test that unimplemented methods raise NotImplementedError.""" + # Create dummy parameters for testing + session_id = SessionId.from_sea_session_id("test-session") + command_id = MagicMock() + cursor = MagicMock() + + # Test execute_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.execute_command( + operation="SELECT 1", + session_id=session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert "execute_command is not yet implemented" in str(excinfo.value) + + # Test cancel_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.cancel_command(command_id) + assert "cancel_command is not yet implemented" in str(excinfo.value) + + # Test close_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.close_command(command_id) + assert "close_command is not yet implemented" in str(excinfo.value) + + # Test get_query_state + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_query_state(command_id) + assert "get_query_state is not yet implemented" in str(excinfo.value) + + # Test get_execution_result + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_execution_result(command_id, cursor) + assert "get_execution_result is not yet implemented" in str(excinfo.value) + + # Test metadata operations + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_catalogs(session_id, 100, 1000, cursor) + assert "get_catalogs is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_schemas(session_id, 100, 1000, cursor) + assert "get_schemas is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_tables(session_id, 100, 1000, cursor) + assert "get_tables is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_columns(session_id, 100, 1000, cursor) + assert "get_columns is not yet implemented" in str(excinfo.value) + + def test_max_download_threads_property(self, sea_client): + """Test the max_download_threads property.""" + assert sea_client.max_download_threads == 10 + + # Create a client with a custom value + custom_client = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=20, + ) + + # Verify the custom value is returned + assert custom_client.max_download_threads == 20