diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py new file mode 100644 index 000000000..97d25a058 --- /dev/null +++ b/src/databricks/sql/backend/sea/backend.py @@ -0,0 +1,364 @@ +import logging +import re +from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set + +if TYPE_CHECKING: + from databricks.sql.client import Cursor + +from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.exc import ServerOperationError +from databricks.sql.backend.sea.utils.http_client import SeaHttpClient +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.types import SSLOptions + +from databricks.sql.backend.sea.models import ( + CreateSessionRequest, + DeleteSessionRequest, + CreateSessionResponse, +) + +logger = logging.getLogger(__name__) + + +def _filter_session_configuration( + session_configuration: Optional[Dict[str, str]] +) -> Optional[Dict[str, str]]: + if not session_configuration: + return None + + filtered_session_configuration = {} + ignored_configs: Set[str] = set() + + for key, value in session_configuration.items(): + if key.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP: + filtered_session_configuration[key.lower()] = value + else: + ignored_configs.add(key) + + if ignored_configs: + logger.warning( + "Some session configurations were ignored because they are not supported: %s", + ignored_configs, + ) + logger.warning( + "Supported session configurations are: %s", + list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()), + ) + + return filtered_session_configuration + + +class SeaDatabricksClient(DatabricksClient): + """ + Statement Execution API (SEA) implementation of the DatabricksClient interface. + """ + + # SEA API paths + BASE_PATH = "/api/2.0/sql/" + SESSION_PATH = BASE_PATH + "sessions" + SESSION_PATH_WITH_ID = SESSION_PATH + "/{}" + STATEMENT_PATH = BASE_PATH + "statements" + STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}" + CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel" + + def __init__( + self, + server_hostname: str, + port: int, + http_path: str, + http_headers: List[Tuple[str, str]], + auth_provider, + ssl_options: SSLOptions, + **kwargs, + ): + """ + Initialize the SEA backend client. + + Args: + server_hostname: Hostname of the Databricks server + port: Port number for the connection + http_path: HTTP path for the connection + http_headers: List of HTTP headers to include in requests + auth_provider: Authentication provider + ssl_options: SSL configuration options + **kwargs: Additional keyword arguments + """ + + logger.debug( + "SeaDatabricksClient.__init__(server_hostname=%s, port=%s, http_path=%s)", + server_hostname, + port, + http_path, + ) + + self._max_download_threads = kwargs.get("max_download_threads", 10) + + # Extract warehouse ID from http_path + self.warehouse_id = self._extract_warehouse_id(http_path) + + # Initialize HTTP client + self.http_client = SeaHttpClient( + server_hostname=server_hostname, + port=port, + http_path=http_path, + http_headers=http_headers, + auth_provider=auth_provider, + ssl_options=ssl_options, + **kwargs, + ) + + def _extract_warehouse_id(self, http_path: str) -> str: + """ + Extract the warehouse ID from the HTTP path. + + Args: + http_path: The HTTP path from which to extract the warehouse ID + + Returns: + The extracted warehouse ID + + Raises: + ValueError: If the warehouse ID cannot be extracted from the path + """ + + warehouse_pattern = re.compile(r".*/warehouses/(.+)") + endpoint_pattern = re.compile(r".*/endpoints/(.+)") + + for pattern in [warehouse_pattern, endpoint_pattern]: + match = pattern.match(http_path) + if not match: + continue + warehouse_id = match.group(1) + logger.debug( + f"Extracted warehouse ID: {warehouse_id} from path: {http_path}" + ) + return warehouse_id + + # If no match found, raise error + error_message = ( + f"Could not extract warehouse ID from http_path: {http_path}. " + f"Expected format: /path/to/warehouses/{{warehouse_id}} or " + f"/path/to/endpoints/{{warehouse_id}}." + f"Note: SEA only works for warehouses." + ) + logger.error(error_message) + raise ValueError(error_message) + + @property + def max_download_threads(self) -> int: + """Get the maximum number of download threads for cloud fetch operations.""" + return self._max_download_threads + + def open_session( + self, + session_configuration: Optional[Dict[str, str]], + catalog: Optional[str], + schema: Optional[str], + ) -> SessionId: + """ + Opens a new session with the Databricks SQL service using SEA. + + Args: + session_configuration: Optional dictionary of configuration parameters for the session. + Only specific parameters are supported as documented at: + https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-parameters + catalog: Optional catalog name to use as the initial catalog for the session + schema: Optional schema name to use as the initial schema for the session + + Returns: + SessionId: A session identifier object that can be used for subsequent operations + + Raises: + Error: If the session configuration is invalid + OperationalError: If there's an error establishing the session + """ + + logger.debug( + "SeaDatabricksClient.open_session(session_configuration=%s, catalog=%s, schema=%s)", + session_configuration, + catalog, + schema, + ) + + session_configuration = _filter_session_configuration(session_configuration) + + request_data = CreateSessionRequest( + warehouse_id=self.warehouse_id, + session_confs=session_configuration, + catalog=catalog, + schema=schema, + ) + + response = self.http_client._make_request( + method="POST", path=self.SESSION_PATH, data=request_data.to_dict() + ) + + session_response = CreateSessionResponse.from_dict(response) + session_id = session_response.session_id + if not session_id: + raise ServerOperationError( + "Failed to create session: No session ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) + + return SessionId.from_sea_session_id(session_id) + + def close_session(self, session_id: SessionId) -> None: + """ + Closes an existing session with the Databricks SQL service. + + Args: + session_id: The session identifier returned by open_session() + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error closing the session + """ + + logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id) + + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + sea_session_id = session_id.to_sea_session_id() + + request_data = DeleteSessionRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + ) + + self.http_client._make_request( + method="DELETE", + path=self.SESSION_PATH_WITH_ID.format(sea_session_id), + data=request_data.to_dict(), + ) + + @staticmethod + def get_default_session_configuration_value(name: str) -> Optional[str]: + """ + Get the default value for a session configuration parameter. + + Args: + name: The name of the session configuration parameter + + Returns: + The default value if the parameter is supported, None otherwise + """ + return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper()) + + @staticmethod + def get_allowed_session_configurations() -> List[str]: + """ + Get the list of allowed session configuration parameters. + + Returns: + List of allowed session configuration parameter names + """ + return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()) + + # == Not Implemented Operations == + # These methods will be implemented in future iterations + + def execute_command( + self, + operation: str, + session_id: SessionId, + max_rows: int, + max_bytes: int, + lz4_compression: bool, + cursor: "Cursor", + use_cloud_fetch: bool, + parameters: List[ttypes.TSparkParameter], + async_op: bool, + enforce_embedded_schema_correctness: bool, + ): + """Not implemented yet.""" + raise NotImplementedError( + "execute_command is not yet implemented for SEA backend" + ) + + def cancel_command(self, command_id: CommandId) -> None: + """Not implemented yet.""" + raise NotImplementedError( + "cancel_command is not yet implemented for SEA backend" + ) + + def close_command(self, command_id: CommandId) -> None: + """Not implemented yet.""" + raise NotImplementedError( + "close_command is not yet implemented for SEA backend" + ) + + def get_query_state(self, command_id: CommandId) -> CommandState: + """Not implemented yet.""" + raise NotImplementedError( + "get_query_state is not yet implemented for SEA backend" + ) + + def get_execution_result( + self, + command_id: CommandId, + cursor: "Cursor", + ): + """Not implemented yet.""" + raise NotImplementedError( + "get_execution_result is not yet implemented for SEA backend" + ) + + # == Metadata Operations == + + def get_catalogs( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + ): + """Not implemented yet.""" + raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") + + def get_schemas( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + ): + """Not implemented yet.""" + raise NotImplementedError("get_schemas is not yet implemented for SEA backend") + + def get_tables( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + table_types: Optional[List[str]] = None, + ): + """Not implemented yet.""" + raise NotImplementedError("get_tables is not yet implemented for SEA backend") + + def get_columns( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + column_name: Optional[str] = None, + ): + """Not implemented yet.""" + raise NotImplementedError("get_columns is not yet implemented for SEA backend") diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py new file mode 100644 index 000000000..c9310d367 --- /dev/null +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -0,0 +1,22 @@ +""" +Models for the SEA (Statement Execution API) backend. + +This package contains data models for SEA API requests and responses. +""" + +from databricks.sql.backend.sea.models.requests import ( + CreateSessionRequest, + DeleteSessionRequest, +) + +from databricks.sql.backend.sea.models.responses import ( + CreateSessionResponse, +) + +__all__ = [ + # Request models + "CreateSessionRequest", + "DeleteSessionRequest", + # Response models + "CreateSessionResponse", +] diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py new file mode 100644 index 000000000..7966cb502 --- /dev/null +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -0,0 +1,39 @@ +from typing import Dict, Any, Optional +from dataclasses import dataclass + + +@dataclass +class CreateSessionRequest: + """Request to create a new session.""" + + warehouse_id: str + session_confs: Optional[Dict[str, str]] = None + catalog: Optional[str] = None + schema: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + result: Dict[str, Any] = {"warehouse_id": self.warehouse_id} + + if self.session_confs: + result["session_confs"] = self.session_confs + + if self.catalog: + result["catalog"] = self.catalog + + if self.schema: + result["schema"] = self.schema + + return result + + +@dataclass +class DeleteSessionRequest: + """Request to delete a session.""" + + warehouse_id: str + session_id: str + + def to_dict(self) -> Dict[str, str]: + """Convert the request to a dictionary for JSON serialization.""" + return {"warehouse_id": self.warehouse_id, "session_id": self.session_id} diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py new file mode 100644 index 000000000..1bb54590f --- /dev/null +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -0,0 +1,14 @@ +from typing import Dict, Any +from dataclasses import dataclass + + +@dataclass +class CreateSessionResponse: + """Response from creating a new session.""" + + session_id: str + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse": + """Create a CreateSessionResponse from a dictionary.""" + return cls(session_id=data.get("session_id", "")) diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py new file mode 100644 index 000000000..9160ef6ad --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -0,0 +1,17 @@ +""" +Constants for the Statement Execution API (SEA) backend. +""" + +from typing import Dict + +# from https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-parameters +ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP: Dict[str, str] = { + "ANSI_MODE": "true", + "ENABLE_PHOTON": "true", + "LEGACY_TIME_PARSER_POLICY": "Exception", + "MAX_FILE_PARTITION_BYTES": "128m", + "READ_ONLY_EXTERNAL_METASTORE": "false", + "STATEMENT_TIMEOUT": "0", + "TIMEZONE": "UTC", + "USE_CACHED_RESULT": "true", +} diff --git a/src/databricks/sql/backend/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py similarity index 100% rename from src/databricks/sql/backend/utils/http_client.py rename to src/databricks/sql/backend/sea/utils/http_client.py diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 6d69b5487..7c33d9b2d 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Tuple, List, Optional, Any +from typing import Dict, Tuple, List, Optional, Any, Type from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions @@ -8,8 +8,9 @@ from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.backend.types import SessionId logger = logging.getLogger(__name__) @@ -61,6 +62,7 @@ def __init__( useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) base_headers = [("User-Agent", useragent_header)] + all_headers = (http_headers or []) + base_headers self._ssl_options = SSLOptions( # Double negation is generally a bad thing, but we have to keep backward compatibility @@ -74,19 +76,49 @@ def __init__( tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) - self.backend: DatabricksClient = ThriftDatabricksClient( - self.host, - self.port, + self.backend = self._create_backend( + server_hostname, http_path, - (http_headers or []) + base_headers, + all_headers, auth_provider, - ssl_options=self._ssl_options, - _use_arrow_native_complex_types=_use_arrow_native_complex_types, - **kwargs, + _use_arrow_native_complex_types, + kwargs, ) self.protocol_version = None + def _create_backend( + self, + server_hostname: str, + http_path: str, + all_headers: List[Tuple[str, str]], + auth_provider, + _use_arrow_native_complex_types: Optional[bool], + kwargs: dict, + ) -> DatabricksClient: + """Create and return the appropriate backend client.""" + use_sea = kwargs.get("use_sea", False) + + databricks_client_class: Type[DatabricksClient] + if use_sea: + logger.debug("Creating SEA backend client") + databricks_client_class = SeaDatabricksClient + else: + logger.debug("Creating Thrift backend client") + databricks_client_class = ThriftDatabricksClient + + common_args = { + "server_hostname": server_hostname, + "port": self.port, + "http_path": http_path, + "http_headers": all_headers, + "auth_provider": auth_provider, + "ssl_options": self._ssl_options, + "_use_arrow_native_complex_types": _use_arrow_native_complex_types, + **kwargs, + } + return databricks_client_class(**common_args) + def open(self): self._session_id = self.backend.open_session( session_configuration=self.session_configuration, diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py new file mode 100644 index 000000000..bc2688a68 --- /dev/null +++ b/tests/unit/test_sea_backend.py @@ -0,0 +1,283 @@ +import pytest +from unittest.mock import patch, MagicMock + +from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.types import SSLOptions +from databricks.sql.auth.authenticators import AuthProvider +from databricks.sql.exc import Error + + +class TestSeaBackend: + """Test suite for the SeaDatabricksClient class.""" + + @pytest.fixture + def mock_http_client(self): + """Create a mock HTTP client.""" + with patch( + "databricks.sql.backend.sea.backend.SeaHttpClient" + ) as mock_client_class: + mock_client = mock_client_class.return_value + yield mock_client + + @pytest.fixture + def sea_client(self, mock_http_client): + """Create a SeaDatabricksClient instance with mocked dependencies.""" + server_hostname = "test-server.databricks.com" + port = 443 + http_path = "/sql/warehouses/abc123" + http_headers = [("header1", "value1"), ("header2", "value2")] + auth_provider = AuthProvider() + ssl_options = SSLOptions() + + client = SeaDatabricksClient( + server_hostname=server_hostname, + port=port, + http_path=http_path, + http_headers=http_headers, + auth_provider=auth_provider, + ssl_options=ssl_options, + ) + + return client + + def test_init_extracts_warehouse_id(self, mock_http_client): + """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" + # Test with warehouses format + client1 = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client1.warehouse_id == "abc123" + + # Test with endpoints format + client2 = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/endpoints/def456", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client2.warehouse_id == "def456" + + def test_init_raises_error_for_invalid_http_path(self, mock_http_client): + """Test that the constructor raises an error for invalid HTTP paths.""" + with pytest.raises(ValueError) as excinfo: + SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/invalid/path", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert "Could not extract warehouse ID" in str(excinfo.value) + + def test_open_session_basic(self, sea_client, mock_http_client): + """Test the open_session method with minimal parameters.""" + # Set up mock response + mock_http_client._make_request.return_value = {"session_id": "test-session-123"} + + # Call the method + session_id = sea_client.open_session(None, None, None) + + # Verify the result + assert isinstance(session_id, SessionId) + assert session_id.backend_type == BackendType.SEA + assert session_id.guid == "test-session-123" + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once_with( + method="POST", path=sea_client.SESSION_PATH, data={"warehouse_id": "abc123"} + ) + + def test_open_session_with_all_parameters(self, sea_client, mock_http_client): + """Test the open_session method with all parameters.""" + # Set up mock response + mock_http_client._make_request.return_value = {"session_id": "test-session-456"} + + # Call the method with all parameters, including both supported and unsupported configurations + session_config = { + "ANSI_MODE": "FALSE", # Supported parameter + "STATEMENT_TIMEOUT": "3600", # Supported parameter + "unsupported_param": "value", # Unsupported parameter + } + catalog = "test_catalog" + schema = "test_schema" + + session_id = sea_client.open_session(session_config, catalog, schema) + + # Verify the result + assert isinstance(session_id, SessionId) + assert session_id.backend_type == BackendType.SEA + assert session_id.guid == "test-session-456" + + # Verify the HTTP request - only supported parameters should be included + # and keys should be in lowercase + expected_data = { + "warehouse_id": "abc123", + "session_confs": { + "ansi_mode": "FALSE", + "statement_timeout": "3600", + }, + "catalog": catalog, + "schema": schema, + } + mock_http_client._make_request.assert_called_once_with( + method="POST", path=sea_client.SESSION_PATH, data=expected_data + ) + + def test_open_session_error_handling(self, sea_client, mock_http_client): + """Test error handling in the open_session method.""" + # Set up mock response without session_id + mock_http_client._make_request.return_value = {} + + # Call the method and expect an error + with pytest.raises(Error) as excinfo: + sea_client.open_session(None, None, None) + + assert "Failed to create session" in str(excinfo.value) + + def test_close_session_valid_id(self, sea_client, mock_http_client): + """Test closing a session with a valid session ID.""" + # Create a valid SEA session ID + session_id = SessionId.from_sea_session_id("test-session-789") + + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.close_session(session_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once_with( + method="DELETE", + path=sea_client.SESSION_PATH_WITH_ID.format("test-session-789"), + data={"session_id": "test-session-789", "warehouse_id": "abc123"}, + ) + + def test_close_session_invalid_id_type(self, sea_client): + """Test closing a session with an invalid session ID type.""" + # Create a Thrift session ID (not SEA) + mock_thrift_handle = MagicMock() + mock_thrift_handle.sessionId.guid = b"guid" + mock_thrift_handle.sessionId.secret = b"secret" + session_id = SessionId.from_thrift_handle(mock_thrift_handle) + + # Call the method and expect an error + with pytest.raises(ValueError) as excinfo: + sea_client.close_session(session_id) + + assert "Not a valid SEA session ID" in str(excinfo.value) + + def test_session_configuration_helpers(self): + """Test the session configuration helper methods.""" + # Test getting default value for a supported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "ANSI_MODE" + ) + assert default_value == "true" + + # Test getting default value for an unsupported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "UNSUPPORTED_PARAM" + ) + assert default_value is None + + # Test getting the list of allowed configurations + allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() + + expected_keys = { + "ANSI_MODE", + "ENABLE_PHOTON", + "LEGACY_TIME_PARSER_POLICY", + "MAX_FILE_PARTITION_BYTES", + "READ_ONLY_EXTERNAL_METASTORE", + "STATEMENT_TIMEOUT", + "TIMEZONE", + "USE_CACHED_RESULT", + } + assert set(allowed_configs) == expected_keys + + def test_unimplemented_methods(self, sea_client): + """Test that unimplemented methods raise NotImplementedError.""" + # Create dummy parameters for testing + session_id = SessionId.from_sea_session_id("test-session") + command_id = MagicMock() + cursor = MagicMock() + + # Test execute_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.execute_command( + operation="SELECT 1", + session_id=session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert "execute_command is not yet implemented" in str(excinfo.value) + + # Test cancel_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.cancel_command(command_id) + assert "cancel_command is not yet implemented" in str(excinfo.value) + + # Test close_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.close_command(command_id) + assert "close_command is not yet implemented" in str(excinfo.value) + + # Test get_query_state + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_query_state(command_id) + assert "get_query_state is not yet implemented" in str(excinfo.value) + + # Test get_execution_result + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_execution_result(command_id, cursor) + assert "get_execution_result is not yet implemented" in str(excinfo.value) + + # Test metadata operations + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_catalogs(session_id, 100, 1000, cursor) + assert "get_catalogs is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_schemas(session_id, 100, 1000, cursor) + assert "get_schemas is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_tables(session_id, 100, 1000, cursor) + assert "get_tables is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_columns(session_id, 100, 1000, cursor) + assert "get_columns is not yet implemented" in str(excinfo.value) + + def test_max_download_threads_property(self, sea_client): + """Test the max_download_threads property.""" + assert sea_client.max_download_threads == 10 + + # Create a client with a custom value + custom_client = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=20, + ) + + # Verify the custom value is returned + assert custom_client.max_download_threads == 20 diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 858119f92..92de8d8fd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1,12 +1,7 @@ import unittest -from unittest.mock import patch, MagicMock, Mock, PropertyMock +from unittest.mock import patch, Mock import gc -from databricks.sql.thrift_api.TCLIService.ttypes import ( - TOpenSessionResp, - TSessionHandle, - THandleIdentifier, -) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql @@ -62,9 +57,9 @@ def test_auth_args(self, mock_client_class): for args in connection_args: connection = databricks.sql.connect(**args) - host, port, http_path, *_ = mock_client_class.call_args[0] - self.assertEqual(args["server_hostname"], host) - self.assertEqual(args["http_path"], http_path) + call_kwargs = mock_client_class.call_args[1] + self.assertEqual(args["server_hostname"], call_kwargs["server_hostname"]) + self.assertEqual(args["http_path"], call_kwargs["http_path"]) connection.close() @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) @@ -72,8 +67,8 @@ def test_http_header_passthrough(self, mock_client_class): http_headers = [("foo", "bar")] databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) - call_args = mock_client_class.call_args[0][3] - self.assertIn(("foo", "bar"), call_args) + call_kwargs = mock_client_class.call_args[1] + self.assertIn(("foo", "bar"), call_kwargs["http_headers"]) @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_tls_arg_passthrough(self, mock_client_class): @@ -95,7 +90,8 @@ def test_tls_arg_passthrough(self, mock_client_class): def test_useragent_header(self, mock_client_class): databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - http_headers = mock_client_class.call_args[0][3] + call_kwargs = mock_client_class.call_args[1] + http_headers = call_kwargs["http_headers"] user_agent_header = ( "User-Agent", "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), @@ -109,7 +105,8 @@ def test_useragent_header(self, mock_client_class): databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar" ), ) - http_headers = mock_client_class.call_args[0][3] + call_kwargs = mock_client_class.call_args[1] + http_headers = call_kwargs["http_headers"] self.assertIn(user_agent_header_with_entry, http_headers) @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)