Skip to content

Commit 08827ef

Browse files
add minimal retry func
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 9ba6cb3 commit 08827ef

File tree

2 files changed

+118
-40
lines changed

2 files changed

+118
-40
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 66 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,35 @@ def __init__(
130130
# Extract warehouse ID from http_path
131131
self.warehouse_id = self._extract_warehouse_id(http_path)
132132

133-
# Initialize ThriftHttpClient
133+
# Extract retry policy parameters
134+
retry_policy = kwargs.get("_retry_policy", None)
135+
retry_stop_after_attempts_count = kwargs.get("_retry_stop_after_attempts_count", 30)
136+
retry_stop_after_attempts_duration = kwargs.get("_retry_stop_after_attempts_duration", 600)
137+
retry_delay_min = kwargs.get("_retry_delay_min", 1)
138+
retry_delay_max = kwargs.get("_retry_delay_max", 60)
139+
retry_delay_default = kwargs.get("_retry_delay_default", 5)
140+
retry_dangerous_codes = kwargs.get("_retry_dangerous_codes", [])
141+
142+
# Create retry policy if not provided
143+
if not retry_policy:
144+
from databricks.sql.auth.retry import DatabricksRetryPolicy
145+
retry_policy = DatabricksRetryPolicy(
146+
delay_min=retry_delay_min,
147+
delay_max=retry_delay_max,
148+
stop_after_attempts_count=retry_stop_after_attempts_count,
149+
stop_after_attempts_duration=retry_stop_after_attempts_duration,
150+
delay_default=retry_delay_default,
151+
force_dangerous_codes=retry_dangerous_codes,
152+
)
153+
154+
# Initialize ThriftHttpClient with retry policy
134155
thrift_client = THttpClient(
135156
auth_provider=auth_provider,
136157
uri_or_host=f"https://{server_hostname}:{port}",
137158
path=http_path,
138159
ssl_options=ssl_options,
139160
max_connections=kwargs.get("max_connections", 1),
140-
retry_policy=kwargs.get("_retry_stop_after_attempts_count", 30),
161+
retry_policy=retry_policy,
141162
)
142163

143164
# Set custom headers
@@ -394,7 +415,7 @@ def _results_message_to_execute_response(self, sea_response, command_id):
394415
description=description,
395416
has_been_closed_server_side=False,
396417
lz4_compressed=lz4_compressed,
397-
is_staging_operation=False,
418+
is_staging_operation=manifest_obj.is_volume_operation,
398419
arrow_schema_bytes=None,
399420
result_format=manifest_obj.format,
400421
)
@@ -475,48 +496,56 @@ def execute_command(
475496
result_compression=result_compression,
476497
)
477498

478-
response_data = self.http_client.post(
479-
path=self.STATEMENT_PATH, data=request.to_dict()
480-
)
481-
response = ExecuteStatementResponse.from_dict(response_data)
482-
statement_id = response.statement_id
483-
if not statement_id:
484-
raise ServerOperationError(
485-
"Failed to execute command: No statement ID returned",
486-
{
487-
"operation-id": None,
488-
"diagnostic-info": None,
489-
},
499+
try:
500+
response_data = self.http_client.post(
501+
path=self.STATEMENT_PATH, data=request.to_dict()
490502
)
503+
response = ExecuteStatementResponse.from_dict(response_data)
504+
statement_id = response.statement_id
505+
if not statement_id:
506+
raise ServerOperationError(
507+
"Failed to execute command: No statement ID returned",
508+
{
509+
"operation-id": None,
510+
"diagnostic-info": None,
511+
},
512+
)
491513

492-
command_id = CommandId.from_sea_statement_id(statement_id)
514+
command_id = CommandId.from_sea_statement_id(statement_id)
493515

494-
# Store the command ID in the cursor
495-
cursor.active_command_id = command_id
516+
# Store the command ID in the cursor
517+
cursor.active_command_id = command_id
496518

497-
# If async operation, return and let the client poll for results
498-
if async_op:
499-
return None
519+
# If async operation, return and let the client poll for results
520+
if async_op:
521+
return None
500522

501-
# For synchronous operation, wait for the statement to complete
502-
status = response.status
503-
state = status.state
523+
# For synchronous operation, wait for the statement to complete
524+
status = response.status
525+
state = status.state
504526

505-
# Keep polling until we reach a terminal state
506-
while state in [CommandState.PENDING, CommandState.RUNNING]:
507-
time.sleep(0.5) # add a small delay to avoid excessive API calls
508-
state = self.get_query_state(command_id)
527+
# Keep polling until we reach a terminal state
528+
while state in [CommandState.PENDING, CommandState.RUNNING]:
529+
time.sleep(0.5) # add a small delay to avoid excessive API calls
530+
state = self.get_query_state(command_id)
509531

510-
if state != CommandState.SUCCEEDED:
511-
raise ServerOperationError(
512-
f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}",
513-
{
514-
"operation-id": command_id.to_sea_statement_id(),
515-
"diagnostic-info": None,
516-
},
517-
)
532+
if state != CommandState.SUCCEEDED:
533+
raise ServerOperationError(
534+
f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}",
535+
{
536+
"operation-id": command_id.to_sea_statement_id(),
537+
"diagnostic-info": None,
538+
},
539+
)
518540

