Skip to content

Commit 71d306f

Browse files
authored
Add retry mechanism to telemetry requests (#617)
* telemetry retry Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com> * shifted tests to unit test, removed unused imports Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com> * tests Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com> --------- Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 141a004 commit 71d306f

File tree

5 files changed

+184
-7
lines changed

5 files changed

+184
-7
lines changed

src/databricks/sql/common/http.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import threading
66
from dataclasses import dataclass
77
from contextlib import contextmanager
8-
from typing import Generator
8+
from typing import Generator, Optional
99
import logging
10+
from requests.adapters import HTTPAdapter
11+
from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType
1012

1113
logger = logging.getLogger(__name__)
1214

@@ -81,3 +83,70 @@ def execute(
8183

8284
def close(self):
8385
self.session.close()
86+
87+
88+
class TelemetryHTTPAdapter(HTTPAdapter):
89+
"""
90+
Custom HTTP adapter to prepare our DatabricksRetryPolicy before each request.
91+
This ensures the retry timer is started and the command type is set correctly,
92+
allowing the policy to manage its state for the duration of the request retries.
93+
"""
94+
95+
def send(self, request, **kwargs):
96+
self.max_retries.command_type = CommandType.OTHER
97+
self.max_retries.start_retry_timer()
98+
return super().send(request, **kwargs)
99+
100+
101+
class TelemetryHttpClient: # TODO: Unify all the http clients in the PySQL Connector
102+
"""Singleton HTTP client for sending telemetry data."""
103+
104+
_instance: Optional["TelemetryHttpClient"] = None
105+
_lock = threading.Lock()
106+
107+
TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3
108+
TELEMETRY_RETRY_DELAY_MIN = 1.0
109+
TELEMETRY_RETRY_DELAY_MAX = 10.0
110+
TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0
111+
112+
def __init__(self):
113+
"""Initializes the session and mounts the custom retry adapter."""
114+
retry_policy = DatabricksRetryPolicy(
115+
delay_min=self.TELEMETRY_RETRY_DELAY_MIN,
116+
delay_max=self.TELEMETRY_RETRY_DELAY_MAX,
117+
stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT,
118+
stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION,
119+
delay_default=1.0,
120+
force_dangerous_codes=[],
121+
)
122+
adapter = TelemetryHTTPAdapter(max_retries=retry_policy)
123+
self.session = requests.Session()
124+
self.session.mount("https://", adapter)
125+
self.session.mount("http://", adapter)
126+
127+
@classmethod
128+
def get_instance(cls) -> "TelemetryHttpClient":
129+
"""Get the singleton instance of the TelemetryHttpClient."""
130+
if cls._instance is None:
131+
with cls._lock:
132+
if cls._instance is None:
133+
logger.debug("Initializing singleton TelemetryHttpClient")
134+
cls._instance = TelemetryHttpClient()
135+
return cls._instance
136+
137+
def post(self, url: str, **kwargs) -> requests.Response:
138+
"""
139+
Executes a POST request using the configured session.
140+
141+
This is a blocking call intended to be run in a background thread.
142+
"""
143+
logger.debug("Executing telemetry POST request to: %s", url)
144+
return self.session.post(url, **kwargs)
145+
146+
def close(self):
147+
"""Closes the underlying requests.Session."""
148+
logger.debug("Closing TelemetryHttpClient session.")
149+
self.session.close()
150+
# Clear the instance to allow for re-initialization if needed
151+
with TelemetryHttpClient._lock:
152+
TelemetryHttpClient._instance = None

src/databricks/sql/exc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import logging
33

44
logger = logging.getLogger(__name__)
5-
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
6-
75

86
### PEP-249 Mandated ###
97
# https://peps.python.org/pep-0249/#exceptions
@@ -22,6 +20,8 @@ def __init__(
2220

2321
error_name = self.__class__.__name__
2422
if session_id_hex:
23+
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
24+
2525
telemetry_client = TelemetryClientFactory.get_telemetry_client(
2626
session_id_hex
2727
)

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import threading
22
import time
3-
import requests
43
import logging
54
from concurrent.futures import ThreadPoolExecutor
65
from typing import Dict, Optional
6+
from databricks.sql.common.http import TelemetryHttpClient
77
from databricks.sql.telemetry.models.event import (
88
TelemetryEvent,
99
DriverSystemConfiguration,
@@ -159,6 +159,7 @@ def __init__(
159159
self._driver_connection_params = None
160160
self._host_url = host_url
161161
self._executor = executor
162+
self._http_client = TelemetryHttpClient.get_instance()
162163

163164
def _export_event(self, event):
164165
"""Add an event to the batch queue and flush if batch is full"""
@@ -207,7 +208,7 @@ def _send_telemetry(self, events):
207208
try:
208209
logger.debug("Submitting telemetry request to thread pool")
209210
future = self._executor.submit(
210-
requests.post,
211+
self._http_client.post,
211212
url,
212213
data=request.to_json(),
213214
headers=headers,
@@ -433,6 +434,7 @@ def close(session_id_hex):
433434
)
434435
try:
435436
TelemetryClientFactory._executor.shutdown(wait=True)
437+
TelemetryHttpClient.close()
436438
except Exception as e:
437439
logger.debug("Failed to shutdown thread pool executor: %s", e)
438440
TelemetryClientFactory._executor = None

tests/unit/test_telemetry.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import uuid
22
import pytest
3-
import requests
43
from unittest.mock import patch, MagicMock
54

65
from databricks.sql.telemetry.telemetry_client import (
@@ -90,7 +89,7 @@ def test_network_request_flow(self, mock_post, mock_telemetry_client):
9089
args, kwargs = client._executor.submit.call_args
9190

9291
# Verify correct function and URL
93-
assert args[0] == requests.post
92+
assert args[0] == client._http_client.post
9493
assert args[1] == "https://test-host.com/telemetry-ext"
9594
assert kwargs["headers"]["Authorization"] == "Bearer test-token"
9695

tests/unit/test_telemetry_retry.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import pytest
2+
from unittest.mock import patch, MagicMock
3+
import io
4+
import time
5+
6+
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
7+
from databricks.sql.auth.retry import DatabricksRetryPolicy
8+
9+
PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn'
10+
11+
def create_mock_conn(responses):
12+
"""Creates a mock connection object whose getresponse() method yields a series of responses."""
13+
mock_conn = MagicMock()
14+
mock_http_responses = []
15+
for resp in responses:
16+
mock_http_response = MagicMock()
17+
mock_http_response.status = resp.get("status")
18+
mock_http_response.headers = resp.get("headers", {})
19+
body = resp.get("body", b'{}')
20+
mock_http_response.fp = io.BytesIO(body)
21+
def release():
22+
mock_http_response.fp.close()
23+
mock_http_response.release_conn = release
24+
mock_http_responses.append(mock_http_response)
25+
mock_conn.getresponse.side_effect = mock_http_responses
26+
return mock_conn
27+
28+
class TestTelemetryClientRetries:
29+
@pytest.fixture(autouse=True)
30+
def setup_and_teardown(self):
31+
TelemetryClientFactory._initialized = False
32+
TelemetryClientFactory._clients = {}
33+
TelemetryClientFactory._executor = None
34+
yield
35+
if TelemetryClientFactory._executor:
36+
TelemetryClientFactory._executor.shutdown(wait=True)
37+
TelemetryClientFactory._initialized = False
38+
TelemetryClientFactory._clients = {}
39+
TelemetryClientFactory._executor = None
40+
41+
def get_client(self, session_id, num_retries=3):
42+
"""
43+
Configures a client with a specific number of retries.
44+
"""
45+
TelemetryClientFactory.initialize_telemetry_client(
46+
telemetry_enabled=True,
47+
session_id_hex=session_id,
48+
auth_provider=None,
49+
host_url="test.databricks.com",
50+
)
51+
client = TelemetryClientFactory.get_telemetry_client(session_id)
52+
53+
retry_policy = DatabricksRetryPolicy(
54+
delay_min=0.01,
55+
delay_max=0.02,
56+
stop_after_attempts_duration=2.0,
57+
stop_after_attempts_count=num_retries,
58+
delay_default=0.1,
59+
force_dangerous_codes=[],
60+
urllib3_kwargs={'total': num_retries}
61+
)
62+
adapter = client._http_client.session.adapters.get("https://")
63+
adapter.max_retries = retry_policy
64+
return client
65+
66+
@pytest.mark.parametrize(
67+
"status_code, description",
68+
[
69+
(401, "Unauthorized"),
70+
(403, "Forbidden"),
71+
(501, "Not Implemented"),
72+
(200, "Success"),
73+
],
74+
)
75+
def test_non_retryable_status_codes_are_not_retried(self, status_code, description):
76+
"""
77+
Verifies that terminal error codes (401, 403, 501) and success codes (200) are not retried.
78+
"""
79+
# Use the status code in the session ID for easier debugging if it fails
80+
client = self.get_client(f"session-{status_code}")
81+
mock_responses = [{"status": status_code}]
82+
83+
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
84+
client.export_failure_log("TestError", "Test message")
85+
TelemetryClientFactory.close(client._session_id_hex)
86+
87+
mock_get_conn.return_value.getresponse.assert_called_once()
88+
89+
def test_exceeds_retry_count_limit(self):
90+
"""
91+
Verifies that the client retries up to the specified number of times before giving up.
92+
Verifies that the client respects the Retry-After header and retries on 429, 502, 503.
93+
"""
94+
num_retries = 3
95+
expected_total_calls = num_retries + 1
96+
retry_after = 1
97+
client = self.get_client("session-exceed-limit", num_retries=num_retries)
98+
mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}]
99+
100+
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
101+
start_time = time.time()
102+
client.export_failure_log("TestError", "Test message")
103+
TelemetryClientFactory.close(client._session_id_hex)
104+
end_time = time.time()
105+
106+
assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls
107+
assert end_time - start_time > retry_after

0 commit comments

Comments
 (0)