1
1
import logging
2
+ import uuid
2
3
import time
3
4
import re
4
- from typing import Any , Dict , Tuple , List , Optional , Union , TYPE_CHECKING , Set
5
+ from typing import Dict , Tuple , List , Optional , Any , Union , TYPE_CHECKING , Set
5
6
6
- from databricks .sql .backend .sea .models .base import ResultManifest
7
+ from databricks .sql .backend .sea .models .base import ExternalLink
7
8
from databricks .sql .backend .sea .utils .constants import (
8
9
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP ,
9
10
ResultFormat ,
10
11
ResultDisposition ,
11
12
ResultCompression ,
12
13
WaitTimeout ,
13
- MetadataCommands ,
14
14
)
15
15
16
16
if TYPE_CHECKING :
25
25
BackendType ,
26
26
ExecuteResponse ,
27
27
)
28
- from databricks .sql .exc import DatabaseError , ServerOperationError
28
+ from databricks .sql .exc import ServerOperationError
29
29
from databricks .sql .backend .sea .utils .http_client import SeaHttpClient
30
+ from databricks .sql .thrift_api .TCLIService import ttypes
30
31
from databricks .sql .types import SSLOptions
31
32
32
33
from databricks .sql .backend .sea .models import (
40
41
ExecuteStatementResponse ,
41
42
GetStatementResponse ,
42
43
CreateSessionResponse ,
44
+ GetChunksResponse ,
43
45
)
44
46
from databricks .sql .backend .sea .models .responses import (
45
- _parse_status ,
46
- _parse_manifest ,
47
- _parse_result ,
47
+ parse_status ,
48
+ parse_manifest ,
49
+ parse_result ,
48
50
)
49
51
50
52
logger = logging .getLogger (__name__ )
@@ -90,9 +92,7 @@ class SeaDatabricksClient(DatabricksClient):
90
92
STATEMENT_PATH = BASE_PATH + "statements"
91
93
STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
92
94
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
93
-
94
- # SEA constants
95
- POLL_INTERVAL_SECONDS = 0.2
95
+ CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
96
96
97
97
def __init__ (
98
98
self ,
@@ -124,7 +124,7 @@ def __init__(
124
124
http_path ,
125
125
)
126
126
127
- self . _max_download_threads = kwargs . get ( "max_download_threads" , 10 )
127
+ super (). __init__ ( ssl_options , ** kwargs )
128
128
129
129
# Extract warehouse ID from http_path
130
130
self .warehouse_id = self ._extract_warehouse_id (http_path )
@@ -136,7 +136,7 @@ def __init__(
136
136
http_path = http_path ,
137
137
http_headers = http_headers ,
138
138
auth_provider = auth_provider ,
139
- ssl_options = ssl_options ,
139
+ ssl_options = self . _ssl_options ,
140
140
** kwargs ,
141
141
)
142
142
@@ -291,28 +291,28 @@ def get_allowed_session_configurations() -> List[str]:
291
291
"""
292
292
return list (ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP .keys ())
293
293
294
- def _extract_description_from_manifest (
295
- self , manifest : ResultManifest
296
- ) -> Optional [List ]:
294
+ def _extract_description_from_manifest (self , manifest_obj ) -> Optional [List ]:
297
295
"""
298
- Extract column description from a manifest object, in the format defined by
299
- the spec: https://peps.python.org/pep-0249/#description
296
+ Extract column description from a manifest object.
300
297
301
298
Args:
302
- manifest : The ResultManifest object containing schema information
299
+ manifest_obj : The ResultManifest object containing schema information
303
300
304
301
Returns:
305
302
Optional[List]: A list of column tuples or None if no columns are found
306
303
"""
307
304
308
- schema_data = manifest .schema
305
+ schema_data = manifest_obj .schema
309
306
columns_data = schema_data .get ("columns" , [])
310
307
311
308
if not columns_data :
312
309
return None
313
310
314
311
columns = []
315
312
for col_data in columns_data :
313
+ if not isinstance (col_data , dict ):
314
+ continue
315
+
316
316
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
317
317
columns .append (
318
318
(
@@ -328,9 +328,38 @@ def _extract_description_from_manifest(
328
328
329
329
return columns if columns else None
330
330
331
- def _results_message_to_execute_response (
332
- self , response : GetStatementResponse
333
- ) -> ExecuteResponse :
331
+ def get_chunk_link (self , statement_id : str , chunk_index : int ) -> ExternalLink :
332
+ """
333
+ Get links for chunks starting from the specified index.
334
+
335
+ Args:
336
+ statement_id: The statement ID
337
+ chunk_index: The starting chunk index
338
+
339
+ Returns:
340
+ ExternalLink: External link for the chunk
341
+ """
342
+
343
+ response_data = self .http_client ._make_request (
344
+ method = "GET" ,
345
+ path = self .CHUNK_PATH_WITH_ID_AND_INDEX .format (statement_id , chunk_index ),
346
+ )
347
+ response = GetChunksResponse .from_dict (response_data )
348
+
349
+ links = response .external_links
350
+ link = next ((l for l in links if l .chunk_index == chunk_index ), None )
351
+ if not link :
352
+ raise ServerOperationError (
353
+ f"No link found for chunk index { chunk_index } " ,
354
+ {
355
+ "operation-id" : statement_id ,
356
+ "diagnostic-info" : None ,
357
+ },
358
+ )
359
+
360
+ return link
361
+
362
+ def _results_message_to_execute_response (self , sea_response , command_id ):
334
363
"""
335
364
Convert a SEA response to an ExecuteResponse and extract result data.
336
365
@@ -339,65 +368,33 @@ def _results_message_to_execute_response(
339
368
command_id: The command ID
340
369
341
370
Returns:
342
- ExecuteResponse: The normalized execute response
371
+ tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response,
372
+ result data object, and manifest object
343
373
"""
344
374
375
+ # Parse the response
376
+ status = parse_status (sea_response )
377
+ manifest_obj = parse_manifest (sea_response )
378
+ result_data_obj = parse_result (sea_response )
379
+
345
380
# Extract description from manifest schema
346
- description = self ._extract_description_from_manifest (response . manifest )
381
+ description = self ._extract_description_from_manifest (manifest_obj )
347
382
348
383
# Check for compression
349
- lz4_compressed = (
350
- response .manifest .result_compression == ResultCompression .LZ4_FRAME
351
- )
384
+ lz4_compressed = manifest_obj .result_compression == "LZ4_FRAME"
352
385
353
386
execute_response = ExecuteResponse (
354
- command_id = CommandId . from_sea_statement_id ( response . statement_id ) ,
355
- status = response . status .state ,
387
+ command_id = command_id ,
388
+ status = status .state ,
356
389
description = description ,
357
390
has_been_closed_server_side = False ,
358
391
lz4_compressed = lz4_compressed ,
359
392
is_staging_operation = False ,
360
393
arrow_schema_bytes = None ,
361
- result_format = response . manifest .format ,
394
+ result_format = manifest_obj .format ,
362
395
)
363
396
364
- return execute_response
365
-
366
- def _check_command_not_in_failed_or_closed_state (
367
- self , state : CommandState , command_id : CommandId
368
- ) -> None :
369
- if state == CommandState .CLOSED :
370
- raise DatabaseError (
371
- "Command {} unexpectedly closed server side" .format (command_id ),
372
- {
373
- "operation-id" : command_id ,
374
- },
375
- )
376
- if state == CommandState .FAILED :
377
- raise ServerOperationError (
378
- "Command {} failed" .format (command_id ),
379
- {
380
- "operation-id" : command_id ,
381
- },
382
- )
383
-
384
- def _wait_until_command_done (
385
- self , response : ExecuteStatementResponse
386
- ) -> CommandState :
387
- """
388
- Wait until a command is done.
389
- """
390
-
391
- state = response .status .state
392
- command_id = CommandId .from_sea_statement_id (response .statement_id )
393
-
394
- while state in [CommandState .PENDING , CommandState .RUNNING ]:
395
- time .sleep (self .POLL_INTERVAL_SECONDS )
396
- state = self .get_query_state (command_id )
397
-
398
- self ._check_command_not_in_failed_or_closed_state (state , command_id )
399
-
400
- return state
397
+ return execute_response , result_data_obj , manifest_obj
401
398
402
399
def execute_command (
403
400
self ,
@@ -408,7 +405,7 @@ def execute_command(
408
405
lz4_compression : bool ,
409
406
cursor : "Cursor" ,
410
407
use_cloud_fetch : bool ,
411
- parameters : List [ Dict [ str , Any ]] ,
408
+ parameters : List ,
412
409
async_op : bool ,
413
410
enforce_embedded_schema_correctness : bool ,
414
411
) -> Union ["ResultSet" , None ]:
@@ -442,9 +439,9 @@ def execute_command(
442
439
for param in parameters :
443
440
sea_parameters .append (
444
441
StatementParameter (
445
- name = param [ " name" ] ,
446
- value = param [ " value" ] ,
447
- type = param [ " type" ] if "type" in param else None ,
442
+ name = param . name ,
443
+ value = param . value ,
444
+ type = param . type if hasattr ( param , "type" ) else None ,
448
445
)
449
446
)
450
447
@@ -496,7 +493,24 @@ def execute_command(
496
493
if async_op :
497
494
return None
498
495
499
- self ._wait_until_command_done (response )
496
+ # For synchronous operation, wait for the statement to complete
497
+ status = response .status
498
+ state = status .state
499
+
500
+ # Keep polling until we reach a terminal state
501
+ while state in [CommandState .PENDING , CommandState .RUNNING ]:
502
+ time .sleep (0.5 ) # add a small delay to avoid excessive API calls
503
+ state = self .get_query_state (command_id )
504
+
505
+ if state != CommandState .SUCCEEDED :
506
+ raise ServerOperationError (
507
+ f"Statement execution did not succeed: { status .error .message if status .error else 'Unknown error' } " ,
508
+ {
509
+ "operation-id" : command_id .to_sea_statement_id (),
510
+ "diagnostic-info" : None ,
511
+ },
512
+ )
513
+
500
514
return self .get_execution_result (command_id , cursor )
501
515
502
516
def cancel_command (self , command_id : CommandId ) -> None :
@@ -608,21 +622,25 @@ def get_execution_result(
608
622
path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
609
623
data = request .to_dict (),
610
624
)
611
- response = GetStatementResponse .from_dict (response_data )
612
625
613
626
# Create and return a SeaResultSet
614
627
from databricks .sql .result_set import SeaResultSet
615
628
616
- execute_response = self ._results_message_to_execute_response (response )
629
+ # Convert the response to an ExecuteResponse and extract result data
630
+ (
631
+ execute_response ,
632
+ result_data ,
633
+ manifest ,
634
+ ) = self ._results_message_to_execute_response (response_data , command_id )
617
635
618
636
return SeaResultSet (
619
637
connection = cursor .connection ,
620
638
execute_response = execute_response ,
621
639
sea_client = self ,
622
640
buffer_size_bytes = cursor .buffer_size_bytes ,
623
641
arraysize = cursor .arraysize ,
624
- result_data = response . result ,
625
- manifest = response . manifest ,
642
+ result_data = result_data ,
643
+ manifest = manifest ,
626
644
)
627
645
628
646
# == Metadata Operations ==
@@ -636,7 +654,7 @@ def get_catalogs(
636
654
) -> "ResultSet" :
637
655
"""Get available catalogs by executing 'SHOW CATALOGS'."""
638
656
result = self .execute_command (
639
- operation = MetadataCommands . SHOW_CATALOGS . value ,
657
+ operation = "SHOW CATALOGS" ,
640
658
session_id = session_id ,
641
659
max_rows = max_rows ,
642
660
max_bytes = max_bytes ,
@@ -663,10 +681,10 @@ def get_schemas(
663
681
if not catalog_name :
664
682
raise ValueError ("Catalog name is required for get_schemas" )
665
683
666
- operation = MetadataCommands . SHOW_SCHEMAS . value . format ( catalog_name )
684
+ operation = f"SHOW SCHEMAS IN ` { catalog_name } `"
667
685
668
686
if schema_name :
669
- operation += MetadataCommands . LIKE_PATTERN . value . format ( schema_name )
687
+ operation += f" LIKE ' { schema_name } '"
670
688
671
689
result = self .execute_command (
672
690
operation = operation ,
@@ -698,19 +716,17 @@ def get_tables(
698
716
if not catalog_name :
699
717
raise ValueError ("Catalog name is required for get_tables" )
700
718
701
- operation = (
702
- MetadataCommands . SHOW_TABLES_ALL_CATALOGS . value
719
+ operation = "SHOW TABLES IN " + (
720
+ "ALL CATALOGS"
703
721
if catalog_name in [None , "*" , "%" ]
704
- else MetadataCommands .SHOW_TABLES .value .format (
705
- MetadataCommands .CATALOG_SPECIFIC .value .format (catalog_name )
706
- )
722
+ else f"CATALOG `{ catalog_name } `"
707
723
)
708
724
709
725
if schema_name :
710
- operation += MetadataCommands . SCHEMA_LIKE_PATTERN . value . format ( schema_name )
726
+ operation += f" SCHEMA LIKE ' { schema_name } '"
711
727
712
728
if table_name :
713
- operation += MetadataCommands . LIKE_PATTERN . value . format ( table_name )
729
+ operation += f" LIKE ' { table_name } '"
714
730
715
731
result = self .execute_command (
716
732
operation = operation ,
@@ -726,7 +742,7 @@ def get_tables(
726
742
)
727
743
assert result is not None , "execute_command returned None in synchronous mode"
728
744
729
- # Apply client-side filtering by table_types
745
+ # Apply client-side filtering by table_types if specified
730
746
from databricks .sql .backend .filters import ResultSetFilter
731
747
732
748
result = ResultSetFilter .filter_tables_by_type (result , table_types )
@@ -748,16 +764,16 @@ def get_columns(
748
764
if not catalog_name :
749
765
raise ValueError ("Catalog name is required for get_columns" )
750
766
751
- operation = MetadataCommands . SHOW_COLUMNS . value . format ( catalog_name )
767
+ operation = f"SHOW COLUMNS IN CATALOG ` { catalog_name } `"
752
768
753
769
if schema_name :
754
- operation += MetadataCommands . SCHEMA_LIKE_PATTERN . value . format ( schema_name )
770
+ operation += f" SCHEMA LIKE ' { schema_name } '"
755
771
756
772
if table_name :
757
- operation += MetadataCommands . TABLE_LIKE_PATTERN . value . format ( table_name )
773
+ operation += f" TABLE LIKE ' { table_name } '"
758
774
759
775
if column_name :
760
- operation += MetadataCommands . LIKE_PATTERN . value . format ( column_name )
776
+ operation += f" LIKE ' { column_name } '"
761
777
762
778
result = self .execute_command (
763
779
operation = operation ,
0 commit comments