diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 3a682984e..26705f3f8 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -18,6 +18,9 @@ OperationalError, SessionAlreadyClosedError, CursorAlreadyClosedError, + InterfaceError, + NotSupportedError, + ProgrammingError, ) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.thrift_backend import ThriftBackend @@ -50,8 +53,8 @@ TOperationState, ) from databricks.sql.telemetry.telemetry_client import ( - TelemetryClientFactory, TelemetryHelper, + TelemetryClientFactory, ) from databricks.sql.telemetry.models.enums import DatabricksClientType from databricks.sql.telemetry.models.event import ( @@ -305,13 +308,13 @@ def read(self) -> Optional[OAuthToken]: TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=self.telemetry_enabled, - connection_uuid=self.get_session_id_hex(), + session_id_hex=self.get_session_id_hex(), auth_provider=auth_provider, host_url=self.host, ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( - connection_uuid=self.get_session_id_hex() + session_id_hex=self.get_session_id_hex() ) driver_connection_params = DriverConnectionParameters( @@ -421,7 +424,10 @@ def cursor( Will throw an Error if the connection has been closed. """ if not self.open: - raise Error("Cannot create cursor from closed connection") + raise InterfaceError( + "Cannot create cursor from closed connection", + session_id_hex=self.get_session_id_hex(), + ) cursor = Cursor( self, @@ -464,14 +470,17 @@ def _close(self, close_cursors=True) -> None: self.open = False - self._telemetry_client.close() + TelemetryClientFactory.close(self.get_session_id_hex()) def commit(self): """No-op because Databricks does not support transactions""" pass def rollback(self): - raise NotSupportedError("Transactions are not supported on Databricks") + raise NotSupportedError( + "Transactions are not supported on Databricks", + session_id_hex=self.get_session_id_hex(), + ) class Cursor: @@ -523,7 +532,10 @@ def __iter__(self): for row in self.active_result_set: yield row else: - raise Error("There is no active result set") + raise ProgrammingError( + "There is no active result set", + session_id_hex=self.connection.get_session_id_hex(), + ) def _determine_parameter_approach( self, params: Optional[TParameterCollection] @@ -660,7 +672,10 @@ def _close_and_clear_active_result_set(self): def _check_not_closed(self): if not self.open: - raise Error("Attempting operation on closed cursor") + raise InterfaceError( + "Attempting operation on closed cursor", + session_id_hex=self.connection.get_session_id_hex(), + ) def _handle_staging_operation( self, staging_allowed_local_path: Union[None, str, List[str]] @@ -677,8 +692,9 @@ def _handle_staging_operation( elif isinstance(staging_allowed_local_path, type(list())): _staging_allowed_local_paths = staging_allowed_local_path else: - raise Error( - "You must provide at least one staging_allowed_local_path when initialising a connection to perform ingestion commands" + raise ProgrammingError( + "You must provide at least one staging_allowed_local_path when initialising a connection to perform ingestion commands", + session_id_hex=self.connection.get_session_id_hex(), ) abs_staging_allowed_local_paths = [ @@ -706,8 +722,9 @@ def _handle_staging_operation( else: continue if not allow_operation: - raise Error( - "Local file operations are restricted to paths within the configured staging_allowed_local_path" + raise ProgrammingError( + "Local file operations are restricted to paths within the configured staging_allowed_local_path", + session_id_hex=self.connection.get_session_id_hex(), ) # May be real headers, or could be json string @@ -735,9 +752,10 @@ def _handle_staging_operation( handler_args.pop("local_file") return self._handle_staging_remove(**handler_args) else: - raise Error( + raise ProgrammingError( f"Operation {row.operation} is not supported. " - + "Supported operations are GET, PUT, and REMOVE" + + "Supported operations are GET, PUT, and REMOVE", + session_id_hex=self.connection.get_session_id_hex(), ) def _handle_staging_put( @@ -749,7 +767,10 @@ def _handle_staging_put( """ if local_file is None: - raise Error("Cannot perform PUT without specifying a local_file") + raise ProgrammingError( + "Cannot perform PUT without specifying a local_file", + session_id_hex=self.connection.get_session_id_hex(), + ) with open(local_file, "rb") as fh: r = requests.put(url=presigned_url, data=fh, headers=headers) @@ -765,8 +786,9 @@ def _handle_staging_put( # fmt: on if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: - raise Error( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}" + raise OperationalError( + f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + session_id_hex=self.connection.get_session_id_hex(), ) if r.status_code == ACCEPTED: @@ -784,15 +806,19 @@ def _handle_staging_get( """ if local_file is None: - raise Error("Cannot perform GET without specifying a local_file") + raise ProgrammingError( + "Cannot perform GET without specifying a local_file", + session_id_hex=self.connection.get_session_id_hex(), + ) r = requests.get(url=presigned_url, headers=headers) # response.ok verifies the status code is not between 400-600. # Any 2xx or 3xx will evaluate r.ok == True if not r.ok: - raise Error( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}" + raise OperationalError( + f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + session_id_hex=self.connection.get_session_id_hex(), ) with open(local_file, "wb") as fp: @@ -806,8 +832,9 @@ def _handle_staging_remove( r = requests.delete(url=presigned_url, headers=headers) if not r.ok: - raise Error( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}" + raise OperationalError( + f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + session_id_hex=self.connection.get_session_id_hex(), ) def execute( @@ -1005,8 +1032,9 @@ def get_async_execution_result(self): return self else: - raise Error( - f"get_execution_result failed with Operation status {operation_state}" + raise OperationalError( + f"get_execution_result failed with Operation status {operation_state}", + session_id_hex=self.connection.get_session_id_hex(), ) def executemany(self, operation, seq_of_parameters): @@ -1156,7 +1184,10 @@ def fetchall(self) -> List[Row]: if self.active_result_set: return self.active_result_set.fetchall() else: - raise Error("There is no active result set") + raise ProgrammingError( + "There is no active result set", + session_id_hex=self.connection.get_session_id_hex(), + ) def fetchone(self) -> Optional[Row]: """ @@ -1170,7 +1201,10 @@ def fetchone(self) -> Optional[Row]: if self.active_result_set: return self.active_result_set.fetchone() else: - raise Error("There is no active result set") + raise ProgrammingError( + "There is no active result set", + session_id_hex=self.connection.get_session_id_hex(), + ) def fetchmany(self, size: int) -> List[Row]: """ @@ -1192,21 +1226,30 @@ def fetchmany(self, size: int) -> List[Row]: if self.active_result_set: return self.active_result_set.fetchmany(size) else: - raise Error("There is no active result set") + raise ProgrammingError( + "There is no active result set", + session_id_hex=self.connection.get_session_id_hex(), + ) def fetchall_arrow(self) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: return self.active_result_set.fetchall_arrow() else: - raise Error("There is no active result set") + raise ProgrammingError( + "There is no active result set", + session_id_hex=self.connection.get_session_id_hex(), + ) def fetchmany_arrow(self, size) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: return self.active_result_set.fetchmany_arrow(size) else: - raise Error("There is no active result set") + raise ProgrammingError( + "There is no active result set", + session_id_hex=self.connection.get_session_id_hex(), + ) def cancel(self) -> None: """ diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 3b27283a4..30fd6c26d 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -1,21 +1,32 @@ import json import logging -logger = logging.getLogger(__name__) +from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory +logger = logging.getLogger(__name__) ### PEP-249 Mandated ### +# https://peps.python.org/pep-0249/#exceptions class Error(Exception): """Base class for DB-API2.0 exceptions. `message`: An optional user-friendly error message. It should be short, actionable and stable `context`: Optional extra context about the error. MUST be JSON serializable """ - def __init__(self, message=None, context=None, *args, **kwargs): + def __init__( + self, message=None, context=None, session_id_hex=None, *args, **kwargs + ): super().__init__(message, *args, **kwargs) self.message = message self.context = context or {} + error_name = self.__class__.__name__ + if session_id_hex: + telemetry_client = TelemetryClientFactory.get_telemetry_client( + session_id_hex + ) + telemetry_client.export_failure_log(error_name, self.message) + def __str__(self): return self.message diff --git a/src/databricks/sql/telemetry/models/event.py b/src/databricks/sql/telemetry/models/event.py index 4429a7626..f5496deec 100644 --- a/src/databricks/sql/telemetry/models/event.py +++ b/src/databricks/sql/telemetry/models/event.py @@ -1,5 +1,4 @@ -import json -from dataclasses import dataclass, asdict +from dataclasses import dataclass from databricks.sql.telemetry.models.enums import ( AuthMech, AuthFlow, @@ -9,11 +8,11 @@ ExecutionResultFormat, ) from typing import Optional -from databricks.sql.telemetry.utils import EnumEncoder +from databricks.sql.telemetry.utils import JsonSerializableMixin @dataclass -class HostDetails: +class HostDetails(JsonSerializableMixin): """ Represents the host connection details for a Databricks workspace. @@ -25,12 +24,9 @@ class HostDetails: host_url: str port: int - def to_json(self): - return json.dumps(asdict(self)) - @dataclass -class DriverConnectionParameters: +class DriverConnectionParameters(JsonSerializableMixin): """ Contains all connection parameters used to establish a connection to Databricks SQL. This includes authentication details, host information, and connection settings. @@ -51,12 +47,9 @@ class DriverConnectionParameters: auth_flow: Optional[AuthFlow] = None socket_timeout: Optional[int] = None - def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) - @dataclass -class DriverSystemConfiguration: +class DriverSystemConfiguration(JsonSerializableMixin): """ Contains system-level configuration information about the client environment. This includes details about the operating system, runtime, and driver version. @@ -87,12 +80,9 @@ class DriverSystemConfiguration: client_app_name: Optional[str] = None locale_name: Optional[str] = None - def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) - @dataclass -class DriverVolumeOperation: +class DriverVolumeOperation(JsonSerializableMixin): """ Represents a volume operation performed by the driver. Used for tracking volume-related operations in telemetry. @@ -105,12 +95,9 @@ class DriverVolumeOperation: volume_operation_type: DriverVolumeOperationType volume_path: str - def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) - @dataclass -class DriverErrorInfo: +class DriverErrorInfo(JsonSerializableMixin): """ Contains detailed information about errors that occur during driver operations. Used for error tracking and debugging in telemetry. @@ -123,12 +110,9 @@ class DriverErrorInfo: error_name: str stack_trace: str - def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) - @dataclass -class SqlExecutionEvent: +class SqlExecutionEvent(JsonSerializableMixin): """ Represents a SQL query execution event. Contains details about the query execution, including type, compression, and result format. @@ -145,12 +129,9 @@ class SqlExecutionEvent: execution_result: ExecutionResultFormat retry_count: int - def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) - @dataclass -class TelemetryEvent: +class TelemetryEvent(JsonSerializableMixin): """ Main telemetry event class that aggregates all telemetry data. Contains information about the session, system configuration, connection parameters, @@ -177,6 +158,3 @@ class TelemetryEvent: sql_operation: Optional[SqlExecutionEvent] = None error_info: Optional[DriverErrorInfo] = None operation_latency_ms: Optional[int] = None - - def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) diff --git a/src/databricks/sql/telemetry/models/frontend_logs.py b/src/databricks/sql/telemetry/models/frontend_logs.py index 36086a7cc..4cc314ec3 100644 --- a/src/databricks/sql/telemetry/models/frontend_logs.py +++ b/src/databricks/sql/telemetry/models/frontend_logs.py @@ -1,12 +1,11 @@ -import json -from dataclasses import dataclass, asdict +from dataclasses import dataclass from databricks.sql.telemetry.models.event import TelemetryEvent -from databricks.sql.telemetry.utils import EnumEncoder +from databricks.sql.telemetry.utils import JsonSerializableMixin from typing import Optional @dataclass -class TelemetryClientContext: +class TelemetryClientContext(JsonSerializableMixin): """ Contains client-side context information for telemetry events. This includes timestamp and user agent information for tracking when and how the client is being used. @@ -19,12 +18,9 @@ class TelemetryClientContext: timestamp_millis: int user_agent: str - def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) - @dataclass -class FrontendLogContext: +class FrontendLogContext(JsonSerializableMixin): """ Wrapper for client context information in frontend logs. Provides additional context about the client environment for telemetry events. @@ -35,12 +31,9 @@ class FrontendLogContext: client_context: TelemetryClientContext - def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) - @dataclass -class FrontendLogEntry: +class FrontendLogEntry(JsonSerializableMixin): """ Contains the actual telemetry event data in a frontend log. Wraps the SQL driver log information for frontend processing. @@ -51,12 +44,9 @@ class FrontendLogEntry: sql_driver_log: TelemetryEvent - def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) - @dataclass -class TelemetryFrontendLog: +class TelemetryFrontendLog(JsonSerializableMixin): """ Main container for frontend telemetry data. Aggregates workspace information, event ID, context, and the actual log entry. @@ -73,6 +63,3 @@ class TelemetryFrontendLog: context: FrontendLogContext entry: FrontendLogEntry workspace_id: Optional[int] = None - - def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index d095d685c..f7fccf47a 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -8,6 +8,7 @@ from databricks.sql.telemetry.models.event import ( TelemetryEvent, DriverSystemConfiguration, + DriverErrorInfo, ) from databricks.sql.telemetry.models.frontend_logs import ( TelemetryFrontendLog, @@ -26,7 +27,6 @@ import uuid import locale from abc import ABC, abstractmethod -from databricks.sql import __version__ logger = logging.getLogger(__name__) @@ -34,22 +34,26 @@ class TelemetryHelper: """Helper class for getting telemetry related information.""" - _DRIVER_SYSTEM_CONFIGURATION = DriverSystemConfiguration( - driver_name="Databricks SQL Python Connector", - driver_version=__version__, - runtime_name=f"Python {sys.version.split()[0]}", - runtime_vendor=platform.python_implementation(), - runtime_version=platform.python_version(), - os_name=platform.system(), - os_version=platform.release(), - os_arch=platform.machine(), - client_app_name=None, # TODO: Add client app name - locale_name=locale.getlocale()[0] or locale.getdefaultlocale()[0], - char_set_encoding=sys.getdefaultencoding(), - ) + _DRIVER_SYSTEM_CONFIGURATION = None @classmethod - def getDriverSystemConfiguration(cls) -> DriverSystemConfiguration: + def get_driver_system_configuration(cls) -> DriverSystemConfiguration: + if cls._DRIVER_SYSTEM_CONFIGURATION is None: + from databricks.sql import __version__ + + cls._DRIVER_SYSTEM_CONFIGURATION = DriverSystemConfiguration( + driver_name="Databricks SQL Python Connector", + driver_version=__version__, + runtime_name=f"Python {sys.version.split()[0]}", + runtime_vendor=platform.python_implementation(), + runtime_version=platform.python_version(), + os_name=platform.system(), + os_version=platform.release(), + os_arch=platform.machine(), + client_app_name=None, # TODO: Add client app name + locale_name=locale.getlocale()[0] or locale.getdefaultlocale()[0], + char_set_encoding=sys.getdefaultencoding(), + ) return cls._DRIVER_SYSTEM_CONFIGURATION @staticmethod @@ -99,12 +103,18 @@ class BaseTelemetryClient(ABC): """ @abstractmethod - def export_initial_telemetry_log(self, **kwargs): - pass + def export_initial_telemetry_log(self, driver_connection_params, user_agent): + raise NotImplementedError( + "Subclasses must implement export_initial_telemetry_log" + ) + + @abstractmethod + def export_failure_log(self, error_name, error_message): + raise NotImplementedError("Subclasses must implement export_failure_log") @abstractmethod def close(self): - pass + raise NotImplementedError("Subclasses must implement close") class NoopTelemetryClient(BaseTelemetryClient): @@ -123,6 +133,9 @@ def __new__(cls): def export_initial_telemetry_log(self, driver_connection_params, user_agent): pass + def export_failure_log(self, error_name, error_message): + pass + def close(self): pass @@ -140,15 +153,15 @@ class TelemetryClient(BaseTelemetryClient): def __init__( self, telemetry_enabled, - connection_uuid, + session_id_hex, auth_provider, host_url, executor, ): - logger.debug("Initializing TelemetryClient for connection: %s", connection_uuid) + logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled self._batch_size = 10 # TODO: Decide on batch size - self._connection_uuid = connection_uuid + self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None self._events_batch = [] @@ -157,18 +170,18 @@ def __init__( self._host_url = host_url self._executor = executor - def export_event(self, event): + def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" - logger.debug("Exporting event for connection %s", self._connection_uuid) + logger.debug("Exporting event for connection %s", self._session_id_hex) with self._lock: self._events_batch.append(event) if len(self._events_batch) >= self._batch_size: logger.debug( "Batch size limit reached (%s), flushing events", self._batch_size ) - self.flush() + self._flush() - def flush(self): + def _flush(self): """Flush the current batch of events to the server""" with self._lock: events_to_flush = self._events_batch.copy() @@ -230,36 +243,66 @@ def _telemetry_request_callback(self, future): def export_initial_telemetry_log(self, driver_connection_params, user_agent): logger.debug( - "Exporting initial telemetry log for connection %s", self._connection_uuid + "Exporting initial telemetry log for connection %s", self._session_id_hex ) - self._driver_connection_params = driver_connection_params - self._user_agent = user_agent + try: + self._driver_connection_params = driver_connection_params + self._user_agent = user_agent + + telemetry_frontend_log = TelemetryFrontendLog( + frontend_log_event_id=str(uuid.uuid4()), + context=FrontendLogContext( + client_context=TelemetryClientContext( + timestamp_millis=int(time.time() * 1000), + user_agent=self._user_agent, + ) + ), + entry=FrontendLogEntry( + sql_driver_log=TelemetryEvent( + session_id=self._session_id_hex, + system_configuration=TelemetryHelper.get_driver_system_configuration(), + driver_connection_params=self._driver_connection_params, + ) + ), + ) + + self._export_event(telemetry_frontend_log) - telemetry_frontend_log = TelemetryFrontendLog( - frontend_log_event_id=str(uuid.uuid4()), - context=FrontendLogContext( - client_context=TelemetryClientContext( - timestamp_millis=int(time.time() * 1000), - user_agent=self._user_agent, - ) - ), - entry=FrontendLogEntry( - sql_driver_log=TelemetryEvent( - session_id=self._connection_uuid, - system_configuration=TelemetryHelper.getDriverSystemConfiguration(), - driver_connection_params=self._driver_connection_params, - ) - ), - ) + except Exception as e: + logger.debug("Failed to export initial telemetry log: %s", e) - self.export_event(telemetry_frontend_log) + def export_failure_log(self, error_name, error_message): + logger.debug("Exporting failure log for connection %s", self._session_id_hex) + try: + error_info = DriverErrorInfo( + error_name=error_name, stack_trace=error_message + ) + telemetry_frontend_log = TelemetryFrontendLog( + frontend_log_event_id=str(uuid.uuid4()), + context=FrontendLogContext( + client_context=TelemetryClientContext( + timestamp_millis=int(time.time() * 1000), + user_agent=self._user_agent, + ) + ), + entry=FrontendLogEntry( + sql_driver_log=TelemetryEvent( + session_id=self._session_id_hex, + system_configuration=TelemetryHelper.get_driver_system_configuration(), + driver_connection_params=self._driver_connection_params, + error_info=error_info, + ) + ), + ) + self._export_event(telemetry_frontend_log) + except Exception as e: + logger.debug("Failed to export failure log: %s", e) def close(self): """Flush remaining events before closing""" - logger.debug("Closing TelemetryClient for connection %s", self._connection_uuid) - self.flush() - TelemetryClientFactory.close(self._connection_uuid) + logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) + self._flush() class TelemetryClientFactory: @@ -270,74 +313,117 @@ class TelemetryClientFactory: _clients: Dict[ str, BaseTelemetryClient - ] = {} # Map of connection_uuid -> BaseTelemetryClient + ] = {} # Map of session_id_hex -> BaseTelemetryClient _executor: Optional[ThreadPoolExecutor] = None _initialized: bool = False _lock = threading.Lock() # Thread safety for factory operations + _original_excepthook = None + _excepthook_installed = False @classmethod def _initialize(cls): """Initialize the factory if not already initialized""" - with cls._lock: - if not cls._initialized: - cls._clients = {} - cls._executor = ThreadPoolExecutor( - max_workers=10 - ) # Thread pool for async operations TODO: Decide on max workers - cls._initialized = True - logger.debug( - "TelemetryClientFactory initialized with thread pool (max_workers=10)" - ) + + if not cls._initialized: + cls._clients = {} + cls._executor = ThreadPoolExecutor( + max_workers=10 + ) # Thread pool for async operations TODO: Decide on max workers + cls._install_exception_hook() + cls._initialized = True + logger.debug( + "TelemetryClientFactory initialized with thread pool (max_workers=10)" + ) + + @classmethod + def _install_exception_hook(cls): + """Install global exception handler for unhandled exceptions""" + if not cls._excepthook_installed: + cls._original_excepthook = sys.excepthook + sys.excepthook = cls._handle_unhandled_exception + cls._excepthook_installed = True + logger.debug("Global exception handler installed for telemetry") + + @classmethod + def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): + """Handle unhandled exceptions by sending telemetry and flushing thread pool""" + logger.debug("Handling unhandled exception: %s", exc_type.__name__) + + clients_to_close = list(cls._clients.values()) + for client in clients_to_close: + client.close() + + # Call the original exception handler to maintain normal behavior + if cls._original_excepthook: + cls._original_excepthook(exc_type, exc_value, exc_traceback) @staticmethod def initialize_telemetry_client( telemetry_enabled, - connection_uuid, + session_id_hex, auth_provider, host_url, ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" - TelemetryClientFactory._initialize() + try: - with TelemetryClientFactory._lock: - if connection_uuid not in TelemetryClientFactory._clients: - logger.debug( - "Creating new TelemetryClient for connection %s", connection_uuid - ) - if telemetry_enabled: - TelemetryClientFactory._clients[connection_uuid] = TelemetryClient( - telemetry_enabled=telemetry_enabled, - connection_uuid=connection_uuid, - auth_provider=auth_provider, - host_url=host_url, - executor=TelemetryClientFactory._executor, + with TelemetryClientFactory._lock: + TelemetryClientFactory._initialize() + + if session_id_hex not in TelemetryClientFactory._clients: + logger.debug( + "Creating new TelemetryClient for connection %s", + session_id_hex, ) - else: - TelemetryClientFactory._clients[ - connection_uuid - ] = NoopTelemetryClient() + if telemetry_enabled: + TelemetryClientFactory._clients[ + session_id_hex + ] = TelemetryClient( + telemetry_enabled=telemetry_enabled, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + executor=TelemetryClientFactory._executor, + ) + else: + TelemetryClientFactory._clients[ + session_id_hex + ] = NoopTelemetryClient() + except Exception as e: + logger.debug("Failed to initialize telemetry client: %s", e) + # Fallback to NoopTelemetryClient to ensure connection doesn't fail + TelemetryClientFactory._clients[session_id_hex] = NoopTelemetryClient() @staticmethod - def get_telemetry_client(connection_uuid): + def get_telemetry_client(session_id_hex): """Get the telemetry client for a specific connection""" - if connection_uuid in TelemetryClientFactory._clients: - return TelemetryClientFactory._clients[connection_uuid] - else: - logger.error( - "Telemetry client not initialized for connection %s", connection_uuid - ) + try: + if session_id_hex in TelemetryClientFactory._clients: + return TelemetryClientFactory._clients[session_id_hex] + else: + logger.error( + "Telemetry client not initialized for connection %s", + session_id_hex, + ) + return NoopTelemetryClient() + except Exception as e: + logger.debug("Failed to get telemetry client: %s", e) return NoopTelemetryClient() @staticmethod - def close(connection_uuid): + def close(session_id_hex): """Close and remove the telemetry client for a specific connection""" with TelemetryClientFactory._lock: - if connection_uuid in TelemetryClientFactory._clients: + if ( + telemetry_client := TelemetryClientFactory._clients.pop( + session_id_hex, None + ) + ) is not None: logger.debug( - "Removing telemetry client for connection %s", connection_uuid + "Removing telemetry client for connection %s", session_id_hex ) - TelemetryClientFactory._clients.pop(connection_uuid, None) + telemetry_client.close() # Shutdown executor if no more clients if not TelemetryClientFactory._clients and TelemetryClientFactory._executor: diff --git a/src/databricks/sql/telemetry/utils.py b/src/databricks/sql/telemetry/utils.py index 6a4d64eba..df7acf28c 100644 --- a/src/databricks/sql/telemetry/utils.py +++ b/src/databricks/sql/telemetry/utils.py @@ -1,5 +1,28 @@ import json from enum import Enum +from dataclasses import asdict, is_dataclass + + +class JsonSerializableMixin: + """Mixin class to provide JSON serialization capabilities to dataclasses.""" + + def to_json(self) -> str: + """ + Convert the object to a JSON string, excluding None values. + Handles Enum serialization and filters out None values from the output. + """ + if not is_dataclass(self): + raise TypeError( + f"{self.__class__.__name__} must be a dataclass to use JsonSerializableMixin" + ) + + return json.dumps( + asdict( + self, + dict_factory=lambda data: {k: v for k, v in data if v is not None}, + ), + cls=EnumEncoder, + ) class EnumEncoder(json.JSONEncoder): diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index e3dc38ad5..78683ac31 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -223,6 +223,7 @@ def __init__( raise self._request_lock = threading.RLock() + self._session_id_hex = None # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918) def _initialize_retry_args(self, kwargs): @@ -255,12 +256,15 @@ def _initialize_retry_args(self, kwargs): ) @staticmethod - def _check_response_for_error(response): + def _check_response_for_error(response, session_id_hex=None): if response.status and response.status.statusCode in [ ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ]: - raise DatabaseError(response.status.errorMessage) + raise DatabaseError( + response.status.errorMessage, + session_id_hex=session_id_hex, + ) @staticmethod def _extract_error_message_from_headers(headers): @@ -311,7 +315,10 @@ def _handle_request_error(self, error_info, attempt, elapsed): no_retry_reason, attempt, elapsed ) network_request_error = RequestError( - user_friendly_error_message, full_error_info_context, error_info.error + user_friendly_error_message, + full_error_info_context, + self._session_id_hex, + error_info.error, ) logger.info(network_request_error.message_with_context()) @@ -483,7 +490,7 @@ def attempt_request(attempt): if not isinstance(response_or_error_info, RequestErrorInfo): # log nothing here, presume that main request logging covers response = response_or_error_info - ThriftBackend._check_response_for_error(response) + ThriftBackend._check_response_for_error(response, self._session_id_hex) return response error_info = response_or_error_info @@ -497,7 +504,8 @@ def _check_protocol_version(self, t_open_session_resp): raise OperationalError( "Error: expected server to use a protocol version >= " "SPARK_CLI_SERVICE_PROTOCOL_V2, " - "instead got: {}".format(protocol_version) + "instead got: {}".format(protocol_version), + session_id_hex=self._session_id_hex, ) def _check_initial_namespace(self, catalog, schema, response): @@ -510,14 +518,16 @@ def _check_initial_namespace(self, catalog, schema, response): ): raise InvalidServerResponseError( "Setting initial namespace not supported by the DBR version, " - "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0." + "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.", + session_id_hex=self._session_id_hex, ) if catalog: if not response.canUseMultipleCatalogs: raise InvalidServerResponseError( "Unexpected response from server: Trying to set initial catalog to {}, " - + "but server does not support multiple catalogs.".format(catalog) # type: ignore + + "but server does not support multiple catalogs.".format(catalog), # type: ignore + session_id_hex=self._session_id_hex, ) def _check_session_configuration(self, session_configuration): @@ -531,7 +541,8 @@ def _check_session_configuration(self, session_configuration): "while using the Databricks SQL connector, it must be false not {}".format( TIMESTAMP_AS_STRING_CONFIG, session_configuration[TIMESTAMP_AS_STRING_CONFIG], - ) + ), + session_id_hex=self._session_id_hex, ) def open_session(self, session_configuration, catalog, schema): @@ -562,6 +573,11 @@ def open_session(self, session_configuration, catalog, schema): response = self.make_request(self._client.OpenSession, open_session_req) self._check_initial_namespace(catalog, schema, response) self._check_protocol_version(response) + self._session_id_hex = ( + self.handle_to_hex_id(response.sessionHandle) + if response.sessionHandle + else None + ) return response except: self._transport.close() @@ -586,6 +602,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": get_operations_resp.diagnosticInfo, }, + session_id_hex=self._session_id_hex, ) else: raise ServerOperationError( @@ -595,6 +612,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": None, }, + session_id_hex=self._session_id_hex, ) elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE: raise DatabaseError( @@ -605,6 +623,7 @@ def _check_command_not_in_error_or_closed_state( "operation-id": op_handle and self.guid_to_hex_id(op_handle.operationId.guid) }, + session_id_hex=self._session_id_hex, ) def _poll_for_status(self, op_handle): @@ -625,7 +644,10 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti t_row_set.arrowBatches, lz4_compressed, schema_bytes ) else: - raise OperationalError("Unsupported TRowSet instance {}".format(t_row_set)) + raise OperationalError( + "Unsupported TRowSet instance {}".format(t_row_set), + session_id_hex=self._session_id_hex, + ) return convert_decimals_in_arrow_table(arrow_table, description), num_rows def _get_metadata_resp(self, op_handle): @@ -633,7 +655,7 @@ def _get_metadata_resp(self, op_handle): return self.make_request(self._client.GetResultSetMetadata, req) @staticmethod - def _hive_schema_to_arrow_schema(t_table_schema): + def _hive_schema_to_arrow_schema(t_table_schema, session_id_hex=None): def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -664,7 +686,8 @@ def map_type(t_type_entry): # Current thriftserver implementation should always return a primitiveEntry, # even for complex types raise OperationalError( - "Thrift protocol error: t_type_entry not a primitiveEntry" + "Thrift protocol error: t_type_entry not a primitiveEntry", + session_id_hex=session_id_hex, ) def convert_col(t_column_desc): @@ -675,7 +698,7 @@ def convert_col(t_column_desc): return pyarrow.schema([convert_col(col) for col in t_table_schema.columns]) @staticmethod - def _col_to_description(col): + def _col_to_description(col, session_id_hex=None): type_entry = col.typeDesc.types[0] if type_entry.primitiveEntry: @@ -684,7 +707,8 @@ def _col_to_description(col): cleaned_type = (name[:-5] if name.endswith("_TYPE") else name).lower() else: raise OperationalError( - "Thrift protocol error: t_type_entry not a primitiveEntry" + "Thrift protocol error: t_type_entry not a primitiveEntry", + session_id_hex=session_id_hex, ) if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE: @@ -697,7 +721,8 @@ def _col_to_description(col): else: raise OperationalError( "Decimal type did not provide typeQualifier precision, scale in " - "primitiveEntry {}".format(type_entry.primitiveEntry) + "primitiveEntry {}".format(type_entry.primitiveEntry), + session_id_hex=session_id_hex, ) else: precision, scale = None, None @@ -705,9 +730,10 @@ def _col_to_description(col): return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod - def _hive_schema_to_description(t_table_schema): + def _hive_schema_to_description(t_table_schema, session_id_hex=None): return [ - ThriftBackend._col_to_description(col) for col in t_table_schema.columns + ThriftBackend._col_to_description(col, session_id_hex) + for col in t_table_schema.columns ] def _results_message_to_execute_response(self, resp, operation_state): @@ -727,7 +753,8 @@ def _results_message_to_execute_response(self, resp, operation_state): ttypes.TSparkRowSetType._VALUES_TO_NAMES[ t_result_set_metadata_resp.resultFormat ] - ) + ), + session_id_hex=self._session_id_hex, ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation @@ -737,13 +764,16 @@ def _results_message_to_execute_response(self, resp, operation_state): or direct_results.resultSet.hasMoreRows ) description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema + t_result_set_metadata_resp.schema, + self._session_id_hex, ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + or self._hive_schema_to_arrow_schema( + t_result_set_metadata_resp.schema, self._session_id_hex + ) .serialize() .to_pybytes() ) @@ -804,13 +834,16 @@ def get_execution_result(self, op_handle, cursor): is_staging_operation = t_result_set_metadata_resp.isStagingOperation has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema + t_result_set_metadata_resp.schema, + self._session_id_hex, ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + or self._hive_schema_to_arrow_schema( + t_result_set_metadata_resp.schema, self._session_id_hex + ) .serialize() .to_pybytes() ) @@ -864,23 +897,27 @@ def get_query_state(self, op_handle) -> "TOperationState": return operation_state @staticmethod - def _check_direct_results_for_error(t_spark_direct_results): + def _check_direct_results_for_error(t_spark_direct_results, session_id_hex=None): if t_spark_direct_results: if t_spark_direct_results.operationStatus: ThriftBackend._check_response_for_error( - t_spark_direct_results.operationStatus + t_spark_direct_results.operationStatus, + session_id_hex, ) if t_spark_direct_results.resultSetMetadata: ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSetMetadata + t_spark_direct_results.resultSetMetadata, + session_id_hex, ) if t_spark_direct_results.resultSet: ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSet + t_spark_direct_results.resultSet, + session_id_hex, ) if t_spark_direct_results.closeOperation: ThriftBackend._check_response_for_error( - t_spark_direct_results.closeOperation + t_spark_direct_results.closeOperation, + session_id_hex, ) def execute_command( @@ -1029,7 +1066,7 @@ def get_columns( def _handle_execute_response(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults) + self._check_direct_results_for_error(resp.directResults, self._session_id_hex) final_operation_state = self._wait_until_command_done( resp.operationHandle, @@ -1040,7 +1077,7 @@ def _handle_execute_response(self, resp, cursor): def _handle_execute_response_async(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults) + self._check_direct_results_for_error(resp.directResults, self._session_id_hex) def fetch_results( self, @@ -1074,7 +1111,8 @@ def fetch_results( raise DataError( "fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format( expected_row_start_offset, resp.results.startRowOffset - ) + ), + session_id_hex=self._session_id_hex, ) queue = ResultSetQueueFactory.build_queue( diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 478205b18..699480bbe 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -7,10 +7,13 @@ TelemetryClient, NoopTelemetryClient, TelemetryClientFactory, + TelemetryHelper, + BaseTelemetryClient ) from databricks.sql.telemetry.models.enums import ( AuthMech, DatabricksClientType, + AuthFlow, ) from databricks.sql.telemetry.models.event import ( DriverConnectionParameters, @@ -18,6 +21,8 @@ ) from databricks.sql.auth.authenticators import ( AccessTokenAuthProvider, + DatabricksOAuthProvider, + ExternalAuthProvider, ) @@ -30,14 +35,14 @@ def noop_telemetry_client(): @pytest.fixture def telemetry_client_setup(): """Fixture for TelemetryClient setup data.""" - connection_uuid = str(uuid.uuid4()) + session_id_hex = str(uuid.uuid4()) auth_provider = AccessTokenAuthProvider("test-token") host_url = "test-host" executor = MagicMock() client = TelemetryClient( telemetry_enabled=True, - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, auth_provider=auth_provider, host_url=host_url, executor=executor, @@ -45,7 +50,7 @@ def telemetry_client_setup(): return { "client": client, - "connection_uuid": connection_uuid, + "session_id_hex": session_id_hex, "auth_provider": auth_provider, "host_url": host_url, "executor": executor, @@ -53,36 +58,44 @@ def telemetry_client_setup(): @pytest.fixture -def telemetry_factory_reset(): - """Fixture to reset TelemetryClientFactory state before each test.""" - # Reset the static class state before each test - TelemetryClientFactory._clients = {} +def telemetry_system_reset(): + """Fixture to reset telemetry system state before each test.""" + # Reset the static state before each test + TelemetryClientFactory._clients.clear() + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False yield # Cleanup after test if needed - TelemetryClientFactory._clients = {} + TelemetryClientFactory._clients.clear() if TelemetryClientFactory._executor: TelemetryClientFactory._executor.shutdown(wait=True) - TelemetryClientFactory._executor = None + TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False class TestNoopTelemetryClient: - """Tests for the NoopTelemetryClient class.""" + """Tests for the NoopTelemetryClient.""" def test_singleton(self): """Test that NoopTelemetryClient is a singleton.""" client1 = NoopTelemetryClient() client2 = NoopTelemetryClient() assert client1 is client2 - + def test_export_initial_telemetry_log(self, noop_telemetry_client): """Test that export_initial_telemetry_log does nothing.""" noop_telemetry_client.export_initial_telemetry_log( driver_connection_params=MagicMock(), user_agent="test" ) + def test_export_failure_log(self, noop_telemetry_client): + """Test that export_failure_log does nothing.""" + noop_telemetry_client.export_failure_log( + error_name="TestError", error_message="Test error message" + ) + def test_close(self, noop_telemetry_client): """Test that close does nothing.""" noop_telemetry_client.close() @@ -92,7 +105,7 @@ class TestTelemetryClient: """Tests for the TelemetryClient class.""" @patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog") - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.getDriverSystemConfiguration") + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.get_driver_system_configuration") @patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4") @patch("databricks.sql.telemetry.telemetry_client.time.time") def test_export_initial_telemetry_log( @@ -111,7 +124,7 @@ def test_export_initial_telemetry_log( client = telemetry_client_setup["client"] host_url = telemetry_client_setup["host_url"] - client.export_event = MagicMock() + client._export_event = MagicMock() driver_connection_params = DriverConnectionParameters( http_path="test-path", @@ -125,23 +138,64 @@ def test_export_initial_telemetry_log( client.export_initial_telemetry_log(driver_connection_params, user_agent) mock_frontend_log.assert_called_once() - client.export_event.assert_called_once_with(mock_frontend_log.return_value) + client._export_event.assert_called_once_with(mock_frontend_log.return_value) + + @patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog") + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.get_driver_system_configuration") + @patch("databricks.sql.telemetry.telemetry_client.DriverErrorInfo") + @patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4") + @patch("databricks.sql.telemetry.telemetry_client.time.time") + def test_export_failure_log( + self, + mock_time, + mock_uuid4, + mock_driver_error_info, + mock_get_driver_config, + mock_frontend_log, + telemetry_client_setup + ): + """Test exporting failure telemetry log.""" + mock_time.return_value = 2000 + mock_uuid4.return_value = "test-error-uuid" + mock_get_driver_config.return_value = "test-driver-config" + mock_driver_error_info.return_value = MagicMock() + mock_frontend_log.return_value = MagicMock() + + client = telemetry_client_setup["client"] + client._export_event = MagicMock() + + client._driver_connection_params = "test-connection-params" + client._user_agent = "test-user-agent" + + error_name = "TestError" + error_message = "This is a test error message" + + client.export_failure_log(error_name, error_message) + + mock_driver_error_info.assert_called_once_with( + error_name=error_name, + stack_trace=error_message + ) + + mock_frontend_log.assert_called_once() + + client._export_event.assert_called_once_with(mock_frontend_log.return_value) def test_export_event(self, telemetry_client_setup): """Test exporting an event.""" client = telemetry_client_setup["client"] - client.flush = MagicMock() + client._flush = MagicMock() for i in range(5): - client.export_event(f"event-{i}") + client._export_event(f"event-{i}") - client.flush.assert_not_called() + client._flush.assert_not_called() assert len(client._events_batch) == 5 for i in range(5, 10): - client.export_event(f"event-{i}") + client._export_event(f"event-{i}") - client.flush.assert_called_once() + client._flush.assert_called_once() assert len(client._events_batch) == 10 @patch("requests.post") @@ -171,7 +225,7 @@ def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup) unauthenticated_client = TelemetryClient( telemetry_enabled=True, - connection_uuid=str(uuid.uuid4()), + session_id_hex=str(uuid.uuid4()), auth_provider=None, # No auth provider host_url=host_url, executor=executor, @@ -197,133 +251,251 @@ def test_flush(self, telemetry_client_setup): client._events_batch = ["event1", "event2"] client._send_telemetry = MagicMock() - client.flush() + client._flush() client._send_telemetry.assert_called_once_with(["event1", "event2"]) assert client._events_batch == [] - @patch("databricks.sql.telemetry.telemetry_client.TelemetryClientFactory") - def test_close(self, mock_factory_class, telemetry_client_setup): + def test_close(self, telemetry_client_setup): """Test closing the client.""" client = telemetry_client_setup["client"] - connection_uuid = telemetry_client_setup["connection_uuid"] - client.flush = MagicMock() + client._flush = MagicMock() client.close() - client.flush.assert_called_once() - mock_factory_class.close.assert_called_once_with(connection_uuid) + client._flush.assert_called_once() + + @patch("requests.post") + def test_telemetry_request_callback_success(self, mock_post, telemetry_client_setup): + """Test successful telemetry request callback.""" + client = telemetry_client_setup["client"] + + mock_response = MagicMock() + mock_response.status_code = 200 + + mock_future = MagicMock() + mock_future.result.return_value = mock_response + + client._telemetry_request_callback(mock_future) + + mock_future.result.assert_called_once() + + @patch("requests.post") + def test_telemetry_request_callback_failure(self, mock_post, telemetry_client_setup): + """Test telemetry request callback with failure""" + client = telemetry_client_setup["client"] + + # Test with non-200 status code + mock_response = MagicMock() + mock_response.status_code = 500 + future = MagicMock() + future.result.return_value = mock_response + client._telemetry_request_callback(future) + + # Test with exception + future = MagicMock() + future.result.side_effect = Exception("Test error") + client._telemetry_request_callback(future) + + def test_telemetry_client_exception_handling(self, telemetry_client_setup): + """Test exception handling in telemetry client methods.""" + client = telemetry_client_setup["client"] + + # Test export_initial_telemetry_log with exception + with patch.object(client, '_export_event', side_effect=Exception("Test error")): + # Should not raise exception + client.export_initial_telemetry_log(MagicMock(), "test-agent") + + # Test export_failure_log with exception + with patch.object(client, '_export_event', side_effect=Exception("Test error")): + # Should not raise exception + client.export_failure_log("TestError", "Test error message") + + # Test _send_telemetry with exception + with patch.object(client._executor, 'submit', side_effect=Exception("Test error")): + # Should not raise exception + client._send_telemetry([MagicMock()]) + def test_send_telemetry_thread_pool_failure(self, telemetry_client_setup): + """Test handling of thread pool submission failure""" + client = telemetry_client_setup["client"] + client._executor.submit.side_effect = Exception("Thread pool error") + event = MagicMock() + client._send_telemetry([event]) + + def test_base_telemetry_client_abstract_methods(self): + """Test that BaseTelemetryClient cannot be instantiated without implementing all abstract methods""" + class TestBaseClient(BaseTelemetryClient): + pass + + with pytest.raises(TypeError): + TestBaseClient() # Can't instantiate abstract class + + +class TestTelemetryHelper: + """Tests for the TelemetryHelper class.""" + + def test_get_driver_system_configuration(self): + """Test getting driver system configuration.""" + config = TelemetryHelper.get_driver_system_configuration() + + assert isinstance(config.driver_name, str) + assert isinstance(config.driver_version, str) + assert isinstance(config.runtime_name, str) + assert isinstance(config.runtime_vendor, str) + assert isinstance(config.runtime_version, str) + assert isinstance(config.os_name, str) + assert isinstance(config.os_version, str) + assert isinstance(config.os_arch, str) + assert isinstance(config.locale_name, str) + assert isinstance(config.char_set_encoding, str) + + assert config.driver_name == "Databricks SQL Python Connector" + assert "Python" in config.runtime_name + assert config.runtime_vendor in ["CPython", "PyPy", "Jython", "IronPython"] + assert config.os_name in ["Darwin", "Linux", "Windows"] + + # Verify caching behavior + config2 = TelemetryHelper.get_driver_system_configuration() + assert config is config2 # Should return same instance -class TestTelemetryClientFactory: - """Tests for the TelemetryClientFactory static class.""" + def test_get_auth_mechanism(self): + """Test getting auth mechanism for different auth providers.""" + # Test PAT auth + pat_auth = AccessTokenAuthProvider("test-token") + assert TelemetryHelper.get_auth_mechanism(pat_auth) == AuthMech.PAT + + # Test OAuth auth + oauth_auth = MagicMock(spec=DatabricksOAuthProvider) + assert TelemetryHelper.get_auth_mechanism(oauth_auth) == AuthMech.DATABRICKS_OAUTH + + # Test External auth + external_auth = MagicMock(spec=ExternalAuthProvider) + assert TelemetryHelper.get_auth_mechanism(external_auth) == AuthMech.EXTERNAL_AUTH + + # Test None auth provider + assert TelemetryHelper.get_auth_mechanism(None) is None + + # Test unknown auth provider + unknown_auth = MagicMock() + assert TelemetryHelper.get_auth_mechanism(unknown_auth) == AuthMech.CLIENT_CERT - @patch("databricks.sql.telemetry.telemetry_client.TelemetryClient") - def test_initialize_telemetry_client_enabled(self, mock_client_class, telemetry_factory_reset): + def test_get_auth_flow(self): + """Test getting auth flow for different OAuth providers.""" + # Test OAuth with existing tokens + oauth_with_tokens = MagicMock(spec=DatabricksOAuthProvider) + oauth_with_tokens._access_token = "test-access-token" + oauth_with_tokens._refresh_token = "test-refresh-token" + assert TelemetryHelper.get_auth_flow(oauth_with_tokens) == AuthFlow.TOKEN_PASSTHROUGH + + # Test OAuth with browser-based auth + oauth_with_browser = MagicMock(spec=DatabricksOAuthProvider) + oauth_with_browser._access_token = None + oauth_with_browser._refresh_token = None + oauth_with_browser.oauth_manager = MagicMock() + assert TelemetryHelper.get_auth_flow(oauth_with_browser) == AuthFlow.BROWSER_BASED_AUTHENTICATION + + # Test non-OAuth provider + pat_auth = AccessTokenAuthProvider("test-token") + assert TelemetryHelper.get_auth_flow(pat_auth) is None + + # Test None auth provider + assert TelemetryHelper.get_auth_flow(None) is None + + +class TestTelemetrySystem: + """Tests for the telemetry system functions.""" + + def test_initialize_telemetry_client_enabled(self, telemetry_system_reset): """Test initializing a telemetry client when telemetry is enabled.""" - connection_uuid = "test-uuid" + session_id_hex = "test-uuid" auth_provider = MagicMock() host_url = "test-host" - mock_client = MagicMock() - mock_client_class.return_value = mock_client TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, - connection_uuid=connection_uuid, - auth_provider=auth_provider, - host_url=host_url, - ) - - # Verify a new client was created and stored - mock_client_class.assert_called_once_with( - telemetry_enabled=True, - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, auth_provider=auth_provider, host_url=host_url, - executor=TelemetryClientFactory._executor, ) - assert TelemetryClientFactory._clients[connection_uuid] == mock_client - # Call again with the same connection_uuid - client2 = TelemetryClientFactory.get_telemetry_client(connection_uuid=connection_uuid) - - # Verify the same client was returned and no new client was created - assert client2 == mock_client - mock_client_class.assert_called_once() # Still only called once + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, TelemetryClient) + assert client._session_id_hex == session_id_hex + assert client._auth_provider == auth_provider + assert client._host_url == host_url - def test_initialize_telemetry_client_disabled(self, telemetry_factory_reset): + def test_initialize_telemetry_client_disabled(self, telemetry_system_reset): """Test initializing a telemetry client when telemetry is disabled.""" - connection_uuid = "test-uuid" + session_id_hex = "test-uuid" TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, auth_provider=MagicMock(), host_url="test-host", ) - # Verify a NoopTelemetryClient was stored - assert isinstance(TelemetryClientFactory._clients[connection_uuid], NoopTelemetryClient) - - client2 = TelemetryClientFactory.get_telemetry_client(connection_uuid) - assert isinstance(client2, NoopTelemetryClient) - - def test_get_telemetry_client_existing(self, telemetry_factory_reset): - """Test getting an existing telemetry client.""" - connection_uuid = "test-uuid" - mock_client = MagicMock() - TelemetryClientFactory._clients[connection_uuid] = mock_client - - client = TelemetryClientFactory.get_telemetry_client(connection_uuid) - - assert client == mock_client + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) - def test_get_telemetry_client_nonexistent(self, telemetry_factory_reset): + def test_get_telemetry_client_nonexistent(self, telemetry_system_reset): """Test getting a non-existent telemetry client.""" client = TelemetryClientFactory.get_telemetry_client("nonexistent-uuid") - assert isinstance(client, NoopTelemetryClient) - @patch("databricks.sql.telemetry.telemetry_client.ThreadPoolExecutor") - @patch("databricks.sql.telemetry.telemetry_client.TelemetryClient") - def test_close(self, mock_client_class, mock_executor_class, telemetry_factory_reset): - """Test that factory reinitializes properly after complete shutdown.""" - connection_uuid1 = "test-uuid1" - mock_executor1 = MagicMock() - mock_client1 = MagicMock() - mock_executor_class.return_value = mock_executor1 - mock_client_class.return_value = mock_client1 - - TelemetryClientFactory._clients[connection_uuid1] = mock_client1 - TelemetryClientFactory._executor = mock_executor1 - TelemetryClientFactory._initialized = True - - TelemetryClientFactory.close(connection_uuid1) - - assert TelemetryClientFactory._clients == {} - assert TelemetryClientFactory._executor is None - assert TelemetryClientFactory._initialized is False - mock_executor1.shutdown.assert_called_once_with(wait=True) - - # Now create a new client - this should reinitialize the factory - connection_uuid2 = "test-uuid2" - mock_executor2 = MagicMock() - mock_client2 = MagicMock() - mock_executor_class.return_value = mock_executor2 - mock_client_class.return_value = mock_client2 + def test_close_telemetry_client(self, telemetry_system_reset): + """Test closing a telemetry client.""" + session_id_hex = "test-uuid" + auth_provider = MagicMock() + host_url = "test-host" TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, - connection_uuid=connection_uuid2, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, TelemetryClient) + + client.close = MagicMock() + + TelemetryClientFactory.close(session_id_hex) + + client.close.assert_called_once() + + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) + + def test_close_telemetry_client_noop(self, telemetry_system_reset): + """Test closing a no-op telemetry client.""" + session_id_hex = "test-uuid" + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=False, + session_id_hex=session_id_hex, auth_provider=MagicMock(), host_url="test-host", ) - # Verify factory was reinitialized - assert TelemetryClientFactory._initialized is True - assert TelemetryClientFactory._executor is not None - assert TelemetryClientFactory._executor == mock_executor2 - assert connection_uuid2 in TelemetryClientFactory._clients - assert TelemetryClientFactory._clients[connection_uuid2] == mock_client2 + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) + + client.close = MagicMock() + + TelemetryClientFactory.close(session_id_hex) + + client.close.assert_called_once() + + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) + + @patch("databricks.sql.telemetry.telemetry_client.TelemetryClientFactory._handle_unhandled_exception") + def test_global_exception_hook(self, mock_handle_exception, telemetry_system_reset): + """Test that global exception hook is installed and handles exceptions.""" + TelemetryClientFactory._install_exception_hook() + + test_exception = ValueError("Test exception") + TelemetryClientFactory._handle_unhandled_exception(type(test_exception), test_exception, None) - # Verify new ThreadPoolExecutor was created - assert mock_executor_class.call_count == 1 \ No newline at end of file + mock_handle_exception.assert_called_once_with(type(test_exception), test_exception, None) \ No newline at end of file