Skip to content

Introduce SeaDatabricksClient (Session Implementation) #582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jun 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f3e4a97
[squashed from prev branch] introduce sea client with session open an…
varun-edachali-dbx Jun 3, 2025
3df5752
remove accidental changes to workflows (merge artifacts)
varun-edachali-dbx Jun 3, 2025
9146a94
pass test_input to get_protocol_version instead of session_id to main…
varun-edachali-dbx Jun 3, 2025
9b39e37
formatting (black + line gaps after multi-line pydocs)
varun-edachali-dbx Jun 3, 2025
1ccbcd2
use factory for backend instantiation
varun-edachali-dbx Jun 3, 2025
3528523
fix type issues
varun-edachali-dbx Jun 3, 2025
b39e83b
remove redundant comments
varun-edachali-dbx Jun 3, 2025
ba36126
introduce models for requests and responses
varun-edachali-dbx Jun 3, 2025
059cd4d
remove http client and test script
varun-edachali-dbx Jun 4, 2025
6830327
Introduce Sea HTTP Client and test script (#583)
varun-edachali-dbx Jun 4, 2025
ab847da
CustomHttpClient -> SeaHttpClient
varun-edachali-dbx Jun 4, 2025
1c399d5
redundant comment in backend client
varun-edachali-dbx Jun 4, 2025
42c4581
regex for warehouse_id instead of .split, remove excess imports and b…
varun-edachali-dbx Jun 4, 2025
8bfca45
remove redundant attributes
varun-edachali-dbx Jun 4, 2025
5005b13
formatting (black)
varun-edachali-dbx Jun 4, 2025
8efa68c
[nit] reduce nested code
varun-edachali-dbx Jun 4, 2025
6e41ebf
line gap after multi-line pydoc
varun-edachali-dbx Jun 5, 2025
638e1df
Merge branch 'sea-migration' into sessions-sea
varun-edachali-dbx Jun 5, 2025
ed4931e
redundant imports
varun-edachali-dbx Jun 5, 2025
4ff64ed
move sea backend and models into separate sea/ dir
varun-edachali-dbx Jun 7, 2025
9aebea2
move http client into separate sea/ dir
varun-edachali-dbx Jun 7, 2025
a05f1fd
change commands to include ones in docs
varun-edachali-dbx Jun 7, 2025
46104e2
add link to sql-ref-parameters for session-confs
varun-edachali-dbx Jun 7, 2025
390c1e7
add client side filtering for session confs, add note on warehouses o…
varun-edachali-dbx Jun 8, 2025
86ee56f
test unimplemented methods and max_download_threads prop
varun-edachali-dbx Jun 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
364 changes: 364 additions & 0 deletions src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,364 @@
import logging
import re
from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set

if TYPE_CHECKING:
from databricks.sql.client import Cursor

from databricks.sql.backend.databricks_client import DatabricksClient
from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType
from databricks.sql.exc import ServerOperationError
from databricks.sql.backend.sea.utils.http_client import SeaHttpClient
from databricks.sql.backend.sea.utils.constants import (
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP,
)
from databricks.sql.thrift_api.TCLIService import ttypes
from databricks.sql.types import SSLOptions

from databricks.sql.backend.sea.models import (
CreateSessionRequest,
DeleteSessionRequest,
CreateSessionResponse,
)

logger = logging.getLogger(__name__)


def _filter_session_configuration(
session_configuration: Optional[Dict[str, str]]
) -> Optional[Dict[str, str]]:
if not session_configuration:
return None

filtered_session_configuration = {}
ignored_configs: Set[str] = set()

for key, value in session_configuration.items():
if key.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP:
filtered_session_configuration[key.lower()] = value
else:
ignored_configs.add(key)

if ignored_configs:
logger.warning(
"Some session configurations were ignored because they are not supported: %s",
ignored_configs,
)
logger.warning(
"Supported session configurations are: %s",
list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()),
)

return filtered_session_configuration


class SeaDatabricksClient(DatabricksClient):
"""
Statement Execution API (SEA) implementation of the DatabricksClient interface.
"""

