1
+ from __future__ import annotations
2
+
1
3
import logging
2
4
import time
3
5
import re
16
18
17
19
if TYPE_CHECKING :
18
20
from databricks .sql .client import Cursor
19
- from databricks .sql .result_set import ResultSet
21
+ from databricks .sql .result_set import SeaResultSet
20
22
21
23
from databricks .sql .backend .databricks_client import DatabricksClient
22
24
from databricks .sql .backend .types import (
26
28
BackendType ,
27
29
ExecuteResponse ,
28
30
)
29
- from databricks .sql .exc import DatabaseError , ServerOperationError
31
+ from databricks .sql .exc import DatabaseError , ProgrammingError , ServerOperationError
30
32
from databricks .sql .backend .sea .utils .http_client import SeaHttpClient
31
33
from databricks .sql .types import SSLOptions
32
34
@@ -171,7 +173,7 @@ def _extract_warehouse_id(self, http_path: str) -> str:
171
173
f"Note: SEA only works for warehouses."
172
174
)
173
175
logger .error (error_message )
174
- raise ValueError (error_message )
176
+ raise ProgrammingError (error_message )
175
177
176
178
@property
177
179
def max_download_threads (self ) -> int :
@@ -243,14 +245,14 @@ def close_session(self, session_id: SessionId) -> None:
243
245
session_id: The session identifier returned by open_session()
244
246
245
247
Raises:
246
- ValueError : If the session ID is invalid
248
+ ProgrammingError : If the session ID is invalid
247
249
OperationalError: If there's an error closing the session
248
250
"""
249
251
250
252
logger .debug ("SeaDatabricksClient.close_session(session_id=%s)" , session_id )
251
253
252
254
if session_id .backend_type != BackendType .SEA :
253
- raise ValueError ("Not a valid SEA session ID" )
255
+ raise ProgrammingError ("Not a valid SEA session ID" )
254
256
sea_session_id = session_id .to_sea_session_id ()
255
257
256
258
request_data = DeleteSessionRequest (
@@ -402,12 +404,12 @@ def execute_command(
402
404
max_rows : int ,
403
405
max_bytes : int ,
404
406
lz4_compression : bool ,
405
- cursor : " Cursor" ,
407
+ cursor : Cursor ,
406
408
use_cloud_fetch : bool ,
407
409
parameters : List [ttypes .TSparkParameter ],
408
410
async_op : bool ,
409
411
enforce_embedded_schema_correctness : bool ,
410
- ) -> Union ["ResultSet" , None ]:
412
+ ) -> Union [SeaResultSet , None ]:
411
413
"""
412
414
Execute a SQL command using the SEA backend.
413
415
@@ -428,7 +430,7 @@ def execute_command(
428
430
"""
429
431
430
432
if session_id .backend_type != BackendType .SEA :
431
- raise ValueError ("Not a valid SEA session ID" )
433
+ raise ProgrammingError ("Not a valid SEA session ID" )
432
434
433
435
sea_session_id = session_id .to_sea_session_id ()
434
436
@@ -503,11 +505,11 @@ def cancel_command(self, command_id: CommandId) -> None:
503
505
command_id: Command identifier to cancel
504
506
505
507
Raises:
506
- ValueError : If the command ID is invalid
508
+ ProgrammingError : If the command ID is invalid
507
509
"""
508
510
509
511
if command_id .backend_type != BackendType .SEA :
510
- raise ValueError ("Not a valid SEA command ID" )
512
+ raise ProgrammingError ("Not a valid SEA command ID" )
511
513
512
514
sea_statement_id = command_id .to_sea_statement_id ()
513
515
@@ -526,11 +528,11 @@ def close_command(self, command_id: CommandId) -> None:
526
528
command_id: Command identifier to close
527
529
528
530
Raises:
529
- ValueError : If the command ID is invalid
531
+ ProgrammingError : If the command ID is invalid
530
532
"""
531
533
532
534
if command_id .backend_type != BackendType .SEA :
533
- raise ValueError ("Not a valid SEA command ID" )
535
+ raise ProgrammingError ("Not a valid SEA command ID" )
534
536
535
537
sea_statement_id = command_id .to_sea_statement_id ()
536
538
@@ -552,7 +554,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
552
554
CommandState: The current state of the command
553
555
554
556
Raises:
555
- ValueError : If the command ID is invalid
557
+ ProgrammingError : If the command ID is invalid
556
558
"""
557
559
558
560
if command_id .backend_type != BackendType .SEA :
@@ -574,8 +576,8 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
574
576
def get_execution_result (
575
577
self ,
576
578
command_id : CommandId ,
577
- cursor : " Cursor" ,
578
- ) -> "ResultSet" :
579
+ cursor : Cursor ,
580
+ ) -> SeaResultSet :
579
581
"""
580
582
Get the result of a command execution.
581
583
@@ -584,14 +586,14 @@ def get_execution_result(
584
586
cursor: Cursor executing the command
585
587
586
588
Returns:
587
- ResultSet : A SeaResultSet instance with the execution results
589
+ SeaResultSet : A SeaResultSet instance with the execution results
588
590
589
591
Raises:
590
592
ValueError: If the command ID is invalid
591
593
"""
592
594
593
595
if command_id .backend_type != BackendType .SEA :
594
- raise ValueError ("Not a valid SEA command ID" )
596
+ raise ProgrammingError ("Not a valid SEA command ID" )
595
597
596
598
sea_statement_id = command_id .to_sea_statement_id ()
597
599
@@ -628,8 +630,8 @@ def get_catalogs(
628
630
session_id : SessionId ,
629
631
max_rows : int ,
630
632
max_bytes : int ,
631
- cursor : " Cursor" ,
632
- ) -> "ResultSet" :
633
+ cursor : Cursor ,
634
+ ) -> SeaResultSet :
633
635
"""Get available catalogs by executing 'SHOW CATALOGS'."""
634
636
result = self .execute_command (
635
637
operation = MetadataCommands .SHOW_CATALOGS .value ,
@@ -651,13 +653,13 @@ def get_schemas(
651
653
session_id : SessionId ,
652
654
max_rows : int ,
653
655
max_bytes : int ,
654
- cursor : " Cursor" ,
656
+ cursor : Cursor ,
655
657
catalog_name : Optional [str ] = None ,
656
658
schema_name : Optional [str ] = None ,
657
- ) -> "ResultSet" :
659
+ ) -> SeaResultSet :
658
660
"""Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'."""
659
661
if not catalog_name :
660
- raise ValueError ("Catalog name is required for get_schemas" )
662
+ raise DatabaseError ("Catalog name is required for get_schemas" )
661
663
662
664
operation = MetadataCommands .SHOW_SCHEMAS .value .format (catalog_name )
663
665
@@ -684,12 +686,12 @@ def get_tables(
684
686
session_id : SessionId ,
685
687
max_rows : int ,
686
688
max_bytes : int ,
687
- cursor : " Cursor" ,
689
+ cursor : Cursor ,
688
690
catalog_name : Optional [str ] = None ,
689
691
schema_name : Optional [str ] = None ,
690
692
table_name : Optional [str ] = None ,
691
693
table_types : Optional [List [str ]] = None ,
692
- ) -> "ResultSet" :
694
+ ) -> SeaResultSet :
693
695
"""Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'."""
694
696
operation = (
695
697
MetadataCommands .SHOW_TABLES_ALL_CATALOGS .value
@@ -719,12 +721,6 @@ def get_tables(
719
721
)
720
722
assert result is not None , "execute_command returned None in synchronous mode"
721
723
722
- from databricks .sql .result_set import SeaResultSet
723
-
724
- assert isinstance (
725
- result , SeaResultSet
726
- ), "execute_command returned a non-SeaResultSet"
727
-
728
724
# Apply client-side filtering by table_types
729
725
from databricks .sql .backend .sea .utils .filters import ResultSetFilter
730
726
@@ -737,15 +733,15 @@ def get_columns(
737
733
session_id : SessionId ,
738
734
max_rows : int ,
739
735
max_bytes : int ,
740
- cursor : " Cursor" ,
736
+ cursor : Cursor ,
741
737
catalog_name : Optional [str ] = None ,
742
738
schema_name : Optional [str ] = None ,
743
739
table_name : Optional [str ] = None ,
744
740
column_name : Optional [str ] = None ,
745
- ) -> "ResultSet" :
741
+ ) -> SeaResultSet :
746
742
"""Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'."""
747
743
if not catalog_name :
748
- raise ValueError ("Catalog name is required for get_columns" )
744
+ raise DatabaseError ("Catalog name is required for get_columns" )
749
745
750
746
operation = MetadataCommands .SHOW_COLUMNS .value .format (catalog_name )
751
747
0 commit comments