Skip to content

Commit 0887bc1

Browse files
Introduce SeaDatabricksClient (Session Implementation) (#582)
* [squashed from prev branch] introduce sea client with session open and close functionality Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove accidental changes to workflows (merge artifacts) Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * pass test_input to get_protocol_version instead of session_id to maintain previous API Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * formatting (black + line gaps after multi-line pydocs) Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * use factory for backend instantiation Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * fix type issues Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove redundant comments Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * introduce models for requests and responses Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove http client and test script to prevent diff from showing up post http-client merge Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * Introduce Sea HTTP Client and test script (#583) * introduce http client (temp) and sea test file Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * reduce verbosity Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * redundant comment Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * reduce redundancy, params and data separate Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * rename client Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * fix type issues Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * reduce repetition in request calls Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove un-necessary elifs Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * add newline at EOF Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> --------- Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * CustomHttpClient -> SeaHttpClient Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * redundant comment in backend client Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * regex for warehouse_id instead of .split, remove excess imports and behaviour Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove redundant attributes Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * formatting (black) Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * [nit] reduce nested code Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * line gap after multi-line pydoc Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * redundant imports Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * move sea backend and models into separate sea/ dir Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * move http client into separate sea/ dir Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * change commands to include ones in docs Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * add link to sql-ref-parameters for session-confs Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * add client side filtering for session confs, add note on warehouses over endoints Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * test unimplemented methods and max_download_threads prop Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> --------- Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 9625229 commit 0887bc1

File tree

9 files changed

+790
-22
lines changed

9 files changed

+790
-22
lines changed
Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
import logging
2+
import re
3+
from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set
4+
5+
if TYPE_CHECKING:
6+
from databricks.sql.client import Cursor
7+
8+
from databricks.sql.backend.databricks_client import DatabricksClient
9+
from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType
10+
from databricks.sql.exc import ServerOperationError
11+
from databricks.sql.backend.sea.utils.http_client import SeaHttpClient
12+
from databricks.sql.backend.sea.utils.constants import (
13+
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP,
14+
)
15+
from databricks.sql.thrift_api.TCLIService import ttypes
16+
from databricks.sql.types import SSLOptions
17+
18+
from databricks.sql.backend.sea.models import (
19+
CreateSessionRequest,
20+
DeleteSessionRequest,
21+
CreateSessionResponse,
22+
)
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
def _filter_session_configuration(
28+
session_configuration: Optional[Dict[str, str]]
29+
) -> Optional[Dict[str, str]]:
30+
if not session_configuration:
31+
return None
32+
33+
filtered_session_configuration = {}
34+
ignored_configs: Set[str] = set()
35+
36+
for key, value in session_configuration.items():
37+
if key.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP:
38+
filtered_session_configuration[key.lower()] = value
39+
else:
40+
ignored_configs.add(key)
41+
42+
if ignored_configs:
43+
logger.warning(
44+
"Some session configurations were ignored because they are not supported: %s",
45+
ignored_configs,
46+
)
47+
logger.warning(
48+
"Supported session configurations are: %s",
49+
list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()),
50+
)
51+
52+
return filtered_session_configuration
53+
54+
55+
class SeaDatabricksClient(DatabricksClient):
56+
"""
57+
Statement Execution API (SEA) implementation of the DatabricksClient interface.
58+
"""
59+
60+
# SEA API paths
61+
BASE_PATH = "/api/2.0/sql/"
62+
SESSION_PATH = BASE_PATH + "sessions"
63+
SESSION_PATH_WITH_ID = SESSION_PATH + "/{}"
64+
STATEMENT_PATH = BASE_PATH + "statements"
65+
STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
66+
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
67+
68+
def __init__(
69+
self,
70+
server_hostname: str,
71+
port: int,
72+
http_path: str,
73+
http_headers: List[Tuple[str, str]],
74+
auth_provider,
75+
ssl_options: SSLOptions,
76+
**kwargs,
77+
):
78+
"""
79+
Initialize the SEA backend client.
80+
81+
Args:
82+
server_hostname: Hostname of the Databricks server
83+
port: Port number for the connection
84+
http_path: HTTP path for the connection
85+
http_headers: List of HTTP headers to include in requests
86+
auth_provider: Authentication provider
87+
ssl_options: SSL configuration options
88+
**kwargs: Additional keyword arguments
89+
"""
90+
91+
logger.debug(
92+
"SeaDatabricksClient.__init__(server_hostname=%s, port=%s, http_path=%s)",
93+
server_hostname,
94+
port,
95+
http_path,
96+
)
97+
98+
self._max_download_threads = kwargs.get("max_download_threads", 10)
99+
100+
# Extract warehouse ID from http_path
101+
self.warehouse_id = self._extract_warehouse_id(http_path)
102+
103+
# Initialize HTTP client
104+
self.http_client = SeaHttpClient(
105+
server_hostname=server_hostname,
106+
port=port,
107+
http_path=http_path,
108+
http_headers=http_headers,
109+
auth_provider=auth_provider,
110+
ssl_options=ssl_options,
111+
**kwargs,
112+
)
113+
114+
def _extract_warehouse_id(self, http_path: str) -> str:
115+
"""
116+
Extract the warehouse ID from the HTTP path.
117+
118+
Args:
119+
http_path: The HTTP path from which to extract the warehouse ID
120+
121+
Returns:
122+
The extracted warehouse ID
123+
124+
Raises:
125+
ValueError: If the warehouse ID cannot be extracted from the path
126+
"""
127+
128+
warehouse_pattern = re.compile(r".*/warehouses/(.+)")
129+
endpoint_pattern = re.compile(r".*/endpoints/(.+)")
130+
131+
for pattern in [warehouse_pattern, endpoint_pattern]:
132+
match = pattern.match(http_path)
133+
if not match:
134+
continue
135+
warehouse_id = match.group(1)
136+
logger.debug(
137+
f"Extracted warehouse ID: {warehouse_id} from path: {http_path}"
138+
)
139+
return warehouse_id
140+
141+
# If no match found, raise error
142+
error_message = (
143+
f"Could not extract warehouse ID from http_path: {http_path}. "
144+
f"Expected format: /path/to/warehouses/{{warehouse_id}} or "
145+
f"/path/to/endpoints/{{warehouse_id}}."
146+
f"Note: SEA only works for warehouses."
147+
)
148+
logger.error(error_message)
149+
raise ValueError(error_message)
150+
151+
@property
152+
def max_download_threads(self) -> int:
153+
"""Get the maximum number of download threads for cloud fetch operations."""
154+
return self._max_download_threads
155+
156+
def open_session(
157+
self,
158+
session_configuration: Optional[Dict[str, str]],
159+
catalog: Optional[str],
160+
schema: Optional[str],
161+
) -> SessionId:
162+
"""
163+
Opens a new session with the Databricks SQL service using SEA.
164+
165+
Args:
166+
session_configuration: Optional dictionary of configuration parameters for the session.
167+
Only specific parameters are supported as documented at:
168+
https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-parameters
169+
catalog: Optional catalog name to use as the initial catalog for the session
170+
schema: Optional schema name to use as the initial schema for the session
171+
172+
Returns:
173+
SessionId: A session identifier object that can be used for subsequent operations
174+
175+
Raises:
176+
Error: If the session configuration is invalid
177+
OperationalError: If there's an error establishing the session
178+
"""
179+
180+
logger.debug(
181+
"SeaDatabricksClient.open_session(session_configuration=%s, catalog=%s, schema=%s)",
182+
session_configuration,
183+
catalog,
184+
schema,
185+
)
186+
187+
session_configuration = _filter_session_configuration(session_configuration)
188+
189+
request_data = CreateSessionRequest(
190+
warehouse_id=self.warehouse_id,
191+
session_confs=session_configuration,
192+
catalog=catalog,
193+
schema=schema,
194+
)
195+
196+
response = self.http_client._make_request(
197+
method="POST", path=self.SESSION_PATH, data=request_data.to_dict()
198+
)
199+
200+
session_response = CreateSessionResponse.from_dict(response)
201+
session_id = session_response.session_id
202+
if not session_id:
203+
raise ServerOperationError(
204+
"Failed to create session: No session ID returned",
205+
{
206+
"operation-id": None,
207+
"diagnostic-info": None,
208+
},
209+
)
210+
211+
return SessionId.from_sea_session_id(session_id)
212+
213+
def close_session(self, session_id: SessionId) -> None:
214+
"""
215+
Closes an existing session with the Databricks SQL service.
216+
217+
Args:
218+
session_id: The session identifier returned by open_session()
219+
220+
Raises:
221+
ValueError: If the session ID is invalid
222+
OperationalError: If there's an error closing the session
223+
"""
224+
225+
logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id)
226+
227+
if session_id.backend_type != BackendType.SEA:
228+
raise ValueError("Not a valid SEA session ID")
229+
sea_session_id = session_id.to_sea_session_id()
230+
231+
request_data = DeleteSessionRequest(
232+
warehouse_id=self.warehouse_id,
233+
session_id=sea_session_id,
234+
)
235+
236+
self.http_client._make_request(
237+
method="DELETE",
238+
path=self.SESSION_PATH_WITH_ID.format(sea_session_id),
239+
data=request_data.to_dict(),
240+
)
241+
242+
@staticmethod
243+
def get_default_session_configuration_value(name: str) -> Optional[str]:
244+
"""
245+
Get the default value for a session configuration parameter.
246+
247+
Args:
248+
name: The name of the session configuration parameter
249+
250+
Returns:
251+
The default value if the parameter is supported, None otherwise
252+
"""
253+
return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper())
254+
255+
@staticmethod
256+
def get_allowed_session_configurations() -> List[str]:
257+
"""
258+
Get the list of allowed session configuration parameters.
259+
260+
Returns:
261+
List of allowed session configuration parameter names
262+
"""
263+
return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys())
264+
265+
# == Not Implemented Operations ==
266+
# These methods will be implemented in future iterations
267+
268+
def execute_command(
269+
self,
270+
operation: str,
271+
session_id: SessionId,
272+
max_rows: int,
273+
max_bytes: int,
274+
lz4_compression: bool,
275+
cursor: "Cursor",
276+
use_cloud_fetch: bool,
277+
parameters: List[ttypes.TSparkParameter],
278+
async_op: bool,
279+
enforce_embedded_schema_correctness: bool,
280+
):
281+
"""Not implemented yet."""
282+
raise NotImplementedError(
283+
"execute_command is not yet implemented for SEA backend"
284+
)
285+
286+
def cancel_command(self, command_id: CommandId) -> None:
287+
"""Not implemented yet."""
288+
raise NotImplementedError(
289+
"cancel_command is not yet implemented for SEA backend"
290+
)
291+
292+
def close_command(self, command_id: CommandId) -> None:
293+
"""Not implemented yet."""
294+
raise NotImplementedError(
295+
"close_command is not yet implemented for SEA backend"
296+
)
297+
298+
def get_query_state(self, command_id: CommandId) -> CommandState:
299+
"""Not implemented yet."""
300+
raise NotImplementedError(
301+
"get_query_state is not yet implemented for SEA backend"
302+
)
303+
304+
def get_execution_result(
305+
self,
306+
command_id: CommandId,
307+
cursor: "Cursor",
308+
):
309+
"""Not implemented yet."""
310+
raise NotImplementedError(
311+
"get_execution_result is not yet implemented for SEA backend"
312+
)
313+
314+
# == Metadata Operations ==
315+
316+
def get_catalogs(
317+
self,
318+
session_id: SessionId,
319+
max_rows: int,
320+
max_bytes: int,
321+
cursor: "Cursor",
322+
):
323+
"""Not implemented yet."""
324+
raise NotImplementedError("get_catalogs is not yet implemented for SEA backend")
325+
326+
def get_schemas(
327+
self,
328+
session_id: SessionId,
329+
max_rows: int,
330+
max_bytes: int,
331+
cursor: "Cursor",
332+
catalog_name: Optional[str] = None,
333+
schema_name: Optional[str] = None,
334+
):
335+
"""Not implemented yet."""
336+
raise NotImplementedError("get_schemas is not yet implemented for SEA backend")
337+
338+
def get_tables(
339+
self,
340+
session_id: SessionId,
341+
max_rows: int,
342+
max_bytes: int,
343+
cursor: "Cursor",
344+
catalog_name: Optional[str] = None,
345+
schema_name: Optional[str] = None,
346+
table_name: Optional[str] = None,
347+
table_types: Optional[List[str]] = None,
348+
):
349+
"""Not implemented yet."""
350+
raise NotImplementedError("get_tables is not yet implemented for SEA backend")
351+
352+
def get_columns(
353+
self,
354+
session_id: SessionId,
355+
max_rows: int,
356+
max_bytes: int,
357+
cursor: "Cursor",
358+
catalog_name: Optional[str] = None,
359+
schema_name: Optional[str] = None,
360+
table_name: Optional[str] = None,
361+
column_name: Optional[str] = None,
362+
):
363+
"""Not implemented yet."""
364+
raise NotImplementedError("get_columns is not yet implemented for SEA backend")
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""
2+
Models for the SEA (Statement Execution API) backend.
3+
4+
This package contains data models for SEA API requests and responses.
5+
"""
6+
7+
from databricks.sql.backend.sea.models.requests import (
8+
CreateSessionRequest,
9+
DeleteSessionRequest,
10+
)
11+
12+
from databricks.sql.backend.sea.models.responses import (
13+
CreateSessionResponse,
14+
)
15+
16+
__all__ = [
17+
# Request models
18+
"CreateSessionRequest",
19+
"DeleteSessionRequest",
20+
# Response models
21+
"CreateSessionResponse",
22+
]

0 commit comments

Comments
 (0)