@@ -130,14 +130,35 @@ def __init__(
130
130
# Extract warehouse ID from http_path
131
131
self .warehouse_id = self ._extract_warehouse_id (http_path )
132
132
133
- # Initialize ThriftHttpClient
133
+ # Extract retry policy parameters
134
+ retry_policy = kwargs .get ("_retry_policy" , None )
135
+ retry_stop_after_attempts_count = kwargs .get ("_retry_stop_after_attempts_count" , 30 )
136
+ retry_stop_after_attempts_duration = kwargs .get ("_retry_stop_after_attempts_duration" , 600 )
137
+ retry_delay_min = kwargs .get ("_retry_delay_min" , 1 )
138
+ retry_delay_max = kwargs .get ("_retry_delay_max" , 60 )
139
+ retry_delay_default = kwargs .get ("_retry_delay_default" , 5 )
140
+ retry_dangerous_codes = kwargs .get ("_retry_dangerous_codes" , [])
141
+
142
+ # Create retry policy if not provided
143
+ if not retry_policy :
144
+ from databricks .sql .auth .retry import DatabricksRetryPolicy
145
+ retry_policy = DatabricksRetryPolicy (
146
+ delay_min = retry_delay_min ,
147
+ delay_max = retry_delay_max ,
148
+ stop_after_attempts_count = retry_stop_after_attempts_count ,
149
+ stop_after_attempts_duration = retry_stop_after_attempts_duration ,
150
+ delay_default = retry_delay_default ,
151
+ force_dangerous_codes = retry_dangerous_codes ,
152
+ )
153
+
154
+ # Initialize ThriftHttpClient with retry policy
134
155
thrift_client = THttpClient (
135
156
auth_provider = auth_provider ,
136
157
uri_or_host = f"https://{ server_hostname } :{ port } " ,
137
158
path = http_path ,
138
159
ssl_options = ssl_options ,
139
160
max_connections = kwargs .get ("max_connections" , 1 ),
140
- retry_policy = kwargs . get ( "_retry_stop_after_attempts_count" , 30 ) ,
161
+ retry_policy = retry_policy ,
141
162
)
142
163
143
164
# Set custom headers
@@ -394,7 +415,7 @@ def _results_message_to_execute_response(self, sea_response, command_id):
394
415
description = description ,
395
416
has_been_closed_server_side = False ,
396
417
lz4_compressed = lz4_compressed ,
397
- is_staging_operation = False ,
418
+ is_staging_operation = manifest_obj . is_volume_operation ,
398
419
arrow_schema_bytes = None ,
399
420
result_format = manifest_obj .format ,
400
421
)
@@ -475,48 +496,56 @@ def execute_command(
475
496
result_compression = result_compression ,
476
497
)
477
498
478
- response_data = self .http_client .post (
479
- path = self .STATEMENT_PATH , data = request .to_dict ()
480
- )
481
- response = ExecuteStatementResponse .from_dict (response_data )
482
- statement_id = response .statement_id
483
- if not statement_id :
484
- raise ServerOperationError (
485
- "Failed to execute command: No statement ID returned" ,
486
- {
487
- "operation-id" : None ,
488
- "diagnostic-info" : None ,
489
- },
499
+ try :
500
+ response_data = self .http_client .post (
501
+ path = self .STATEMENT_PATH , data = request .to_dict ()
490
502
)
503
+ response = ExecuteStatementResponse .from_dict (response_data )
504
+ statement_id = response .statement_id
505
+ if not statement_id :
506
+ raise ServerOperationError (
507
+ "Failed to execute command: No statement ID returned" ,
508
+ {
509
+ "operation-id" : None ,
510
+ "diagnostic-info" : None ,
511
+ },
512
+ )
491
513
492
- command_id = CommandId .from_sea_statement_id (statement_id )
514
+ command_id = CommandId .from_sea_statement_id (statement_id )
493
515
494
- # Store the command ID in the cursor
495
- cursor .active_command_id = command_id
516
+ # Store the command ID in the cursor
517
+ cursor .active_command_id = command_id
496
518
497
- # If async operation, return and let the client poll for results
498
- if async_op :
499
- return None
519
+ # If async operation, return and let the client poll for results
520
+ if async_op :
521
+ return None
500
522
501
- # For synchronous operation, wait for the statement to complete
502
- status = response .status
503
- state = status .state
523
+ # For synchronous operation, wait for the statement to complete
524
+ status = response .status
525
+ state = status .state
504
526
505
- # Keep polling until we reach a terminal state
506
- while state in [CommandState .PENDING , CommandState .RUNNING ]:
507
- time .sleep (0.5 ) # add a small delay to avoid excessive API calls
508
- state = self .get_query_state (command_id )
527
+ # Keep polling until we reach a terminal state
528
+ while state in [CommandState .PENDING , CommandState .RUNNING ]:
529
+ time .sleep (0.5 ) # add a small delay to avoid excessive API calls
530
+ state = self .get_query_state (command_id )
509
531
510
- if state != CommandState .SUCCEEDED :
511
- raise ServerOperationError (
512
- f"Statement execution did not succeed: { status .error .message if status .error else 'Unknown error' } " ,
513
- {
514
- "operation-id" : command_id .to_sea_statement_id (),
515
- "diagnostic-info" : None ,
516
- },
517
- )
532
+ if state != CommandState .SUCCEEDED :
533
+ raise ServerOperationError (
534
+ f"Statement execution did not succeed: { status .error .message if status .error else 'Unknown error' } " ,
535
+ {
536
+ "operation-id" : command_id .to_sea_statement_id (),
537
+ "diagnostic-info" : None ,
538
+ },
539
+ )
518
540
519
- return self .get_execution_result (command_id , cursor )
541
+ return self .get_execution_result (command_id , cursor )
542
+ except Exception as e :
543
+ # Map exceptions to match Thrift behavior
544
+ from databricks .sql .exc import RequestError , OperationalError
545
+ if isinstance (e , (ServerOperationError , RequestError )):
546
+ raise
547
+ else :
548
+ raise OperationalError (f"Error executing command: { str (e )} " )
520
549
521
550
def cancel_command (self , command_id : CommandId ) -> None :
522
551
"""
0 commit comments