Skip to content

Commit 5e01e7b

Browse files
Merge branch 'metadata-sea' into fetch-json-inline
2 parents eb1a9b4 + 09a1b11 commit 5e01e7b

File tree

3 files changed

+38
-42
lines changed

3 files changed

+38
-42
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import logging
24
import time
35
import re
@@ -16,7 +18,7 @@
1618

1719
if TYPE_CHECKING:
1820
from databricks.sql.client import Cursor
19-
from databricks.sql.result_set import ResultSet
21+
from databricks.sql.result_set import SeaResultSet
2022

2123
from databricks.sql.backend.databricks_client import DatabricksClient
2224
from databricks.sql.backend.types import (
@@ -26,7 +28,7 @@
2628
BackendType,
2729
ExecuteResponse,
2830
)
29-
from databricks.sql.exc import DatabaseError, ServerOperationError
31+
from databricks.sql.exc import DatabaseError, ProgrammingError, ServerOperationError
3032
from databricks.sql.backend.sea.utils.http_client import SeaHttpClient
3133
from databricks.sql.types import SSLOptions
3234

@@ -171,7 +173,7 @@ def _extract_warehouse_id(self, http_path: str) -> str:
171173
f"Note: SEA only works for warehouses."
172174
)
173175
logger.error(error_message)
174-
raise ValueError(error_message)
176+
raise ProgrammingError(error_message)
175177

176178
@property
177179
def max_download_threads(self) -> int:
@@ -243,14 +245,14 @@ def close_session(self, session_id: SessionId) -> None:
243245
session_id: The session identifier returned by open_session()
244246
245247
Raises:
246-
ValueError: If the session ID is invalid
248+
ProgrammingError: If the session ID is invalid
247249
OperationalError: If there's an error closing the session
248250
"""
249251

250252
logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id)
251253

252254
if session_id.backend_type != BackendType.SEA:
253-
raise ValueError("Not a valid SEA session ID")
255+
raise ProgrammingError("Not a valid SEA session ID")
254256
sea_session_id = session_id.to_sea_session_id()
255257

256258
request_data = DeleteSessionRequest(
@@ -402,12 +404,12 @@ def execute_command(
402404
max_rows: int,
403405
max_bytes: int,
404406
lz4_compression: bool,
405-
cursor: "Cursor",
407+
cursor: Cursor,
406408
use_cloud_fetch: bool,
407409
parameters: List[ttypes.TSparkParameter],
408410
async_op: bool,
409411
enforce_embedded_schema_correctness: bool,
410-
) -> Union["ResultSet", None]:
412+
) -> Union[SeaResultSet, None]:
411413
"""
412414
Execute a SQL command using the SEA backend.
413415
@@ -428,7 +430,7 @@ def execute_command(
428430
"""
429431

430432
if session_id.backend_type != BackendType.SEA:
431-
raise ValueError("Not a valid SEA session ID")
433+
raise ProgrammingError("Not a valid SEA session ID")
432434

433435
sea_session_id = session_id.to_sea_session_id()
434436

@@ -503,11 +505,11 @@ def cancel_command(self, command_id: CommandId) -> None:
503505
command_id: Command identifier to cancel
504506
505507
Raises:
506-
ValueError: If the command ID is invalid
508+
ProgrammingError: If the command ID is invalid
507509
"""
508510

509511
if command_id.backend_type != BackendType.SEA:
510-
raise ValueError("Not a valid SEA command ID")
512+
raise ProgrammingError("Not a valid SEA command ID")
511513

512514
sea_statement_id = command_id.to_sea_statement_id()
513515

@@ -526,11 +528,11 @@ def close_command(self, command_id: CommandId) -> None:
526528
command_id: Command identifier to close
527529
528530
Raises:
529-
ValueError: If the command ID is invalid
531+
ProgrammingError: If the command ID is invalid
530532
"""
531533

532534
if command_id.backend_type != BackendType.SEA:
533-
raise ValueError("Not a valid SEA command ID")
535+
raise ProgrammingError("Not a valid SEA command ID")
534536

535537
sea_statement_id = command_id.to_sea_statement_id()
536538

@@ -552,7 +554,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
552554
CommandState: The current state of the command
553555
554556
Raises:
555-
ValueError: If the command ID is invalid
557+
ProgrammingError: If the command ID is invalid
556558
"""
557559

558560
if command_id.backend_type != BackendType.SEA:
@@ -574,8 +576,8 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
574576
def get_execution_result(
575577
self,
576578
command_id: CommandId,
577-
cursor: "Cursor",
578-
) -> "ResultSet":
579+
cursor: Cursor,
580+
) -> SeaResultSet:
579581
"""
580582
Get the result of a command execution.
581583
@@ -584,14 +586,14 @@ def get_execution_result(
584586
cursor: Cursor executing the command
585587
586588
Returns:
587-
ResultSet: A SeaResultSet instance with the execution results
589+
SeaResultSet: A SeaResultSet instance with the execution results
588590
589591
Raises:
590592
ValueError: If the command ID is invalid
591593
"""
592594