# SEA API paths
BASE_PATH = "/api/2.0/sql/"
SESSION_PATH = BASE_PATH + "sessions"
SESSION_PATH_WITH_ID = SESSION_PATH + "/{}"
STATEMENT_PATH = BASE_PATH + "statements"
STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"

def __init__(
self,
server_hostname: str,
port: int,
http_path: str,
http_headers: List[Tuple[str, str]],
auth_provider,
ssl_options: SSLOptions,
**kwargs,
):
"""
Initialize the SEA backend client.

Args:
server_hostname: Hostname of the Databricks server
port: Port number for the connection
http_path: HTTP path for the connection
http_headers: List of HTTP headers to include in requests
auth_provider: Authentication provider
ssl_options: SSL configuration options
**kwargs: Additional keyword arguments
"""

logger.debug(
"SeaDatabricksClient.__init__(server_hostname=%s, port=%s, http_path=%s)",
server_hostname,
port,
http_path,
)

self._max_download_threads = kwargs.get("max_download_threads", 10)

# Extract warehouse ID from http_path
self.warehouse_id = self._extract_warehouse_id(http_path)

# Initialize HTTP client
self.http_client = SeaHttpClient(
server_hostname=server_hostname,
port=port,
http_path=http_path,
http_headers=http_headers,
auth_provider=auth_provider,
ssl_options=ssl_options,
**kwargs,
)

def _extract_warehouse_id(self, http_path: str) -> str:
"""
Extract the warehouse ID from the HTTP path.

Args:
http_path: The HTTP path from which to extract the warehouse ID

Returns:
The extracted warehouse ID

Raises:
ValueError: If the warehouse ID cannot be extracted from the path
"""

warehouse_pattern = re.compile(r".*/warehouses/(.+)")
endpoint_pattern = re.compile(r".*/endpoints/(.+)")

for pattern in [warehouse_pattern, endpoint_pattern]:
match = pattern.match(http_path)
if not match:
continue
warehouse_id = match.group(1)
logger.debug(
f"Extracted warehouse ID: {warehouse_id} from path: {http_path}"
)
return warehouse_id

# If no match found, raise error
error_message = (
f"Could not extract warehouse ID from http_path: {http_path}. "
f"Expected format: /path/to/warehouses/{{warehouse_id}} or "
f"/path/to/endpoints/{{warehouse_id}}."
f"Note: SEA only works for warehouses."
)
logger.error(error_message)
raise ValueError(error_message)

@property
def max_download_threads(self) -> int:
"""Get the maximum number of download threads for cloud fetch operations."""
return self._max_download_threads

def open_session(
self,
session_configuration: Optional[Dict[str, str]],
catalog: Optional[str],
schema: Optional[str],
) -> SessionId:
"""
Opens a new session with the Databricks SQL service using SEA.

Args:
session_configuration: Optional dictionary of configuration parameters for the session.
Only specific parameters are supported as documented at:
https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-parameters
catalog: Optional catalog name to use as the initial catalog for the session
schema: Optional schema name to use as the initial schema for the session

Returns:
SessionId: A session identifier object that can be used for subsequent operations

Raises:
Error: If the session configuration is invalid
OperationalError: If there's an error establishing the session
"""

logger.debug(
"SeaDatabricksClient.open_session(session_configuration=%s, catalog=%s, schema=%s)",
session_configuration,
catalog,
schema,
)

session_configuration = _filter_session_configuration(session_configuration)

request_data = CreateSessionRequest(
warehouse_id=self.warehouse_id,
session_confs=session_configuration,
catalog=catalog,
schema=schema,
)

response = self.http_client._make_request(
method="POST", path=self.SESSION_PATH, data=request_data.to_dict()
)

session_response = CreateSessionResponse.from_dict(response)
session_id = session_response.session_id
if not session_id:
raise ServerOperationError(
"Failed to create session: No session ID returned",
{
"operation-id": None,
"diagnostic-info": None,
},
)

return SessionId.from_sea_session_id(session_id)

def close_session(self, session_id: SessionId) -> None:
"""
Closes an existing session with the Databricks SQL service.

Args:
session_id: The session identifier returned by open_session()

Raises:
ValueError: If the session ID is invalid
OperationalError: If there's an error closing the session
"""

logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id)

if session_id.backend_type != BackendType.SEA:
raise ValueError("Not a valid SEA session ID")
sea_session_id = session_id.to_sea_session_id()

request_data = DeleteSessionRequest(
warehouse_id=self.warehouse_id,
session_id=sea_session_id,
)

self.http_client._make_request(
method="DELETE",
path=self.SESSION_PATH_WITH_ID.format(sea_session_id),
data=request_data.to_dict(),
)

@staticmethod
def get_default_session_configuration_value(name: str) -> Optional[str]:
"""
Get the default value for a session configuration parameter.

Args:
name: The name of the session configuration parameter

Returns:
The default value if the parameter is supported, None otherwise
"""
return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper())

@staticmethod
def get_allowed_session_configurations() -> List[str]:
"""
Get the list of allowed session configuration parameters.

Returns:
List of allowed session configuration parameter names
"""
return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys())

# == Not Implemented Operations ==
# These methods will be implemented in future iterations

def execute_command(
self,
operation: str,
session_id: SessionId,
max_rows: int,
max_bytes: int,
lz4_compression: bool,
cursor: "Cursor",
use_cloud_fetch: bool,
parameters: List[ttypes.TSparkParameter],
async_op: bool,
enforce_embedded_schema_correctness: bool,
):
"""Not implemented yet."""
raise NotImplementedError(
"execute_command is not yet implemented for SEA backend"
)

def cancel_command(self, command_id: CommandId) -> None:
"""Not implemented yet."""
raise NotImplementedError(
"cancel_command is not yet implemented for SEA backend"
)

def close_command(self, command_id: CommandId) -> None:
"""Not implemented yet."""
raise NotImplementedError(
"close_command is not yet implemented for SEA backend"
)

def get_query_state(self, command_id: CommandId) -> CommandState:
"""Not implemented yet."""
raise NotImplementedError(
"get_query_state is not yet implemented for SEA backend"
)

def get_execution_result(
self,
command_id: CommandId,
cursor: "Cursor",
):
"""Not implemented yet."""
raise NotImplementedError(
"get_execution_result is not yet implemented for SEA backend"
)

# == Metadata Operations ==

def get_catalogs(
self,
session_id: SessionId,
max_rows: int,
max_bytes: int,
cursor: "Cursor",
):
"""Not implemented yet."""
raise NotImplementedError("get_catalogs is not yet implemented for SEA backend")

def get_schemas(
self,
session_id: SessionId,
max_rows: int,
max_bytes: int,
cursor: "Cursor",
catalog_name: Optional[str] = None,
schema_name: Optional[str] = None,
):
"""Not implemented yet."""
raise NotImplementedError("get_schemas is not yet implemented for SEA backend")

def get_tables(
self,
session_id: SessionId,
max_rows: int,
max_bytes: int,
cursor: "Cursor",
catalog_name: Optional[str] = None,
schema_name: Optional[str] = None,
table_name: Optional[str] = None,
table_types: Optional[List[str]] = None,
):
"""Not implemented yet."""
raise NotImplementedError("get_tables is not yet implemented for SEA backend")

def get_columns(
self,
session_id: SessionId,
max_rows: int,
max_bytes: int,
cursor: "Cursor",
catalog_name: Optional[str] = None,
schema_name: Optional[str] = None,
table_name: Optional[str] = None,
column_name: Optional[str] = None,
):
"""Not implemented yet."""
raise NotImplementedError("get_columns is not yet implemented for SEA backend")
22 changes: 22 additions & 0 deletions src/databricks/sql/backend/sea/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Models for the SEA (Statement Execution API) backend.

This package contains data models for SEA API requests and responses.
"""

from databricks.sql.backend.sea.models.requests import (
CreateSessionRequest,
DeleteSessionRequest,
)

from databricks.sql.backend.sea.models.responses import (
CreateSessionResponse,
)

__all__ = [
# Request models
"CreateSessionRequest",
"DeleteSessionRequest",
# Response models
"CreateSessionResponse",
]
Loading
Loading