@@ -130,14 +130,40 @@ def __init__(
130
130
# Extract warehouse ID from http_path
131
131
self .warehouse_id = self ._extract_warehouse_id (http_path )
132
132
133
- # Initialize ThriftHttpClient
133
+ # Extract retry policy parameters
134
+ retry_policy = kwargs .get ("_retry_policy" , None )
135
+ retry_stop_after_attempts_count = kwargs .get (
136
+ "_retry_stop_after_attempts_count" , 30
137
+ )
138
+ retry_stop_after_attempts_duration = kwargs .get (
139
+ "_retry_stop_after_attempts_duration" , 600
140
+ )
141
+ retry_delay_min = kwargs .get ("_retry_delay_min" , 1 )
142
+ retry_delay_max = kwargs .get ("_retry_delay_max" , 60 )
143
+ retry_delay_default = kwargs .get ("_retry_delay_default" , 5 )
144
+ retry_dangerous_codes = kwargs .get ("_retry_dangerous_codes" , [])
145
+
146
+ # Create retry policy if not provided
147
+ if not retry_policy :
148
+ from databricks .sql .auth .retry import DatabricksRetryPolicy
149
+
150
+ retry_policy = DatabricksRetryPolicy (
151
+ delay_min = retry_delay_min ,
152
+ delay_max = retry_delay_max ,
153
+ stop_after_attempts_count = retry_stop_after_attempts_count ,
154
+ stop_after_attempts_duration = retry_stop_after_attempts_duration ,
155
+ delay_default = retry_delay_default ,
156
+ force_dangerous_codes = retry_dangerous_codes ,
157
+ )
158
+
159
+ # Initialize ThriftHttpClient with retry policy
134
160
thrift_client = THttpClient (
135
161
auth_provider = auth_provider ,
136
162
uri_or_host = f"https://{ server_hostname } :{ port } " ,
137
163
path = http_path ,
138
164
ssl_options = ssl_options ,
139
165
max_connections = kwargs .get ("max_connections" , 1 ),
140
- retry_policy = kwargs . get ( "_retry_stop_after_attempts_count" , 30 ) ,
166
+ retry_policy = retry_policy ,
141
167
)
142
168
143
169
# Set custom headers
@@ -475,48 +501,99 @@ def execute_command(
475
501
result_compression = result_compression ,
476
502
)
477
503
478
- response_data = self .http_client .post (
479
- path = self .STATEMENT_PATH , data = request .to_dict ()
480
- )
481
- response = ExecuteStatementResponse .from_dict (response_data )
482
- statement_id = response .statement_id
483
- if not statement_id :
484
- raise ServerOperationError (
485
- "Failed to execute command: No statement ID returned" ,
486
- {
487
- "operation-id" : None ,
488
- "diagnostic-info" : None ,
489
- },
504
+ try :
505
+ response_data = self .http_client .post (
506
+ path = self .STATEMENT_PATH , data = request .to_dict ()
490
507
)
508
+ response = ExecuteStatementResponse .from_dict (response_data )
509
+ statement_id = response .statement_id
510
+
511
+ if not statement_id :
512
+ raise ServerOperationError (
513
+ "Failed to execute command: No statement ID returned" ,
514
+ {
515
+ "operation-id" : None ,
516
+ "diagnostic-info" : None ,
517
+ },
518
+ )
491
519
492
- command_id = CommandId .from_sea_statement_id (statement_id )
520
+ command_id = CommandId .from_sea_statement_id (statement_id )
493
521
494
- # Store the command ID in the cursor
495
- cursor .active_command_id = command_id
522
+ # Store the command ID in the cursor
523
+ cursor .active_command_id = command_id
496
524
497
- # If async operation, return and let the client poll for results
498
- if async_op :
499
- return None
525
+ # If async operation, return and let the client poll for results
526
+ if async_op :
527
+ return None
500
528
501
- # For synchronous operation, wait for the statement to complete
502
- status = response .status
503
- state = status .state
529
+ # For synchronous operation, wait for the statement to complete
530
+ status = response .status
531
+ state = status .state
504
532
505
- # Keep polling until we reach a terminal state
506
- while state in [CommandState .PENDING , CommandState .RUNNING ]:
507
- time .sleep (0.5 ) # add a small delay to avoid excessive API calls
508
- state = self .get_query_state (command_id )
533
+ # Keep polling until we reach a terminal state
534
+ while state in [CommandState .PENDING , CommandState .RUNNING ]:
535
+ time .sleep (0.5 ) # add a small delay to avoid excessive API calls
536
+ state = self .get_query_state (command_id )
509
537
510
- if state != CommandState .SUCCEEDED :
511
- raise ServerOperationError (
512
- f"Statement execution did not succeed: { status .error .message if status .error else 'Unknown error' } " ,
513
- {
514
- "operation-id" : command_id .to_sea_statement_id (),
515
- "diagnostic-info" : None ,
516
- },
517
- )
538
+ if state != CommandState .SUCCEEDED :
539
+ error_message = (
540
+ status .error .message if status .error else "Unknown error"
541
+ )
542
+ error_code = status .error .error_code if status .error else None
543
+
544
+ # Map error codes to appropriate exceptions to match Thrift behavior
545
+ from databricks .sql .exc import (
546
+ DatabaseError ,
547
+ ProgrammingError ,
548
+ OperationalError ,
549
+ )
550
+
551
+ if (
552
+ error_code == "SYNTAX_ERROR"
553
+ or "syntax error" in error_message .lower ()
554
+ ):
555
+ raise DatabaseError (
556
+ f"Syntax error in SQL statement: { error_message } "
557
+ )
558
+ elif error_code == "TEMPORARILY_UNAVAILABLE" :
559
+ raise OperationalError (
560
+ f"Service temporarily unavailable: { error_message } "
561
+ )
562
+ elif error_code == "PERMISSION_DENIED" :
563
+ raise OperationalError (f"Permission denied: { error_message } " )
564
+ else :
565
+ raise ServerOperationError (
566
+ f"Statement execution failed: { error_message } " ,
567
+ {
568
+ "operation-id" : command_id .to_sea_statement_id (),
569
+ "diagnostic-info" : None ,
570
+ },
571
+ )
518
572
519
- return self .get_execution_result (command_id , cursor )
573
+ return self .get_execution_result (command_id , cursor )
574
+
575
+ except Exception as e :
576
+ # Map exceptions to match Thrift behavior
577
+ from databricks .sql .exc import DatabaseError , OperationalError , RequestError
578
+
579
+ if isinstance (e , (DatabaseError , OperationalError , RequestError )):
580
+ # Pass through these exceptions as they're already properly typed
581
+ raise
582
+ elif "syntax error" in str (e ).lower ():
583
+ # Syntax errors
584
+ raise DatabaseError (f"Syntax error in SQL statement: { str (e )} " )
585
+ elif "permission denied" in str (e ).lower ():
586
+ # Permission errors
587
+ raise OperationalError (f"Permission denied: { str (e )} " )
588
+ elif "database" in str (e ).lower () and "not found" in str (e ).lower ():
589
+ # Database not found errors
590
+ raise DatabaseError (f"Database not found: { str (e )} " )
591
+ elif "table" in str (e ).lower () and "not found" in str (e ).lower ():
592
+ # Table not found errors
593
+ raise DatabaseError (f"Table not found: { str (e )} " )
594
+ else :
595
+ # Generic operational errors
596
+ raise OperationalError (f"Error executing command: { str (e )} " )
520
597
521
598
def cancel_command (self , command_id : CommandId ) -> None :
522
599
"""
0 commit comments