1
1
import logging
2
2
import time
3
3
import re
4
- from typing import Dict , Tuple , List , Optional , Union , TYPE_CHECKING , Set
4
+ from typing import Any , Dict , Tuple , List , Optional , Union , TYPE_CHECKING , Set
5
5
6
+ from databricks .sql .backend .sea .models .base import ResultManifest
6
7
from databricks .sql .backend .sea .utils .constants import (
7
8
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP ,
8
9
ResultFormat ,
23
24
BackendType ,
24
25
ExecuteResponse ,
25
26
)
26
- from databricks .sql .exc import ServerOperationError
27
+ from databricks .sql .exc import DatabaseError , ServerOperationError
27
28
from databricks .sql .backend .sea .utils .http_client import SeaHttpClient
28
29
from databricks .sql .types import SSLOptions
29
30
@@ -89,6 +90,9 @@ class SeaDatabricksClient(DatabricksClient):
89
90
STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
90
91
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
91
92
93
+ # SEA constants
94
+ POLL_INTERVAL_SECONDS = 0.2
95
+
92
96
def __init__ (
93
97
self ,
94
98
server_hostname : str ,
@@ -286,28 +290,28 @@ def get_allowed_session_configurations() -> List[str]:
286
290
"""
287
291
return list (ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP .keys ())
288
292
289
- def _extract_description_from_manifest (self , manifest_obj ) -> Optional [List ]:
293
+ def _extract_description_from_manifest (
294
+ self , manifest : ResultManifest
295
+ ) -> Optional [List ]:
290
296
"""
291
- Extract column description from a manifest object.
297
+ Extract column description from a manifest object, in the format defined by
298
+ the spec: https://peps.python.org/pep-0249/#description
292
299
293
300
Args:
294
- manifest_obj : The ResultManifest object containing schema information
301
+ manifest : The ResultManifest object containing schema information
295
302
296
303
Returns:
297
304
Optional[List]: A list of column tuples or None if no columns are found
298
305
"""
299
306
300
- schema_data = manifest_obj .schema
307
+ schema_data = manifest .schema
301
308
columns_data = schema_data .get ("columns" , [])
302
309
303
310
if not columns_data :
304
311
return None
305
312
306
313
columns = []
307
314
for col_data in columns_data :
308
- if not isinstance (col_data , dict ):
309
- continue
310
-
311
315
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
312
316
columns .append (
313
317
(
@@ -323,7 +327,9 @@ def _extract_description_from_manifest(self, manifest_obj) -> Optional[List]:
323
327
324
328
return columns if columns else None
325
329
326
- def _results_message_to_execute_response (self , sea_response , command_id ):
330
+ def _results_message_to_execute_response (
331
+ self , response : GetStatementResponse
332
+ ) -> ExecuteResponse :
327
333
"""
328
334
Convert a SEA response to an ExecuteResponse and extract result data.
329
335
@@ -332,33 +338,65 @@ def _results_message_to_execute_response(self, sea_response, command_id):
332
338
command_id: The command ID
333
339
334
340
Returns:
335
- tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response,
336
- result data object, and manifest object
341
+ ExecuteResponse: The normalized execute response
337
342
"""
338
343
339
- # Parse the response
340
- status = parse_status (sea_response )
341
- manifest_obj = parse_manifest (sea_response )
342
- result_data_obj = parse_result (sea_response )
343
-
344
344
# Extract description from manifest schema
345
- description = self ._extract_description_from_manifest (manifest_obj )
345
+ description = self ._extract_description_from_manifest (response . manifest )
346
346
347
347
# Check for compression
348
- lz4_compressed = manifest_obj .result_compression == "LZ4_FRAME"
348
+ lz4_compressed = (
349
+ response .manifest .result_compression == ResultCompression .LZ4_FRAME
350
+ )
349
351
350
352
execute_response = ExecuteResponse (
351
- command_id = command_id ,
352
- status = status .state ,
353
+ command_id = CommandId . from_sea_statement_id ( response . statement_id ) ,
354
+ status = response . status .state ,
353
355
description = description ,
354
356
has_been_closed_server_side = False ,
355
357
lz4_compressed = lz4_compressed ,
356
358
is_staging_operation = False ,
357
359
arrow_schema_bytes = None ,
358
- result_format = manifest_obj .format ,
360
+ result_format = response . manifest .format ,
359
361
)
360
362
361
- return execute_response , result_data_obj , manifest_obj
363
+ return execute_response
364
+
365
+ def _check_command_not_in_failed_or_closed_state (
366
+ self , state : CommandState , command_id : CommandId
367
+ ) -> None :
368
+ if state == CommandState .CLOSED :
369
+ raise DatabaseError (
370
+ "Command {} unexpectedly closed server side" .format (command_id ),
371
+ {
372
+ "operation-id" : command_id ,
373
+ },
374
+ )
375
+ if state == CommandState .FAILED :
376
+ raise ServerOperationError (
377
+ "Command {} failed" .format (command_id ),
378
+ {
379
+ "operation-id" : command_id ,
380
+ },
381
+ )
382
+
383
+ def _wait_until_command_done (
384
+ self , response : ExecuteStatementResponse
385
+ ) -> CommandState :
386
+ """
387
+ Wait until a command is done.
388
+ """
389
+
390
+ state = response .status .state
391
+ command_id = CommandId .from_sea_statement_id (response .statement_id )
392
+
393
+ while state in [CommandState .PENDING , CommandState .RUNNING ]:
394
+ time .sleep (self .POLL_INTERVAL_SECONDS )
395
+ state = self .get_query_state (command_id )
396
+
397
+ self ._check_command_not_in_failed_or_closed_state (state , command_id )
398
+
399
+ return state
362
400
363
401
def execute_command (
364
402
self ,
@@ -369,7 +407,7 @@ def execute_command(
369
407
lz4_compression : bool ,
370
408
cursor : "Cursor" ,
371
409
use_cloud_fetch : bool ,
372
- parameters : List ,
410
+ parameters : List [ Dict [ str , Any ]] ,
373
411
async_op : bool ,
374
412
enforce_embedded_schema_correctness : bool ,
375
413
) -> Union ["ResultSet" , None ]:
@@ -403,9 +441,9 @@ def execute_command(
403
441
for param in parameters :
404
442
sea_parameters .append (
405
443
StatementParameter (
406
- name = param . name ,
407
- value = param . value ,
408
- type = param . type if hasattr ( param , "type" ) else None ,
444
+ name = param [ " name" ] ,
445
+ value = param [ " value" ] ,
446
+ type = param [ " type" ] if "type" in param else None ,
409
447
)
410
448
)
411
449
@@ -457,24 +495,7 @@ def execute_command(
457
495
if async_op :
458
496
return None
459
497
460
- # For synchronous operation, wait for the statement to complete
461
- status = response .status
462
- state = status .state
463
-
464
- # Keep polling until we reach a terminal state
465
- while state in [CommandState .PENDING , CommandState .RUNNING ]:
466
- time .sleep (0.5 ) # add a small delay to avoid excessive API calls
467
- state = self .get_query_state (command_id )
468
-
469
- if state != CommandState .SUCCEEDED :
470
- raise ServerOperationError (
471
- f"Statement execution did not succeed: { status .error .message if status .error else 'Unknown error' } " ,
472
- {
473
- "operation-id" : command_id .to_sea_statement_id (),
474
- "diagnostic-info" : None ,
475
- },
476
- )
477
-
498
+ self ._wait_until_command_done (response )
478
499
return self .get_execution_result (command_id , cursor )
479
500
480
501
def cancel_command (self , command_id : CommandId ) -> None :
@@ -586,25 +607,21 @@ def get_execution_result(
586
607
path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
587
608
data = request .to_dict (),
588
609
)
610
+ response = GetStatementResponse .from_dict (response_data )
589
611
590
612
# Create and return a SeaResultSet
591
613
from databricks .sql .result_set import SeaResultSet
592
614
593
- # Convert the response to an ExecuteResponse and extract result data
594
- (
595
- execute_response ,
596
- result_data ,
597
- manifest ,
598
- ) = self ._results_message_to_execute_response (response_data , command_id )
615
+ execute_response = self ._results_message_to_execute_response (response )
599
616
600
617
return SeaResultSet (
601
618
connection = cursor .connection ,
602
619
execute_response = execute_response ,
603
620
sea_client = self ,
604
621
buffer_size_bytes = cursor .buffer_size_bytes ,
605
622
arraysize = cursor .arraysize ,
606
- result_data = result_data ,
607
- manifest = manifest ,
623
+ result_data = response . result ,
624
+ manifest = response . manifest ,
608
625
)
609
626
610
627
# == Metadata Operations ==
0 commit comments