From 67a8497861f1ed58cbc501f5af83b4a471da4284 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 25 Jun 2025 10:49:28 +0530 Subject: [PATCH 1/4] added multithreaded tests, exeception handling tests Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 8 +- tests/unit/test_telemetry.py | 408 +++++++++++++++++- 2 files changed, 407 insertions(+), 9 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 10aa04ef..db9299ab 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -149,6 +149,7 @@ class TelemetryClient(BaseTelemetryClient): # Telemetry endpoint paths TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext" TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth" + DEFAULT_BATCH_SIZE = 10 def __init__( self, @@ -160,7 +161,7 @@ def __init__( ): logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled - self._batch_size = 10 # TODO: Decide on batch size + self._batch_size = self.DEFAULT_BATCH_SIZE # TODO: Decide on batch size self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None @@ -431,6 +432,9 @@ def close(session_id_hex): logger.debug( "No more telemetry clients, shutting down thread pool executor" ) - TelemetryClientFactory._executor.shutdown(wait=True) + try: + TelemetryClientFactory._executor.shutdown(wait=True) + except Exception as e: + logger.debug("Failed to shutdown thread pool executor: %s", e) TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 699480bb..d1611909 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -1,7 +1,9 @@ import uuid import pytest import requests -from unittest.mock import patch, MagicMock, call +from unittest.mock import patch, MagicMock +import threading +import random from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, @@ -186,17 +188,16 @@ def test_export_event(self, telemetry_client_setup): client = telemetry_client_setup["client"] client._flush = MagicMock() - for i in range(5): + for i in range(TelemetryClient.DEFAULT_BATCH_SIZE-1): client._export_event(f"event-{i}") client._flush.assert_not_called() - assert len(client._events_batch) == 5 + assert len(client._events_batch) == TelemetryClient.DEFAULT_BATCH_SIZE - 1 - for i in range(5, 10): - client._export_event(f"event-{i}") + # Add one more event to reach batch size (this will trigger flush) + client._export_event(f"event-{TelemetryClient.DEFAULT_BATCH_SIZE - 1}") client._flush.assert_called_once() - assert len(client._events_batch) == 10 @patch("requests.post") def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup): @@ -498,4 +499,397 @@ def test_global_exception_hook(self, mock_handle_exception, telemetry_system_res test_exception = ValueError("Test exception") TelemetryClientFactory._handle_unhandled_exception(type(test_exception), test_exception, None) - mock_handle_exception.assert_called_once_with(type(test_exception), test_exception, None) \ No newline at end of file + mock_handle_exception.assert_called_once_with(type(test_exception), test_exception, None) + + def test_initialize_telemetry_client_exception_handling(self, telemetry_system_reset): + """Test that exceptions in initialize_telemetry_client don't cause connector to fail.""" + session_id_hex = "test-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Test exception during TelemetryClient creation + with patch('databricks.sql.telemetry.telemetry_client.TelemetryClient', side_effect=Exception("TelemetryClient creation failed")): + # Should not raise exception, should fallback to NoopTelemetryClient + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) + + def test_get_telemetry_client_exception_handling(self, telemetry_system_reset): + """Test that exceptions in get_telemetry_client don't cause connector to fail.""" + session_id_hex = "test-uuid" + + # Test exception during client lookup by mocking the clients dict + mock_clients = MagicMock() + mock_clients.__contains__.side_effect = Exception("Client lookup failed") + + with patch.object(TelemetryClientFactory, '_clients', mock_clients): + # Should not raise exception, should return NoopTelemetryClient + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) + + def test_get_telemetry_client_dict_access_exception(self, telemetry_system_reset): + """Test that exceptions during dictionary access don't cause connector to fail.""" + session_id_hex = "test-uuid" + + # Test exception during dictionary access + mock_clients = MagicMock() + mock_clients.__contains__.side_effect = Exception("Dictionary access failed") + TelemetryClientFactory._clients = mock_clients + + # Should not raise exception, should return NoopTelemetryClient + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) + + def test_close_telemetry_client_shutdown_executor_exception(self, telemetry_system_reset): + """Test that exceptions during executor shutdown don't cause connector to fail.""" + session_id_hex = "test-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Initialize a client first + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + + # Mock executor to raise exception during shutdown + mock_executor = MagicMock() + mock_executor.shutdown.side_effect = Exception("Executor shutdown failed") + TelemetryClientFactory._executor = mock_executor + + # Should not raise exception (executor shutdown is wrapped in try-catch) + TelemetryClientFactory.close(session_id_hex) + + # Verify executor shutdown was attempted + mock_executor.shutdown.assert_called_once_with(wait=True) + + + +class TestTelemetryRaceConditions: + """Tests for race conditions in multithreaded scenarios.""" + + @pytest.fixture + def race_condition_setup(self): + """Setup for race condition tests.""" + # Reset telemetry system + TelemetryClientFactory._clients.clear() + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + yield + + # Cleanup + TelemetryClientFactory._clients.clear() + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + def test_telemetry_client_concurrent_export_events(self, race_condition_setup): + """Test race conditions in TelemetryClient._export_event with concurrent access.""" + session_id_hex = "test-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + executor = MagicMock() + + client = TelemetryClient( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + executor=executor, + ) + + # Mock _flush to avoid actual network calls + client._flush = MagicMock() + + # Track events added by each thread + thread_events = {} + lock = threading.Lock() + + def add_events(thread_id): + """Add events from a specific thread.""" + events = [] + for i in range(10): + event = f"event-{thread_id}-{i}" + client._export_event(event) + events.append(event) + + with lock: + thread_events[thread_id] = events + + # Start multiple threads adding events concurrently + threads = [] + for i in range(5): + thread = threading.Thread(target=add_events, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify all events were added (no data loss due to race conditions) + total_expected_events = sum(len(events) for events in thread_events.values()) + assert len(client._events_batch) == total_expected_events + + def test_telemetry_client_concurrent_flush_operations(self, race_condition_setup): + """Test race conditions in TelemetryClient._flush with concurrent access.""" + session_id_hex = "test-flush-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + executor = MagicMock() + + client = TelemetryClient( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + executor=executor, + ) + + # Mock _send_telemetry to avoid actual network calls + client._send_telemetry = MagicMock() + + # Add events to trigger flush + for i in range(TelemetryClient.DEFAULT_BATCH_SIZE - 1): + client._export_event(f"event-{i}") + + # Track flush operations + flush_count = 0 + flush_lock = threading.Lock() + + def concurrent_flush(): + """Call flush concurrently.""" + nonlocal flush_count + client._flush() + with flush_lock: + flush_count += 1 + + # Start multiple threads calling flush concurrently + threads = [] + for i in range(10): + thread = threading.Thread(target=concurrent_flush) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify flush was called the expected number of times + assert flush_count == 10 + + # Verify _send_telemetry was called at least once (some calls may have empty batches due to lock) + assert client._send_telemetry.call_count >= 1 + + # Verify that the total events processed is correct (no data loss) + # The first flush should have processed all events, subsequent flushes should have empty batches + total_events_sent = sum(len(call.args[0]) for call in client._send_telemetry.call_args_list) + assert total_events_sent == TelemetryClient.DEFAULT_BATCH_SIZE - 1 + + def test_telemetry_client_factory_concurrent_initialization(self, race_condition_setup): + """Test race conditions in TelemetryClientFactory.initialize_telemetry_client with concurrent access.""" + session_id_hex = "test-factory-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Track initialization attempts + init_results = [] + init_lock = threading.Lock() + + def concurrent_initialize(thread_id): + """Initialize telemetry client concurrently.""" + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + + with init_lock: + init_results.append({ + 'thread_id': thread_id, + 'client_type': type(client).__name__ + }) + + # Start multiple threads initializing concurrently + threads = [] + for i in range(10): + thread = threading.Thread(target=concurrent_initialize, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + assert len(init_results) == 10 + + # Verify only one client was created (no duplicate clients due to race conditions) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, TelemetryClient) + + # Verify the client is the same for all threads (singleton behavior) + client_ids = set() + for result in init_results: + client_ids.add(id(TelemetryClientFactory.get_telemetry_client(session_id_hex))) + + assert len(client_ids) == 1 + + def test_telemetry_client_factory_concurrent_get_client(self, race_condition_setup): + """Test race conditions in TelemetryClientFactory.get_telemetry_client with concurrent access.""" + session_id_hex = "test-get-client-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Initialize a client first + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + + # Track get_client attempts + get_results = [] + get_lock = threading.Lock() + + def concurrent_get_client(thread_id): + """Get telemetry client concurrently.""" + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + + with get_lock: + get_results.append({ + 'thread_id': thread_id, + 'client_type': type(client).__name__, + 'client_id': id(client) + }) + + # Start multiple threads getting client concurrently + threads = [] + for i in range(20): + thread = threading.Thread(target=concurrent_get_client, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify all get_client calls succeeded + assert len(get_results) == 20 + + # Verify all threads got the same client instance (no race conditions) + client_ids = set(result['client_id'] for result in get_results) + assert len(client_ids) == 1 # Only one client instance returned + + def test_telemetry_client_factory_concurrent_close(self, race_condition_setup): + """Test race conditions in TelemetryClientFactory.close with concurrent access.""" + session_id_hex = "test-close-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Initialize a client first + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + + def concurrent_close(thread_id): + """Close telemetry client concurrently.""" + + TelemetryClientFactory.close(session_id_hex) + + # Start multiple threads closing concurrently + threads = [] + for i in range(5): + thread = threading.Thread(target=concurrent_close, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify client is no longer available after close + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) + + def test_telemetry_client_factory_mixed_concurrent_operations(self, race_condition_setup): + """Test race conditions with mixed concurrent operations on TelemetryClientFactory.""" + session_id_hex = "test-mixed-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Track operation results + operation_results = [] + operation_lock = threading.Lock() + + def mixed_operations(thread_id): + """Perform mixed operations concurrently.""" + + # Randomly choose an operation + operation = random.choice(['init', 'get', 'close']) + + if operation == 'init': + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + + with operation_lock: + operation_results.append({ + 'thread_id': thread_id, + 'operation': 'init', + 'client_type': type(client).__name__ + }) + + elif operation == 'get': + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + + with operation_lock: + operation_results.append({ + 'thread_id': thread_id, + 'operation': 'get', + 'client_type': type(client).__name__ + }) + + elif operation == 'close': + TelemetryClientFactory.close(session_id_hex) + + with operation_lock: + operation_results.append({ + 'thread_id': thread_id, + 'operation': 'close' + }) + + # Start multiple threads performing mixed operations + threads = [] + for i in range(15): + thread = threading.Thread(target=mixed_operations, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + assert len(operation_results) == 15 From 70fd810270dfd7db7924e76bfe0f84b9f6299b34 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 1 Jul 2025 14:09:46 +0530 Subject: [PATCH 2/4] used batch size instead of default batch size Signed-off-by: Sai Shree Pradhan --- tests/unit/test_telemetry.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index d1611909..519f79f1 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -188,14 +188,15 @@ def test_export_event(self, telemetry_client_setup): client = telemetry_client_setup["client"] client._flush = MagicMock() - for i in range(TelemetryClient.DEFAULT_BATCH_SIZE-1): + batch_size = client._batch_size + + for i in range(batch_size - 1): client._export_event(f"event-{i}") client._flush.assert_not_called() - assert len(client._events_batch) == TelemetryClient.DEFAULT_BATCH_SIZE - 1 - - # Add one more event to reach batch size (this will trigger flush) - client._export_event(f"event-{TelemetryClient.DEFAULT_BATCH_SIZE - 1}") + assert len(client._events_batch) == batch_size - 1 + + client._export_event(f"event-{batch_size - 1}") client._flush.assert_called_once() @@ -658,11 +659,9 @@ def test_telemetry_client_concurrent_flush_operations(self, race_condition_setup executor=executor, ) - # Mock _send_telemetry to avoid actual network calls client._send_telemetry = MagicMock() - # Add events to trigger flush - for i in range(TelemetryClient.DEFAULT_BATCH_SIZE - 1): + for i in range(client._batch_size - 1): client._export_event(f"event-{i}") # Track flush operations @@ -690,13 +689,13 @@ def concurrent_flush(): # Verify flush was called the expected number of times assert flush_count == 10 - # Verify _send_telemetry was called at least once (some calls may have empty batches due to lock) - assert client._send_telemetry.call_count >= 1 + # Verify _send_telemetry was called once + assert client._send_telemetry.call_count == 1 # Verify that the total events processed is correct (no data loss) # The first flush should have processed all events, subsequent flushes should have empty batches total_events_sent = sum(len(call.args[0]) for call in client._send_telemetry.call_args_list) - assert total_events_sent == TelemetryClient.DEFAULT_BATCH_SIZE - 1 + assert total_events_sent == client._batch_size - 1 def test_telemetry_client_factory_concurrent_initialization(self, race_condition_setup): """Test race conditions in TelemetryClientFactory.initialize_telemetry_client with concurrent access.""" From 3e9b47d0d4b1f5da8c6bf8bb69a6733c65a59358 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 4 Jul 2025 14:31:15 +0530 Subject: [PATCH 3/4] tests Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 5 +- tests/unit/test_telemetry.py | 179 +++++++++++++++++- 2 files changed, 179 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 842b491d..85dc4f66 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -156,7 +156,6 @@ class TelemetryClient(BaseTelemetryClient): # Telemetry endpoint paths TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext" TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth" - DEFAULT_BATCH_SIZE = 10 def __init__( self, @@ -168,7 +167,7 @@ def __init__( ): logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled - self._batch_size = self.DEFAULT_BATCH_SIZE # TODO: Decide on batch size + self._batch_size = 10 # TODO: Decide on batch size self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None @@ -403,7 +402,7 @@ def get_telemetry_client(session_id_hex): if session_id_hex in TelemetryClientFactory._clients: return TelemetryClientFactory._clients[session_id_hex] else: - logger.error( + logger.debug( "Telemetry client not initialized for connection %s", session_id_hex, ) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 271e8497..c485c555 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -2,13 +2,16 @@ import pytest import requests from unittest.mock import patch, MagicMock +import threading +import random +import time +from concurrent.futures import ThreadPoolExecutor from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, NoopTelemetryClient, TelemetryClientFactory, TelemetryHelper, - BaseTelemetryClient ) from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow from databricks.sql.auth.authenticators import ( @@ -283,4 +286,176 @@ def test_factory_shutdown_flow(self, telemetry_system_reset): # Close second client - factory should shut down TelemetryClientFactory.close(session2) assert TelemetryClientFactory._initialized is False - assert TelemetryClientFactory._executor is None \ No newline at end of file + assert TelemetryClientFactory._executor is None + + +# A helper function to run a target in multiple threads and wait for them. +def run_in_threads(target, num_threads, pass_index=False): + """Creates, starts, and joins a specified number of threads. + + Args: + target: The function to run in each thread + num_threads: Number of threads to create + pass_index: If True, passes the thread index (0, 1, 2, ...) as first argument + """ + threads = [ + threading.Thread(target=target, args=(i,) if pass_index else ()) + for i in range(num_threads) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + +class TestTelemetryRaceConditions: + """Tests for race conditions in multithreaded scenarios.""" + + @pytest.fixture(autouse=True) + def clean_factory(self): + """A fixture to automatically reset the factory's state before each test.""" + # Clean up at the start of each test + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._clients.clear() + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + yield + + # Clean up at the end of each test + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._clients.clear() + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + def test_factory_concurrent_initialization_of_DIFFERENT_clients(self): + """ + Tests that multiple threads creating DIFFERENT clients concurrently + share a single ThreadPoolExecutor and all clients are created successfully. + """ + num_threads = 20 + + def create_client(thread_id): + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=f"session_{thread_id}", + auth_provider=None, + host_url="test-host", + ) + + run_in_threads(create_client, 20, pass_index=True) + + # ASSERT: The factory was properly initialized + assert TelemetryClientFactory._initialized is True + assert TelemetryClientFactory._executor is not None + assert isinstance(TelemetryClientFactory._executor, ThreadPoolExecutor) + + # ASSERT: All clients were successfully created + assert len(TelemetryClientFactory._clients) == num_threads + + # ASSERT: All TelemetryClient instances share the same executor + telemetry_clients = [ + client for client in TelemetryClientFactory._clients.values() + if isinstance(client, TelemetryClient) + ] + assert len(telemetry_clients) == num_threads + + shared_executor = TelemetryClientFactory._executor + for client in telemetry_clients: + assert client._executor is shared_executor + + def test_factory_concurrent_initialization_of_SAME_client(self): + """ + Tests that multiple threads trying to initialize the SAME client + result in only one client instance being created. + """ + session_id = "shared-session" + num_threads = 20 + + def create_same_client(): + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id, + auth_provider=None, + host_url="test-host", + ) + + run_in_threads(create_same_client, num_threads) + + # ASSERT: Only one client was created in the factory. + assert len(TelemetryClientFactory._clients) == 1 + client = TelemetryClientFactory.get_telemetry_client(session_id) + assert isinstance(client, TelemetryClient) + + def test_client_concurrent_event_export(self): + """ + Tests that no events are lost when multiple threads call _export_event + on the same client instance concurrently. + """ + client = TelemetryClient(True, "session-1", None, "host", MagicMock()) + # Mock _flush to prevent auto-flushing when batch size threshold is reached + original_flush = client._flush + client._flush = MagicMock() + + num_threads = 5 + events_per_thread = 10 + + def add_events(): + for i in range(events_per_thread): + client._export_event(f"event-{i}") + + run_in_threads(add_events, num_threads) + + # ASSERT: The batch contains all events from all threads, none were lost. + total_expected_events = num_threads * events_per_thread + assert len(client._events_batch) == total_expected_events + + # Restore original flush method for cleanup + client._flush = original_flush + + def test_client_concurrent_flush(self): + """ + Tests that if multiple threads trigger _flush at the same time, + the underlying send operation is only called once for the batch. + """ + client = TelemetryClient(True, "session-1", None, "host", MagicMock()) + client._send_telemetry = MagicMock() + + # Pre-fill the batch so there's something to flush + client._events_batch = ["event"] * 5 + + def call_flush(): + client._flush() + + run_in_threads(call_flush, 10) + + # ASSERT: The send operation was called exactly once. + # This proves the lock prevents multiple threads from sending the same batch. + client._send_telemetry.assert_called_once() + # ASSERT: The event batch is now empty. + assert len(client._events_batch) == 0 + + def test_factory_concurrent_create_and_close(self): + """ + Tests that concurrently creating and closing different clients + doesn't corrupt the factory state and correctly shuts down the executor. + """ + num_ops = 50 + + def create_and_close_client(i): + session_id = f"session_{i}" + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, session_id_hex=session_id, auth_provider=None, host_url="host" + ) + # Small sleep to increase chance of interleaving operations + time.sleep(random.uniform(0, 0.01)) + TelemetryClientFactory.close(session_id) + + run_in_threads(create_and_close_client, num_ops, pass_index=True) + + # ASSERT: After all operations, the factory should be empty and reset. + assert not TelemetryClientFactory._clients + assert TelemetryClientFactory._executor is None + assert not TelemetryClientFactory._initialized \ No newline at end of file From 11d41cea1d3a2dcdee90d6d4894771278a5a4933 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 4 Jul 2025 15:38:09 +0530 Subject: [PATCH 4/4] test Signed-off-by: Sai Shree Pradhan --- tests/e2e/test_concurrent_telemetry.py | 174 +++++++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 tests/e2e/test_concurrent_telemetry.py diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py new file mode 100644 index 00000000..bb148f23 --- /dev/null +++ b/tests/e2e/test_concurrent_telemetry.py @@ -0,0 +1,174 @@ +import threading +from unittest.mock import patch, MagicMock + +from databricks.sql.client import Connection +from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory, TelemetryClient +from databricks.sql.thrift_backend import ThriftBackend +from databricks.sql.utils import ExecuteResponse +from databricks.sql.thrift_api.TCLIService.ttypes import TSessionHandle, TOperationHandle, TOperationState, THandleIdentifier + +try: + import pyarrow as pa +except ImportError: + pa = None + + +def run_in_threads(target, num_threads, pass_index=False): + """Helper to run target function in multiple threads.""" + threads = [ + threading.Thread(target=target, args=(i,) if pass_index else ()) + for i in range(num_threads) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + +class MockArrowQueue: + """Mock queue that behaves like ArrowQueue but returns empty results.""" + + def __init__(self): + # Create an empty arrow table if pyarrow is available, otherwise use None + if pa is not None: + self.empty_table = pa.table({'column': pa.array([])}) + else: + # Create a simple mock table-like object + self.empty_table = MagicMock() + self.empty_table.num_rows = 0 + self.empty_table.num_columns = 0 + + def next_n_rows(self, num_rows: int): + """Return empty results.""" + return self.empty_table + + def remaining_rows(self): + """Return empty results.""" + return self.empty_table + + +def test_concurrent_queries_with_telemetry_capture(): + """ + Test showing concurrent threads executing queries with real telemetry capture. + Uses the actual Connection and Cursor classes, mocking only the ThriftBackend. + """ + num_threads = 5 + captured_telemetry = [] + connections = [] # Store connections to close them later + connections_lock = threading.Lock() # Thread safety for connections list + + def mock_send_telemetry(self, events): + """Capture telemetry events instead of sending them over network.""" + captured_telemetry.extend(events) + + # Clean up any existing state + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._clients.clear() + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + with patch.object(TelemetryClient, '_send_telemetry', mock_send_telemetry): + # Mock the ThriftBackend to avoid actual network calls + with patch.object(ThriftBackend, 'open_session') as mock_open_session, \ + patch.object(ThriftBackend, 'execute_command') as mock_execute_command, \ + patch.object(ThriftBackend, 'close_session') as mock_close_session, \ + patch.object(ThriftBackend, 'fetch_results') as mock_fetch_results, \ + patch.object(ThriftBackend, 'close_command') as mock_close_command, \ + patch.object(ThriftBackend, 'handle_to_hex_id') as mock_handle_to_hex_id, \ + patch('databricks.sql.auth.thrift_http_client.THttpClient.open') as mock_transport_open: + + # Mock transport.open() to prevent actual network connection + mock_transport_open.return_value = None + + # Set up mock responses with proper structure + mock_handle_identifier = THandleIdentifier() + mock_handle_identifier.guid = b'1234567890abcdef' + mock_handle_identifier.secret = b'test_secret_1234' + + mock_session_handle = TSessionHandle() + mock_session_handle.sessionId = mock_handle_identifier + mock_session_handle.serverProtocolVersion = 1 + + mock_open_session.return_value = MagicMock( + sessionHandle=mock_session_handle, + serverProtocolVersion=1 + ) + + mock_handle_to_hex_id.return_value = "test-session-id-12345678" + + mock_op_handle = TOperationHandle() + mock_op_handle.operationId = THandleIdentifier() + mock_op_handle.operationId.guid = b'abcdef1234567890' + mock_op_handle.operationId.secret = b'op_secret_abcd' + + # Create proper mock arrow_queue with required methods + mock_arrow_queue = MockArrowQueue() + + mock_execute_response = ExecuteResponse( + arrow_queue=mock_arrow_queue, + description=[], + command_handle=mock_op_handle, + status=TOperationState.FINISHED_STATE, + has_been_closed_server_side=False, + has_more_rows=False, + lz4_compressed=False, + arrow_schema_bytes=b'', + is_staging_operation=False + ) + mock_execute_command.return_value = mock_execute_response + + # Mock fetch_results to return empty results + mock_fetch_results.return_value = (mock_arrow_queue, False) + + # Mock close_command to do nothing + mock_close_command.return_value = None + + # Mock close_session to do nothing + mock_close_session.return_value = None + + def execute_query_worker(thread_id): + """Each thread creates a connection and executes a query.""" + + # Create real Connection and Cursor objects + conn = Connection( + server_hostname="test-host", + http_path="/test/path", + access_token="test-token", + enable_telemetry=True + ) + + # Thread-safe storage of connection + with connections_lock: + connections.append(conn) + + cursor = conn.cursor() + # This will trigger the @log_latency decorator naturally + cursor.execute(f"SELECT {thread_id} as thread_id") + result = cursor.fetchall() + conn.close() + + + run_in_threads(execute_query_worker, num_threads, pass_index=True) + + # We expect at least 2 events per thread (one for open_session and one for execute_command) + assert len(captured_telemetry) >= num_threads*2 + print(f"Captured telemetry: {captured_telemetry}") + + # Verify the decorator was used (check some telemetry events have latency measurement) + events_with_latency = [ + e for e in captured_telemetry + if hasattr(e, 'entry') and hasattr(e.entry, 'sql_driver_log') + and e.entry.sql_driver_log.operation_latency_ms is not None + ] + assert len(events_with_latency) >= num_threads + + # Verify we have events with statement IDs (indicating @log_latency decorator worked) + events_with_statements = [ + e for e in captured_telemetry + if hasattr(e, 'entry') and hasattr(e.entry, 'sql_driver_log') + and e.entry.sql_driver_log.sql_statement_id is not None + ] + assert len(events_with_statements) >= num_threads + + \ No newline at end of file