Skip to content

Commit 4fd2a3f

Browse files
Merge branch 'main' into sea-migration
2 parents ef5836b + 141a004 commit 4fd2a3f

20 files changed

+291
-404
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def execute_command(
9696
max_rows: Maximum number of rows to fetch in a single fetch batch
9797
max_bytes: Maximum number of bytes to fetch in a single fetch batch
9898
lz4_compression: Whether to use LZ4 compression for result data
99-
cursor: The cursor object that will handle the results
99+
cursor: The cursor object that will handle the results. The command id is set in this cursor.
100100
use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets
101101
parameters: List of parameters to bind to the query
102102
async_op: Whether to execute the command asynchronously
@@ -282,7 +282,9 @@ def get_tables(
282282
max_bytes: Maximum number of bytes to fetch in a single batch
283283
cursor: The cursor object that will handle the results
284284
catalog_name: Optional catalog name pattern to filter by
285+
if catalog_name is None, we fetch across all catalogs
285286
schema_name: Optional schema name pattern to filter by
287+
if schema_name is None, we fetch across all schemas
286288
table_name: Optional table name pattern to filter by
287289
table_types: Optional list of table types to filter by (e.g., ['TABLE', 'VIEW'])
288290
@@ -321,6 +323,7 @@ def get_columns(
321323
catalog_name: Optional catalog name pattern to filter by
322324
schema_name: Optional schema name pattern to filter by
323325
table_name: Optional table name pattern to filter by
326+
if table_name is None, we fetch across all tables
324327
column_name: Optional column name pattern to filter by
325328
326329
Returns:

src/databricks/sql/backend/sea/backend.py

Lines changed: 59 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import re
66
from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set
77

8-
from databricks.sql.backend.sea.models.base import ExternalLink, ResultManifest
8+
from databricks.sql.backend.sea.models.base import ResultManifest, StatementStatus
99
from databricks.sql.backend.sea.utils.constants import (
1010
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP,
1111
ResultFormat,
@@ -19,7 +19,7 @@
1919
if TYPE_CHECKING:
2020
from databricks.sql.client import Cursor
2121

22-
from databricks.sql.backend.sea.result_set import SeaResultSet
22+
from databricks.sql.result_set import SeaResultSet
2323

2424
from databricks.sql.backend.databricks_client import DatabricksClient
2525
from databricks.sql.backend.types import (
@@ -45,14 +45,34 @@
4545
GetStatementResponse,
4646
CreateSessionResponse,
4747
)
48-
from databricks.sql.backend.sea.models.responses import GetChunksResponse
4948

5049
logger = logging.getLogger(__name__)
5150

5251

5352
def _filter_session_configuration(
5453
session_configuration: Optional[Dict[str, Any]],
5554
) -> Dict[str, str]:
55+
"""
56+
Filter and normalise the provided session configuration parameters.
57+
58+
The Statement Execution API supports only a subset of SQL session
59+
configuration options. This helper validates the supplied
60+
``session_configuration`` dictionary against the allow-list defined in
61+
``ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP`` and returns a new
62+
dictionary that contains **only** the supported parameters.
63+
64+
Args:
65+
session_configuration: Optional mapping of session configuration
66+
names to their desired values. Key comparison is
67+
case-insensitive.
68+
69+
Returns:
70+
Dict[str, str]: A dictionary containing only the supported
71+
configuration parameters with lower-case keys and string values. If
72+
*session_configuration* is ``None`` or empty, an empty dictionary is
73+
returned.
74+
"""
75+
5676
if not session_configuration:
5777
return {}
5878

@@ -90,7 +110,6 @@ class SeaDatabricksClient(DatabricksClient):
90110
STATEMENT_PATH = BASE_PATH + "statements"
91111
STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
92112
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
93-
CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
94113

95114
# SEA constants
96115
POLL_INTERVAL_SECONDS = 0.2
@@ -143,7 +162,7 @@ def __init__(
143162
http_path=http_path,
144163
http_headers=http_headers,
145164
auth_provider=auth_provider,
146-
ssl_options=self._ssl_options,
165+
ssl_options=ssl_options,
147166
**kwargs,
148167
)
149168

@@ -275,32 +294,9 @@ def close_session(self, session_id: SessionId) -> None:
275294
data=request_data.to_dict(),
276295
)
277296

278-
@staticmethod
279-
def get_default_session_configuration_value(name: str) -> Optional[str]:
280-
"""
281-
Get the default value for a session configuration parameter.
282-
283-
Args:
284-
name: The name of the session configuration parameter
285-
286-
Returns:
287-
The default value if the parameter is supported, None otherwise
288-
"""
289-
return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper())
290-
291-
@staticmethod
292-
def get_allowed_session_configurations() -> List[str]:
293-
"""
294-
Get the list of allowed session configuration parameters.
295-
296-
Returns:
297-
List of allowed session configuration parameter names
298-
"""
299-
return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys())
300-
301297
def _extract_description_from_manifest(
302298
self, manifest: ResultManifest
303-
) -> List[Tuple]:
299+
) -> Optional[List]:
304300
"""
305301
Extract column description from a manifest object, in the format defined by
306302
the spec: https://peps.python.org/pep-0249/#description
@@ -309,28 +305,39 @@ def _extract_description_from_manifest(
309305
manifest: The ResultManifest object containing schema information
310306
311307
Returns:
312-
List[Tuple]: A list of column tuples
308+
Optional[List]: A list of column tuples or None if no columns are found
313309
"""
314310

