1
+ import threading
2
+ from unittest .mock import patch
3
+ import pytest
4
+
5
+ from databricks .sql .telemetry .telemetry_client import TelemetryClient , TelemetryClientFactory
6
+ from tests .e2e .test_driver import PySQLPytestTestCase
7
+
8
+ def run_in_threads (target , num_threads , pass_index = False ):
9
+ """Helper to run target function in multiple threads."""
10
+ threads = [
11
+ threading .Thread (target = target , args = (i ,) if pass_index else ())
12
+ for i in range (num_threads )
13
+ ]
14
+ for t in threads :
15
+ t .start ()
16
+ for t in threads :
17
+ t .join ()
18
+
19
+
20
+ class TestE2ETelemetry (PySQLPytestTestCase ):
21
+
22
+ @pytest .fixture (autouse = True )
23
+ def telemetry_setup_teardown (self ):
24
+ """
25
+ This fixture ensures the TelemetryClientFactory is in a clean state
26
+ before each test and shuts it down afterward. Using a fixture makes
27
+ this robust and automatic.
28
+ """
29
+ # --- SETUP ---
30
+ if TelemetryClientFactory ._executor :
31
+ TelemetryClientFactory ._executor .shutdown (wait = True )
32
+ TelemetryClientFactory ._clients .clear ()
33
+ TelemetryClientFactory ._executor = None
34
+ TelemetryClientFactory ._initialized = False
35
+
36
+ yield # This is where the test runs
37
+
38
+ # --- TEARDOWN ---
39
+ if TelemetryClientFactory ._executor :
40
+ TelemetryClientFactory ._executor .shutdown (wait = True )
41
+ TelemetryClientFactory ._executor = None
42
+ TelemetryClientFactory ._initialized = False
43
+
44
+ def test_concurrent_queries_sends_telemetry (self ):
45
+ """
46
+ An E2E test where concurrent threads execute real queries against
47
+ the staging endpoint, while we capture and verify the generated telemetry.
48
+ """
49
+ num_threads = 5
50
+ captured_telemetry = []
51
+ captured_telemetry_lock = threading .Lock ()
52
+
53
+ original_send_telemetry = TelemetryClient ._send_telemetry
54
+
55
+ def send_telemetry_wrapper (self_client , events ):
56
+ with captured_telemetry_lock :
57
+ captured_telemetry .extend (events )
58
+ original_send_telemetry (self_client , events )
59
+
60
+ with patch .object (TelemetryClient , "_send_telemetry" , send_telemetry_wrapper ):
61
+
62
+ def execute_query_worker (thread_id ):
63
+ """Each thread creates a connection and executes a query."""
64
+ with self .connection (extra_params = {"enable_telemetry" : True }) as conn :
65
+ with conn .cursor () as cursor :
66
+ cursor .execute (f"SELECT { thread_id } " )
67
+ cursor .fetchall ()
68
+
69
+ # Run the workers concurrently
70
+ run_in_threads (execute_query_worker , num_threads , pass_index = True )
71
+
72
+ if TelemetryClientFactory ._executor :
73
+ TelemetryClientFactory ._executor .shutdown (wait = True )
74
+
75
+ # --- VERIFICATION ---
76
+ assert len (captured_telemetry ) == num_threads * 2 # 2 events per thread (initial_telemetry_log, latency_log (execute))
77
+
78
+ events_with_latency = [
79
+ e for e in captured_telemetry
80
+ if e .entry .sql_driver_log .operation_latency_ms is not None and e .entry .sql_driver_log .sql_statement_id is not None
81
+ ]
82
+ assert len (events_with_latency ) == num_threads # 1 event per thread (execute)
0 commit comments