-
Notifications
You must be signed in to change notification settings - Fork 114
Introduce SeaDatabricksClient
(Session Implementation)
#582
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
f3e4a97
3df5752
9146a94
9b39e37
1ccbcd2
3528523
b39e83b
ba36126
059cd4d
6830327
ab847da
1c399d5
42c4581
8bfca45
5005b13
8efa68c
6e41ebf
638e1df
ed4931e
4ff64ed
9aebea2
a05f1fd
46104e2
390c1e7
86ee56f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
""" | ||
Models for the SEA (Statement Execution API) backend. | ||
|
||
This package contains data models for SEA API requests and responses. | ||
""" | ||
|
||
from databricks.sql.backend.models.requests import ( | ||
CreateSessionRequest, | ||
DeleteSessionRequest, | ||
) | ||
|
||
from databricks.sql.backend.models.responses import ( | ||
CreateSessionResponse, | ||
) | ||
|
||
__all__ = [ | ||
# Request models | ||
"CreateSessionRequest", | ||
"DeleteSessionRequest", | ||
# Response models | ||
"CreateSessionResponse", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import Dict, Any, Optional | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class CreateSessionRequest: | ||
"""Request to create a new session.""" | ||
|
||
warehouse_id: str | ||
session_confs: Optional[Dict[str, str]] = None | ||
catalog: Optional[str] = None | ||
schema: Optional[str] = None | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
"""Convert the request to a dictionary for JSON serialization.""" | ||
result: Dict[str, Any] = {"warehouse_id": self.warehouse_id} | ||
|
||
if self.session_confs: | ||
result["session_confs"] = self.session_confs | ||
|
||
if self.catalog: | ||
result["catalog"] = self.catalog | ||
|
||
if self.schema: | ||
result["schema"] = self.schema | ||
|
||
return result | ||
|
||
|
||
@dataclass | ||
class DeleteSessionRequest: | ||
"""Request to delete a session.""" | ||
|
||
warehouse_id: str | ||
session_id: str | ||
|
||
def to_dict(self) -> Dict[str, str]: | ||
"""Convert the request to a dictionary for JSON serialization.""" | ||
return {"warehouse_id": self.warehouse_id, "session_id": self.session_id} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from typing import Dict, Any | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class CreateSessionResponse: | ||
"""Response from creating a new session.""" | ||
|
||
session_id: str | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse": | ||
"""Create a CreateSessionResponse from a dictionary.""" | ||
return cls(session_id=data.get("session_id", "")) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,305 @@ | ||
import logging | ||
import re | ||
from typing import Dict, Tuple, List, Optional, TYPE_CHECKING | ||
|
||
if TYPE_CHECKING: | ||
from databricks.sql.client import Cursor | ||
|
||
from databricks.sql.backend.databricks_client import DatabricksClient | ||
from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType | ||
from databricks.sql.exc import ServerOperationError | ||
from databricks.sql.backend.utils.http_client import SeaHttpClient | ||
from databricks.sql.thrift_api.TCLIService import ttypes | ||
from databricks.sql.types import SSLOptions | ||
|
||
from databricks.sql.backend.models import ( | ||
CreateSessionRequest, | ||
DeleteSessionRequest, | ||
CreateSessionResponse, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SeaDatabricksClient(DatabricksClient): | ||
""" | ||
Statement Execution API (SEA) implementation of the DatabricksClient interface. | ||
""" | ||
|
||
# SEA API paths | ||
BASE_PATH = "/api/2.0/sql/" | ||
SESSION_PATH = BASE_PATH + "sessions" | ||
SESSION_PATH_WITH_ID = SESSION_PATH + "/{}" | ||
STATEMENT_PATH = BASE_PATH + "statements" | ||
STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}" | ||
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel" | ||
|
||
def __init__( | ||
self, | ||
server_hostname: str, | ||
port: int, | ||
http_path: str, | ||
http_headers: List[Tuple[str, str]], | ||
auth_provider, | ||
ssl_options: SSLOptions, | ||
**kwargs, | ||
): | ||
""" | ||
Initialize the SEA backend client. | ||
|
||
Args: | ||
server_hostname: Hostname of the Databricks server | ||
port: Port number for the connection | ||
http_path: HTTP path for the connection | ||
http_headers: List of HTTP headers to include in requests | ||
auth_provider: Authentication provider | ||
ssl_options: SSL configuration options | ||
**kwargs: Additional keyword arguments | ||
""" | ||
|
||
logger.debug( | ||
"SeaDatabricksClient.__init__(server_hostname=%s, port=%s, http_path=%s)", | ||
server_hostname, | ||
port, | ||
http_path, | ||
) | ||
|
||
self._max_download_threads = kwargs.get("max_download_threads", 10) | ||
|
||
# Extract warehouse ID from http_path | ||
self.warehouse_id = self._extract_warehouse_id(http_path) | ||
|
||
# Initialize HTTP client | ||
self.http_client = SeaHttpClient( | ||
server_hostname=server_hostname, | ||
port=port, | ||
http_path=http_path, | ||
http_headers=http_headers, | ||
auth_provider=auth_provider, | ||
ssl_options=ssl_options, | ||
**kwargs, | ||
) | ||
|
||
def _extract_warehouse_id(self, http_path: str) -> str: | ||
""" | ||
Extract the warehouse ID from the HTTP path. | ||
|
||
Args: | ||
http_path: The HTTP path from which to extract the warehouse ID | ||
|
||
Returns: | ||
The extracted warehouse ID | ||
|
||
Raises: | ||
ValueError: If the warehouse ID cannot be extracted from the path | ||
""" | ||
|
||
warehouse_pattern = re.compile(r".*/warehouses/(.+)") | ||
endpoint_pattern = re.compile(r".*/endpoints/(.+)") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we have endpoints now? i don't think so There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The JDBC driver supports it. |
||
|
||
for pattern in [warehouse_pattern, endpoint_pattern]: | ||
match = pattern.match(http_path) | ||
if not match: | ||
continue | ||
warehouse_id = match.group(1) | ||
logger.debug( | ||
f"Extracted warehouse ID: {warehouse_id} from path: {http_path}" | ||
) | ||
return warehouse_id | ||
|
||
# If no match found, raise error | ||
error_message = ( | ||
f"Could not extract warehouse ID from http_path: {http_path}. " | ||
f"Expected format: /path/to/warehouses/{{warehouse_id}} or " | ||
f"/path/to/endpoints/{{warehouse_id}}" | ||
) | ||
logger.error(error_message) | ||
raise ValueError(error_message) | ||
|
||
@property | ||
def max_download_threads(self) -> int: | ||
"""Get the maximum number of download threads for cloud fetch operations.""" | ||
return self._max_download_threads | ||
|
||
def open_session( | ||
self, | ||
session_configuration: Optional[Dict[str, str]], | ||
catalog: Optional[str], | ||
schema: Optional[str], | ||
) -> SessionId: | ||
""" | ||
Opens a new session with the Databricks SQL service using SEA. | ||
|
||
Args: | ||
session_configuration: Optional dictionary of configuration parameters for the session | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. right now only a select spark configs are allowed in DBSQL as session configs: https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-parameters please document this appropriately. also, please test if server ignores any other arbitrary session config. if not, please include client-side checks/filtering There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, I added a link to the above doc in the doc-string. I tested passing some arbitrary (un-supported) parameters in the session config and I did not get any warning or error from the server, so the server seems to ignore irrelevant params. To be clear, this means we need not include any client side filtering, right? |
||
catalog: Optional catalog name to use as the initial catalog for the session | ||
schema: Optional schema name to use as the initial schema for the session | ||
|
||
Returns: | ||
SessionId: A session identifier object that can be used for subsequent operations | ||
|
||
Raises: | ||
Error: If the session configuration is invalid | ||
OperationalError: If there's an error establishing the session | ||
""" | ||
|
||
logger.debug( | ||
"SeaDatabricksClient.open_session(session_configuration=%s, catalog=%s, schema=%s)", | ||
session_configuration, | ||
catalog, | ||
schema, | ||
) | ||
|
||
request_data = CreateSessionRequest( | ||
warehouse_id=self.warehouse_id, | ||
session_confs=session_configuration, | ||
catalog=catalog, | ||
schema=schema, | ||
) | ||
|
||
response = self.http_client._make_request( | ||
varun-edachali-dbx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
method="POST", path=self.SESSION_PATH, data=request_data.to_dict() | ||
) | ||
|
||
session_response = CreateSessionResponse.from_dict(response) | ||
session_id = session_response.session_id | ||
if not session_id: | ||
raise ServerOperationError( | ||
"Failed to create session: No session ID returned", | ||
{ | ||
"operation-id": None, | ||
"diagnostic-info": None, | ||
}, | ||
) | ||
|
||
return SessionId.from_sea_session_id(session_id) | ||
|
||
def close_session(self, session_id: SessionId) -> None: | ||
""" | ||
Closes an existing session with the Databricks SQL service. | ||
|
||
Args: | ||
session_id: The session identifier returned by open_session() | ||
|
||
Raises: | ||
ValueError: If the session ID is invalid | ||
OperationalError: If there's an error closing the session | ||
""" | ||
|
||
logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id) | ||
|
||
if session_id.backend_type != BackendType.SEA: | ||
raise ValueError("Not a valid SEA session ID") | ||
sea_session_id = session_id.to_sea_session_id() | ||
|
||
request_data = DeleteSessionRequest( | ||
warehouse_id=self.warehouse_id, | ||
session_id=sea_session_id, | ||
) | ||
|
||
self.http_client._make_request( | ||
method="DELETE", | ||
path=self.SESSION_PATH_WITH_ID.format(sea_session_id), | ||
data=request_data.to_dict(), | ||
) | ||
|
||
# == Not Implemented Operations == | ||
# These methods will be implemented in future iterations | ||
|
||
def execute_command( | ||
self, | ||
operation: str, | ||
session_id: SessionId, | ||
max_rows: int, | ||
max_bytes: int, | ||
lz4_compression: bool, | ||
cursor: "Cursor", | ||
use_cloud_fetch: bool, | ||
parameters: List[ttypes.TSparkParameter], | ||
async_op: bool, | ||
enforce_embedded_schema_correctness: bool, | ||
): | ||
"""Not implemented yet.""" | ||
raise NotImplementedError( | ||
"execute_command is not yet implemented for SEA backend" | ||
) | ||
|
||
def cancel_command(self, command_id: CommandId) -> None: | ||
"""Not implemented yet.""" | ||
raise NotImplementedError( | ||
"cancel_command is not yet implemented for SEA backend" | ||
) | ||
|
||
def close_command(self, command_id: CommandId) -> None: | ||
"""Not implemented yet.""" | ||
raise NotImplementedError( | ||
"close_command is not yet implemented for SEA backend" | ||
) | ||
|
||
def get_query_state(self, command_id: CommandId) -> CommandState: | ||
"""Not implemented yet.""" | ||
raise NotImplementedError( | ||
"get_query_state is not yet implemented for SEA backend" | ||
) | ||
|
||
def get_execution_result( | ||
self, | ||
command_id: CommandId, | ||
cursor: "Cursor", | ||
): | ||
"""Not implemented yet.""" | ||
raise NotImplementedError( | ||
"get_execution_result is not yet implemented for SEA backend" | ||
) | ||
|
||
# == Metadata Operations == | ||
|
||
def get_catalogs( | ||
self, | ||
session_id: SessionId, | ||
max_rows: int, | ||
max_bytes: int, | ||
cursor: "Cursor", | ||
): | ||
"""Not implemented yet.""" | ||
raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") | ||
|
||
def get_schemas( | ||
self, | ||
session_id: SessionId, | ||
max_rows: int, | ||
max_bytes: int, | ||
cursor: "Cursor", | ||
catalog_name: Optional[str] = None, | ||
schema_name: Optional[str] = None, | ||
): | ||
"""Not implemented yet.""" | ||
raise NotImplementedError("get_schemas is not yet implemented for SEA backend") | ||
|
||
def get_tables( | ||
self, | ||
session_id: SessionId, | ||
max_rows: int, | ||
max_bytes: int, | ||
cursor: "Cursor", | ||
catalog_name: Optional[str] = None, | ||
schema_name: Optional[str] = None, | ||
table_name: Optional[str] = None, | ||
table_types: Optional[List[str]] = None, | ||
): | ||
"""Not implemented yet.""" | ||
raise NotImplementedError("get_tables is not yet implemented for SEA backend") | ||
|
||
def get_columns( | ||
self, | ||
session_id: SessionId, | ||
max_rows: int, | ||
max_bytes: int, | ||
cursor: "Cursor", | ||
catalog_name: Optional[str] = None, | ||
schema_name: Optional[str] = None, | ||
table_name: Optional[str] = None, | ||
column_name: Optional[str] = None, | ||
): | ||
"""Not implemented yet.""" | ||
raise NotImplementedError("get_columns is not yet implemented for SEA backend") |
Uh oh!
There was an error while loading. Please reload this page.