593595
if command_id.backend_type != BackendType.SEA:
594-
raise ValueError("Not a valid SEA command ID")
596+
raise ProgrammingError("Not a valid SEA command ID")
595597

596598
sea_statement_id = command_id.to_sea_statement_id()
597599

@@ -628,8 +630,8 @@ def get_catalogs(
628630
session_id: SessionId,
629631
max_rows: int,
630632
max_bytes: int,
631-
cursor: "Cursor",
632-
) -> "ResultSet":
633+
cursor: Cursor,
634+
) -> SeaResultSet:
633635
"""Get available catalogs by executing 'SHOW CATALOGS'."""
634636
result = self.execute_command(
635637
operation=MetadataCommands.SHOW_CATALOGS.value,
@@ -651,13 +653,13 @@ def get_schemas(
651653
session_id: SessionId,
652654
max_rows: int,
653655
max_bytes: int,
654-
cursor: "Cursor",
656+
cursor: Cursor,
655657
catalog_name: Optional[str] = None,
656658
schema_name: Optional[str] = None,
657-
) -> "ResultSet":
659+
) -> SeaResultSet:
658660
"""Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'."""
659661
if not catalog_name:
660-
raise ValueError("Catalog name is required for get_schemas")
662+
raise DatabaseError("Catalog name is required for get_schemas")
661663

662664
operation = MetadataCommands.SHOW_SCHEMAS.value.format(catalog_name)
663665

@@ -684,12 +686,12 @@ def get_tables(
684686
session_id: SessionId,
685687
max_rows: int,
686688
max_bytes: int,
687-
cursor: "Cursor",
689+
cursor: Cursor,
688690
catalog_name: Optional[str] = None,
689691
schema_name: Optional[str] = None,
690692
table_name: Optional[str] = None,
691693
table_types: Optional[List[str]] = None,
692-
) -> "ResultSet":
694+
) -> SeaResultSet:
693695
"""Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'."""
694696
operation = (
695697
MetadataCommands.SHOW_TABLES_ALL_CATALOGS.value
@@ -719,12 +721,6 @@ def get_tables(
719721
)
720722
assert result is not None, "execute_command returned None in synchronous mode"
721723

722-
from databricks.sql.result_set import SeaResultSet
723-
724-
assert isinstance(
725-
result, SeaResultSet
726-
), "execute_command returned a non-SeaResultSet"
727-
728724
# Apply client-side filtering by table_types
729725
from databricks.sql.backend.sea.utils.filters import ResultSetFilter
730726

@@ -737,15 +733,15 @@ def get_columns(
737733
session_id: SessionId,
738734
max_rows: int,
739735
max_bytes: int,
740-
cursor: "Cursor",
736+
cursor: Cursor,
741737
catalog_name: Optional[str] = None,
742738
schema_name: Optional[str] = None,
743739
table_name: Optional[str] = None,
744740
column_name: Optional[str] = None,
745-
) -> "ResultSet":
741+
) -> SeaResultSet:
746742
"""Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'."""
747743
if not catalog_name:
748-
raise ValueError("Catalog name is required for get_columns")
744+
raise DatabaseError("Catalog name is required for get_columns")
749745

