1
+ import pytest
2
+ from unittest .mock import patch , MagicMock
3
+ import io
4
+ import time
5
+
6
+ from databricks .sql .telemetry .telemetry_client import TelemetryClientFactory
7
+ from databricks .sql .auth .retry import DatabricksRetryPolicy
8
+
9
+ PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn'
10
+
11
+ def create_mock_conn (responses ):
12
+ """Creates a mock connection object whose getresponse() method yields a series of responses."""
13
+ mock_conn = MagicMock ()
14
+ mock_http_responses = []
15
+ for resp in responses :
16
+ mock_http_response = MagicMock ()
17
+ mock_http_response .status = resp .get ("status" )
18
+ mock_http_response .headers = resp .get ("headers" , {})
19
+ body = resp .get ("body" , b'{}' )
20
+ mock_http_response .fp = io .BytesIO (body )
21
+ def release ():
22
+ mock_http_response .fp .close ()
23
+ mock_http_response .release_conn = release
24
+ mock_http_responses .append (mock_http_response )
25
+ mock_conn .getresponse .side_effect = mock_http_responses
26
+ return mock_conn
27
+
28
+ class TestTelemetryClientRetries :
29
+ @pytest .fixture (autouse = True )
30
+ def setup_and_teardown (self ):
31
+ TelemetryClientFactory ._initialized = False
32
+ TelemetryClientFactory ._clients = {}
33
+ TelemetryClientFactory ._executor = None
34
+ yield
35
+ if TelemetryClientFactory ._executor :
36
+ TelemetryClientFactory ._executor .shutdown (wait = True )
37
+ TelemetryClientFactory ._initialized = False
38
+ TelemetryClientFactory ._clients = {}
39
+ TelemetryClientFactory ._executor = None
40
+
41
+ def get_client (self , session_id , num_retries = 3 ):
42
+ """
43
+ Configures a client with a specific number of retries.
44
+ """
45
+ TelemetryClientFactory .initialize_telemetry_client (
46
+ telemetry_enabled = True ,
47
+ session_id_hex = session_id ,
48
+ auth_provider = None ,
49
+ host_url = "test.databricks.com" ,
50
+ )
51
+ client = TelemetryClientFactory .get_telemetry_client (session_id )
52
+
53
+ retry_policy = DatabricksRetryPolicy (
54
+ delay_min = 0.01 ,
55
+ delay_max = 0.02 ,
56
+ stop_after_attempts_duration = 2.0 ,
57
+ stop_after_attempts_count = num_retries ,
58
+ delay_default = 0.1 ,
59
+ force_dangerous_codes = [],
60
+ urllib3_kwargs = {'total' : num_retries }
61
+ )
62
+ adapter = client ._http_client .session .adapters .get ("https://" )
63
+ adapter .max_retries = retry_policy
64
+ return client
65
+
66
+ @pytest .mark .parametrize (
67
+ "status_code, description" ,
68
+ [
69
+ (401 , "Unauthorized" ),
70
+ (403 , "Forbidden" ),
71
+ (501 , "Not Implemented" ),
72
+ (200 , "Success" ),
73
+ ],
74
+ )
75
+ def test_non_retryable_status_codes_are_not_retried (self , status_code , description ):
76
+ """
77
+ Verifies that terminal error codes (401, 403, 501) and success codes (200) are not retried.
78
+ """
79
+ # Use the status code in the session ID for easier debugging if it fails
80
+ client = self .get_client (f"session-{ status_code } " )
81
+ mock_responses = [{"status" : status_code }]
82
+
83
+ with patch (PATCH_TARGET , return_value = create_mock_conn (mock_responses )) as mock_get_conn :
84
+ client .export_failure_log ("TestError" , "Test message" )
85
+ TelemetryClientFactory .close (client ._session_id_hex )
86
+
87
+ mock_get_conn .return_value .getresponse .assert_called_once ()
88
+
89
+ def test_exceeds_retry_count_limit (self ):
90
+ """
91
+ Verifies that the client retries up to the specified number of times before giving up.
92
+ Verifies that the client respects the Retry-After header and retries on 429, 502, 503.
93
+ """
94
+ num_retries = 3
95
+ expected_total_calls = num_retries + 1
96
+ retry_after = 1
97
+ client = self .get_client ("session-exceed-limit" , num_retries = num_retries )
98
+ mock_responses = [{"status" : 503 , "headers" : {"Retry-After" : str (retry_after )}}, {"status" : 429 }, {"status" : 502 }, {"status" : 503 }]
99
+
100
+ with patch (PATCH_TARGET , return_value = create_mock_conn (mock_responses )) as mock_get_conn :
101
+ start_time = time .time ()
102
+ client .export_failure_log ("TestError" , "Test message" )
103
+ TelemetryClientFactory .close (client ._session_id_hex )
104
+ end_time = time .time ()
105
+
106
+ assert mock_get_conn .return_value .getresponse .call_count == expected_total_calls
107
+ assert end_time - start_time > retry_after
0 commit comments