519-
return self.get_execution_result(command_id, cursor)
541+
return self.get_execution_result(command_id, cursor)
542+
except Exception as e:
543+
# Map exceptions to match Thrift behavior
544+
from databricks.sql.exc import RequestError, OperationalError
545+
if isinstance(e, (ServerOperationError, RequestError)):
546+
raise
547+
else:
548+
raise OperationalError(f"Error executing command: {str(e)}")
520549

521550
def cancel_command(self, command_id: CommandId) -> None:
522551
"""

src/databricks/sql/backend/sea/utils/http_client_adapter.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Dict, Optional, Any
1010

1111
from databricks.sql.auth.thrift_http_client import THttpClient
12+
from databricks.sql.auth.retry import CommandType
1213

1314
logger = logging.getLogger(__name__)
1415

@@ -36,14 +37,50 @@ def __init__(
3637
"""
3738
self.thrift_client = thrift_client
3839

40+
def _determine_command_type(self, path: str, method: str, data: Optional[Dict[str, Any]] = None) -> CommandType:
41+
"""
42+
Determine the CommandType based on the request path and method.
43+
44+
Args:
45+
path: API endpoint path
46+
method: HTTP method (GET, POST, DELETE)
47+
data: Request payload data
48+
49+
Returns:
50+
CommandType: The appropriate CommandType enum value
51+
"""
52+
# Extract the base path component (e.g., "sessions", "statements")
53+
path_parts = path.strip('/').split('/')
54+
base_path = path_parts[-1] if path_parts else ""
55+
56+
# Check for specific operations based on path and method
57+
if "statements" in path:
58+
if method == "POST" and "cancel" in path:
59+
return CommandType.CLOSE_OPERATION
60+
elif method == "POST" and "cancel" not in path:
61+
return CommandType.EXECUTE_STATEMENT
62+
elif method == "GET":
63+
return CommandType.GET_OPERATION_STATUS
64+
elif method == "DELETE":
65+
return CommandType.CLOSE_OPERATION
66+
elif "sessions" in path:
67+
if method == "POST":
68+
# Creating a new session
69+
return CommandType.OTHER
70+
elif method == "DELETE":
71+
return CommandType.CLOSE_SESSION
72+
73+
# Default for any other operations
74+
return CommandType.OTHER
75+
3976
def get(
4077
self,
4178
path: str,
4279
params: Optional[Dict[str, Any]] = None,
4380
headers: Optional[Dict[str, str]] = None,
4481
) -> Dict[str, Any]:
4582
"""
46-
Convenience method for GET requests.
83+
Convenience method for GET requests with retry support.
4784
4885
Args:
4986
path: API endpoint path
@@ -53,6 +90,10 @@ def get(
5390
Returns:
5491
Response data parsed from JSON
5592
"""
93+
command_type = self._determine_command_type(path, "GET")
94+
self.thrift_client.set_retry_command_type(command_type)
95+
self.thrift_client.startRetryTimer()
96+
5697
return self.thrift_client.make_rest_request(
5798
"GET", path, params=params, headers=headers
5899
)
@@ -65,7 +106,7 @@ def post(
65106
headers: Optional[Dict[str, str]] = None,
66107
) -> Dict[str, Any]:
67108
"""
68-
Convenience method for POST requests.
109+
Convenience method for POST requests with retry support.
69110
70111
Args:
71112
path: API endpoint path
@@ -76,6 +117,10 @@ def post(
76117
Returns:
77118
Response data parsed from JSON
78119
"""
120+
command_type = self._determine_command_type(path, "POST", data)
121+
self.thrift_client.set_retry_command_type(command_type)
122+
self.thrift_client.startRetryTimer()
123+
79124
return self.thrift_client.make_rest_request(
80125
"POST", path, data=data, params=params, headers=headers
81126
)
@@ -88,7 +133,7 @@ def delete(
88133
headers: Optional[Dict[str, str]] = None,
89134
) -> Dict[str, Any]:
90135
"""
91-
Convenience method for DELETE requests.
136+
Convenience method for DELETE requests with retry support.
92137
93138
Args:
94139
path: API endpoint path
@@ -99,6 +144,10 @@ def delete(
99144
Returns:
100145
Response data parsed from JSON
101146
"""
147+
command_type = self._determine_command_type(path, "DELETE", data)
148+
self.thrift_client.set_retry_command_type(command_type)
149+
self.thrift_client.startRetryTimer()
150+
102151
return self.thrift_client.make_rest_request(
103152
"DELETE", path, data=data, params=params, headers=headers
104153
)

0 commit comments

Comments
 (0)