Skip to content

Commit f971360

Browse files
committed
tests
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent a13155b commit f971360

File tree

1 file changed

+107
-0
lines changed

1 file changed

+107
-0
lines changed

tests/unit/test_telemetry_retry.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import pytest
2+
from unittest.mock import patch, MagicMock
3+
import io
4+
import time
5+
6+
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
7+
from databricks.sql.auth.retry import DatabricksRetryPolicy
8+
9+
PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn'
10+
11+
def create_mock_conn(responses):
12+
"""Creates a mock connection object whose getresponse() method yields a series of responses."""
13+
mock_conn = MagicMock()
14+
mock_http_responses = []
15+
for resp in responses:
16+
mock_http_response = MagicMock()
17+
mock_http_response.status = resp.get("status")
18+
mock_http_response.headers = resp.get("headers", {})
19+
body = resp.get("body", b'{}')
20+
mock_http_response.fp = io.BytesIO(body)
21+
def release():
22+
mock_http_response.fp.close()
23+
mock_http_response.release_conn = release
24+
mock_http_responses.append(mock_http_response)
25+
mock_conn.getresponse.side_effect = mock_http_responses
26+
return mock_conn
27+
28+
class TestTelemetryClientRetries:
29+
@pytest.fixture(autouse=True)
30+
def setup_and_teardown(self):
31+
TelemetryClientFactory._initialized = False
32+
TelemetryClientFactory._clients = {}
33+
TelemetryClientFactory._executor = None
34+
yield
35+
if TelemetryClientFactory._executor:
36+
TelemetryClientFactory._executor.shutdown(wait=True)
37+
TelemetryClientFactory._initialized = False
38+
TelemetryClientFactory._clients = {}
39+
TelemetryClientFactory._executor = None
40+
41+
def get_client(self, session_id, num_retries=3):
42+
"""
43+
Configures a client with a specific number of retries.
44+
"""
45+
TelemetryClientFactory.initialize_telemetry_client(
46+
telemetry_enabled=True,
47+
session_id_hex=session_id,
48+
auth_provider=None,
49+
host_url="test.databricks.com",
50+
)
51+
client = TelemetryClientFactory.get_telemetry_client(session_id)
52+
53+
retry_policy = DatabricksRetryPolicy(
54+
delay_min=0.01,
55+
delay_max=0.02,
56+
stop_after_attempts_duration=2.0,
57+
stop_after_attempts_count=num_retries,
58+
delay_default=0.1,
59+
force_dangerous_codes=[],
60+
urllib3_kwargs={'total': num_retries}
61+
)
62+
adapter = client._http_client.session.adapters.get("https://")
63+
adapter.max_retries = retry_policy
64+
return client
65+
66+
@pytest.mark.parametrize(
67+
"status_code, description",
68+
[
69+
(401, "Unauthorized"),
70+
(403, "Forbidden"),
71+
(501, "Not Implemented"),
72+
(200, "Success"),
73+
],
74+
)
75+
def test_non_retryable_status_codes_are_not_retried(self, status_code, description):
76+
"""
77+
Verifies that terminal error codes (401, 403, 501) and success codes (200) are not retried.
78+
"""
79+
# Use the status code in the session ID for easier debugging if it fails
80+
client = self.get_client(f"session-{status_code}")
81+
mock_responses = [{"status": status_code}]
82+
83+
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
84+
client.export_failure_log("TestError", "Test message")
85+
TelemetryClientFactory.close(client._session_id_hex)
86+
87+
mock_get_conn.return_value.getresponse.assert_called_once()
88+
89+
def test_exceeds_retry_count_limit(self):
90+
"""
91+
Verifies that the client retries up to the specified number of times before giving up.
92+
Verifies that the client respects the Retry-After header and retries on 429, 502, 503.
93+
"""
94+
num_retries = 3
95+
expected_total_calls = num_retries + 1
96+
retry_after = 1
97+
client = self.get_client("session-exceed-limit", num_retries=num_retries)
98+
mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}]
99+
100+
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
101+
start_time = time.time()
102+
client.export_failure_log("TestError", "Test message")
103+
TelemetryClientFactory.close(client._session_id_hex)
104+
end_time = time.time()
105+
106+
assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls
107+
assert end_time - start_time > retry_after

0 commit comments

Comments
 (0)