1
1
import logging
2
- import uuid
3
2
import time
4
3
import re
5
- from typing import Dict , Tuple , List , Optional , Any , Union , TYPE_CHECKING , Set
4
+ from typing import Any , Dict , Tuple , List , Optional , Union , TYPE_CHECKING , Set
6
5
7
- from databricks .sql .backend .sea .models .base import ExternalLink
6
+ from databricks .sql .backend .sea .models .base import ResultManifest
8
7
from databricks .sql .backend .sea .utils .constants import (
9
8
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP ,
10
9
ResultFormat ,
11
10
ResultDisposition ,
12
11
ResultCompression ,
13
12
WaitTimeout ,
13
+ MetadataCommands ,
14
14
)
15
15
16
16
if TYPE_CHECKING :
25
25
BackendType ,
26
26
ExecuteResponse ,
27
27
)
28
- from databricks .sql .exc import ServerOperationError
28
+ from databricks .sql .exc import DatabaseError , ServerOperationError
29
29
from databricks .sql .backend .sea .utils .http_client import SeaHttpClient
30
- from databricks .sql .thrift_api .TCLIService import ttypes
31
30
from databricks .sql .types import SSLOptions
32
31
33
32
from databricks .sql .backend .sea .models import (
41
40
ExecuteStatementResponse ,
42
41
GetStatementResponse ,
43
42
CreateSessionResponse ,
44
- GetChunksResponse ,
45
43
)
46
44
from databricks .sql .backend .sea .models .responses import (
47
- parse_status ,
48
- parse_manifest ,
49
- parse_result ,
45
+ _parse_status ,
46
+ _parse_manifest ,
47
+ _parse_result ,
50
48
)
51
49
52
50
logger = logging .getLogger (__name__ )
@@ -92,7 +90,9 @@ class SeaDatabricksClient(DatabricksClient):
92
90
STATEMENT_PATH = BASE_PATH + "statements"
93
91
STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
94
92
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
95
- CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
93
+
94
+ # SEA constants
95
+ POLL_INTERVAL_SECONDS = 0.2
96
96
97
97
def __init__ (
98
98
self ,
@@ -124,7 +124,7 @@ def __init__(
124
124
http_path ,
125
125
)
126
126
127
- super (). __init__ ( ssl_options , ** kwargs )
127
+ self . _max_download_threads = kwargs . get ( "max_download_threads" , 10 )
128
128
129
129
# Extract warehouse ID from http_path
130
130
self .warehouse_id = self ._extract_warehouse_id (http_path )
@@ -136,7 +136,7 @@ def __init__(
136
136
http_path = http_path ,
137
137
http_headers = http_headers ,
138
138
auth_provider = auth_provider ,
139
- ssl_options = self . _ssl_options ,
139
+ ssl_options = ssl_options ,
140
140
** kwargs ,
141
141
)
142
142
@@ -291,28 +291,28 @@ def get_allowed_session_configurations() -> List[str]:
291
291
"""
292
292
return list (ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP .keys ())
293
293
294
- def _extract_description_from_manifest (self , manifest_obj ) -> Optional [List ]:
294
+ def _extract_description_from_manifest (
295
+ self , manifest : ResultManifest
296
+ ) -> Optional [List ]:
295
297
"""
296
- Extract column description from a manifest object.
298
+ Extract column description from a manifest object, in the format defined by
299
+ the spec: https://peps.python.org/pep-0249/#description
297
300
298
301
Args:
299
- manifest_obj : The ResultManifest object containing schema information
302
+ manifest : The ResultManifest object containing schema information
300
303
301
304
Returns:
302
305
Optional[List]: A list of column tuples or None if no columns are found
303
306
"""
304
307
305
- schema_data = manifest_obj .schema
308
+ schema_data = manifest .schema
306
309
columns_data = schema_data .get ("columns" , [])
307
310
308
311
if not columns_data :
309
312
return None
310
313
311
314
columns = []
312
315
for col_data in columns_data :
313
- if not isinstance (col_data , dict ):
314
- continue
315
-
316
316
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
317
317
columns .append (
318
318
(
@@ -328,38 +328,9 @@ def _extract_description_from_manifest(self, manifest_obj) -> Optional[List]:
328
328
329
329
return columns if columns else None
330
330
331
- def get_chunk_link (self , statement_id : str , chunk_index : int ) -> ExternalLink :
332
- """
333
- Get links for chunks starting from the specified index.
334
-
335
- Args:
336
- statement_id: The statement ID
337
- chunk_index: The starting chunk index
338
-
339
- Returns:
340
- ExternalLink: External link for the chunk
341
- """
342
-
343
- response_data = self .http_client ._make_request (
344
- method = "GET" ,
345
- path = self .CHUNK_PATH_WITH_ID_AND_INDEX .format (statement_id , chunk_index ),
346
- )
347
- response = GetChunksResponse .from_dict (response_data )
348
-
349
- links = response .external_links
350
- link = next ((l for l in links if l .chunk_index == chunk_index ), None )
351
- if not link :
352
- raise ServerOperationError (
353
- f"No link found for chunk index { chunk_index } " ,
354
- {
355
- "operation-id" : statement_id ,
356
- "diagnostic-info" : None ,
357
- },
358
- )
359
-
360
- return link
361
-
362
- def _results_message_to_execute_response (self , sea_response , command_id ):
331
+ def _results_message_to_execute_response (
332
+ self , response : GetStatementResponse
333
+ ) -> ExecuteResponse :
363
334
"""
364
335
Convert a SEA response to an ExecuteResponse and extract result data.
365
336
@@ -368,33 +339,65 @@ def _results_message_to_execute_response(self, sea_response, command_id):
368
339
command_id: The command ID
369
340
370
341
Returns:
371
- tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response,
372
- result data object, and manifest object
342
+ ExecuteResponse: The normalized execute response
373
343
"""
374
344
375
- # Parse the response
376
- status = parse_status (sea_response )
377
- manifest_obj = parse_manifest (sea_response )
378
- result_data_obj = parse_result (sea_response )
379
-
380
345
# Extract description from manifest schema
381
- description = self ._extract_description_from_manifest (manifest_obj )
346
+ description = self ._extract_description_from_manifest (response . manifest )
382
347
383
348
# Check for compression
384
- lz4_compressed = manifest_obj .result_compression == "LZ4_FRAME"
349
+ lz4_compressed = (
350
+ response .manifest .result_compression == ResultCompression .LZ4_FRAME
351
+ )
385
352
386
353
execute_response = ExecuteResponse (
387
- command_id = command_id ,
388
- status = status .state ,
354
+ command_id = CommandId . from_sea_statement_id ( response . statement_id ) ,
355
+ status = response . status .state ,
389
356
description = description ,
390
357
has_been_closed_server_side = False ,
391
358
lz4_compressed = lz4_compressed ,
392
359
is_staging_operation = False ,
393
360
arrow_schema_bytes = None ,
394
- result_format = manifest_obj .format ,
361
+ result_format = response . manifest .format ,
395
362
)
396
363
397
- return execute_response , result_data_obj , manifest_obj
364
+ return execute_response
365
+
366
+ def _check_command_not_in_failed_or_closed_state (
367
+ self , state : CommandState , command_id : CommandId
368
+ ) -> None :
369
+ if state == CommandState .CLOSED :
370
+ raise DatabaseError (
371
+ "Command {} unexpectedly closed server side" .format (command_id ),
372
+ {
373
+ "operation-id" : command_id ,
374
+ },
375
+ )
376
+ if state == CommandState .FAILED :
377
+ raise ServerOperationError (
378
+ "Command {} failed" .format (command_id ),
379
+ {
380
+ "operation-id" : command_id ,
381
+ },
382
+ )
383
+
384
+ def _wait_until_command_done (
385
+ self , response : ExecuteStatementResponse
386
+ ) -> CommandState :
387
+ """
388
+ Wait until a command is done.
389
+ """
390
+
391
+ state = response .status .state
392
+ command_id = CommandId .from_sea_statement_id (response .statement_id )
393
+
394
+ while state in [CommandState .PENDING , CommandState .RUNNING ]:
395
+ time .sleep (self .POLL_INTERVAL_SECONDS )
396
+ state = self .get_query_state (command_id )
397
+
398
+ self ._check_command_not_in_failed_or_closed_state (state , command_id )
399
+
400
+ return state
398
401
399
402
def execute_command (
400
403
self ,
@@ -405,7 +408,7 @@ def execute_command(
405
408
lz4_compression : bool ,
406
409
cursor : "Cursor" ,
407
410
use_cloud_fetch : bool ,
408
- parameters : List ,
411
+ parameters : List [ Dict [ str , Any ]] ,
409
412
async_op : bool ,
410
413
enforce_embedded_schema_correctness : bool ,
411
414
) -> Union ["ResultSet" , None ]:
@@ -439,9 +442,9 @@ def execute_command(
439
442
for param in parameters :
440
443
sea_parameters .append (
441
444
StatementParameter (
442
- name = param . name ,
443
- value = param . value ,
444
- type = param . type if hasattr ( param , "type" ) else None ,
445
+ name = param [ " name" ] ,
446
+ value = param [ " value" ] ,
447
+ type = param [ " type" ] if "type" in param else None ,
445
448
)
446
449
)
447
450
@@ -493,24 +496,7 @@ def execute_command(
493
496
if async_op :
494
497
return None
495
498
496
- # For synchronous operation, wait for the statement to complete
497
- status = response .status
498
- state = status .state
499
-
500
- # Keep polling until we reach a terminal state
501
- while state in [CommandState .PENDING , CommandState .RUNNING ]:
502
- time .sleep (0.5 ) # add a small delay to avoid excessive API calls
503
- state = self .get_query_state (command_id )
504
-
505
- if state != CommandState .SUCCEEDED :
506
- raise ServerOperationError (
507
- f"Statement execution did not succeed: { status .error .message if status .error else 'Unknown error' } " ,
508
- {
509
- "operation-id" : command_id .to_sea_statement_id (),
510
- "diagnostic-info" : None ,
511
- },
512
- )
513
-
499
+ self ._wait_until_command_done (response )
514
500
return self .get_execution_result (command_id , cursor )
515
501
516
502
def cancel_command (self , command_id : CommandId ) -> None :
@@ -622,25 +608,21 @@ def get_execution_result(
622
608
path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
623
609
data = request .to_dict (),
624
610
)
611
+ response = GetStatementResponse .from_dict (response_data )
625
612
626
613
# Create and return a SeaResultSet
627
614
from databricks .sql .result_set import SeaResultSet
628
615
629
- # Convert the response to an ExecuteResponse and extract result data
630
- (
631
- execute_response ,
632
- result_data ,
633
- manifest ,
634
- ) = self ._results_message_to_execute_response (response_data , command_id )
616
+ execute_response = self ._results_message_to_execute_response (response )
635
617
636
618
return SeaResultSet (
637
619
connection = cursor .connection ,
638
620
execute_response = execute_response ,
639
621
sea_client = self ,
640
622
buffer_size_bytes = cursor .buffer_size_bytes ,
641
623
arraysize = cursor .arraysize ,
642
- result_data = result_data ,
643
- manifest = manifest ,
624
+ result_data = response . result ,
625
+ manifest = response . manifest ,
644
626
)
645
627
646
628
# == Metadata Operations ==
@@ -654,7 +636,7 @@ def get_catalogs(
654
636
) -> "ResultSet" :
655
637
"""Get available catalogs by executing 'SHOW CATALOGS'."""
656
638
result = self .execute_command (
657
- operation = "SHOW CATALOGS" ,
639
+ operation = MetadataCommands . SHOW_CATALOGS . value ,
658
640
session_id = session_id ,
659
641
max_rows = max_rows ,
660
642
max_bytes = max_bytes ,
@@ -681,10 +663,10 @@ def get_schemas(
681
663
if not catalog_name :
682
664
raise ValueError ("Catalog name is required for get_schemas" )
683
665
684
- operation = f"SHOW SCHEMAS IN ` { catalog_name } `"
666
+ operation = MetadataCommands . SHOW_SCHEMAS . value . format ( catalog_name )
685
667
686
668
if schema_name :
687
- operation += f" LIKE ' { schema_name } '"
669
+ operation += MetadataCommands . LIKE_PATTERN . value . format ( schema_name )
688
670
689
671
result = self .execute_command (
690
672
operation = operation ,
@@ -716,17 +698,19 @@ def get_tables(
716
698
if not catalog_name :
717
699
raise ValueError ("Catalog name is required for get_tables" )
718
700
719
- operation = "SHOW TABLES IN " + (
720
- "ALL CATALOGS"
701
+ operation = (
702
+ MetadataCommands . SHOW_TABLES_ALL_CATALOGS . value
721
703
if catalog_name in [None , "*" , "%" ]
722
- else f"CATALOG `{ catalog_name } `"
704
+ else MetadataCommands .SHOW_TABLES .value .format (
705
+ MetadataCommands .CATALOG_SPECIFIC .value .format (catalog_name )
706
+ )
723
707
)
724
708
725
709
if schema_name :
726
- operation += f" SCHEMA LIKE ' { schema_name } '"
710
+ operation += MetadataCommands . SCHEMA_LIKE_PATTERN . value . format ( schema_name )
727
711
728
712
if table_name :
729
- operation += f" LIKE ' { table_name } '"
713
+ operation += MetadataCommands . LIKE_PATTERN . value . format ( table_name )
730
714
731
715
result = self .execute_command (
732
716
operation = operation ,
@@ -742,7 +726,7 @@ def get_tables(
742
726
)
743
727
assert result is not None , "execute_command returned None in synchronous mode"
744
728
745
- # Apply client-side filtering by table_types if specified
729
+ # Apply client-side filtering by table_types
746
730
from databricks .sql .backend .filters import ResultSetFilter
747
731
748
732
result = ResultSetFilter .filter_tables_by_type (result , table_types )
@@ -764,16 +748,16 @@ def get_columns(
764
748
if not catalog_name :
765
749
raise ValueError ("Catalog name is required for get_columns" )
766
750
767
- operation = f"SHOW COLUMNS IN CATALOG ` { catalog_name } `"
751
+ operation = MetadataCommands . SHOW_COLUMNS . value . format ( catalog_name )
768
752
769
753
if schema_name :
770
- operation += f" SCHEMA LIKE ' { schema_name } '"
754
+ operation += MetadataCommands . SCHEMA_LIKE_PATTERN . value . format ( schema_name )
771
755
772
756
if table_name :
773
- operation += f" TABLE LIKE ' { table_name } '"
757
+ operation += MetadataCommands . TABLE_LIKE_PATTERN . value . format ( table_name )
774
758
775
759
if column_name :
776
- operation += f" LIKE ' { column_name } '"
760
+ operation += MetadataCommands . LIKE_PATTERN . value . format ( column_name )
777
761
778
762
result = self .execute_command (
779
763
operation = operation ,
0 commit comments