Skip to content

Commit c200ad0

Browse files
preliminary reetries
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 947fcbf commit c200ad0

File tree

4 files changed

+348
-43
lines changed

4 files changed

+348
-43
lines changed

src/databricks/sql/auth/thrift_http_client.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,13 +352,21 @@ def _check_rest_response_for_error(
352352
"""
353353
if status_code >= 400:
354354
error_message = f"REST HTTP request failed with status {status_code}"
355+
error_code = None
355356

356357
# Try to extract error details from JSON response
357358
if response_data:
358359
try:
359360
error_details = json.loads(response_data.decode("utf-8"))
360-
if isinstance(error_details, dict) and "message" in error_details:
361-
error_message = f"{error_message}: {error_details['message']}"
361+
if isinstance(error_details, dict):
362+
if "message" in error_details:
363+
error_message = (
364+
f"{error_message}: {error_details['message']}"
365+
)
366+
if "error_code" in error_details:
367+
error_code = error_details["error_code"]
368+
elif "errorCode" in error_details:
369+
error_code = error_details["errorCode"]
362370
logger.error(
363371
f"Request failed (status {status_code}): {error_details}"
364372
)
@@ -369,6 +377,69 @@ def _check_rest_response_for_error(
369377
else:
370378
logger.error(f"Request failed (status {status_code}): No response data")
371379

372-
from databricks.sql.exc import RequestError
380+
from databricks.sql.exc import (
381+
RequestError,
382+
OperationalError,
383+
DatabaseError,
384+
SessionAlreadyClosedError,
385+
CursorAlreadyClosedError,
386+
)
373387

374-
raise RequestError(error_message)
388+
# Map status codes to appropriate exceptions to match Thrift behavior
389+
if status_code == 429:
390+
# Rate limiting errors
391+
retry_after = None
392+
if self.headers and "Retry-After" in self.headers:
393+
retry_after = self.headers["Retry-After"]
394+
395+
rate_limit_msg = f"Maximum rate has been exceeded. Please reduce the rate of requests and try again"
396+
if retry_after:
397+
rate_limit_msg += f" after {retry_after} seconds."
398+
raise RequestError(rate_limit_msg)
399+
400+
elif status_code == 503:
401+
# Service unavailable errors
402+
raise OperationalError(
403+
"TEMPORARILY_UNAVAILABLE: Service temporarily unavailable"
404+
)
405+
406+
elif status_code == 404:
407+
# Not found errors - could be session or operation already closed
408+
if error_message and "session" in error_message.lower():
409+
raise SessionAlreadyClosedError(
410+
"Session was closed by a prior request"
411+
)
412+
elif error_message and (
413+
"operation" in error_message.lower()
414+
or "statement" in error_message.lower()
415+
):
416+
raise CursorAlreadyClosedError(
417+
"Operation was canceled by a prior request"
418+
)
419+
else:
420+
raise RequestError(error_message)
421+
422+
elif status_code == 401:
423+
# Authentication errors
424+
raise OperationalError(
425+
"Authentication failed. Please check your credentials."
426+
)
427+
428+
elif status_code == 403:
429+
# Permission errors
430+
raise OperationalError(
431+
"Permission denied. You do not have access to this resource."
432+
)
433+
434+
elif status_code == 400:
435+
# Bad request errors - often syntax errors
436+
if error_message and "syntax" in error_message.lower():
437+
raise DatabaseError(
438+
f"Syntax error in SQL statement: {error_message}"
439+
)
440+
else:
441+
raise RequestError(error_message)
442+
443+
else:
444+
# Generic errors
445+
raise RequestError(error_message)

src/databricks/sql/backend/sea/backend.py

Lines changed: 113 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,40 @@ def __init__(
130130
# Extract warehouse ID from http_path
131131
self.warehouse_id = self._extract_warehouse_id(http_path)
132132

133-
# Initialize ThriftHttpClient
133+
# Extract retry policy parameters
134+
retry_policy = kwargs.get("_retry_policy", None)
135+
retry_stop_after_attempts_count = kwargs.get(
136+
"_retry_stop_after_attempts_count", 30
137+
)
138+
retry_stop_after_attempts_duration = kwargs.get(
139+
"_retry_stop_after_attempts_duration", 600
140+
)
141+
retry_delay_min = kwargs.get("_retry_delay_min", 1)
142+
retry_delay_max = kwargs.get("_retry_delay_max", 60)
143+
retry_delay_default = kwargs.get("_retry_delay_default", 5)
144+
retry_dangerous_codes = kwargs.get("_retry_dangerous_codes", [])
145+
146+
# Create retry policy if not provided
147+
if not retry_policy:
148+
from databricks.sql.auth.retry import DatabricksRetryPolicy
149+
150+
retry_policy = DatabricksRetryPolicy(
151+
delay_min=retry_delay_min,
152+
delay_max=retry_delay_max,
153+
stop_after_attempts_count=retry_stop_after_attempts_count,
154+
stop_after_attempts_duration=retry_stop_after_attempts_duration,
155+
delay_default=retry_delay_default,
156+
force_dangerous_codes=retry_dangerous_codes,
157+
)
158+
159+
# Initialize ThriftHttpClient with retry policy
134160
thrift_client = THttpClient(
135161
auth_provider=auth_provider,
136162
uri_or_host=f"https://{server_hostname}:{port}",
137163
path=http_path,
138164
ssl_options=ssl_options,
139165
max_connections=kwargs.get("max_connections", 1),
140-
retry_policy=kwargs.get("_retry_stop_after_attempts_count", 30),
166+
retry_policy=retry_policy,
141167
)
142168

143169
# Set custom headers
@@ -475,48 +501,99 @@ def execute_command(
475501
result_compression=result_compression,
476502
)
477503

478-
response_data = self.http_client.post(
479-
path=self.STATEMENT_PATH, data=request.to_dict()
480-
)
481-
response = ExecuteStatementResponse.from_dict(response_data)
482-
statement_id = response.statement_id
483-
if not statement_id:
484-
raise ServerOperationError(
485-
"Failed to execute command: No statement ID returned",
486-
{
487-
"operation-id": None,
488-
"diagnostic-info": None,
489-
},
504+
try:
505+
response_data = self.http_client.post(
506+
path=self.STATEMENT_PATH, data=request.to_dict()
490507
)
508+
response = ExecuteStatementResponse.from_dict(response_data)
509+
statement_id = response.statement_id
510+
511+
if not statement_id:
512+
raise ServerOperationError(
513+
"Failed to execute command: No statement ID returned",
514+
{
515+
"operation-id": None,
516+
"diagnostic-info": None,
517+
},
518+
)
491519

492-
command_id = CommandId.from_sea_statement_id(statement_id)
520+
command_id = CommandId.from_sea_statement_id(statement_id)
493521

494-
# Store the command ID in the cursor
495-
cursor.active_command_id = command_id
522+
# Store the command ID in the cursor
523+
cursor.active_command_id = command_id
496524

497-
# If async operation, return and let the client poll for results
498-
if async_op:
499-
return None
525+
# If async operation, return and let the client poll for results
526+
if async_op:
527+
return None
500528

501-
# For synchronous operation, wait for the statement to complete
502-
status = response.status
503-
state = status.state
529+
# For synchronous operation, wait for the statement to complete
530+
status = response.status
531+
state = status.state
504532

505-
# Keep polling until we reach a terminal state
506-
while state in [CommandState.PENDING, CommandState.RUNNING]:
507-
time.sleep(0.5) # add a small delay to avoid excessive API calls
508-
state = self.get_query_state(command_id)
533+
# Keep polling until we reach a terminal state
534+
while state in [CommandState.PENDING, CommandState.RUNNING]:
535+
time.sleep(0.5) # add a small delay to avoid excessive API calls
536+
state = self.get_query_state(command_id)
509537

510-
if state != CommandState.SUCCEEDED:
511-
raise ServerOperationError(
512-
f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}",
513-
{
514-
"operation-id": command_id.to_sea_statement_id(),
515-
"diagnostic-info": None,
516-
},
517-
)
538+
if state != CommandState.SUCCEEDED:
539+
error_message = (
540+
status.error.message if status.error else "Unknown error"
541+
)
542+
error_code = status.error.error_code if status.error else None
543+
544+
# Map error codes to appropriate exceptions to match Thrift behavior
545+
from databricks.sql.exc import (
546+
DatabaseError,
547+
ProgrammingError,
548+
OperationalError,
549+
)
550+
551+
if (
552+
error_code == "SYNTAX_ERROR"
553+
or "syntax error" in error_message.lower()
554+
):
555+
raise DatabaseError(
556+
f"Syntax error in SQL statement: {error_message}"
557+
)
558+
elif error_code == "TEMPORARILY_UNAVAILABLE":
559+
raise OperationalError(
560+
f"Service temporarily unavailable: {error_message}"
561+
)
562+
elif error_code == "PERMISSION_DENIED":
563+
raise OperationalError(f"Permission denied: {error_message}")
564+
else:
565+
raise ServerOperationError(
566+
f"Statement execution failed: {error_message}",
567+
{
568+
"operation-id": command_id.to_sea_statement_id(),
569+
"diagnostic-info": None,
570+
},
571+
)
518572

519-
return self.get_execution_result(command_id, cursor)
573+
return self.get_execution_result(command_id, cursor)
574+
575+
except Exception as e:
576+
# Map exceptions to match Thrift behavior
577+
from databricks.sql.exc import DatabaseError, OperationalError, RequestError
578+
579+
if isinstance(e, (DatabaseError, OperationalError, RequestError)):
580+
# Pass through these exceptions as they're already properly typed
581+
raise
582+
elif "syntax error" in str(e).lower():
583+
# Syntax errors
584+
raise DatabaseError(f"Syntax error in SQL statement: {str(e)}")
585+
elif "permission denied" in str(e).lower():
586+
# Permission errors
587+
raise OperationalError(f"Permission denied: {str(e)}")
588+
elif "database" in str(e).lower() and "not found" in str(e).lower():
589+
# Database not found errors
590+
raise DatabaseError(f"Database not found: {str(e)}")
591+
elif "table" in str(e).lower() and "not found" in str(e).lower():
592+
# Table not found errors
593+
raise DatabaseError(f"Table not found: {str(e)}")
594+
else:
595+
# Generic operational errors
596+
raise OperationalError(f"Error executing command: {str(e)}")
520597

521598
def cancel_command(self, command_id: CommandId) -> None:
522599
"""

src/databricks/sql/backend/sea/utils/http_client_adapter.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Dict, Optional, Any
1010

1111
from databricks.sql.auth.thrift_http_client import THttpClient
12+
from databricks.sql.auth.retry import CommandType
1213

1314
logger = logging.getLogger(__name__)
1415

@@ -36,14 +37,49 @@ def __init__(
3637
"""
3738
self.thrift_client = thrift_client
3839

40+
def _determine_command_type(
41+
self, path: str, method: str, data: Optional[Dict[str, Any]] = None
42+
) -> CommandType:
43+
"""
44+
Determine the CommandType based on the request path and method.
45+
46+
Args:
47+
path: API endpoint path
48+
method: HTTP method (GET, POST, DELETE)
49+
data: Request payload data
50+
51+
Returns:
52+
CommandType: The appropriate CommandType enum value
53+
"""
54+
# Extract the base path component (e.g., "sessions", "statements")
55+
path_parts = path.strip("/").split("/")
56+
base_path = path_parts[0] if path_parts else ""
57+
58+
# Check for specific operations based on path and method
59+
if "statements" in path:
60+
if method == "POST" and any(part == "cancel" for part in path_parts):
61+
return CommandType.CLOSE_OPERATION
62+
elif method == "POST" and not any(part == "cancel" for part in path_parts):
63+
return CommandType.EXECUTE_STATEMENT
64+
elif method == "GET":
65+
return CommandType.GET_OPERATION_STATUS
66+
elif method == "DELETE":
67+
return CommandType.CLOSE_OPERATION
68+
elif "sessions" in path:
69+
if method == "DELETE":
70+
return CommandType.CLOSE_SESSION
71+
72+
# Default for any other operations
73+
return CommandType.OTHER
74+
3975
def get(
4076
self,
4177
path: str,
4278
params: Optional[Dict[str, Any]] = None,
4379
headers: Optional[Dict[str, str]] = None,
4480
) -> Dict[str, Any]:
4581
"""
46-
Convenience method for GET requests.
82+
Convenience method for GET requests with retry support.
4783
4884
Args:
4985
path: API endpoint path
@@ -53,6 +89,10 @@ def get(
5389
Returns:
5490
Response data parsed from JSON
5591
"""
92+
command_type = self._determine_command_type(path, "GET")
93+
self.thrift_client.set_retry_command_type(command_type)
94+
self.thrift_client.startRetryTimer()
95+
5696
return self.thrift_client.make_rest_request(
5797
"GET", path, params=params, headers=headers
5898
)
@@ -65,7 +105,7 @@ def post(
65105
headers: Optional[Dict[str, str]] = None,
66106
) -> Dict[str, Any]:
67107
"""
68-
Convenience method for POST requests.
108+
Convenience method for POST requests with retry support.
69109
70110
Args:
71111
path: API endpoint path
@@ -76,6 +116,10 @@ def post(
76116
Returns:
77117
Response data parsed from JSON
78118
"""
119+
command_type = self._determine_command_type(path, "POST", data)
120+
self.thrift_client.set_retry_command_type(command_type)
121+
self.thrift_client.startRetryTimer()
122+
79123
return self.thrift_client.make_rest_request(
80124
"POST", path, data=data, params=params, headers=headers
81125
)
@@ -88,7 +132,7 @@ def delete(
88132
headers: Optional[Dict[str, str]] = None,
89133
) -> Dict[str, Any]:
90134
"""
91-
Convenience method for DELETE requests.
135+
Convenience method for DELETE requests with retry support.
92136
93137
Args:
94138
path: API endpoint path
@@ -99,6 +143,10 @@ def delete(
99143
Returns:
100144
Response data parsed from JSON
101145
"""
146+
command_type = self._determine_command_type(path, "DELETE", data)
147+
self.thrift_client.set_retry_command_type(command_type)
148+
self.thrift_client.startRetryTimer()
149+
102150
return self.thrift_client.make_rest_request(
103151
"DELETE", path, data=data, params=params, headers=headers
104152
)

0 commit comments

Comments
 (0)