diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 9446d87a3..3a682984e 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -49,6 +49,15 @@ TSparkParameter, TOperationState, ) +from databricks.sql.telemetry.telemetry_client import ( + TelemetryClientFactory, + TelemetryHelper, +) +from databricks.sql.telemetry.models.enums import DatabricksClientType +from databricks.sql.telemetry.models.event import ( + DriverConnectionParameters, + HostDetails, +) logger = logging.getLogger(__name__) @@ -294,6 +303,31 @@ def read(self) -> Optional[OAuthToken]: kwargs.get("use_inline_params", False) ) + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=self.telemetry_enabled, + connection_uuid=self.get_session_id_hex(), + auth_provider=auth_provider, + host_url=self.host, + ) + + self._telemetry_client = TelemetryClientFactory.get_telemetry_client( + connection_uuid=self.get_session_id_hex() + ) + + driver_connection_params = DriverConnectionParameters( + http_path=http_path, + mode=DatabricksClientType.THRIFT, + host_info=HostDetails(host_url=server_hostname, port=self.port), + auth_mech=TelemetryHelper.get_auth_mechanism(auth_provider), + auth_flow=TelemetryHelper.get_auth_flow(auth_provider), + socket_timeout=kwargs.get("_socket_timeout", None), + ) + + self._telemetry_client.export_initial_telemetry_log( + driver_connection_params=driver_connection_params, + user_agent=useragent_header, + ) + def _set_use_inline_params_with_warning(self, value: Union[bool, str]): """Valid values are True, False, and "silent" @@ -430,6 +464,8 @@ def _close(self, close_cursors=True) -> None: self.open = False + self._telemetry_client.close() + def commit(self): """No-op because Databricks does not support transactions""" pass diff --git a/src/databricks/sql/telemetry/models/enums.py b/src/databricks/sql/telemetry/models/enums.py index cd7cd9a33..a5363a57e 100644 --- a/src/databricks/sql/telemetry/models/enums.py +++ b/src/databricks/sql/telemetry/models/enums.py @@ -3,15 +3,14 @@ class AuthFlow(Enum): TOKEN_PASSTHROUGH = "token_passthrough" - CLIENT_CREDENTIALS = "client_credentials" BROWSER_BASED_AUTHENTICATION = "browser_based_authentication" - AZURE_MANAGED_IDENTITIES = "azure_managed_identities" class AuthMech(Enum): - OTHER = "other" - PAT = "pat" - OAUTH = "oauth" + CLIENT_CERT = "CLIENT_CERT" # ssl certificate authentication + PAT = "PAT" # Personal Access Token authentication + DATABRICKS_OAUTH = "DATABRICKS_OAUTH" # Databricks-managed OAuth flow + EXTERNAL_AUTH = "EXTERNAL_AUTH" # External identity provider (AWS, Azure, etc.) class DatabricksClientType(Enum): diff --git a/src/databricks/sql/telemetry/models/event.py b/src/databricks/sql/telemetry/models/event.py index 03ce5c5db..4429a7626 100644 --- a/src/databricks/sql/telemetry/models/event.py +++ b/src/databricks/sql/telemetry/models/event.py @@ -9,6 +9,7 @@ ExecutionResultFormat, ) from typing import Optional +from databricks.sql.telemetry.utils import EnumEncoder @dataclass @@ -40,26 +41,18 @@ class DriverConnectionParameters: host_info (HostDetails): Details about the host connection auth_mech (AuthMech): The authentication mechanism used auth_flow (AuthFlow): The authentication flow type - auth_scope (str): The scope of authentication - discovery_url (str): URL for service discovery - allowed_volume_ingestion_paths (str): JSON string of allowed paths for volume operations - azure_tenant_id (str): Azure tenant ID for Azure authentication socket_timeout (int): Connection timeout in milliseconds """ http_path: str mode: DatabricksClientType host_info: HostDetails - auth_mech: AuthMech - auth_flow: AuthFlow - auth_scope: str - discovery_url: str - allowed_volume_ingestion_paths: str - azure_tenant_id: str - socket_timeout: int + auth_mech: Optional[AuthMech] = None + auth_flow: Optional[AuthFlow] = None + socket_timeout: Optional[int] = None def to_json(self): - return json.dumps(asdict(self)) + return json.dumps(asdict(self), cls=EnumEncoder) @dataclass @@ -89,13 +82,13 @@ class DriverSystemConfiguration: runtime_name: str runtime_version: str runtime_vendor: str - client_app_name: str - locale_name: str driver_name: str char_set_encoding: str + client_app_name: Optional[str] = None + locale_name: Optional[str] = None def to_json(self): - return json.dumps(asdict(self)) + return json.dumps(asdict(self), cls=EnumEncoder) @dataclass @@ -113,7 +106,7 @@ class DriverVolumeOperation: volume_path: str def to_json(self): - return json.dumps(asdict(self)) + return json.dumps(asdict(self), cls=EnumEncoder) @dataclass @@ -131,7 +124,7 @@ class DriverErrorInfo: stack_trace: str def to_json(self): - return json.dumps(asdict(self)) + return json.dumps(asdict(self), cls=EnumEncoder) @dataclass @@ -153,7 +146,7 @@ class SqlExecutionEvent: retry_count: int def to_json(self): - return json.dumps(asdict(self)) + return json.dumps(asdict(self), cls=EnumEncoder) @dataclass @@ -186,4 +179,4 @@ class TelemetryEvent: operation_latency_ms: Optional[int] = None def to_json(self): - return json.dumps(asdict(self)) + return json.dumps(asdict(self), cls=EnumEncoder) diff --git a/src/databricks/sql/telemetry/models/frontend_logs.py b/src/databricks/sql/telemetry/models/frontend_logs.py index 953e39b39..36086a7cc 100644 --- a/src/databricks/sql/telemetry/models/frontend_logs.py +++ b/src/databricks/sql/telemetry/models/frontend_logs.py @@ -1,6 +1,7 @@ import json from dataclasses import dataclass, asdict from databricks.sql.telemetry.models.event import TelemetryEvent +from databricks.sql.telemetry.utils import EnumEncoder from typing import Optional @@ -19,7 +20,7 @@ class TelemetryClientContext: user_agent: str def to_json(self): - return json.dumps(asdict(self)) + return json.dumps(asdict(self), cls=EnumEncoder) @dataclass @@ -35,7 +36,7 @@ class FrontendLogContext: client_context: TelemetryClientContext def to_json(self): - return json.dumps(asdict(self)) + return json.dumps(asdict(self), cls=EnumEncoder) @dataclass @@ -51,7 +52,7 @@ class FrontendLogEntry: sql_driver_log: TelemetryEvent def to_json(self): - return json.dumps(asdict(self)) + return json.dumps(asdict(self), cls=EnumEncoder) @dataclass @@ -74,4 +75,4 @@ class TelemetryFrontendLog: workspace_id: Optional[int] = None def to_json(self): - return json.dumps(asdict(self)) + return json.dumps(asdict(self), cls=EnumEncoder) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py new file mode 100644 index 000000000..d095d685c --- /dev/null +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -0,0 +1,349 @@ +import threading +import time +import json +import requests +import logging +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, Optional +from databricks.sql.telemetry.models.event import ( + TelemetryEvent, + DriverSystemConfiguration, +) +from databricks.sql.telemetry.models.frontend_logs import ( + TelemetryFrontendLog, + TelemetryClientContext, + FrontendLogContext, + FrontendLogEntry, +) +from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow +from databricks.sql.auth.authenticators import ( + AccessTokenAuthProvider, + DatabricksOAuthProvider, + ExternalAuthProvider, +) +import sys +import platform +import uuid +import locale +from abc import ABC, abstractmethod +from databricks.sql import __version__ + +logger = logging.getLogger(__name__) + + +class TelemetryHelper: + """Helper class for getting telemetry related information.""" + + _DRIVER_SYSTEM_CONFIGURATION = DriverSystemConfiguration( + driver_name="Databricks SQL Python Connector", + driver_version=__version__, + runtime_name=f"Python {sys.version.split()[0]}", + runtime_vendor=platform.python_implementation(), + runtime_version=platform.python_version(), + os_name=platform.system(), + os_version=platform.release(), + os_arch=platform.machine(), + client_app_name=None, # TODO: Add client app name + locale_name=locale.getlocale()[0] or locale.getdefaultlocale()[0], + char_set_encoding=sys.getdefaultencoding(), + ) + + @classmethod + def getDriverSystemConfiguration(cls) -> DriverSystemConfiguration: + return cls._DRIVER_SYSTEM_CONFIGURATION + + @staticmethod + def get_auth_mechanism(auth_provider): + """Get the auth mechanism for the auth provider.""" + # AuthMech is an enum with the following values: + # PAT, DATABRICKS_OAUTH, EXTERNAL_AUTH, CLIENT_CERT + + if not auth_provider: + return None + if isinstance(auth_provider, AccessTokenAuthProvider): + return AuthMech.PAT # Personal Access Token authentication + elif isinstance(auth_provider, DatabricksOAuthProvider): + return AuthMech.DATABRICKS_OAUTH # Databricks-managed OAuth flow + elif isinstance(auth_provider, ExternalAuthProvider): + return ( + AuthMech.EXTERNAL_AUTH + ) # External identity provider (AWS, Azure, etc.) + return AuthMech.CLIENT_CERT # Client certificate (ssl) + + @staticmethod + def get_auth_flow(auth_provider): + """Get the auth flow for the auth provider.""" + # AuthFlow is an enum with the following values: + # TOKEN_PASSTHROUGH, BROWSER_BASED_AUTHENTICATION + + if not auth_provider: + return None + + if isinstance(auth_provider, DatabricksOAuthProvider): + if auth_provider._access_token and auth_provider._refresh_token: + return ( + AuthFlow.TOKEN_PASSTHROUGH + ) # Has existing tokens, no user interaction needed + if hasattr(auth_provider, "oauth_manager"): + return ( + AuthFlow.BROWSER_BASED_AUTHENTICATION + ) # Will initiate OAuth flow requiring browser + + return None + + +class BaseTelemetryClient(ABC): + """ + Base class for telemetry clients. + It is used to define the interface for telemetry clients. + """ + + @abstractmethod + def export_initial_telemetry_log(self, **kwargs): + pass + + @abstractmethod + def close(self): + pass + + +class NoopTelemetryClient(BaseTelemetryClient): + """ + NoopTelemetryClient is a telemetry client that does not send any events to the server. + It is used when telemetry is disabled. + """ + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(NoopTelemetryClient, cls).__new__(cls) + return cls._instance + + def export_initial_telemetry_log(self, driver_connection_params, user_agent): + pass + + def close(self): + pass + + +class TelemetryClient(BaseTelemetryClient): + """ + Telemetry client class that handles sending telemetry events in batches to the server. + It uses a thread pool to handle asynchronous operations, that it gets from the TelemetryClientFactory. + """ + + # Telemetry endpoint paths + TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext" + TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth" + + def __init__( + self, + telemetry_enabled, + connection_uuid, + auth_provider, + host_url, + executor, + ): + logger.debug("Initializing TelemetryClient for connection: %s", connection_uuid) + self._telemetry_enabled = telemetry_enabled + self._batch_size = 10 # TODO: Decide on batch size + self._connection_uuid = connection_uuid + self._auth_provider = auth_provider + self._user_agent = None + self._events_batch = [] + self._lock = threading.Lock() + self._driver_connection_params = None + self._host_url = host_url + self._executor = executor + + def export_event(self, event): + """Add an event to the batch queue and flush if batch is full""" + logger.debug("Exporting event for connection %s", self._connection_uuid) + with self._lock: + self._events_batch.append(event) + if len(self._events_batch) >= self._batch_size: + logger.debug( + "Batch size limit reached (%s), flushing events", self._batch_size + ) + self.flush() + + def flush(self): + """Flush the current batch of events to the server""" + with self._lock: + events_to_flush = self._events_batch.copy() + self._events_batch = [] + + if events_to_flush: + logger.debug("Flushing %s telemetry events to server", len(events_to_flush)) + self._send_telemetry(events_to_flush) + + def _send_telemetry(self, events): + """Send telemetry events to the server""" + + request = { + "uploadTime": int(time.time() * 1000), + "items": [], + "protoLogs": [event.to_json() for event in events], + } + + path = ( + self.TELEMETRY_AUTHENTICATED_PATH + if self._auth_provider + else self.TELEMETRY_UNAUTHENTICATED_PATH + ) + url = f"https://{self._host_url}{path}" + + headers = {"Accept": "application/json", "Content-Type": "application/json"} + + if self._auth_provider: + self._auth_provider.add_headers(headers) + + try: + logger.debug("Submitting telemetry request to thread pool") + future = self._executor.submit( + requests.post, + url, + data=json.dumps(request), + headers=headers, + timeout=10, + ) + future.add_done_callback(self._telemetry_request_callback) + except Exception as e: + logger.debug("Failed to submit telemetry request: %s", e) + + def _telemetry_request_callback(self, future): + """Callback function to handle telemetry request completion""" + try: + response = future.result() + + if response.status_code == 200: + logger.debug("Telemetry request completed successfully") + else: + logger.debug( + "Telemetry request failed with status code: %s", + response.status_code, + ) + + except Exception as e: + logger.debug("Telemetry request failed with exception: %s", e) + + def export_initial_telemetry_log(self, driver_connection_params, user_agent): + logger.debug( + "Exporting initial telemetry log for connection %s", self._connection_uuid + ) + + self._driver_connection_params = driver_connection_params + self._user_agent = user_agent + + telemetry_frontend_log = TelemetryFrontendLog( + frontend_log_event_id=str(uuid.uuid4()), + context=FrontendLogContext( + client_context=TelemetryClientContext( + timestamp_millis=int(time.time() * 1000), + user_agent=self._user_agent, + ) + ), + entry=FrontendLogEntry( + sql_driver_log=TelemetryEvent( + session_id=self._connection_uuid, + system_configuration=TelemetryHelper.getDriverSystemConfiguration(), + driver_connection_params=self._driver_connection_params, + ) + ), + ) + + self.export_event(telemetry_frontend_log) + + def close(self): + """Flush remaining events before closing""" + logger.debug("Closing TelemetryClient for connection %s", self._connection_uuid) + self.flush() + TelemetryClientFactory.close(self._connection_uuid) + + +class TelemetryClientFactory: + """ + Static factory class for creating and managing telemetry clients. + It uses a thread pool to handle asynchronous operations. + """ + + _clients: Dict[ + str, BaseTelemetryClient + ] = {} # Map of connection_uuid -> BaseTelemetryClient + _executor: Optional[ThreadPoolExecutor] = None + _initialized: bool = False + _lock = threading.Lock() # Thread safety for factory operations + + @classmethod + def _initialize(cls): + """Initialize the factory if not already initialized""" + with cls._lock: + if not cls._initialized: + cls._clients = {} + cls._executor = ThreadPoolExecutor( + max_workers=10 + ) # Thread pool for async operations TODO: Decide on max workers + cls._initialized = True + logger.debug( + "TelemetryClientFactory initialized with thread pool (max_workers=10)" + ) + + @staticmethod + def initialize_telemetry_client( + telemetry_enabled, + connection_uuid, + auth_provider, + host_url, + ): + """Initialize a telemetry client for a specific connection if telemetry is enabled""" + TelemetryClientFactory._initialize() + + with TelemetryClientFactory._lock: + if connection_uuid not in TelemetryClientFactory._clients: + logger.debug( + "Creating new TelemetryClient for connection %s", connection_uuid + ) + if telemetry_enabled: + TelemetryClientFactory._clients[connection_uuid] = TelemetryClient( + telemetry_enabled=telemetry_enabled, + connection_uuid=connection_uuid, + auth_provider=auth_provider, + host_url=host_url, + executor=TelemetryClientFactory._executor, + ) + else: + TelemetryClientFactory._clients[ + connection_uuid + ] = NoopTelemetryClient() + + @staticmethod + def get_telemetry_client(connection_uuid): + """Get the telemetry client for a specific connection""" + if connection_uuid in TelemetryClientFactory._clients: + return TelemetryClientFactory._clients[connection_uuid] + else: + logger.error( + "Telemetry client not initialized for connection %s", connection_uuid + ) + return NoopTelemetryClient() + + @staticmethod + def close(connection_uuid): + """Close and remove the telemetry client for a specific connection""" + + with TelemetryClientFactory._lock: + if connection_uuid in TelemetryClientFactory._clients: + logger.debug( + "Removing telemetry client for connection %s", connection_uuid + ) + TelemetryClientFactory._clients.pop(connection_uuid, None) + + # Shutdown executor if no more clients + if not TelemetryClientFactory._clients and TelemetryClientFactory._executor: + logger.debug( + "No more telemetry clients, shutting down thread pool executor" + ) + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False diff --git a/src/databricks/sql/telemetry/utils.py b/src/databricks/sql/telemetry/utils.py new file mode 100644 index 000000000..6a4d64eba --- /dev/null +++ b/src/databricks/sql/telemetry/utils.py @@ -0,0 +1,15 @@ +import json +from enum import Enum + + +class EnumEncoder(json.JSONEncoder): + """ + Custom JSON encoder to handle Enum values. + This is used to convert Enum values to their string representations. + Default JSON encoder raises a TypeError for Enums. + """ + + def default(self, obj): + if isinstance(obj, Enum): + return obj.value + return super().default(obj) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py new file mode 100644 index 000000000..478205b18 --- /dev/null +++ b/tests/unit/test_telemetry.py @@ -0,0 +1,329 @@ +import uuid +import pytest +import requests +from unittest.mock import patch, MagicMock, call + +from databricks.sql.telemetry.telemetry_client import ( + TelemetryClient, + NoopTelemetryClient, + TelemetryClientFactory, +) +from databricks.sql.telemetry.models.enums import ( + AuthMech, + DatabricksClientType, +) +from databricks.sql.telemetry.models.event import ( + DriverConnectionParameters, + HostDetails, +) +from databricks.sql.auth.authenticators import ( + AccessTokenAuthProvider, +) + + +@pytest.fixture +def noop_telemetry_client(): + """Fixture for NoopTelemetryClient.""" + return NoopTelemetryClient() + + +@pytest.fixture +def telemetry_client_setup(): + """Fixture for TelemetryClient setup data.""" + connection_uuid = str(uuid.uuid4()) + auth_provider = AccessTokenAuthProvider("test-token") + host_url = "test-host" + executor = MagicMock() + + client = TelemetryClient( + telemetry_enabled=True, + connection_uuid=connection_uuid, + auth_provider=auth_provider, + host_url=host_url, + executor=executor, + ) + + return { + "client": client, + "connection_uuid": connection_uuid, + "auth_provider": auth_provider, + "host_url": host_url, + "executor": executor, + } + + +@pytest.fixture +def telemetry_factory_reset(): + """Fixture to reset TelemetryClientFactory state before each test.""" + # Reset the static class state before each test + TelemetryClientFactory._clients = {} + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + yield + # Cleanup after test if needed + TelemetryClientFactory._clients = {} + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + +class TestNoopTelemetryClient: + """Tests for the NoopTelemetryClient class.""" + + def test_singleton(self): + """Test that NoopTelemetryClient is a singleton.""" + client1 = NoopTelemetryClient() + client2 = NoopTelemetryClient() + assert client1 is client2 + + def test_export_initial_telemetry_log(self, noop_telemetry_client): + """Test that export_initial_telemetry_log does nothing.""" + noop_telemetry_client.export_initial_telemetry_log( + driver_connection_params=MagicMock(), user_agent="test" + ) + + def test_close(self, noop_telemetry_client): + """Test that close does nothing.""" + noop_telemetry_client.close() + + +class TestTelemetryClient: + """Tests for the TelemetryClient class.""" + + @patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog") + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.getDriverSystemConfiguration") + @patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4") + @patch("databricks.sql.telemetry.telemetry_client.time.time") + def test_export_initial_telemetry_log( + self, + mock_time, + mock_uuid4, + mock_get_driver_config, + mock_frontend_log, + telemetry_client_setup + ): + """Test exporting initial telemetry log.""" + mock_time.return_value = 1000 + mock_uuid4.return_value = "test-uuid" + mock_get_driver_config.return_value = "test-driver-config" + mock_frontend_log.return_value = MagicMock() + + client = telemetry_client_setup["client"] + host_url = telemetry_client_setup["host_url"] + client.export_event = MagicMock() + + driver_connection_params = DriverConnectionParameters( + http_path="test-path", + mode=DatabricksClientType.THRIFT, + host_info=HostDetails(host_url=host_url, port=443), + auth_mech=AuthMech.PAT, + auth_flow=None, + ) + user_agent = "test-user-agent" + + client.export_initial_telemetry_log(driver_connection_params, user_agent) + + mock_frontend_log.assert_called_once() + client.export_event.assert_called_once_with(mock_frontend_log.return_value) + + def test_export_event(self, telemetry_client_setup): + """Test exporting an event.""" + client = telemetry_client_setup["client"] + client.flush = MagicMock() + + for i in range(5): + client.export_event(f"event-{i}") + + client.flush.assert_not_called() + assert len(client._events_batch) == 5 + + for i in range(5, 10): + client.export_event(f"event-{i}") + + client.flush.assert_called_once() + assert len(client._events_batch) == 10 + + @patch("requests.post") + def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup): + """Test sending telemetry to the server with authentication.""" + client = telemetry_client_setup["client"] + executor = telemetry_client_setup["executor"] + + events = [MagicMock(), MagicMock()] + events[0].to_json.return_value = '{"event": "1"}' + events[1].to_json.return_value = '{"event": "2"}' + + client._send_telemetry(events) + + executor.submit.assert_called_once() + args, kwargs = executor.submit.call_args + assert args[0] == requests.post + assert kwargs["timeout"] == 10 + assert "Authorization" in kwargs["headers"] + assert kwargs["headers"]["Authorization"] == "Bearer test-token" + + @patch("requests.post") + def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup): + """Test sending telemetry to the server without authentication.""" + host_url = telemetry_client_setup["host_url"] + executor = telemetry_client_setup["executor"] + + unauthenticated_client = TelemetryClient( + telemetry_enabled=True, + connection_uuid=str(uuid.uuid4()), + auth_provider=None, # No auth provider + host_url=host_url, + executor=executor, + ) + + events = [MagicMock(), MagicMock()] + events[0].to_json.return_value = '{"event": "1"}' + events[1].to_json.return_value = '{"event": "2"}' + + unauthenticated_client._send_telemetry(events) + + executor.submit.assert_called_once() + args, kwargs = executor.submit.call_args + assert args[0] == requests.post + assert kwargs["timeout"] == 10 + assert "Authorization" not in kwargs["headers"] # No auth header + assert kwargs["headers"]["Accept"] == "application/json" + assert kwargs["headers"]["Content-Type"] == "application/json" + + def test_flush(self, telemetry_client_setup): + """Test flushing events.""" + client = telemetry_client_setup["client"] + client._events_batch = ["event1", "event2"] + client._send_telemetry = MagicMock() + + client.flush() + + client._send_telemetry.assert_called_once_with(["event1", "event2"]) + assert client._events_batch == [] + + @patch("databricks.sql.telemetry.telemetry_client.TelemetryClientFactory") + def test_close(self, mock_factory_class, telemetry_client_setup): + """Test closing the client.""" + client = telemetry_client_setup["client"] + connection_uuid = telemetry_client_setup["connection_uuid"] + client.flush = MagicMock() + + client.close() + + client.flush.assert_called_once() + mock_factory_class.close.assert_called_once_with(connection_uuid) + + +class TestTelemetryClientFactory: + """Tests for the TelemetryClientFactory static class.""" + + @patch("databricks.sql.telemetry.telemetry_client.TelemetryClient") + def test_initialize_telemetry_client_enabled(self, mock_client_class, telemetry_factory_reset): + """Test initializing a telemetry client when telemetry is enabled.""" + connection_uuid = "test-uuid" + auth_provider = MagicMock() + host_url = "test-host" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + connection_uuid=connection_uuid, + auth_provider=auth_provider, + host_url=host_url, + ) + + # Verify a new client was created and stored + mock_client_class.assert_called_once_with( + telemetry_enabled=True, + connection_uuid=connection_uuid, + auth_provider=auth_provider, + host_url=host_url, + executor=TelemetryClientFactory._executor, + ) + assert TelemetryClientFactory._clients[connection_uuid] == mock_client + + # Call again with the same connection_uuid + client2 = TelemetryClientFactory.get_telemetry_client(connection_uuid=connection_uuid) + + # Verify the same client was returned and no new client was created + assert client2 == mock_client + mock_client_class.assert_called_once() # Still only called once + + def test_initialize_telemetry_client_disabled(self, telemetry_factory_reset): + """Test initializing a telemetry client when telemetry is disabled.""" + connection_uuid = "test-uuid" + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=False, + connection_uuid=connection_uuid, + auth_provider=MagicMock(), + host_url="test-host", + ) + + # Verify a NoopTelemetryClient was stored + assert isinstance(TelemetryClientFactory._clients[connection_uuid], NoopTelemetryClient) + + client2 = TelemetryClientFactory.get_telemetry_client(connection_uuid) + assert isinstance(client2, NoopTelemetryClient) + + def test_get_telemetry_client_existing(self, telemetry_factory_reset): + """Test getting an existing telemetry client.""" + connection_uuid = "test-uuid" + mock_client = MagicMock() + TelemetryClientFactory._clients[connection_uuid] = mock_client + + client = TelemetryClientFactory.get_telemetry_client(connection_uuid) + + assert client == mock_client + + def test_get_telemetry_client_nonexistent(self, telemetry_factory_reset): + """Test getting a non-existent telemetry client.""" + client = TelemetryClientFactory.get_telemetry_client("nonexistent-uuid") + + assert isinstance(client, NoopTelemetryClient) + + @patch("databricks.sql.telemetry.telemetry_client.ThreadPoolExecutor") + @patch("databricks.sql.telemetry.telemetry_client.TelemetryClient") + def test_close(self, mock_client_class, mock_executor_class, telemetry_factory_reset): + """Test that factory reinitializes properly after complete shutdown.""" + connection_uuid1 = "test-uuid1" + mock_executor1 = MagicMock() + mock_client1 = MagicMock() + mock_executor_class.return_value = mock_executor1 + mock_client_class.return_value = mock_client1 + + TelemetryClientFactory._clients[connection_uuid1] = mock_client1 + TelemetryClientFactory._executor = mock_executor1 + TelemetryClientFactory._initialized = True + + TelemetryClientFactory.close(connection_uuid1) + + assert TelemetryClientFactory._clients == {} + assert TelemetryClientFactory._executor is None + assert TelemetryClientFactory._initialized is False + mock_executor1.shutdown.assert_called_once_with(wait=True) + + # Now create a new client - this should reinitialize the factory + connection_uuid2 = "test-uuid2" + mock_executor2 = MagicMock() + mock_client2 = MagicMock() + mock_executor_class.return_value = mock_executor2 + mock_client_class.return_value = mock_client2 + + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + connection_uuid=connection_uuid2, + auth_provider=MagicMock(), + host_url="test-host", + ) + + # Verify factory was reinitialized + assert TelemetryClientFactory._initialized is True + assert TelemetryClientFactory._executor is not None + assert TelemetryClientFactory._executor == mock_executor2 + assert connection_uuid2 in TelemetryClientFactory._clients + assert TelemetryClientFactory._clients[connection_uuid2] == mock_client2 + + # Verify new ThreadPoolExecutor was created + assert mock_executor_class.call_count == 1 \ No newline at end of file