5
5
import re
6
6
from typing import Any , Dict , Tuple , List , Optional , Union , TYPE_CHECKING , Set
7
7
8
- from databricks .sql .backend .sea .models .base import ExternalLink , ResultManifest
8
+ from databricks .sql .backend .sea .models .base import ResultManifest , StatementStatus
9
9
from databricks .sql .backend .sea .utils .constants import (
10
10
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP ,
11
11
ResultFormat ,
19
19
if TYPE_CHECKING :
20
20
from databricks .sql .client import Cursor
21
21
22
- from databricks .sql .backend . sea . result_set import SeaResultSet
22
+ from databricks .sql .result_set import SeaResultSet
23
23
24
24
from databricks .sql .backend .databricks_client import DatabricksClient
25
25
from databricks .sql .backend .types import (
45
45
GetStatementResponse ,
46
46
CreateSessionResponse ,
47
47
)
48
- from databricks .sql .backend .sea .models .responses import GetChunksResponse
49
48
50
49
logger = logging .getLogger (__name__ )
51
50
52
51
53
52
def _filter_session_configuration (
54
53
session_configuration : Optional [Dict [str , Any ]],
55
54
) -> Dict [str , str ]:
55
+ """
56
+ Filter and normalise the provided session configuration parameters.
57
+
58
+ The Statement Execution API supports only a subset of SQL session
59
+ configuration options. This helper validates the supplied
60
+ ``session_configuration`` dictionary against the allow-list defined in
61
+ ``ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP`` and returns a new
62
+ dictionary that contains **only** the supported parameters.
63
+
64
+ Args:
65
+ session_configuration: Optional mapping of session configuration
66
+ names to their desired values. Key comparison is
67
+ case-insensitive.
68
+
69
+ Returns:
70
+ Dict[str, str]: A dictionary containing only the supported
71
+ configuration parameters with lower-case keys and string values. If
72
+ *session_configuration* is ``None`` or empty, an empty dictionary is
73
+ returned.
74
+ """
75
+
56
76
if not session_configuration :
57
77
return {}
58
78
@@ -90,7 +110,6 @@ class SeaDatabricksClient(DatabricksClient):
90
110
STATEMENT_PATH = BASE_PATH + "statements"
91
111
STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
92
112
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
93
- CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
94
113
95
114
# SEA constants
96
115
POLL_INTERVAL_SECONDS = 0.2
@@ -143,7 +162,7 @@ def __init__(
143
162
http_path = http_path ,
144
163
http_headers = http_headers ,
145
164
auth_provider = auth_provider ,
146
- ssl_options = self . _ssl_options ,
165
+ ssl_options = ssl_options ,
147
166
** kwargs ,
148
167
)
149
168
@@ -275,32 +294,9 @@ def close_session(self, session_id: SessionId) -> None:
275
294
data = request_data .to_dict (),
276
295
)
277
296
278
- @staticmethod
279
- def get_default_session_configuration_value (name : str ) -> Optional [str ]:
280
- """
281
- Get the default value for a session configuration parameter.
282
-
283
- Args:
284
- name: The name of the session configuration parameter
285
-
286
- Returns:
287
- The default value if the parameter is supported, None otherwise
288
- """
289
- return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP .get (name .upper ())
290
-
291
- @staticmethod
292
- def get_allowed_session_configurations () -> List [str ]:
293
- """
294
- Get the list of allowed session configuration parameters.
295
-
296
- Returns:
297
- List of allowed session configuration parameter names
298
- """
299
- return list (ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP .keys ())
300
-
301
297
def _extract_description_from_manifest (
302
298
self , manifest : ResultManifest
303
- ) -> List [ Tuple ]:
299
+ ) -> Optional [ List ]:
304
300
"""
305
301
Extract column description from a manifest object, in the format defined by
306
302
the spec: https://peps.python.org/pep-0249/#description
@@ -309,28 +305,39 @@ def _extract_description_from_manifest(
309
305
manifest: The ResultManifest object containing schema information
310
306
311
307
Returns:
312
- List[Tuple ]: A list of column tuples
308
+ Optional[List ]: A list of column tuples or None if no columns are found
313
309
"""
314
310
315
311
schema_data = manifest .schema
316
312
columns_data = schema_data .get ("columns" , [])
317
313
314
+ if not columns_data :
315
+ return None
316
+
318
317
columns = []
319
318
for col_data in columns_data :
320
319
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
320
+ name = col_data .get ("name" , "" )
321
+ type_name = col_data .get ("type_name" , "" )
322
+ type_name = (
323
+ type_name [:- 5 ] if type_name .endswith ("_TYPE" ) else type_name
324
+ ).lower ()
325
+ precision = col_data .get ("type_precision" )
326
+ scale = col_data .get ("type_scale" )
327
+
321
328
columns .append (
322
329
(
323
- col_data . get ( " name" , "" ) , # name
324
- col_data . get ( " type_name" , "" ) , # type_code
330
+ name , # name
331
+ type_name , # type_code
325
332
None , # display_size (not provided by SEA)
326
333
None , # internal_size (not provided by SEA)
327
- col_data . get ( " precision" ) , # precision
328
- col_data . get ( " scale" ) , # scale
329
- col_data . get ( "nullable" , True ) , # null_ok
334
+ precision , # precision
335
+ scale , # scale
336
+ None , # null_ok
330
337
)
331
338
)
332
339
333
- return columns
340
+ return columns if columns else None
334
341
335
342
def _results_message_to_execute_response (
336
343
self , response : Union [ExecuteStatementResponse , GetStatementResponse ]
@@ -351,7 +358,7 @@ def _results_message_to_execute_response(
351
358
352
359
# Check for compression
353
360
lz4_compressed = (
354
- response .manifest .result_compression == ResultCompression .LZ4_FRAME . value
361
+ response .manifest .result_compression == ResultCompression .LZ4_FRAME
355
362
)
356
363
357
364
execute_response = ExecuteResponse (
@@ -389,8 +396,9 @@ def _response_to_result_set(
389
396
)
390
397
391
398
def _check_command_not_in_failed_or_closed_state (
392
- self , state : CommandState , command_id : CommandId
399
+ self , status : StatementStatus , command_id : CommandId
393
400
) -> None :
401
+ state = status .state
394
402
if state == CommandState .CLOSED :
395
403
raise DatabaseError (
396
404
"Command {} unexpectedly closed server side" .format (command_id ),
@@ -399,8 +407,11 @@ def _check_command_not_in_failed_or_closed_state(
399
407
},
400
408
)
401
409
if state == CommandState .FAILED :
410
+ error = status .error
411
+ error_code = error .error_code if error else "UNKNOWN_ERROR_CODE"
412
+ error_message = error .message if error else "UNKNOWN_ERROR_MESSAGE"
402
413
raise ServerOperationError (
403
- "Command {} failed " .format (command_id ),
414
+ "Command failed: {} - {} " .format (error_code , error_message ),
404
415
{
405
416
"operation-id" : command_id ,
406
417
},
@@ -414,16 +425,18 @@ def _wait_until_command_done(
414
425
"""
415
426
416
427
final_response : Union [ExecuteStatementResponse , GetStatementResponse ] = response
417
-
418
- state = final_response .status .state
419
428
command_id = CommandId .from_sea_statement_id (final_response .statement_id )
420
429
421
- while state in [CommandState .PENDING , CommandState .RUNNING ]:
430
+ while final_response .status .state in [
431
+ CommandState .PENDING ,
432
+ CommandState .RUNNING ,
433
+ ]:
422
434
time .sleep (self .POLL_INTERVAL_SECONDS )
423
435
final_response = self ._poll_query (command_id )
424
- state = final_response .status .state
425
436
426
- self ._check_command_not_in_failed_or_closed_state (state , command_id )
437
+ self ._check_command_not_in_failed_or_closed_state (
438
+ final_response .status , command_id
439
+ )
427
440
428
441
return final_response
429
442
@@ -457,7 +470,7 @@ def execute_command(
457
470
enforce_embedded_schema_correctness: Whether to enforce schema correctness
458
471
459
472
Returns:
460
- SeaResultSet : A SeaResultSet instance for the executed command
473
+ ResultSet : A SeaResultSet instance for the executed command
461
474
"""
462
475
463
476
if session_id .backend_type != BackendType .SEA :
@@ -513,14 +526,6 @@ def execute_command(
513
526
)
514
527
response = ExecuteStatementResponse .from_dict (response_data )
515
528
statement_id = response .statement_id
516
- if not statement_id :
517
- raise ServerOperationError (
518
- "Failed to execute command: No statement ID returned" ,
519
- {
520
- "operation-id" : None ,
521
- "diagnostic-info" : None ,
522
- },
523
- )
524
529
525
530
command_id = CommandId .from_sea_statement_id (statement_id )
526
531
@@ -552,8 +557,6 @@ def cancel_command(self, command_id: CommandId) -> None:
552
557
raise ValueError ("Not a valid SEA command ID" )
553
558
554
559
sea_statement_id = command_id .to_sea_statement_id ()
555
- if sea_statement_id is None :
556
- raise ValueError ("Not a valid SEA command ID" )
557
560
558
561
request = CancelStatementRequest (statement_id = sea_statement_id )
559
562
self ._http_client ._make_request (
@@ -577,8 +580,6 @@ def close_command(self, command_id: CommandId) -> None:
577
580
raise ValueError ("Not a valid SEA command ID" )
578
581
579
582
sea_statement_id = command_id .to_sea_statement_id ()
580
- if sea_statement_id is None :
581
- raise ValueError ("Not a valid SEA command ID" )
582
583
583
584
request = CloseStatementRequest (statement_id = sea_statement_id )
584
585
self ._http_client ._make_request (
@@ -596,8 +597,6 @@ def _poll_query(self, command_id: CommandId) -> GetStatementResponse:
596
597
raise ValueError ("Not a valid SEA command ID" )
597
598
598
599
sea_statement_id = command_id .to_sea_statement_id ()
599
- if sea_statement_id is None :
600
- raise ValueError ("Not a valid SEA command ID" )
601
600
602
601
request = GetStatementRequest (statement_id = sea_statement_id )
603
602
response_data = self ._http_client ._make_request (
@@ -620,7 +619,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
620
619
CommandState: The current state of the command
621
620
622
621
Raises:
623
- ProgrammingError : If the command ID is invalid
622
+ ValueError : If the command ID is invalid
624
623
"""
625
624
626
625
response = self ._poll_query (command_id )
@@ -648,27 +647,6 @@ def get_execution_result(
648
647
response = self ._poll_query (command_id )
649
648
return self ._response_to_result_set (response , cursor )
650
649
651
- def get_chunk_links (
652
- self , statement_id : str , chunk_index : int
653
- ) -> List [ExternalLink ]:
654
- """
655
- Get links for chunks starting from the specified index.
656
- Args:
657
- statement_id: The statement ID
658
- chunk_index: The starting chunk index
659
- Returns:
660
- ExternalLink: External link for the chunk
661
- """
662
-
663
- response_data = self ._http_client ._make_request (
664
- method = "GET" ,
665
- path = self .CHUNK_PATH_WITH_ID_AND_INDEX .format (statement_id , chunk_index ),
666
- )
667
- response = GetChunksResponse .from_dict (response_data )
668
-
669
- links = response .external_links or []
670
- return links
671
-
672
650
# == Metadata Operations ==
673
651
674
652
def get_catalogs (
0 commit comments