40
40
GetStatementResponse ,
41
41
CreateSessionResponse ,
42
42
)
43
+ from databricks .sql .backend .sea .models .responses import (
44
+ parse_status ,
45
+ parse_manifest ,
46
+ parse_result ,
47
+ )
43
48
44
49
logger = logging .getLogger (__name__ )
45
50
@@ -75,9 +80,6 @@ def _filter_session_configuration(
75
80
class SeaDatabricksClient (DatabricksClient ):
76
81
"""
77
82
Statement Execution API (SEA) implementation of the DatabricksClient interface.
78
-
79
- This implementation provides session management functionality for SEA,
80
- while other operations raise NotImplementedError.
81
83
"""
82
84
83
85
# SEA API paths
@@ -119,7 +121,6 @@ def __init__(
119
121
)
120
122
121
123
self ._max_download_threads = kwargs .get ("max_download_threads" , 10 )
122
- self .ssl_options = ssl_options
123
124
124
125
# Extract warehouse ID from http_path
125
126
self .warehouse_id = self ._extract_warehouse_id (http_path )
@@ -298,16 +299,16 @@ def _results_message_to_execute_response(self, sea_response, command_id):
298
299
tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response,
299
300
result data object, and manifest object
300
301
"""
301
- # Extract status
302
- status_data = sea_response .get ("status" , {})
303
- state = CommandState .from_sea_state (status_data .get ("state" , "" ))
304
302
305
- # Extract description from manifest
303
+ # Parse the response
304
+ status = parse_status (sea_response )
305
+ manifest_obj = parse_manifest (sea_response )
306
+ result_data_obj = parse_result (sea_response )
307
+
308
+ # Extract description from manifest schema
306
309
description = None
307
- manifest_data = sea_response .get ("manifest" , {})
308
- schema_data = manifest_data .get ("schema" , {})
310
+ schema_data = manifest_obj .schema
309
311
columns_data = schema_data .get ("columns" , [])
310
-
311
312
if columns_data :
312
313
columns = []
313
314
for col_data in columns_data :
@@ -329,61 +330,17 @@ def _results_message_to_execute_response(self, sea_response, command_id):
329
330
description = columns if columns else None
330
331
331
332
# Check for compression
332
- lz4_compressed = manifest_data .get ("result_compression" ) == "LZ4_FRAME"
333
-
334
- # Initialize result_data_obj and manifest_obj
335
- result_data_obj = None
336
- manifest_obj = None
337
-
338
- result_data = sea_response .get ("result" , {})
339
- if result_data :
340
- # Convert external links
341
- external_links = None
342
- if "external_links" in result_data :
343
- external_links = []
344
- for link_data in result_data ["external_links" ]:
345
- external_links .append (
346
- ExternalLink (
347
- external_link = link_data .get ("external_link" , "" ),
348
- expiration = link_data .get ("expiration" , "" ),
349
- chunk_index = link_data .get ("chunk_index" , 0 ),
350
- byte_count = link_data .get ("byte_count" , 0 ),
351
- row_count = link_data .get ("row_count" , 0 ),
352
- row_offset = link_data .get ("row_offset" , 0 ),
353
- next_chunk_index = link_data .get ("next_chunk_index" ),
354
- next_chunk_internal_link = link_data .get (
355
- "next_chunk_internal_link"
356
- ),
357
- http_headers = link_data .get ("http_headers" , {}),
358
- )
359
- )
360
-
361
- # Create the result data object
362
- result_data_obj = ResultData (
363
- data = result_data .get ("data_array" ), external_links = external_links
364
- )
365
-
366
- # Create the manifest object
367
- manifest_obj = ResultManifest (
368
- format = manifest_data .get ("format" , "" ),
369
- schema = manifest_data .get ("schema" , {}),
370
- total_row_count = manifest_data .get ("total_row_count" , 0 ),
371
- total_byte_count = manifest_data .get ("total_byte_count" , 0 ),
372
- total_chunk_count = manifest_data .get ("total_chunk_count" , 0 ),
373
- truncated = manifest_data .get ("truncated" , False ),
374
- chunks = manifest_data .get ("chunks" ),
375
- result_compression = manifest_data .get ("result_compression" ),
376
- )
333
+ lz4_compressed = manifest_obj .result_compression == "LZ4_FRAME"
377
334
378
335
execute_response = ExecuteResponse (
379
336
command_id = command_id ,
380
- status = state ,
337
+ status = status . state ,
381
338
description = description ,
382
339
has_been_closed_server_side = False ,
383
340
lz4_compressed = lz4_compressed ,
384
341
is_staging_operation = False ,
385
342
arrow_schema_bytes = None , # to be extracted during fetch phase for ARROW
386
- result_format = manifest_data . get ( " format" ) ,
343
+ result_format = manifest_obj . format ,
387
344
)
388
345
389
346
return execute_response , result_data_obj , manifest_obj
@@ -419,6 +376,7 @@ def execute_command(
419
376
Returns:
420
377
ResultSet: A SeaResultSet instance for the executed command
421
378
"""
379
+
422
380
if session_id .backend_type != BackendType .SEA :
423
381
raise ValueError ("Not a valid SEA session ID" )
424
382
@@ -506,6 +464,7 @@ def cancel_command(self, command_id: CommandId) -> None:
506
464
Raises:
507
465
ValueError: If the command ID is invalid
508
466
"""
467
+
509
468
if command_id .backend_type != BackendType .SEA :
510
469
raise ValueError ("Not a valid SEA command ID" )
511
470
@@ -528,6 +487,7 @@ def close_command(self, command_id: CommandId) -> None:
528
487
Raises:
529
488
ValueError: If the command ID is invalid
530
489
"""
490
+
531
491
if command_id .backend_type != BackendType .SEA :
532
492
raise ValueError ("Not a valid SEA command ID" )
533
493
@@ -553,6 +513,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
553
513
Raises:
554
514
ValueError: If the command ID is invalid
555
515
"""
516
+
556
517
if command_id .backend_type != BackendType .SEA :
557
518
raise ValueError ("Not a valid SEA command ID" )
558
519
@@ -587,6 +548,7 @@ def get_execution_result(
587
548
Raises:
588
549
ValueError: If the command ID is invalid
589
550
"""
551
+
590
552
if command_id .backend_type != BackendType .SEA :
591
553
raise ValueError ("Not a valid SEA command ID" )
592
554
0 commit comments