Skip to content

Testing for telemetry #616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: telemetry
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/databricks/sql/telemetry/telemetry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def get_telemetry_client(session_id_hex):
if session_id_hex in TelemetryClientFactory._clients:
return TelemetryClientFactory._clients[session_id_hex]
else:
logger.error(
logger.debug(
"Telemetry client not initialized for connection %s",
session_id_hex,
)
Expand Down
174 changes: 174 additions & 0 deletions tests/e2e/test_concurrent_telemetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import threading
from unittest.mock import patch, MagicMock

from databricks.sql.client import Connection
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory, TelemetryClient
from databricks.sql.thrift_backend import ThriftBackend
from databricks.sql.utils import ExecuteResponse
from databricks.sql.thrift_api.TCLIService.ttypes import TSessionHandle, TOperationHandle, TOperationState, THandleIdentifier

try:
import pyarrow as pa
except ImportError:
pa = None


def run_in_threads(target, num_threads, pass_index=False):
"""Helper to run target function in multiple threads."""
threads = [
threading.Thread(target=target, args=(i,) if pass_index else ())
for i in range(num_threads)
]
for t in threads:
t.start()
for t in threads:
t.join()


class MockArrowQueue:
"""Mock queue that behaves like ArrowQueue but returns empty results."""

def __init__(self):
# Create an empty arrow table if pyarrow is available, otherwise use None
if pa is not None:
self.empty_table = pa.table({'column': pa.array([])})
else:
# Create a simple mock table-like object
self.empty_table = MagicMock()
self.empty_table.num_rows = 0
self.empty_table.num_columns = 0

def next_n_rows(self, num_rows: int):
"""Return empty results."""
return self.empty_table

def remaining_rows(self):
"""Return empty results."""
return self.empty_table


def test_concurrent_queries_with_telemetry_capture():
"""
Test showing concurrent threads executing queries with real telemetry capture.
Uses the actual Connection and Cursor classes, mocking only the ThriftBackend.
"""
num_threads = 5
captured_telemetry = []
connections = [] # Store connections to close them later
connections_lock = threading.Lock() # Thread safety for connections list

def mock_send_telemetry(self, events):
"""Capture telemetry events instead of sending them over network."""
captured_telemetry.extend(events)

# Clean up any existing state
if TelemetryClientFactory._executor:
TelemetryClientFactory._executor.shutdown(wait=True)
TelemetryClientFactory._clients.clear()
TelemetryClientFactory._executor = None
TelemetryClientFactory._initialized = False

with patch.object(TelemetryClient, '_send_telemetry', mock_send_telemetry):
# Mock the ThriftBackend to avoid actual network calls
with patch.object(ThriftBackend, 'open_session') as mock_open_session, \
patch.object(ThriftBackend, 'execute_command') as mock_execute_command, \
patch.object(ThriftBackend, 'close_session') as mock_close_session, \
patch.object(ThriftBackend, 'fetch_results') as mock_fetch_results, \
patch.object(ThriftBackend, 'close_command') as mock_close_command, \
patch.object(ThriftBackend, 'handle_to_hex_id') as mock_handle_to_hex_id, \
patch('databricks.sql.auth.thrift_http_client.THttpClient.open') as mock_transport_open:

# Mock transport.open() to prevent actual network connection
mock_transport_open.return_value = None

# Set up mock responses with proper structure
mock_handle_identifier = THandleIdentifier()
mock_handle_identifier.guid = b'1234567890abcdef'
mock_handle_identifier.secret = b'test_secret_1234'

mock_session_handle = TSessionHandle()
mock_session_handle.sessionId = mock_handle_identifier
mock_session_handle.serverProtocolVersion = 1

mock_open_session.return_value = MagicMock(
sessionHandle=mock_session_handle,
serverProtocolVersion=1
)

mock_handle_to_hex_id.return_value = "test-session-id-12345678"

mock_op_handle = TOperationHandle()
mock_op_handle.operationId = THandleIdentifier()
mock_op_handle.operationId.guid = b'abcdef1234567890'
mock_op_handle.operationId.secret = b'op_secret_abcd'

# Create proper mock arrow_queue with required methods
mock_arrow_queue = MockArrowQueue()

mock_execute_response = ExecuteResponse(
arrow_queue=mock_arrow_queue,
description=[],
command_handle=mock_op_handle,
status=TOperationState.FINISHED_STATE,
has_been_closed_server_side=False,
has_more_rows=False,
lz4_compressed=False,
arrow_schema_bytes=b'',
is_staging_operation=False
)
mock_execute_command.return_value = mock_execute_response

# Mock fetch_results to return empty results
mock_fetch_results.return_value = (mock_arrow_queue, False)

# Mock close_command to do nothing
mock_close_command.return_value = None

# Mock close_session to do nothing
mock_close_session.return_value = None

def execute_query_worker(thread_id):
"""Each thread creates a connection and executes a query."""

# Create real Connection and Cursor objects
conn = Connection(
server_hostname="test-host",
http_path="/test/path",
access_token="test-token",
enable_telemetry=True
)

# Thread-safe storage of connection
with connections_lock:
connections.append(conn)

cursor = conn.cursor()
# This will trigger the @log_latency decorator naturally
cursor.execute(f"SELECT {thread_id} as thread_id")
result = cursor.fetchall()
conn.close()


run_in_threads(execute_query_worker, num_threads, pass_index=True)

# We expect at least 2 events per thread (one for open_session and one for execute_command)
assert len(captured_telemetry) >= num_threads*2
print(f"Captured telemetry: {captured_telemetry}")

# Verify the decorator was used (check some telemetry events have latency measurement)
events_with_latency = [
e for e in captured_telemetry
if hasattr(e, 'entry') and hasattr(e.entry, 'sql_driver_log')
and e.entry.sql_driver_log.operation_latency_ms is not None
]
assert len(events_with_latency) >= num_threads

# Verify we have events with statement IDs (indicating @log_latency decorator worked)
events_with_statements = [
e for e in captured_telemetry
if hasattr(e, 'entry') and hasattr(e.entry, 'sql_driver_log')
and e.entry.sql_driver_log.sql_statement_id is not None
]
assert len(events_with_statements) >= num_threads


179 changes: 177 additions & 2 deletions tests/unit/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import pytest
import requests
from unittest.mock import patch, MagicMock
import threading
import random
import time
from concurrent.futures import ThreadPoolExecutor

from databricks.sql.telemetry.telemetry_client import (
TelemetryClient,
NoopTelemetryClient,
TelemetryClientFactory,
TelemetryHelper,
BaseTelemetryClient
)
from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow
from databricks.sql.auth.authenticators import (
Expand Down Expand Up @@ -283,4 +286,176 @@ def test_factory_shutdown_flow(self, telemetry_system_reset):
# Close second client - factory should shut down
TelemetryClientFactory.close(session2)
assert TelemetryClientFactory._initialized is False
assert TelemetryClientFactory._executor is None
assert TelemetryClientFactory._executor is None


# A helper function to run a target in multiple threads and wait for them.
def run_in_threads(target, num_threads, pass_index=False):
"""Creates, starts, and joins a specified number of threads.

Args:
target: The function to run in each thread
num_threads: Number of threads to create
pass_index: If True, passes the thread index (0, 1, 2, ...) as first argument
"""
threads = [
threading.Thread(target=target, args=(i,) if pass_index else ())
for i in range(num_threads)
]
for t in threads:
t.start()
for t in threads:
t.join()


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add these in a separate file? @jprakash-db what's the sop in python?

class TestTelemetryRaceConditions:
"""Tests for race conditions in multithreaded scenarios."""

@pytest.fixture(autouse=True)
def clean_factory(self):
"""A fixture to automatically reset the factory's state before each test."""
# Clean up at the start of each test
if TelemetryClientFactory._executor:
TelemetryClientFactory._executor.shutdown(wait=True)
TelemetryClientFactory._clients.clear()
TelemetryClientFactory._executor = None
TelemetryClientFactory._initialized = False

yield

# Clean up at the end of each test
if TelemetryClientFactory._executor:
TelemetryClientFactory._executor.shutdown(wait=True)
TelemetryClientFactory._clients.clear()
TelemetryClientFactory._executor = None
TelemetryClientFactory._initialized = False

def test_factory_concurrent_initialization_of_DIFFERENT_clients(self):
"""
Tests that multiple threads creating DIFFERENT clients concurrently
share a single ThreadPoolExecutor and all clients are created successfully.
"""
num_threads = 20

def create_client(thread_id):
TelemetryClientFactory.initialize_telemetry_client(
telemetry_enabled=True,
session_id_hex=f"session_{thread_id}",
auth_provider=None,
host_url="test-host",
)

run_in_threads(create_client, 20, pass_index=True)

# ASSERT: The factory was properly initialized
assert TelemetryClientFactory._initialized is True
assert TelemetryClientFactory._executor is not None
assert isinstance(TelemetryClientFactory._executor, ThreadPoolExecutor)

# ASSERT: All clients were successfully created
assert len(TelemetryClientFactory._clients) == num_threads

# ASSERT: All TelemetryClient instances share the same executor
telemetry_clients = [
client for client in TelemetryClientFactory._clients.values()
if isinstance(client, TelemetryClient)
]
assert len(telemetry_clients) == num_threads

shared_executor = TelemetryClientFactory._executor
for client in telemetry_clients:
assert client._executor is shared_executor

def test_factory_concurrent_initialization_of_SAME_client(self):
"""
Tests that multiple threads trying to initialize the SAME client
result in only one client instance being created.
"""
session_id = "shared-session"
num_threads = 20

def create_same_client():
TelemetryClientFactory.initialize_telemetry_client(
telemetry_enabled=True,
session_id_hex=session_id,
auth_provider=None,
host_url="test-host",
)

run_in_threads(create_same_client, num_threads)

# ASSERT: Only one client was created in the factory.
assert len(TelemetryClientFactory._clients) == 1
client = TelemetryClientFactory.get_telemetry_client(session_id)
assert isinstance(client, TelemetryClient)

def test_client_concurrent_event_export(self):
"""
Tests that no events are lost when multiple threads call _export_event
on the same client instance concurrently.
"""
client = TelemetryClient(True, "session-1", None, "host", MagicMock())
# Mock _flush to prevent auto-flushing when batch size threshold is reached
original_flush = client._flush
client._flush = MagicMock()

num_threads = 5
events_per_thread = 10

def add_events():
for i in range(events_per_thread):
client._export_event(f"event-{i}")

run_in_threads(add_events, num_threads)

# ASSERT: The batch contains all events from all threads, none were lost.
total_expected_events = num_threads * events_per_thread
assert len(client._events_batch) == total_expected_events

# Restore original flush method for cleanup
client._flush = original_flush

def test_client_concurrent_flush(self):
"""
Tests that if multiple threads trigger _flush at the same time,
the underlying send operation is only called once for the batch.
"""
client = TelemetryClient(True, "session-1", None, "host", MagicMock())
client._send_telemetry = MagicMock()

# Pre-fill the batch so there's something to flush
client._events_batch = ["event"] * 5

def call_flush():
client._flush()

run_in_threads(call_flush, 10)

# ASSERT: The send operation was called exactly once.
# This proves the lock prevents multiple threads from sending the same batch.
client._send_telemetry.assert_called_once()
# ASSERT: The event batch is now empty.
assert len(client._events_batch) == 0

def test_factory_concurrent_create_and_close(self):
"""
Tests that concurrently creating and closing different clients
doesn't corrupt the factory state and correctly shuts down the executor.
"""
num_ops = 50

def create_and_close_client(i):
session_id = f"session_{i}"
TelemetryClientFactory.initialize_telemetry_client(
telemetry_enabled=True, session_id_hex=session_id, auth_provider=None, host_url="host"
)
# Small sleep to increase chance of interleaving operations
time.sleep(random.uniform(0, 0.01))
TelemetryClientFactory.close(session_id)

run_in_threads(create_and_close_client, num_ops, pass_index=True)

# ASSERT: After all operations, the factory should be empty and reset.
assert not TelemetryClientFactory._clients
assert TelemetryClientFactory._executor is None
assert not TelemetryClientFactory._initialized
Loading