315311
schema_data = manifest.schema
316312
columns_data = schema_data.get("columns", [])
317313

314+
if not columns_data:
315+
return None
316+
318317
columns = []
319318
for col_data in columns_data:
320319
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
320+
name = col_data.get("name", "")
321+
type_name = col_data.get("type_name", "")
322+
type_name = (
323+
type_name[:-5] if type_name.endswith("_TYPE") else type_name
324+
).lower()
325+
precision = col_data.get("type_precision")
326+
scale = col_data.get("type_scale")
327+
321328
columns.append(
322329
(
323-
col_data.get("name", ""), # name
324-
col_data.get("type_name", ""), # type_code
330+
name, # name
331+
type_name, # type_code
325332
None, # display_size (not provided by SEA)
326333
None, # internal_size (not provided by SEA)
327-
col_data.get("precision"), # precision
328-
col_data.get("scale"), # scale
329-
col_data.get("nullable", True), # null_ok
334+
precision, # precision
335+
scale, # scale
336+
None, # null_ok
330337
)
331338
)
332339

333-
return columns
340+
return columns if columns else None
334341

335342
def _results_message_to_execute_response(
336343
self, response: Union[ExecuteStatementResponse, GetStatementResponse]
@@ -351,7 +358,7 @@ def _results_message_to_execute_response(
351358

352359
# Check for compression
353360
lz4_compressed = (
354-
response.manifest.result_compression == ResultCompression.LZ4_FRAME.value
361+
response.manifest.result_compression == ResultCompression.LZ4_FRAME
355362
)
356363

357364
execute_response = ExecuteResponse(
@@ -389,8 +396,9 @@ def _response_to_result_set(
389396
)
390397

391398
def _check_command_not_in_failed_or_closed_state(
392-
self, state: CommandState, command_id: CommandId
399+
self, status: StatementStatus, command_id: CommandId
393400
) -> None:
401+
state = status.state
394402
if state == CommandState.CLOSED:
395403
raise DatabaseError(
396404
"Command {} unexpectedly closed server side".format(command_id),
@@ -399,8 +407,11 @@ def _check_command_not_in_failed_or_closed_state(
399407
},
400408
)
401409
if state == CommandState.FAILED:
410+
error = status.error
411+
error_code = error.error_code if error else "UNKNOWN_ERROR_CODE"
412+
error_message = error.message if error else "UNKNOWN_ERROR_MESSAGE"
402413
raise ServerOperationError(
403-
"Command {} failed".format(command_id),
414+
"Command failed: {} - {}".format(error_code, error_message),
404415
{
405416
"operation-id": command_id,
406417
},
@@ -414,16 +425,18 @@ def _wait_until_command_done(
414425
"""
415426

416427
final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response
417-
418-
state = final_response.status.state
419428
command_id = CommandId.from_sea_statement_id(final_response.statement_id)
420429

421-
while state in [CommandState.PENDING, CommandState.RUNNING]:
430+
while final_response.status.state in [
431+
CommandState.PENDING,
432+
CommandState.RUNNING,
433+
]:
422434
time.sleep(self.POLL_INTERVAL_SECONDS)
423435
final_response = self._poll_query(command_id)
424-
state = final_response.status.state
425436

426-
self._check_command_not_in_failed_or_closed_state(state, command_id)
437+
self._check_command_not_in_failed_or_closed_state(
438+
final_response.status, command_id
439+
)
427440

428441
return final_response
429442

@@ -457,7 +470,7 @@ def execute_command(
457470
enforce_embedded_schema_correctness: Whether to enforce schema correctness
458471
459472
Returns:
460-
SeaResultSet: A SeaResultSet instance for the executed command
473+
ResultSet: A SeaResultSet instance for the executed command
461474
"""
462475

463476
if session_id.backend_type != BackendType.SEA:
@@ -513,14 +526,6 @@ def execute_command(
513526
)
514527
response = ExecuteStatementResponse.from_dict(response_data)
515528
statement_id = response.statement_id
516-
if not statement_id:
517-
raise ServerOperationError(
518-
"Failed to execute command: No statement ID returned",
519-
{
520-
"operation-id": None,
521-
"diagnostic-info": None,
522-
},
523-
)
524529

525530
command_id = CommandId.from_sea_statement_id(statement_id)
526531

@@ -552,8 +557,6 @@ def cancel_command(self, command_id: CommandId) -> None:
552557
raise ValueError("Not a valid SEA command ID")
553558

554559
sea_statement_id = command_id.to_sea_statement_id()
555-
if sea_statement_id is None:
556-
raise ValueError("Not a valid SEA command ID")
557560

558561
request = CancelStatementRequest(statement_id=sea_statement_id)
559562
self._http_client._make_request(
@@ -577,8 +580,6 @@ def close_command(self, command_id: CommandId) -> None:
577580
raise ValueError("Not a valid SEA command ID")
578581

579582
sea_statement_id = command_id.to_sea_statement_id()
580-
if sea_statement_id is None:
581-
raise ValueError("Not a valid SEA command ID")
582583

583584
request = CloseStatementRequest(statement_id=sea_statement_id)
584585
self._http_client._make_request(
@@ -596,8 +597,6 @@ def _poll_query(self, command_id: CommandId) -> GetStatementResponse:
596597
raise ValueError("Not a valid SEA command ID")
597598

598599
sea_statement_id = command_id.to_sea_statement_id()
599-
if sea_statement_id is None:
600-
raise ValueError("Not a valid SEA command ID")
601600

602601
request = GetStatementRequest(statement_id=sea_statement_id)
603602
response_data = self._http_client._make_request(
@@ -620,7 +619,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
620619
CommandState: The current state of the command
621620
622621
Raises:
623-
ProgrammingError: If the command ID is invalid
622+
ValueError: If the command ID is invalid
624623
"""
625624

626625
response = self._poll_query(command_id)
@@ -648,27 +647,6 @@ def get_execution_result(
648647
response = self._poll_query(command_id)
649648
return self._response_to_result_set(response, cursor)
650649

651-
def get_chunk_links(
652-
self, statement_id: str, chunk_index: int
653-
) -> List[ExternalLink]:
654-
"""
655-
Get links for chunks starting from the specified index.
656-
Args:
657-
statement_id: The statement ID
658-
chunk_index: The starting chunk index
659-
Returns:
660-
ExternalLink: External link for the chunk
661-
"""
662-
663-
response_data = self._http_client._make_request(
664-
method="GET",
665-
path=self.CHUNK_PATH_WITH_ID_AND_INDEX.format(statement_id, chunk_index),
666-
)
667-
response = GetChunksResponse.from_dict(response_data)
668-
669-
links = response.external_links or []
670-
return links
671-
672650
# == Metadata Operations ==
673651

674652
def get_catalogs(

src/databricks/sql/backend/sea/models/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
StatementStatus,
1010
ExternalLink,
1111
ResultData,
12-
ColumnInfo,
1312
ResultManifest,
1413
)
1514

@@ -27,7 +26,6 @@
2726
ExecuteStatementResponse,
2827
GetStatementResponse,
2928
CreateSessionResponse,
30-
GetChunksResponse,
3129
)
3230

3331
__all__ = [
@@ -36,7 +34,6 @@
3634
"StatementStatus",
3735
"ExternalLink",
3836
"ResultData",
39-
"ColumnInfo",
4037
"ResultManifest",
4138
# Request models
4239
"StatementParameter",
@@ -50,5 +47,4 @@
5047
"ExecuteStatementResponse",
5148
"GetStatementResponse",
5249
"CreateSessionResponse",
53-
"GetChunksResponse",
5450
]

src/databricks/sql/backend/sea/models/base.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,25 +67,12 @@ class ResultData:
6767
attachment: Optional[bytes] = None
6868

6969

70-
@dataclass
71-
class ColumnInfo:
72-
"""Information about a column in the result set."""
73-
74-
name: str
75-
type_name: str
76-
type_text: str
77-
nullable: bool = True
78-
precision: Optional[int] = None
79-
scale: Optional[int] = None
80-
ordinal_position: Optional[int] = None
81-
82-
8370
@dataclass
8471
class ResultManifest:
8572
"""Manifest information for a result set."""
8673

8774
format: str
88-
schema: Dict[str, Any] # Will contain column information
75+
schema: Dict[str, Any]
8976
total_row_count: int
9077
total_byte_count: int
9178
total_chunk_count: int

src/databricks/sql/backend/sea/models/requests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def to_dict(self) -> Dict[str, Any]:
5454
result["parameters"] = [
5555
{
5656
"name": param.name,
57-
**({"value": param.value} if param.value is not None else {}),
58-
**({"type": param.type} if param.type is not None else {}),
57+
"value": param.value,
58+
"type": param.type,
5959
}
6060
for param in self.parameters
6161
]

0 commit comments

Comments
 (0)