diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index b81416e1..1f409bb0 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1,6 +1,5 @@ import time from typing import Dict, Tuple, List, Optional, Any, Union, Sequence - import pandas try: @@ -19,6 +18,9 @@ OperationalError, SessionAlreadyClosedError, CursorAlreadyClosedError, + InterfaceError, + NotSupportedError, + ProgrammingError, ) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.thrift_backend import ThriftBackend @@ -50,7 +52,17 @@ TSparkParameter, TOperationState, ) - +from databricks.sql.telemetry.telemetry_client import ( + TelemetryHelper, + TelemetryClientFactory, +) +from databricks.sql.telemetry.models.enums import DatabricksClientType +from databricks.sql.telemetry.models.event import ( + DriverConnectionParameters, + HostDetails, +) +from databricks.sql.telemetry.latency_logger import log_latency +from databricks.sql.telemetry.models.enums import StatementType logger = logging.getLogger(__name__) @@ -234,6 +246,12 @@ def read(self) -> Optional[OAuthToken]: server_hostname, **kwargs ) + self.server_telemetry_enabled = True + self.client_telemetry_enabled = kwargs.get("enable_telemetry", False) + self.telemetry_enabled = ( + self.client_telemetry_enabled and self.server_telemetry_enabled + ) + user_agent_entry = kwargs.get("user_agent_entry") if user_agent_entry is None: user_agent_entry = kwargs.get("_user_agent_entry") @@ -289,6 +307,31 @@ def read(self) -> Optional[OAuthToken]: kwargs.get("use_inline_params", False) ) + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=self.telemetry_enabled, + session_id_hex=self.get_session_id_hex(), + auth_provider=auth_provider, + host_url=self.host, + ) + + self._telemetry_client = TelemetryClientFactory.get_telemetry_client( + session_id_hex=self.get_session_id_hex() + ) + + driver_connection_params = DriverConnectionParameters( + http_path=http_path, + mode=DatabricksClientType.THRIFT, + host_info=HostDetails(host_url=server_hostname, port=self.port), + auth_mech=TelemetryHelper.get_auth_mechanism(auth_provider), + auth_flow=TelemetryHelper.get_auth_flow(auth_provider), + socket_timeout=kwargs.get("_socket_timeout", None), + ) + + self._telemetry_client.export_initial_telemetry_log( + driver_connection_params=driver_connection_params, + user_agent=useragent_header, + ) + def _set_use_inline_params_with_warning(self, value: Union[bool, str]): """Valid values are True, False, and "silent" @@ -376,7 +419,10 @@ def cursor( Will throw an Error if the connection has been closed. """ if not self.open: - raise Error("Cannot create cursor from closed connection") + raise InterfaceError( + "Cannot create cursor from closed connection", + session_id_hex=self.get_session_id_hex(), + ) cursor = Cursor( self, @@ -419,12 +465,17 @@ def _close(self, close_cursors=True) -> None: self.open = False + TelemetryClientFactory.close(self.get_session_id_hex()) + def commit(self): """No-op because Databricks does not support transactions""" pass def rollback(self): - raise NotSupportedError("Transactions are not supported on Databricks") + raise NotSupportedError( + "Transactions are not supported on Databricks", + session_id_hex=self.get_session_id_hex(), + ) class Cursor: @@ -469,7 +520,10 @@ def __iter__(self): for row in self.active_result_set: yield row else: - raise Error("There is no active result set") + raise ProgrammingError( + "There is no active result set", + session_id_hex=self.connection.get_session_id_hex(), + ) def _determine_parameter_approach( self, params: Optional[TParameterCollection] @@ -606,7 +660,10 @@ def _close_and_clear_active_result_set(self): def _check_not_closed(self): if not self.open: - raise Error("Attempting operation on closed cursor") + raise InterfaceError( + "Attempting operation on closed cursor", + session_id_hex=self.connection.get_session_id_hex(), + ) def _handle_staging_operation( self, staging_allowed_local_path: Union[None, str, List[str]] @@ -623,8 +680,9 @@ def _handle_staging_operation( elif isinstance(staging_allowed_local_path, type(list())): _staging_allowed_local_paths = staging_allowed_local_path else: - raise Error( - "You must provide at least one staging_allowed_local_path when initialising a connection to perform ingestion commands" + raise ProgrammingError( + "You must provide at least one staging_allowed_local_path when initialising a connection to perform ingestion commands", + session_id_hex=self.connection.get_session_id_hex(), ) abs_staging_allowed_local_paths = [ @@ -652,8 +710,9 @@ def _handle_staging_operation( else: continue if not allow_operation: - raise Error( - "Local file operations are restricted to paths within the configured staging_allowed_local_path" + raise ProgrammingError( + "Local file operations are restricted to paths within the configured staging_allowed_local_path", + session_id_hex=self.connection.get_session_id_hex(), ) # May be real headers, or could be json string @@ -681,11 +740,13 @@ def _handle_staging_operation( handler_args.pop("local_file") return self._handle_staging_remove(**handler_args) else: - raise Error( + raise ProgrammingError( f"Operation {row.operation} is not supported. " - + "Supported operations are GET, PUT, and REMOVE" + + "Supported operations are GET, PUT, and REMOVE", + session_id_hex=self.connection.get_session_id_hex(), ) + @log_latency(StatementType.SQL) def _handle_staging_put( self, presigned_url: str, local_file: str, headers: Optional[dict] = None ): @@ -695,7 +756,10 @@ def _handle_staging_put( """ if local_file is None: - raise Error("Cannot perform PUT without specifying a local_file") + raise ProgrammingError( + "Cannot perform PUT without specifying a local_file", + session_id_hex=self.connection.get_session_id_hex(), + ) with open(local_file, "rb") as fh: r = requests.put(url=presigned_url, data=fh, headers=headers) @@ -711,8 +775,9 @@ def _handle_staging_put( # fmt: on if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: - raise Error( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}" + raise OperationalError( + f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + session_id_hex=self.connection.get_session_id_hex(), ) if r.status_code == ACCEPTED: @@ -721,6 +786,7 @@ def _handle_staging_put( + "but not yet applied on the server. It's possible this command may fail later." ) + @log_latency(StatementType.SQL) def _handle_staging_get( self, local_file: str, presigned_url: str, headers: Optional[dict] = None ): @@ -730,20 +796,25 @@ def _handle_staging_get( """ if local_file is None: - raise Error("Cannot perform GET without specifying a local_file") + raise ProgrammingError( + "Cannot perform GET without specifying a local_file", + session_id_hex=self.connection.get_session_id_hex(), + ) r = requests.get(url=presigned_url, headers=headers) # response.ok verifies the status code is not between 400-600. # Any 2xx or 3xx will evaluate r.ok == True if not r.ok: - raise Error( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}" + raise OperationalError( + f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + session_id_hex=self.connection.get_session_id_hex(), ) with open(local_file, "wb") as fp: fp.write(r.content) + @log_latency(StatementType.SQL) def _handle_staging_remove( self, presigned_url: str, headers: Optional[dict] = None ): @@ -752,10 +823,12 @@ def _handle_staging_remove( r = requests.delete(url=presigned_url, headers=headers) if not r.ok: - raise Error( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}" + raise OperationalError( + f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + session_id_hex=self.connection.get_session_id_hex(), ) + @log_latency(StatementType.QUERY) def execute( self, operation: str, @@ -846,6 +919,7 @@ def execute( return self + @log_latency(StatementType.QUERY) def execute_async( self, operation: str, @@ -951,8 +1025,9 @@ def get_async_execution_result(self): return self else: - raise Error( - f"get_execution_result failed with Operation status {operation_state}" + raise OperationalError( + f"get_execution_result failed with Operation status {operation_state}", + session_id_hex=self.connection.get_session_id_hex(), ) def executemany(self, operation, seq_of_parameters): @@ -970,6 +1045,7 @@ def executemany(self, operation, seq_of_parameters): self.execute(operation, parameters) return self + @log_latency(StatementType.METADATA) def catalogs(self) -> "Cursor": """ Get all available catalogs. @@ -993,6 +1069,7 @@ def catalogs(self) -> "Cursor": ) return self + @log_latency(StatementType.METADATA) def schemas( self, catalog_name: Optional[str] = None, schema_name: Optional[str] = None ) -> "Cursor": @@ -1021,6 +1098,7 @@ def schemas( ) return self + @log_latency(StatementType.METADATA) def tables( self, catalog_name: Optional[str] = None, @@ -1056,6 +1134,7 @@ def tables( ) return self + @log_latency(StatementType.METADATA) def columns( self, catalog_name: Optional[str] = None, @@ -1102,7 +1181,10 @@ def fetchall(self) -> List[Row]: if self.active_result_set: return self.active_result_set.fetchall() else: - raise Error("There is no active result set") + raise ProgrammingError( + "There is no active result set", + session_id_hex=self.connection.get_session_id_hex(), + ) def fetchone(self) -> Optional[Row]: """ @@ -1116,7 +1198,10 @@ def fetchone(self) -> Optional[Row]: if self.active_result_set: return self.active_result_set.fetchone() else: - raise Error("There is no active result set") + raise ProgrammingError( + "There is no active result set", + session_id_hex=self.connection.get_session_id_hex(), + ) def fetchmany(self, size: int) -> List[Row]: """ @@ -1138,21 +1223,30 @@ def fetchmany(self, size: int) -> List[Row]: if self.active_result_set: return self.active_result_set.fetchmany(size) else: - raise Error("There is no active result set") + raise ProgrammingError( + "There is no active result set", + session_id_hex=self.connection.get_session_id_hex(), + ) def fetchall_arrow(self) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: return self.active_result_set.fetchall_arrow() else: - raise Error("There is no active result set") + raise ProgrammingError( + "There is no active result set", + session_id_hex=self.connection.get_session_id_hex(), + ) def fetchmany_arrow(self, size) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: return self.active_result_set.fetchmany_arrow(size) else: - raise Error("There is no active result set") + raise ProgrammingError( + "There is no active result set", + session_id_hex=self.connection.get_session_id_hex(), + ) def cancel(self) -> None: """ @@ -1455,6 +1549,7 @@ def fetchall_columnar(self): return results + @log_latency() def fetchone(self) -> Optional[Row]: """ Fetch the next row of a query result set, returning a single sequence, @@ -1471,6 +1566,7 @@ def fetchone(self) -> Optional[Row]: else: return None + @log_latency() def fetchall(self) -> List[Row]: """ Fetch all (remaining) rows of a query result, returning them as a list of rows. @@ -1480,6 +1576,7 @@ def fetchall(self) -> List[Row]: else: return self._convert_arrow_table(self.fetchall_arrow()) + @log_latency() def fetchmany(self, size: int) -> List[Row]: """ Fetch the next set of rows of a query result, returning a list of rows. diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 3b27283a..65235f63 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -2,20 +2,31 @@ import logging logger = logging.getLogger(__name__) +from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory ### PEP-249 Mandated ### +# https://peps.python.org/pep-0249/#exceptions class Error(Exception): """Base class for DB-API2.0 exceptions. `message`: An optional user-friendly error message. It should be short, actionable and stable `context`: Optional extra context about the error. MUST be JSON serializable """ - def __init__(self, message=None, context=None, *args, **kwargs): + def __init__( + self, message=None, context=None, session_id_hex=None, *args, **kwargs + ): super().__init__(message, *args, **kwargs) self.message = message self.context = context or {} + error_name = self.__class__.__name__ + if session_id_hex: + telemetry_client = TelemetryClientFactory.get_telemetry_client( + session_id_hex + ) + telemetry_client.export_failure_log(error_name, self.message) + def __str__(self): return self.message diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py new file mode 100644 index 00000000..406b684f --- /dev/null +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -0,0 +1,237 @@ +import time +import functools +from typing import Optional +import logging +from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory +from databricks.sql.telemetry.models.event import ( + SqlExecutionEvent, +) +from databricks.sql.telemetry.models.enums import ExecutionResultFormat, StatementType +from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue +from uuid import UUID + +logger = logging.getLogger(__name__) + + +class TelemetryExtractor: + """ + Base class for extracting telemetry information from various object types. + + This class serves as a proxy that delegates attribute access to the wrapped object + while providing a common interface for extracting telemetry-related data. + """ + + def __init__(self, obj): + """ + Initialize the extractor with an object to wrap. + + Args: + obj: The object to extract telemetry information from. + """ + self._obj = obj + + def __getattr__(self, name): + """ + Delegate attribute access to the wrapped object. + + Args: + name (str): The name of the attribute to access. + + Returns: + The attribute value from the wrapped object. + """ + return getattr(self._obj, name) + + def get_session_id_hex(self): + pass + + def get_statement_id(self): + pass + + def get_is_compressed(self): + pass + + def get_execution_result(self): + pass + + def get_retry_count(self): + pass + + +class CursorExtractor(TelemetryExtractor): + """ + Telemetry extractor specialized for Cursor objects. + + Extracts telemetry information from database cursor objects, including + statement IDs, session information, compression settings, and result formats. + """ + + def get_statement_id(self) -> Optional[str]: + return self.query_id + + def get_session_id_hex(self) -> Optional[str]: + return self.connection.get_session_id_hex() + + def get_is_compressed(self) -> bool: + return self.connection.lz4_compression + + def get_execution_result(self) -> ExecutionResultFormat: + if self.active_result_set is None: + return ExecutionResultFormat.FORMAT_UNSPECIFIED + + if isinstance(self.active_result_set.results, ColumnQueue): + return ExecutionResultFormat.COLUMNAR_INLINE + elif isinstance(self.active_result_set.results, CloudFetchQueue): + return ExecutionResultFormat.EXTERNAL_LINKS + elif isinstance(self.active_result_set.results, ArrowQueue): + return ExecutionResultFormat.INLINE_ARROW + return ExecutionResultFormat.FORMAT_UNSPECIFIED + + def get_retry_count(self) -> int: + if ( + hasattr(self.thrift_backend, "retry_policy") + and self.thrift_backend.retry_policy + ): + return len(self.thrift_backend.retry_policy.history) + return 0 + + +class ResultSetExtractor(TelemetryExtractor): + """ + Telemetry extractor specialized for ResultSet objects. + + Extracts telemetry information from database result set objects, including + operation IDs, session information, compression settings, and result formats. + """ + + def get_statement_id(self) -> Optional[str]: + if self.command_id: + return str(UUID(bytes=self.command_id.operationId.guid)) + return None + + def get_session_id_hex(self) -> Optional[str]: + return self.connection.get_session_id_hex() + + def get_is_compressed(self) -> bool: + return self.lz4_compressed + + def get_execution_result(self) -> ExecutionResultFormat: + if isinstance(self.results, ColumnQueue): + return ExecutionResultFormat.COLUMNAR_INLINE + elif isinstance(self.results, CloudFetchQueue): + return ExecutionResultFormat.EXTERNAL_LINKS + elif isinstance(self.results, ArrowQueue): + return ExecutionResultFormat.INLINE_ARROW + return ExecutionResultFormat.FORMAT_UNSPECIFIED + + def get_retry_count(self) -> int: + if ( + hasattr(self.thrift_backend, "retry_policy") + and self.thrift_backend.retry_policy + ): + return len(self.thrift_backend.retry_policy.history) + return 0 + + +def get_extractor(obj): + """ + Factory function to create the appropriate telemetry extractor for an object. + + Determines the object type and returns the corresponding specialized extractor + that can extract telemetry information from that object type. + + Args: + obj: The object to create an extractor for. Can be a Cursor, ResultSet, + or any other object. + + Returns: + TelemetryExtractor: A specialized extractor instance: + - CursorExtractor for Cursor objects + - ResultSetExtractor for ResultSet objects + - None for all other objects + """ + if obj.__class__.__name__ == "Cursor": + return CursorExtractor(obj) + elif obj.__class__.__name__ == "ResultSet": + return ResultSetExtractor(obj) + else: + logger.error(f"No extractor found for {obj.__class__.__name__}") + return None + + +def log_latency(statement_type: StatementType = StatementType.NONE): + """ + Decorator for logging execution latency and telemetry information. + + This decorator measures the execution time of a method and sends telemetry + data about the operation, including latency, statement information, and + execution context. + + The decorator automatically: + - Measures execution time using high-precision performance counters + - Extracts telemetry information from the method's object (self) + - Creates a SqlExecutionEvent with execution details + - Sends the telemetry data asynchronously via TelemetryClient + + Args: + statement_type (StatementType): The type of SQL statement being executed. + + Usage: + @log_latency(StatementType.SQL) + def execute(self, query): + # Method implementation + pass + + Returns: + function: A decorator that wraps methods to add latency logging. + + Note: + The wrapped method's object (self) must be compatible with the + telemetry extractor system (e.g., Cursor or ResultSet objects). + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + start_time = time.perf_counter() + result = None + try: + result = func(self, *args, **kwargs) + return result + finally: + + def _safe_call(func_to_call): + """Calls a function and returns a default value on any exception.""" + try: + return func_to_call() + except Exception: + return None + + end_time = time.perf_counter() + duration_ms = int((end_time - start_time) * 1000) + + extractor = get_extractor(self) + + if extractor is not None: + session_id_hex = _safe_call(extractor.get_session_id_hex) + statement_id = _safe_call(extractor.get_statement_id) + + sql_exec_event = SqlExecutionEvent( + statement_type=statement_type, + is_compressed=_safe_call(extractor.get_is_compressed), + execution_result=_safe_call(extractor.get_execution_result), + retry_count=_safe_call(extractor.get_retry_count), + ) + + telemetry_client = TelemetryClientFactory.get_telemetry_client( + session_id_hex + ) + telemetry_client.export_latency_log( + latency_ms=duration_ms, + sql_execution_event=sql_exec_event, + sql_statement_id=statement_id, + ) + + return wrapper + + return decorator diff --git a/src/databricks/sql/telemetry/models/endpoint_models.py b/src/databricks/sql/telemetry/models/endpoint_models.py new file mode 100644 index 00000000..a940d933 --- /dev/null +++ b/src/databricks/sql/telemetry/models/endpoint_models.py @@ -0,0 +1,43 @@ +import json +from dataclasses import dataclass, asdict +from typing import List, Optional + + +@dataclass +class TelemetryRequest: + """ + Represents a request to send telemetry data to the server side. + Contains the telemetry items to be uploaded and optional protocol buffer logs. + + Attributes: + uploadTime (int): Unix timestamp in milliseconds when the request is made + items (List[str]): List of telemetry event items to be uploaded + protoLogs (Optional[List[str]]): Optional list of protocol buffer formatted logs + """ + + uploadTime: int + items: List[str] + protoLogs: Optional[List[str]] + + def to_json(self): + return json.dumps(asdict(self)) + + +@dataclass +class TelemetryResponse: + """ + Represents the response from the telemetry backend after processing a request. + Contains information about the success or failure of the telemetry upload. + + Attributes: + errors (List[str]): List of error messages if any occurred during processing + numSuccess (int): Number of successfully processed telemetry items + numProtoSuccess (int): Number of successfully processed protocol buffer logs + """ + + errors: List[str] + numSuccess: int + numProtoSuccess: int + + def to_json(self): + return json.dumps(asdict(self)) diff --git a/src/databricks/sql/telemetry/models/enums.py b/src/databricks/sql/telemetry/models/enums.py new file mode 100644 index 00000000..dd8f26eb --- /dev/null +++ b/src/databricks/sql/telemetry/models/enums.py @@ -0,0 +1,44 @@ +from enum import Enum + + +class AuthFlow(Enum): + TYPE_UNSPECIFIED = "TYPE_UNSPECIFIED" + TOKEN_PASSTHROUGH = "TOKEN_PASSTHROUGH" + CLIENT_CREDENTIALS = "CLIENT_CREDENTIALS" + BROWSER_BASED_AUTHENTICATION = "BROWSER_BASED_AUTHENTICATION" + + +class AuthMech(Enum): + TYPE_UNSPECIFIED = "TYPE_UNSPECIFIED" + OTHER = "OTHER" + PAT = "PAT" + OAUTH = "OAUTH" + + +class DatabricksClientType(Enum): + SEA = "SEA" + THRIFT = "THRIFT" + + +class DriverVolumeOperationType(Enum): + TYPE_UNSPECIFIED = "TYPE_UNSPECIFIED" + PUT = "PUT" + GET = "GET" + DELETE = "DELETE" + LIST = "LIST" + QUERY = "QUERY" + + +class ExecutionResultFormat(Enum): + FORMAT_UNSPECIFIED = "FORMAT_UNSPECIFIED" + INLINE_ARROW = "INLINE_ARROW" + EXTERNAL_LINKS = "EXTERNAL_LINKS" + COLUMNAR_INLINE = "COLUMNAR_INLINE" + + +class StatementType(Enum): + NONE = "NONE" + QUERY = "QUERY" + SQL = "SQL" + UPDATE = "UPDATE" + METADATA = "METADATA" diff --git a/src/databricks/sql/telemetry/models/event.py b/src/databricks/sql/telemetry/models/event.py new file mode 100644 index 00000000..f5496dee --- /dev/null +++ b/src/databricks/sql/telemetry/models/event.py @@ -0,0 +1,160 @@ +from dataclasses import dataclass +from databricks.sql.telemetry.models.enums import ( + AuthMech, + AuthFlow, + DatabricksClientType, + DriverVolumeOperationType, + StatementType, + ExecutionResultFormat, +) +from typing import Optional +from databricks.sql.telemetry.utils import JsonSerializableMixin + + +@dataclass +class HostDetails(JsonSerializableMixin): + """ + Represents the host connection details for a Databricks workspace. + + Attributes: + host_url (str): The URL of the Databricks workspace (e.g., https://my-workspace.cloud.databricks.com) + port (int): The port number for the connection (typically 443 for HTTPS) + """ + + host_url: str + port: int + + +@dataclass +class DriverConnectionParameters(JsonSerializableMixin): + """ + Contains all connection parameters used to establish a connection to Databricks SQL. + This includes authentication details, host information, and connection settings. + + Attributes: + http_path (str): The HTTP path for the SQL endpoint + mode (DatabricksClientType): The type of client connection (e.g., THRIFT) + host_info (HostDetails): Details about the host connection + auth_mech (AuthMech): The authentication mechanism used + auth_flow (AuthFlow): The authentication flow type + socket_timeout (int): Connection timeout in milliseconds + """ + + http_path: str + mode: DatabricksClientType + host_info: HostDetails + auth_mech: Optional[AuthMech] = None + auth_flow: Optional[AuthFlow] = None + socket_timeout: Optional[int] = None + + +@dataclass +class DriverSystemConfiguration(JsonSerializableMixin): + """ + Contains system-level configuration information about the client environment. + This includes details about the operating system, runtime, and driver version. + + Attributes: + driver_version (str): Version of the Databricks SQL driver + os_name (str): Name of the operating system + os_version (str): Version of the operating system + os_arch (str): Architecture of the operating system + runtime_name (str): Name of the Python runtime (e.g., CPython) + runtime_version (str): Version of the Python runtime + runtime_vendor (str): Vendor of the Python runtime + client_app_name (str): Name of the client application + locale_name (str): System locale setting + driver_name (str): Name of the driver + char_set_encoding (str): Character set encoding used + """ + + driver_version: str + os_name: str + os_version: str + os_arch: str + runtime_name: str + runtime_version: str + runtime_vendor: str + driver_name: str + char_set_encoding: str + client_app_name: Optional[str] = None + locale_name: Optional[str] = None + + +@dataclass +class DriverVolumeOperation(JsonSerializableMixin): + """ + Represents a volume operation performed by the driver. + Used for tracking volume-related operations in telemetry. + + Attributes: + volume_operation_type (DriverVolumeOperationType): Type of volume operation (e.g., LIST) + volume_path (str): Path to the volume being operated on + """ + + volume_operation_type: DriverVolumeOperationType + volume_path: str + + +@dataclass +class DriverErrorInfo(JsonSerializableMixin): + """ + Contains detailed information about errors that occur during driver operations. + Used for error tracking and debugging in telemetry. + + Attributes: + error_name (str): Name/type of the error + stack_trace (str): Full stack trace of the error + """ + + error_name: str + stack_trace: str + + +@dataclass +class SqlExecutionEvent(JsonSerializableMixin): + """ + Represents a SQL query execution event. + Contains details about the query execution, including type, compression, and result format. + + Attributes: + statement_type (StatementType): Type of SQL statement + is_compressed (bool): Whether the result is compressed + execution_result (ExecutionResultFormat): Format of the execution result + retry_count (int): Number of retry attempts made + """ + + statement_type: StatementType + is_compressed: bool + execution_result: ExecutionResultFormat + retry_count: int + + +@dataclass +class TelemetryEvent(JsonSerializableMixin): + """ + Main telemetry event class that aggregates all telemetry data. + Contains information about the session, system configuration, connection parameters, + and any operations or errors that occurred. + + Attributes: + session_id (str): Unique identifier for the session + sql_statement_id (Optional[str]): ID of the SQL statement if applicable + system_configuration (DriverSystemConfiguration): System configuration details + driver_connection_params (DriverConnectionParameters): Connection parameters + auth_type (Optional[str]): Type of authentication used + vol_operation (Optional[DriverVolumeOperation]): Volume operation details if applicable + sql_operation (Optional[SqlExecutionEvent]): SQL execution details if applicable + error_info (Optional[DriverErrorInfo]): Error information if an error occurred + operation_latency_ms (Optional[int]): Operation latency in milliseconds + """ + + session_id: str + system_configuration: DriverSystemConfiguration + driver_connection_params: DriverConnectionParameters + sql_statement_id: Optional[str] = None + auth_type: Optional[str] = None + vol_operation: Optional[DriverVolumeOperation] = None + sql_operation: Optional[SqlExecutionEvent] = None + error_info: Optional[DriverErrorInfo] = None + operation_latency_ms: Optional[int] = None diff --git a/src/databricks/sql/telemetry/models/frontend_logs.py b/src/databricks/sql/telemetry/models/frontend_logs.py new file mode 100644 index 00000000..4cc314ec --- /dev/null +++ b/src/databricks/sql/telemetry/models/frontend_logs.py @@ -0,0 +1,65 @@ +from dataclasses import dataclass +from databricks.sql.telemetry.models.event import TelemetryEvent +from databricks.sql.telemetry.utils import JsonSerializableMixin +from typing import Optional + + +@dataclass +class TelemetryClientContext(JsonSerializableMixin): + """ + Contains client-side context information for telemetry events. + This includes timestamp and user agent information for tracking when and how the client is being used. + + Attributes: + timestamp_millis (int): Unix timestamp in milliseconds when the event occurred + user_agent (str): Identifier for the client application making the request + """ + + timestamp_millis: int + user_agent: str + + +@dataclass +class FrontendLogContext(JsonSerializableMixin): + """ + Wrapper for client context information in frontend logs. + Provides additional context about the client environment for telemetry events. + + Attributes: + client_context (TelemetryClientContext): Client-specific context information + """ + + client_context: TelemetryClientContext + + +@dataclass +class FrontendLogEntry(JsonSerializableMixin): + """ + Contains the actual telemetry event data in a frontend log. + Wraps the SQL driver log information for frontend processing. + + Attributes: + sql_driver_log (TelemetryEvent): The telemetry event containing SQL driver information + """ + + sql_driver_log: TelemetryEvent + + +@dataclass +class TelemetryFrontendLog(JsonSerializableMixin): + """ + Main container for frontend telemetry data. + Aggregates workspace information, event ID, context, and the actual log entry. + Used for sending telemetry data to the server side. + + Attributes: + workspace_id (int): Unique identifier for the Databricks workspace + frontend_log_event_id (str): Unique identifier for this telemetry event + context (FrontendLogContext): Context information about the client + entry (FrontendLogEntry): The actual telemetry event data + """ + + frontend_log_event_id: str + context: FrontendLogContext + entry: FrontendLogEntry + workspace_id: Optional[int] = None diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py new file mode 100644 index 00000000..dab57113 --- /dev/null +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -0,0 +1,437 @@ +import threading +import time +import json +import requests +import logging +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, Optional, List +from databricks.sql.telemetry.models.event import ( + TelemetryEvent, + DriverSystemConfiguration, + DriverErrorInfo, +) +from databricks.sql.telemetry.models.frontend_logs import ( + TelemetryFrontendLog, + TelemetryClientContext, + FrontendLogContext, + FrontendLogEntry, +) +from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow +from databricks.sql.auth.authenticators import ( + AccessTokenAuthProvider, + DatabricksOAuthProvider, + ExternalAuthProvider, +) +import sys +import platform +import uuid +import locale +from abc import ABC, abstractmethod + +logger = logging.getLogger(__name__) + + +class TelemetryHelper: + """Helper class for getting telemetry related information.""" + + _DRIVER_SYSTEM_CONFIGURATION = None + + @classmethod + def get_driver_system_configuration(cls) -> DriverSystemConfiguration: + if cls._DRIVER_SYSTEM_CONFIGURATION is None: + from databricks.sql import __version__ + + cls._DRIVER_SYSTEM_CONFIGURATION = DriverSystemConfiguration( + driver_name="Databricks SQL Python Connector", + driver_version=__version__, + runtime_name=f"Python {sys.version.split()[0]}", + runtime_vendor=platform.python_implementation(), + runtime_version=platform.python_version(), + os_name=platform.system(), + os_version=platform.release(), + os_arch=platform.machine(), + client_app_name=None, # TODO: Add client app name + locale_name=locale.getlocale()[0] or locale.getdefaultlocale()[0], + char_set_encoding=sys.getdefaultencoding(), + ) + return cls._DRIVER_SYSTEM_CONFIGURATION + + @staticmethod + def get_auth_mechanism(auth_provider): + """Get the auth mechanism for the auth provider.""" + # AuthMech is an enum with the following values: + # TYPE_UNSPECIFIED, OTHER, PAT, OAUTH + + if not auth_provider: + return None + if isinstance(auth_provider, AccessTokenAuthProvider): + return AuthMech.PAT + elif isinstance(auth_provider, DatabricksOAuthProvider): + return AuthMech.OAUTH + else: + return AuthMech.OTHER + + @staticmethod + def get_auth_flow(auth_provider): + """Get the auth flow for the auth provider.""" + # AuthFlow is an enum with the following values: + # TYPE_UNSPECIFIED, TOKEN_PASSTHROUGH, CLIENT_CREDENTIALS, BROWSER_BASED_AUTHENTICATION + + if not auth_provider: + return None + if isinstance(auth_provider, DatabricksOAuthProvider): + if auth_provider._access_token and auth_provider._refresh_token: + return AuthFlow.TOKEN_PASSTHROUGH + else: + return AuthFlow.BROWSER_BASED_AUTHENTICATION + elif isinstance(auth_provider, ExternalAuthProvider): + return AuthFlow.CLIENT_CREDENTIALS + else: + return None + + +class BaseTelemetryClient(ABC): + """ + Base class for telemetry clients. + It is used to define the interface for telemetry clients. + """ + + @abstractmethod + def export_initial_telemetry_log(self, driver_connection_params, user_agent): + logger.debug("subclass must implement export_initial_telemetry_log") + pass + + @abstractmethod + def export_failure_log(self, error_name, error_message): + logger.debug("subclass must implement export_failure_log") + pass + + @abstractmethod + def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): + logger.debug("subclass must implement export_latency_log") + pass + + @abstractmethod + def close(self): + logger.debug("subclass must implement close") + pass + + +class NoopTelemetryClient(BaseTelemetryClient): + """ + NoopTelemetryClient is a telemetry client that does not send any events to the server. + It is used when telemetry is disabled. + """ + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(NoopTelemetryClient, cls).__new__(cls) + return cls._instance + + def export_initial_telemetry_log(self, driver_connection_params, user_agent): + pass + + def export_failure_log(self, error_name, error_message): + pass + + def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): + pass + + def close(self): + pass + + +class TelemetryClient(BaseTelemetryClient): + """ + Telemetry client class that handles sending telemetry events in batches to the server. + It uses a thread pool to handle asynchronous operations, that it gets from the TelemetryClientFactory. + """ + + # Telemetry endpoint paths + TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext" + TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth" + + DEFAULT_BATCH_SIZE = 100 + + def __init__( + self, + telemetry_enabled, + session_id_hex, + auth_provider, + host_url, + executor, + ): + logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) + self._telemetry_enabled = telemetry_enabled + self._batch_size = self.DEFAULT_BATCH_SIZE + self._session_id_hex = session_id_hex + self._auth_provider = auth_provider + self._user_agent = None + self._events_batch = [] + self._lock = threading.RLock() + self._driver_connection_params = None + self._host_url = host_url + self._executor = executor + + def _export_event(self, event): + """Add an event to the batch queue and flush if batch is full""" + logger.debug("Exporting event for connection %s", self._session_id_hex) + with self._lock: + self._events_batch.append(event) + if len(self._events_batch) >= self._batch_size: + logger.debug( + "Batch size limit reached (%s), flushing events", self._batch_size + ) + self._flush() + + def _flush(self): + """Flush the current batch of events to the server""" + with self._lock: + events_to_flush = self._events_batch.copy() + self._events_batch = [] + + if events_to_flush: + logger.debug("Flushing %s telemetry events to server", len(events_to_flush)) + self._send_telemetry(events_to_flush) + + def _send_telemetry(self, events): + """Send telemetry events to the server""" + + request = { + "uploadTime": int(time.time() * 1000), + "items": [], + "protoLogs": [event.to_json() for event in events], + } + + path = ( + self.TELEMETRY_AUTHENTICATED_PATH + if self._auth_provider + else self.TELEMETRY_UNAUTHENTICATED_PATH + ) + url = f"https://{self._host_url}{path}" + + headers = {"Accept": "application/json", "Content-Type": "application/json"} + + if self._auth_provider: + self._auth_provider.add_headers(headers) + + try: + logger.debug("Submitting telemetry request to thread pool") + future = self._executor.submit( + requests.post, + url, + data=json.dumps(request), + headers=headers, + timeout=10, + ) + future.add_done_callback(self._telemetry_request_callback) + except Exception as e: + logger.debug("Failed to submit telemetry request: %s", e) + + def _telemetry_request_callback(self, future): + """Callback function to handle telemetry request completion""" + try: + response = future.result() + + if response.status_code == 200: + logger.debug("Telemetry request completed successfully") + else: + logger.debug( + "Telemetry request failed with status code: %s", + response.status_code, + ) + + except Exception as e: + logger.debug("Telemetry request failed with exception: %s", e) + + def _export_telemetry_log(self, **telemetry_event_kwargs): + """ + Common helper method for exporting telemetry logs. + + Args: + **telemetry_event_kwargs: Keyword arguments to pass to TelemetryEvent constructor + """ + logger.debug("Exporting telemetry log for connection %s", self._session_id_hex) + + try: + # Set common fields for all telemetry events + event_kwargs = { + "session_id": self._session_id_hex, + "system_configuration": TelemetryHelper.get_driver_system_configuration(), + "driver_connection_params": self._driver_connection_params, + } + # Add any additional fields passed in + event_kwargs.update(telemetry_event_kwargs) + + telemetry_frontend_log = TelemetryFrontendLog( + frontend_log_event_id=str(uuid.uuid4()), + context=FrontendLogContext( + client_context=TelemetryClientContext( + timestamp_millis=int(time.time() * 1000), + user_agent=self._user_agent, + ) + ), + entry=FrontendLogEntry(sql_driver_log=TelemetryEvent(**event_kwargs)), + ) + + self._export_event(telemetry_frontend_log) + + except Exception as e: + logger.debug("Failed to export telemetry log: %s", e) + + def export_initial_telemetry_log(self, driver_connection_params, user_agent): + self._driver_connection_params = driver_connection_params + self._user_agent = user_agent + self._export_telemetry_log() + + def export_failure_log(self, error_name, error_message): + error_info = DriverErrorInfo(error_name=error_name, stack_trace=error_message) + self._export_telemetry_log(error_info=error_info) + + def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): + self._export_telemetry_log( + sql_statement_id=sql_statement_id, + sql_operation=sql_execution_event, + operation_latency_ms=latency_ms, + ) + + def close(self): + """Flush remaining events before closing""" + logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) + self._flush() + + +class TelemetryClientFactory: + """ + Static factory class for creating and managing telemetry clients. + It uses a thread pool to handle asynchronous operations. + """ + + _clients: Dict[ + str, BaseTelemetryClient + ] = {} # Map of session_id_hex -> BaseTelemetryClient + _executor: Optional[ThreadPoolExecutor] = None + _initialized: bool = False + _lock = threading.RLock() # Thread safety for factory operations + # used RLock instead of Lock to avoid deadlocks when garbage collection is triggered + _original_excepthook = None + _excepthook_installed = False + + @classmethod + def _initialize(cls): + """Initialize the factory if not already initialized""" + + if not cls._initialized: + cls._clients = {} + cls._executor = ThreadPoolExecutor( + max_workers=10 + ) # Thread pool for async operations TODO: Decide on max workers + cls._install_exception_hook() + cls._initialized = True + logger.debug( + "TelemetryClientFactory initialized with thread pool (max_workers=10)" + ) + + @classmethod + def _install_exception_hook(cls): + """Install global exception handler for unhandled exceptions""" + if not cls._excepthook_installed: + cls._original_excepthook = sys.excepthook + sys.excepthook = cls._handle_unhandled_exception + cls._excepthook_installed = True + logger.debug("Global exception handler installed for telemetry") + + @classmethod + def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): + """Handle unhandled exceptions by sending telemetry and flushing thread pool""" + logger.debug("Handling unhandled exception: %s", exc_type.__name__) + + clients_to_close = list(cls._clients.values()) + for client in clients_to_close: + client.close() + + # Call the original exception handler to maintain normal behavior + if cls._original_excepthook: + cls._original_excepthook(exc_type, exc_value, exc_traceback) + + @staticmethod + def initialize_telemetry_client( + telemetry_enabled, + session_id_hex, + auth_provider, + host_url, + ): + """Initialize a telemetry client for a specific connection if telemetry is enabled""" + try: + + with TelemetryClientFactory._lock: + TelemetryClientFactory._initialize() + + if session_id_hex not in TelemetryClientFactory._clients: + logger.debug( + "Creating new TelemetryClient for connection %s", + session_id_hex, + ) + if telemetry_enabled: + TelemetryClientFactory._clients[ + session_id_hex + ] = TelemetryClient( + telemetry_enabled=telemetry_enabled, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + executor=TelemetryClientFactory._executor, + ) + else: + TelemetryClientFactory._clients[ + session_id_hex + ] = NoopTelemetryClient() + except Exception as e: + logger.debug("Failed to initialize telemetry client: %s", e) + # Fallback to NoopTelemetryClient to ensure connection doesn't fail + TelemetryClientFactory._clients[session_id_hex] = NoopTelemetryClient() + + @staticmethod + def get_telemetry_client(session_id_hex): + """Get the telemetry client for a specific connection""" + try: + if session_id_hex in TelemetryClientFactory._clients: + return TelemetryClientFactory._clients[session_id_hex] + else: + logger.debug( + "Telemetry client not initialized for connection %s", + session_id_hex, + ) + return NoopTelemetryClient() + except Exception as e: + logger.debug("Failed to get telemetry client: %s", e) + return NoopTelemetryClient() + + @staticmethod + def close(session_id_hex): + """Close and remove the telemetry client for a specific connection""" + + with TelemetryClientFactory._lock: + if ( + telemetry_client := TelemetryClientFactory._clients.pop( + session_id_hex, None + ) + ) is not None: + logger.debug( + "Removing telemetry client for connection %s", session_id_hex + ) + telemetry_client.close() + + # Shutdown executor if no more clients + if not TelemetryClientFactory._clients and TelemetryClientFactory._executor: + logger.debug( + "No more telemetry clients, shutting down thread pool executor" + ) + try: + TelemetryClientFactory._executor.shutdown(wait=True) + except Exception as e: + logger.debug("Failed to shutdown thread pool executor: %s", e) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False diff --git a/src/databricks/sql/telemetry/utils.py b/src/databricks/sql/telemetry/utils.py new file mode 100644 index 00000000..df7acf28 --- /dev/null +++ b/src/databricks/sql/telemetry/utils.py @@ -0,0 +1,38 @@ +import json +from enum import Enum +from dataclasses import asdict, is_dataclass + + +class JsonSerializableMixin: + """Mixin class to provide JSON serialization capabilities to dataclasses.""" + + def to_json(self) -> str: + """ + Convert the object to a JSON string, excluding None values. + Handles Enum serialization and filters out None values from the output. + """ + if not is_dataclass(self): + raise TypeError( + f"{self.__class__.__name__} must be a dataclass to use JsonSerializableMixin" + ) + + return json.dumps( + asdict( + self, + dict_factory=lambda data: {k: v for k, v in data if v is not None}, + ), + cls=EnumEncoder, + ) + + +class EnumEncoder(json.JSONEncoder): + """ + Custom JSON encoder to handle Enum values. + This is used to convert Enum values to their string representations. + Default JSON encoder raises a TypeError for Enums. + """ + + def default(self, obj): + if isinstance(obj, Enum): + return obj.value + return super().default(obj) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index e3dc38ad..78683ac3 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -223,6 +223,7 @@ def __init__( raise self._request_lock = threading.RLock() + self._session_id_hex = None # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918) def _initialize_retry_args(self, kwargs): @@ -255,12 +256,15 @@ def _initialize_retry_args(self, kwargs): ) @staticmethod - def _check_response_for_error(response): + def _check_response_for_error(response, session_id_hex=None): if response.status and response.status.statusCode in [ ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ]: - raise DatabaseError(response.status.errorMessage) + raise DatabaseError( + response.status.errorMessage, + session_id_hex=session_id_hex, + ) @staticmethod def _extract_error_message_from_headers(headers): @@ -311,7 +315,10 @@ def _handle_request_error(self, error_info, attempt, elapsed): no_retry_reason, attempt, elapsed ) network_request_error = RequestError( - user_friendly_error_message, full_error_info_context, error_info.error + user_friendly_error_message, + full_error_info_context, + self._session_id_hex, + error_info.error, ) logger.info(network_request_error.message_with_context()) @@ -483,7 +490,7 @@ def attempt_request(attempt): if not isinstance(response_or_error_info, RequestErrorInfo): # log nothing here, presume that main request logging covers response = response_or_error_info - ThriftBackend._check_response_for_error(response) + ThriftBackend._check_response_for_error(response, self._session_id_hex) return response error_info = response_or_error_info @@ -497,7 +504,8 @@ def _check_protocol_version(self, t_open_session_resp): raise OperationalError( "Error: expected server to use a protocol version >= " "SPARK_CLI_SERVICE_PROTOCOL_V2, " - "instead got: {}".format(protocol_version) + "instead got: {}".format(protocol_version), + session_id_hex=self._session_id_hex, ) def _check_initial_namespace(self, catalog, schema, response): @@ -510,14 +518,16 @@ def _check_initial_namespace(self, catalog, schema, response): ): raise InvalidServerResponseError( "Setting initial namespace not supported by the DBR version, " - "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0." + "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.", + session_id_hex=self._session_id_hex, ) if catalog: if not response.canUseMultipleCatalogs: raise InvalidServerResponseError( "Unexpected response from server: Trying to set initial catalog to {}, " - + "but server does not support multiple catalogs.".format(catalog) # type: ignore + + "but server does not support multiple catalogs.".format(catalog), # type: ignore + session_id_hex=self._session_id_hex, ) def _check_session_configuration(self, session_configuration): @@ -531,7 +541,8 @@ def _check_session_configuration(self, session_configuration): "while using the Databricks SQL connector, it must be false not {}".format( TIMESTAMP_AS_STRING_CONFIG, session_configuration[TIMESTAMP_AS_STRING_CONFIG], - ) + ), + session_id_hex=self._session_id_hex, ) def open_session(self, session_configuration, catalog, schema): @@ -562,6 +573,11 @@ def open_session(self, session_configuration, catalog, schema): response = self.make_request(self._client.OpenSession, open_session_req) self._check_initial_namespace(catalog, schema, response) self._check_protocol_version(response) + self._session_id_hex = ( + self.handle_to_hex_id(response.sessionHandle) + if response.sessionHandle + else None + ) return response except: self._transport.close() @@ -586,6 +602,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": get_operations_resp.diagnosticInfo, }, + session_id_hex=self._session_id_hex, ) else: raise ServerOperationError( @@ -595,6 +612,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": None, }, + session_id_hex=self._session_id_hex, ) elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE: raise DatabaseError( @@ -605,6 +623,7 @@ def _check_command_not_in_error_or_closed_state( "operation-id": op_handle and self.guid_to_hex_id(op_handle.operationId.guid) }, + session_id_hex=self._session_id_hex, ) def _poll_for_status(self, op_handle): @@ -625,7 +644,10 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti t_row_set.arrowBatches, lz4_compressed, schema_bytes ) else: - raise OperationalError("Unsupported TRowSet instance {}".format(t_row_set)) + raise OperationalError( + "Unsupported TRowSet instance {}".format(t_row_set), + session_id_hex=self._session_id_hex, + ) return convert_decimals_in_arrow_table(arrow_table, description), num_rows def _get_metadata_resp(self, op_handle): @@ -633,7 +655,7 @@ def _get_metadata_resp(self, op_handle): return self.make_request(self._client.GetResultSetMetadata, req) @staticmethod - def _hive_schema_to_arrow_schema(t_table_schema): + def _hive_schema_to_arrow_schema(t_table_schema, session_id_hex=None): def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -664,7 +686,8 @@ def map_type(t_type_entry): # Current thriftserver implementation should always return a primitiveEntry, # even for complex types raise OperationalError( - "Thrift protocol error: t_type_entry not a primitiveEntry" + "Thrift protocol error: t_type_entry not a primitiveEntry", + session_id_hex=session_id_hex, ) def convert_col(t_column_desc): @@ -675,7 +698,7 @@ def convert_col(t_column_desc): return pyarrow.schema([convert_col(col) for col in t_table_schema.columns]) @staticmethod - def _col_to_description(col): + def _col_to_description(col, session_id_hex=None): type_entry = col.typeDesc.types[0] if type_entry.primitiveEntry: @@ -684,7 +707,8 @@ def _col_to_description(col): cleaned_type = (name[:-5] if name.endswith("_TYPE") else name).lower() else: raise OperationalError( - "Thrift protocol error: t_type_entry not a primitiveEntry" + "Thrift protocol error: t_type_entry not a primitiveEntry", + session_id_hex=session_id_hex, ) if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE: @@ -697,7 +721,8 @@ def _col_to_description(col): else: raise OperationalError( "Decimal type did not provide typeQualifier precision, scale in " - "primitiveEntry {}".format(type_entry.primitiveEntry) + "primitiveEntry {}".format(type_entry.primitiveEntry), + session_id_hex=session_id_hex, ) else: precision, scale = None, None @@ -705,9 +730,10 @@ def _col_to_description(col): return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod - def _hive_schema_to_description(t_table_schema): + def _hive_schema_to_description(t_table_schema, session_id_hex=None): return [ - ThriftBackend._col_to_description(col) for col in t_table_schema.columns + ThriftBackend._col_to_description(col, session_id_hex) + for col in t_table_schema.columns ] def _results_message_to_execute_response(self, resp, operation_state): @@ -727,7 +753,8 @@ def _results_message_to_execute_response(self, resp, operation_state): ttypes.TSparkRowSetType._VALUES_TO_NAMES[ t_result_set_metadata_resp.resultFormat ] - ) + ), + session_id_hex=self._session_id_hex, ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation @@ -737,13 +764,16 @@ def _results_message_to_execute_response(self, resp, operation_state): or direct_results.resultSet.hasMoreRows ) description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema + t_result_set_metadata_resp.schema, + self._session_id_hex, ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + or self._hive_schema_to_arrow_schema( + t_result_set_metadata_resp.schema, self._session_id_hex + ) .serialize() .to_pybytes() ) @@ -804,13 +834,16 @@ def get_execution_result(self, op_handle, cursor): is_staging_operation = t_result_set_metadata_resp.isStagingOperation has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema + t_result_set_metadata_resp.schema, + self._session_id_hex, ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + or self._hive_schema_to_arrow_schema( + t_result_set_metadata_resp.schema, self._session_id_hex + ) .serialize() .to_pybytes() ) @@ -864,23 +897,27 @@ def get_query_state(self, op_handle) -> "TOperationState": return operation_state @staticmethod - def _check_direct_results_for_error(t_spark_direct_results): + def _check_direct_results_for_error(t_spark_direct_results, session_id_hex=None): if t_spark_direct_results: if t_spark_direct_results.operationStatus: ThriftBackend._check_response_for_error( - t_spark_direct_results.operationStatus + t_spark_direct_results.operationStatus, + session_id_hex, ) if t_spark_direct_results.resultSetMetadata: ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSetMetadata + t_spark_direct_results.resultSetMetadata, + session_id_hex, ) if t_spark_direct_results.resultSet: ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSet + t_spark_direct_results.resultSet, + session_id_hex, ) if t_spark_direct_results.closeOperation: ThriftBackend._check_response_for_error( - t_spark_direct_results.closeOperation + t_spark_direct_results.closeOperation, + session_id_hex, ) def execute_command( @@ -1029,7 +1066,7 @@ def get_columns( def _handle_execute_response(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults) + self._check_direct_results_for_error(resp.directResults, self._session_id_hex) final_operation_state = self._wait_until_command_done( resp.operationHandle, @@ -1040,7 +1077,7 @@ def _handle_execute_response(self, resp, cursor): def _handle_execute_response_async(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults) + self._check_direct_results_for_error(resp.directResults, self._session_id_hex) def fetch_results( self, @@ -1074,7 +1111,8 @@ def fetch_results( raise DataError( "fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format( expected_row_start_offset, resp.results.startRowOffset - ) + ), + session_id_hex=self._session_id_hex, ) queue = ResultSetQueueFactory.build_queue( diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py new file mode 100644 index 00000000..fcf3fa70 --- /dev/null +++ b/tests/unit/test_telemetry.py @@ -0,0 +1,286 @@ +import uuid +import pytest +import requests +from unittest.mock import patch, MagicMock + +from databricks.sql.telemetry.telemetry_client import ( + TelemetryClient, + NoopTelemetryClient, + TelemetryClientFactory, + TelemetryHelper, + BaseTelemetryClient +) +from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow +from databricks.sql.auth.authenticators import ( + AccessTokenAuthProvider, + DatabricksOAuthProvider, + ExternalAuthProvider, +) + + +@pytest.fixture +def telemetry_system_reset(): + """Reset telemetry system state before each test.""" + TelemetryClientFactory._clients.clear() + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + yield + TelemetryClientFactory._clients.clear() + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + +@pytest.fixture +def mock_telemetry_client(): + """Create a mock telemetry client for testing.""" + session_id = str(uuid.uuid4()) + auth_provider = AccessTokenAuthProvider("test-token") + executor = MagicMock() + + return TelemetryClient( + telemetry_enabled=True, + session_id_hex=session_id, + auth_provider=auth_provider, + host_url="test-host.com", + executor=executor, + ) + + +class TestNoopTelemetryClient: + """Tests for NoopTelemetryClient - should do nothing safely.""" + + def test_noop_client_behavior(self): + """Test that NoopTelemetryClient is a singleton and all methods are safe no-ops.""" + # Test singleton behavior + client1 = NoopTelemetryClient() + client2 = NoopTelemetryClient() + assert client1 is client2 + + # Test that all methods can be called without exceptions + client1.export_initial_telemetry_log(MagicMock(), "test-agent") + client1.export_failure_log("TestError", "Test message") + client1.export_latency_log(100, "EXECUTE_STATEMENT", "test-id") + client1.close() + + +class TestTelemetryClient: + """Tests for actual telemetry client functionality and flows.""" + + def test_event_batching_and_flushing_flow(self, mock_telemetry_client): + """Test the complete event batching and flushing flow.""" + client = mock_telemetry_client + client._batch_size = 3 # Small batch for testing + + # Mock the network call + with patch.object(client, '_send_telemetry') as mock_send: + # Add events one by one - should not flush yet + client._export_event("event1") + client._export_event("event2") + mock_send.assert_not_called() + assert len(client._events_batch) == 2 + + # Third event should trigger flush + client._export_event("event3") + mock_send.assert_called_once() + assert len(client._events_batch) == 0 # Batch cleared after flush + + @patch('requests.post') + def test_network_request_flow(self, mock_post, mock_telemetry_client): + """Test the complete network request flow with authentication.""" + mock_post.return_value.status_code = 200 + client = mock_telemetry_client + + # Create mock events + mock_events = [MagicMock() for _ in range(2)] + for i, event in enumerate(mock_events): + event.to_json.return_value = f'{{"event": "{i}"}}' + + # Send telemetry + client._send_telemetry(mock_events) + + # Verify request was submitted to executor + client._executor.submit.assert_called_once() + args, kwargs = client._executor.submit.call_args + + # Verify correct function and URL + assert args[0] == requests.post + assert args[1] == 'https://test-host.com/telemetry-ext' + assert kwargs['headers']['Authorization'] == 'Bearer test-token' + assert kwargs['timeout'] == 10 + + # Verify request body structure + request_data = kwargs['data'] + assert '"uploadTime"' in request_data + assert '"protoLogs"' in request_data + + def test_telemetry_logging_flows(self, mock_telemetry_client): + """Test all telemetry logging methods work end-to-end.""" + client = mock_telemetry_client + + with patch.object(client, '_export_event') as mock_export: + # Test initial log + client.export_initial_telemetry_log(MagicMock(), "test-agent") + assert mock_export.call_count == 1 + + # Test failure log + client.export_failure_log("TestError", "Error message") + assert mock_export.call_count == 2 + + # Test latency log + client.export_latency_log(150, "EXECUTE_STATEMENT", "stmt-123") + assert mock_export.call_count == 3 + + def test_error_handling_resilience(self, mock_telemetry_client): + """Test that telemetry errors don't break the client.""" + client = mock_telemetry_client + + # Test that exceptions in telemetry don't propagate + with patch.object(client, '_export_event', side_effect=Exception("Test error")): + # These should not raise exceptions + client.export_initial_telemetry_log(MagicMock(), "test-agent") + client.export_failure_log("TestError", "Error message") + client.export_latency_log(100, "EXECUTE_STATEMENT", "stmt-123") + + # Test executor submission failure + client._executor.submit.side_effect = Exception("Thread pool error") + client._send_telemetry([MagicMock()]) # Should not raise + + +class TestTelemetryHelper: + """Tests for TelemetryHelper utility functions.""" + + def test_system_configuration_caching(self): + """Test that system configuration is cached and contains expected data.""" + config1 = TelemetryHelper.get_driver_system_configuration() + config2 = TelemetryHelper.get_driver_system_configuration() + + # Should be cached (same instance) + assert config1 is config2 + + def test_auth_mechanism_detection(self): + """Test authentication mechanism detection for different providers.""" + test_cases = [ + (AccessTokenAuthProvider("token"), AuthMech.PAT), + (MagicMock(spec=DatabricksOAuthProvider), AuthMech.OAUTH), + (MagicMock(spec=ExternalAuthProvider), AuthMech.OTHER), + (MagicMock(), AuthMech.OTHER), # Unknown provider + (None, None), + ] + + for provider, expected in test_cases: + assert TelemetryHelper.get_auth_mechanism(provider) == expected + + def test_auth_flow_detection(self): + """Test authentication flow detection for OAuth providers.""" + # OAuth with existing tokens + oauth_with_tokens = MagicMock(spec=DatabricksOAuthProvider) + oauth_with_tokens._access_token = "test-access-token" + oauth_with_tokens._refresh_token = "test-refresh-token" + assert TelemetryHelper.get_auth_flow(oauth_with_tokens) == AuthFlow.TOKEN_PASSTHROUGH + + # Test OAuth with browser-based auth + oauth_with_browser = MagicMock(spec=DatabricksOAuthProvider) + oauth_with_browser._access_token = None + oauth_with_browser._refresh_token = None + oauth_with_browser.oauth_manager = MagicMock() + assert TelemetryHelper.get_auth_flow(oauth_with_browser) == AuthFlow.BROWSER_BASED_AUTHENTICATION + + # Test non-OAuth provider + pat_auth = AccessTokenAuthProvider("test-token") + assert TelemetryHelper.get_auth_flow(pat_auth) is None + + # Test None auth provider + assert TelemetryHelper.get_auth_flow(None) is None + + +class TestTelemetryFactory: + """Tests for TelemetryClientFactory lifecycle and management.""" + + def test_client_lifecycle_flow(self, telemetry_system_reset): + """Test complete client lifecycle: initialize -> use -> close.""" + session_id_hex = "test-session" + auth_provider = AccessTokenAuthProvider("token") + + # Initialize enabled client + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url="test-host.com" + ) + + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, TelemetryClient) + assert client._session_id_hex == session_id_hex + + # Close client + with patch.object(client, 'close') as mock_close: + TelemetryClientFactory.close(session_id_hex) + mock_close.assert_called_once() + + # Should get NoopTelemetryClient after close + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) + + def test_disabled_telemetry_flow(self, telemetry_system_reset): + """Test that disabled telemetry uses NoopTelemetryClient.""" + session_id_hex = "test-session" + + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=False, + session_id_hex=session_id_hex, + auth_provider=None, + host_url="test-host.com" + ) + + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) + + def test_factory_error_handling(self, telemetry_system_reset): + """Test that factory errors fall back to NoopTelemetryClient.""" + session_id = "test-session" + + # Simulate initialization error + with patch('databricks.sql.telemetry.telemetry_client.TelemetryClient', + side_effect=Exception("Init error")): + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id, + auth_provider=AccessTokenAuthProvider("token"), + host_url="test-host.com" + ) + + # Should fall back to NoopTelemetryClient + client = TelemetryClientFactory.get_telemetry_client(session_id) + assert isinstance(client, NoopTelemetryClient) + + def test_factory_shutdown_flow(self, telemetry_system_reset): + """Test factory shutdown when last client is removed.""" + session1 = "session-1" + session2 = "session-2" + + # Initialize multiple clients + for session in [session1, session2]: + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session, + auth_provider=AccessTokenAuthProvider("token"), + host_url="test-host.com" + ) + + # Factory should be initialized + assert TelemetryClientFactory._initialized is True + assert TelemetryClientFactory._executor is not None + + # Close first client - factory should stay initialized + TelemetryClientFactory.close(session1) + assert TelemetryClientFactory._initialized is True + + # Close second client - factory should shut down + TelemetryClientFactory.close(session2) + assert TelemetryClientFactory._initialized is False + assert TelemetryClientFactory._executor is None \ No newline at end of file