750746
operation = MetadataCommands.SHOW_COLUMNS.value.format(catalog_name)
751747

src/databricks/sql/backend/sea/utils/filters.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def filter_by_column_values(
112112
result_set,
113113
lambda row: (
114114
len(row) > column_index
115-
and isinstance(row[column_index], str)
116115
and (
117116
row[column_index].upper()
118117
if not case_sensitive

tests/unit/test_sea_backend.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from databricks.sql.exc import (
1919
Error,
2020
NotSupportedError,
21+
ProgrammingError,
2122
ServerOperationError,
2223
DatabaseError,
2324
)
@@ -129,7 +130,7 @@ def test_initialization(self, mock_http_client):
129130
assert client3.max_download_threads == 5
130131

131132
# Test with invalid HTTP path
132-
with pytest.raises(ValueError) as excinfo:
133+
with pytest.raises(ProgrammingError) as excinfo:
133134
SeaDatabricksClient(
134135
server_hostname="test-server.databricks.com",
135136
port=443,
@@ -195,7 +196,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i
195196
)
196197

197198
# Test close_session with invalid ID type
198-
with pytest.raises(ValueError) as excinfo:
199+
with pytest.raises(ProgrammingError) as excinfo:
199200
sea_client.close_session(thrift_session_id)
200201
assert "Not a valid SEA session ID" in str(excinfo.value)
201202

@@ -244,7 +245,7 @@ def test_command_execution_sync(
244245
assert cmd_id_arg.guid == "test-statement-123"
245246

246247
# Test with invalid session ID
247-
with pytest.raises(ValueError) as excinfo:
248+
with pytest.raises(ProgrammingError) as excinfo:
248249
mock_thrift_handle = MagicMock()
249250
mock_thrift_handle.sessionId.guid = b"guid"
250251
mock_thrift_handle.sessionId.secret = b"secret"
@@ -452,7 +453,7 @@ def test_command_management(
452453
)
453454

454455
# Test cancel_command with invalid ID
455-
with pytest.raises(ValueError) as excinfo:
456+
with pytest.raises(ProgrammingError) as excinfo:
456457
sea_client.cancel_command(thrift_command_id)
457458
assert "Not a valid SEA command ID" in str(excinfo.value)
458459

@@ -466,7 +467,7 @@ def test_command_management(
466467
)
467468

468469
# Test close_command with invalid ID
469-
with pytest.raises(ValueError) as excinfo:
470+
with pytest.raises(ProgrammingError) as excinfo:
470471
sea_client.close_command(thrift_command_id)
471472
assert "Not a valid SEA command ID" in str(excinfo.value)
472473

@@ -525,7 +526,7 @@ def test_command_management(
525526
assert result.status == CommandState.SUCCEEDED
526527

527528
# Test get_execution_result with invalid ID
528-
with pytest.raises(ValueError) as excinfo:
529+
with pytest.raises(ProgrammingError) as excinfo:
529530
sea_client.get_execution_result(thrift_command_id, mock_cursor)
530531
assert "Not a valid SEA command ID" in str(excinfo.value)
531532

@@ -721,7 +722,7 @@ def test_get_schemas(self, sea_client, sea_session_id, mock_cursor):
721722
)
722723

723724
# Case 3: Without catalog name (should raise ValueError)
724-
with pytest.raises(ValueError) as excinfo:
725+
with pytest.raises(DatabaseError) as excinfo:
725726
sea_client.get_schemas(
726727
session_id=sea_session_id,
727728
max_rows=100,
@@ -872,7 +873,7 @@ def test_get_columns(self, sea_client, sea_session_id, mock_cursor):
872873
)
873874

874875
# Case 3: Without catalog name (should raise ValueError)
875-
with pytest.raises(ValueError) as excinfo:
876+
with pytest.raises(DatabaseError) as excinfo:
876877
sea_client.get_columns(
877878
session_id=sea_session_id,
878879
max_rows=100,

0 commit comments

Comments
 (0)