Skip to content

Commit 4b6e331

Browse files
committed
retry
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 380b0b9 commit 4b6e331

File tree

3 files changed

+76
-12
lines changed

3 files changed

+76
-12
lines changed

src/databricks/sql/exc.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import json
22
import logging
33

4-
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
5-
64
logger = logging.getLogger(__name__)
75

86
### PEP-249 Mandated ###
@@ -21,11 +19,11 @@ def __init__(
2119
self.context = context or {}
2220

2321
error_name = self.__class__.__name__
24-
if session_id_hex:
25-
telemetry_client = TelemetryClientFactory.get_telemetry_client(
26-
session_id_hex
27-
)
28-
telemetry_client.export_failure_log(error_name, self.message)
22+
23+
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
24+
25+
telemetry_client = TelemetryClientFactory.get_telemetry_client(session_id_hex)
26+
telemetry_client.export_failure_log(error_name, self.message)
2927

3028
def __str__(self):
3129
return self.message

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
DatabricksOAuthProvider,
2323
ExternalAuthProvider,
2424
)
25+
from requests.adapters import HTTPAdapter
26+
from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType
2527
import sys
2628
import platform
2729
import uuid
@@ -31,6 +33,24 @@
3133
logger = logging.getLogger(__name__)
3234

3335

36+
class TelemetryHTTPAdapter(HTTPAdapter):
37+
"""
38+
Custom HTTP adapter to prepare our DatabricksRetryPolicy before each request.
39+
This ensures the retry timer is started and the command type is set correctly,
40+
allowing the policy to manage its state for the duration of the request retries.
41+
"""
42+
43+
def send(self, request, **kwargs):
44+
# The DatabricksRetryPolicy needs state set before the first attempt.
45+
if isinstance(self.max_retries, DatabricksRetryPolicy):
46+
# Telemetry requests are idempotent and safe to retry. We use CommandType.OTHER
47+
# to signal this to the retry policy, bypassing stricter rules for commands
48+
# like ExecuteStatement.
49+
self.max_retries.command_type = CommandType.OTHER
50+
self.max_retries.start_retry_timer()
51+
return super().send(request, **kwargs)
52+
53+
3454
class TelemetryHelper:
3555
"""Helper class for getting telemetry related information."""
3656

@@ -146,6 +166,11 @@ class TelemetryClient(BaseTelemetryClient):
146166
It uses a thread pool to handle asynchronous operations, that it gets from the TelemetryClientFactory.
147167
"""
148168

169+
TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3
170+
TELEMETRY_RETRY_DELAY_MIN = 0.5 # seconds
171+
TELEMETRY_RETRY_DELAY_MAX = 5.0 # seconds
172+
TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0
173+
149174
# Telemetry endpoint paths
150175
TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext"
151176
TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth"
@@ -170,6 +195,18 @@ def __init__(
170195
self._host_url = host_url
171196
self._executor = executor
172197

198+
self._telemetry_retry_policy = DatabricksRetryPolicy(
199+
delay_min=self.TELEMETRY_RETRY_DELAY_MIN,
200+
delay_max=self.TELEMETRY_RETRY_DELAY_MAX,
201+
stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT,
202+
stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION,
203+
delay_default=1.0, # Not directly used by telemetry, but required by constructor
204+
force_dangerous_codes=[], # Telemetry doesn't have "dangerous" codes
205+
)
206+
self._session = requests.Session()
207+
adapter = TelemetryHTTPAdapter(max_retries=self._telemetry_retry_policy)
208+
self._session.mount("https://", adapter)
209+
173210
def _export_event(self, event):
174211
"""Add an event to the batch queue and flush if batch is full"""
175212
logger.debug("Exporting event for connection %s", self._session_id_hex)
@@ -215,7 +252,7 @@ def _send_telemetry(self, events):
215252
try:
216253
logger.debug("Submitting telemetry request to thread pool")
217254
future = self._executor.submit(
218-
requests.post,
255+
self._session.post,
219256
url,
220257
data=json.dumps(request),
221258
headers=headers,
@@ -303,6 +340,7 @@ def close(self):
303340
"""Flush remaining events before closing"""
304341
logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex)
305342
self._flush()
343+
self._session.close()
306344

307345

308346
class TelemetryClientFactory:

tests/unit/test_telemetry.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def test_export_event(self, telemetry_client_setup):
198198
client._flush.assert_called_once()
199199
assert len(client._events_batch) == 10
200200

201-
@patch("requests.post")
201+
@patch("requests.Session.post")
202202
def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup):
203203
"""Test sending telemetry to the server with authentication."""
204204
client = telemetry_client_setup["client"]
@@ -212,12 +212,12 @@ def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup):
212212

213213
executor.submit.assert_called_once()
214214
args, kwargs = executor.submit.call_args
215-
assert args[0] == requests.post
215+
assert args[0] == client._session.post
216216
assert kwargs["timeout"] == 10
217217
assert "Authorization" in kwargs["headers"]
218218
assert kwargs["headers"]["Authorization"] == "Bearer test-token"
219219

220-
@patch("requests.post")
220+
@patch("requests.Session.post")
221221
def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup):
222222
"""Test sending telemetry to the server without authentication."""
223223
host_url = telemetry_client_setup["host_url"]
@@ -239,7 +239,7 @@ def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup)
239239

240240
executor.submit.assert_called_once()
241241
args, kwargs = executor.submit.call_args
242-
assert args[0] == requests.post
242+
assert args[0] == unauthenticated_client._session.post
243243
assert kwargs["timeout"] == 10
244244
assert "Authorization" not in kwargs["headers"] # No auth header
245245
assert kwargs["headers"]["Accept"] == "application/json"
@@ -331,6 +331,34 @@ class TestBaseClient(BaseTelemetryClient):
331331
with pytest.raises(TypeError):
332332
TestBaseClient() # Can't instantiate abstract class
333333

334+
def test_telemetry_http_adapter_retry_policy(self, telemetry_client_setup):
335+
"""Test that TelemetryHTTPAdapter properly configures DatabricksRetryPolicy."""
336+
from databricks.sql.telemetry.telemetry_client import TelemetryHTTPAdapter
337+
from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType
338+
339+
client = telemetry_client_setup["client"]
340+
341+
# Verify that the session has the TelemetryHTTPAdapter mounted
342+
adapter = client._session.adapters.get("https://")
343+
assert isinstance(adapter, TelemetryHTTPAdapter)
344+
assert isinstance(adapter.max_retries, DatabricksRetryPolicy)
345+
346+
# Verify that the retry policy has the correct configuration
347+
retry_policy = adapter.max_retries
348+
assert retry_policy.delay_min == client.TELEMETRY_RETRY_DELAY_MIN
349+
assert retry_policy.delay_max == client.TELEMETRY_RETRY_DELAY_MAX
350+
assert retry_policy.stop_after_attempts_count == client.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT
351+
assert retry_policy.stop_after_attempts_duration == client.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION
352+
353+
# Test that the adapter's send method would properly configure the retry policy
354+
# by directly testing the logic that sets command_type and starts the timer
355+
if isinstance(adapter.max_retries, DatabricksRetryPolicy):
356+
adapter.max_retries.command_type = CommandType.OTHER
357+
adapter.max_retries.start_retry_timer()
358+
359+
# Verify that the retry policy was configured correctly
360+
assert retry_policy.command_type == CommandType.OTHER
361+
334362

335363
class TestTelemetryHelper:
336364
"""Tests for the TelemetryHelper class."""

0 commit comments

Comments
 (0)