From 65a75f44e37c49359edf0e32d4549e02d4a31879 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 10 Jun 2025 17:29:56 +0530 Subject: [PATCH 01/48] added functionality for export of failure logs Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 76 ++++++++++++---- src/databricks/sql/exc.py | 19 +++- .../sql/telemetry/telemetry_client.py | 64 +++++++++---- src/databricks/sql/thrift_backend.py | 89 +++++++++++++------ 4 files changed, 184 insertions(+), 64 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 3a682984e..f0f1f5b8a 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -421,7 +421,10 @@ def cursor( Will throw an Error if the connection has been closed. """ if not self.open: - raise Error("Cannot create cursor from closed connection") + raise Error( + "Cannot create cursor from closed connection", + connection_uuid=self.get_session_id_hex(), + ) cursor = Cursor( self, @@ -471,7 +474,10 @@ def commit(self): pass def rollback(self): - raise NotSupportedError("Transactions are not supported on Databricks") + raise NotSupportedError( + "Transactions are not supported on Databricks", + connection_uuid=self.get_session_id_hex(), + ) class Cursor: @@ -523,7 +529,10 @@ def __iter__(self): for row in self.active_result_set: yield row else: - raise Error("There is no active result set") + raise Error( + "There is no active result set", + connection_uuid=self.connection.get_session_id_hex(), + ) def _determine_parameter_approach( self, params: Optional[TParameterCollection] @@ -660,7 +669,10 @@ def _close_and_clear_active_result_set(self): def _check_not_closed(self): if not self.open: - raise Error("Attempting operation on closed cursor") + raise Error( + "Attempting operation on closed cursor", + connection_uuid=self.connection.get_session_id_hex(), + ) def _handle_staging_operation( self, staging_allowed_local_path: Union[None, str, List[str]] @@ -678,7 +690,8 @@ def _handle_staging_operation( _staging_allowed_local_paths = staging_allowed_local_path else: raise Error( - "You must provide at least one staging_allowed_local_path when initialising a connection to perform ingestion commands" + "You must provide at least one staging_allowed_local_path when initialising a connection to perform ingestion commands", + connection_uuid=self.connection.get_session_id_hex(), ) abs_staging_allowed_local_paths = [ @@ -707,7 +720,8 @@ def _handle_staging_operation( continue if not allow_operation: raise Error( - "Local file operations are restricted to paths within the configured staging_allowed_local_path" + "Local file operations are restricted to paths within the configured staging_allowed_local_path", + connection_uuid=self.connection.get_session_id_hex(), ) # May be real headers, or could be json string @@ -737,7 +751,8 @@ def _handle_staging_operation( else: raise Error( f"Operation {row.operation} is not supported. " - + "Supported operations are GET, PUT, and REMOVE" + + "Supported operations are GET, PUT, and REMOVE", + connection_uuid=self.connection.get_session_id_hex(), ) def _handle_staging_put( @@ -749,7 +764,10 @@ def _handle_staging_put( """ if local_file is None: - raise Error("Cannot perform PUT without specifying a local_file") + raise Error( + "Cannot perform PUT without specifying a local_file", + connection_uuid=self.connection.get_session_id_hex(), + ) with open(local_file, "rb") as fh: r = requests.put(url=presigned_url, data=fh, headers=headers) @@ -766,7 +784,8 @@ def _handle_staging_put( if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: raise Error( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}" + f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + connection_uuid=self.connection.get_session_id_hex(), ) if r.status_code == ACCEPTED: @@ -784,7 +803,10 @@ def _handle_staging_get( """ if local_file is None: - raise Error("Cannot perform GET without specifying a local_file") + raise Error( + "Cannot perform GET without specifying a local_file", + connection_uuid=self.connection.get_session_id_hex(), + ) r = requests.get(url=presigned_url, headers=headers) @@ -792,7 +814,8 @@ def _handle_staging_get( # Any 2xx or 3xx will evaluate r.ok == True if not r.ok: raise Error( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}" + f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + connection_uuid=self.connection.get_session_id_hex(), ) with open(local_file, "wb") as fp: @@ -807,7 +830,8 @@ def _handle_staging_remove( if not r.ok: raise Error( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}" + f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + connection_uuid=self.connection.get_session_id_hex(), ) def execute( @@ -1006,7 +1030,8 @@ def get_async_execution_result(self): return self else: raise Error( - f"get_execution_result failed with Operation status {operation_state}" + f"get_execution_result failed with Operation status {operation_state}", + connection_uuid=self.connection.get_session_id_hex(), ) def executemany(self, operation, seq_of_parameters): @@ -1156,7 +1181,10 @@ def fetchall(self) -> List[Row]: if self.active_result_set: return self.active_result_set.fetchall() else: - raise Error("There is no active result set") + raise Error( + "There is no active result set", + connection_uuid=self.connection.get_session_id_hex(), + ) def fetchone(self) -> Optional[Row]: """ @@ -1170,7 +1198,10 @@ def fetchone(self) -> Optional[Row]: if self.active_result_set: return self.active_result_set.fetchone() else: - raise Error("There is no active result set") + raise Error( + "There is no active result set", + connection_uuid=self.connection.get_session_id_hex(), + ) def fetchmany(self, size: int) -> List[Row]: """ @@ -1192,21 +1223,30 @@ def fetchmany(self, size: int) -> List[Row]: if self.active_result_set: return self.active_result_set.fetchmany(size) else: - raise Error("There is no active result set") + raise Error( + "There is no active result set", + connection_uuid=self.connection.get_session_id_hex(), + ) def fetchall_arrow(self) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: return self.active_result_set.fetchall_arrow() else: - raise Error("There is no active result set") + raise Error( + "There is no active result set", + connection_uuid=self.connection.get_session_id_hex(), + ) def fetchmany_arrow(self, size) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: return self.active_result_set.fetchmany_arrow(size) else: - raise Error("There is no active result set") + raise Error( + "There is no active result set", + connection_uuid=self.connection.get_session_id_hex(), + ) def cancel(self) -> None: """ diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 3b27283a4..92577d548 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -1,8 +1,10 @@ import json import logging +import traceback -logger = logging.getLogger(__name__) +from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory +logger = logging.getLogger(__name__) ### PEP-249 Mandated ### class Error(Exception): @@ -11,10 +13,23 @@ class Error(Exception): `context`: Optional extra context about the error. MUST be JSON serializable """ - def __init__(self, message=None, context=None, *args, **kwargs): + def __init__( + self, message=None, context=None, connection_uuid=None, *args, **kwargs + ): super().__init__(message, *args, **kwargs) self.message = message self.context = context or {} + self.connection_uuid = connection_uuid + + error_name = self.__class__.__name__ + if self.connection_uuid: + try: + telemetry_client = TelemetryClientFactory.get_telemetry_client( + self.connection_uuid + ) + telemetry_client.export_failure_log(error_name, self.message) + except Exception as telemetry_error: + logger.error(f"Failed to send error to telemetry: {telemetry_error}") def __str__(self): return self.message diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index d095d685c..b2caa6c1f 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -8,6 +8,7 @@ from databricks.sql.telemetry.models.event import ( TelemetryEvent, DriverSystemConfiguration, + DriverErrorInfo, ) from databricks.sql.telemetry.models.frontend_logs import ( TelemetryFrontendLog, @@ -26,7 +27,6 @@ import uuid import locale from abc import ABC, abstractmethod -from databricks.sql import __version__ logger = logging.getLogger(__name__) @@ -34,22 +34,26 @@ class TelemetryHelper: """Helper class for getting telemetry related information.""" - _DRIVER_SYSTEM_CONFIGURATION = DriverSystemConfiguration( - driver_name="Databricks SQL Python Connector", - driver_version=__version__, - runtime_name=f"Python {sys.version.split()[0]}", - runtime_vendor=platform.python_implementation(), - runtime_version=platform.python_version(), - os_name=platform.system(), - os_version=platform.release(), - os_arch=platform.machine(), - client_app_name=None, # TODO: Add client app name - locale_name=locale.getlocale()[0] or locale.getdefaultlocale()[0], - char_set_encoding=sys.getdefaultencoding(), - ) + _DRIVER_SYSTEM_CONFIGURATION = None @classmethod def getDriverSystemConfiguration(cls) -> DriverSystemConfiguration: + if cls._DRIVER_SYSTEM_CONFIGURATION is None: + from databricks.sql import __version__ + + cls._DRIVER_SYSTEM_CONFIGURATION = DriverSystemConfiguration( + driver_name="Databricks SQL Python Connector", + driver_version=__version__, + runtime_name=f"Python {sys.version.split()[0]}", + runtime_vendor=platform.python_implementation(), + runtime_version=platform.python_version(), + os_name=platform.system(), + os_version=platform.release(), + os_arch=platform.machine(), + client_app_name=None, # TODO: Add client app name + locale_name=locale.getlocale()[0] or locale.getdefaultlocale()[0], + char_set_encoding=sys.getdefaultencoding(), + ) return cls._DRIVER_SYSTEM_CONFIGURATION @staticmethod @@ -99,7 +103,11 @@ class BaseTelemetryClient(ABC): """ @abstractmethod - def export_initial_telemetry_log(self, **kwargs): + def export_initial_telemetry_log(self, driver_connection_params, user_agent): + pass + + @abstractmethod + def export_failure_log(self, error_name, error_message): pass @abstractmethod @@ -123,6 +131,9 @@ def __new__(cls): def export_initial_telemetry_log(self, driver_connection_params, user_agent): pass + def export_failure_log(self, error_name, error_message): + pass + def close(self): pass @@ -255,10 +266,33 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): self.export_event(telemetry_frontend_log) + def export_failure_log(self, error_name, error_message): + logger.debug("Exporting failure log for connection %s", self._connection_uuid) + error_info = DriverErrorInfo(error_name=error_name, stack_trace=error_message) + telemetry_frontend_log = TelemetryFrontendLog( + frontend_log_event_id=str(uuid.uuid4()), + context=FrontendLogContext( + client_context=TelemetryClientContext( + timestamp_millis=int(time.time() * 1000), + user_agent=self._user_agent, + ) + ), + entry=FrontendLogEntry( + sql_driver_log=TelemetryEvent( + session_id=self._connection_uuid, + system_configuration=TelemetryHelper.getDriverSystemConfiguration(), + driver_connection_params=self._driver_connection_params, + error_info=error_info, + ) + ), + ) + self.export_event(telemetry_frontend_log) + def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._connection_uuid) self.flush() + TelemetryClientFactory.close(self._connection_uuid) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index e3dc38ad5..233b4f55c 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -223,6 +223,7 @@ def __init__( raise self._request_lock = threading.RLock() + self._connection_uuid = None # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918) def _initialize_retry_args(self, kwargs): @@ -255,12 +256,14 @@ def _initialize_retry_args(self, kwargs): ) @staticmethod - def _check_response_for_error(response): + def _check_response_for_error(response, connection_uuid=None): if response.status and response.status.statusCode in [ ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ]: - raise DatabaseError(response.status.errorMessage) + raise DatabaseError( + response.status.errorMessage, connection_uuid=connection_uuid + ) @staticmethod def _extract_error_message_from_headers(headers): @@ -311,7 +314,10 @@ def _handle_request_error(self, error_info, attempt, elapsed): no_retry_reason, attempt, elapsed ) network_request_error = RequestError( - user_friendly_error_message, full_error_info_context, error_info.error + user_friendly_error_message, + full_error_info_context, + self._connection_uuid, + error_info.error, ) logger.info(network_request_error.message_with_context()) @@ -483,7 +489,7 @@ def attempt_request(attempt): if not isinstance(response_or_error_info, RequestErrorInfo): # log nothing here, presume that main request logging covers response = response_or_error_info - ThriftBackend._check_response_for_error(response) + ThriftBackend._check_response_for_error(response, self._connection_uuid) return response error_info = response_or_error_info @@ -497,7 +503,8 @@ def _check_protocol_version(self, t_open_session_resp): raise OperationalError( "Error: expected server to use a protocol version >= " "SPARK_CLI_SERVICE_PROTOCOL_V2, " - "instead got: {}".format(protocol_version) + "instead got: {}".format(protocol_version), + connection_uuid=self._connection_uuid, ) def _check_initial_namespace(self, catalog, schema, response): @@ -510,14 +517,16 @@ def _check_initial_namespace(self, catalog, schema, response): ): raise InvalidServerResponseError( "Setting initial namespace not supported by the DBR version, " - "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0." + "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.", + connection_uuid=self._connection_uuid, ) if catalog: if not response.canUseMultipleCatalogs: raise InvalidServerResponseError( "Unexpected response from server: Trying to set initial catalog to {}, " - + "but server does not support multiple catalogs.".format(catalog) # type: ignore + + "but server does not support multiple catalogs.".format(catalog), # type: ignore + connection_uuid=self._connection_uuid, ) def _check_session_configuration(self, session_configuration): @@ -531,7 +540,8 @@ def _check_session_configuration(self, session_configuration): "while using the Databricks SQL connector, it must be false not {}".format( TIMESTAMP_AS_STRING_CONFIG, session_configuration[TIMESTAMP_AS_STRING_CONFIG], - ) + ), + connection_uuid=self._connection_uuid, ) def open_session(self, session_configuration, catalog, schema): @@ -562,6 +572,11 @@ def open_session(self, session_configuration, catalog, schema): response = self.make_request(self._client.OpenSession, open_session_req) self._check_initial_namespace(catalog, schema, response) self._check_protocol_version(response) + self._connection_uuid = ( + self.handle_to_hex_id(response.sessionHandle) + if response.sessionHandle + else None + ) return response except: self._transport.close() @@ -586,6 +601,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": get_operations_resp.diagnosticInfo, }, + connection_uuid=self._connection_uuid, ) else: raise ServerOperationError( @@ -595,6 +611,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": None, }, + connection_uuid=self._connection_uuid, ) elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE: raise DatabaseError( @@ -605,6 +622,7 @@ def _check_command_not_in_error_or_closed_state( "operation-id": op_handle and self.guid_to_hex_id(op_handle.operationId.guid) }, + connection_uuid=self._connection_uuid, ) def _poll_for_status(self, op_handle): @@ -625,7 +643,10 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti t_row_set.arrowBatches, lz4_compressed, schema_bytes ) else: - raise OperationalError("Unsupported TRowSet instance {}".format(t_row_set)) + raise OperationalError( + "Unsupported TRowSet instance {}".format(t_row_set), + connection_uuid=self._connection_uuid, + ) return convert_decimals_in_arrow_table(arrow_table, description), num_rows def _get_metadata_resp(self, op_handle): @@ -633,7 +654,7 @@ def _get_metadata_resp(self, op_handle): return self.make_request(self._client.GetResultSetMetadata, req) @staticmethod - def _hive_schema_to_arrow_schema(t_table_schema): + def _hive_schema_to_arrow_schema(t_table_schema, connection_uuid=None): def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -664,7 +685,8 @@ def map_type(t_type_entry): # Current thriftserver implementation should always return a primitiveEntry, # even for complex types raise OperationalError( - "Thrift protocol error: t_type_entry not a primitiveEntry" + "Thrift protocol error: t_type_entry not a primitiveEntry", + connection_uuid=connection_uuid, ) def convert_col(t_column_desc): @@ -675,7 +697,7 @@ def convert_col(t_column_desc): return pyarrow.schema([convert_col(col) for col in t_table_schema.columns]) @staticmethod - def _col_to_description(col): + def _col_to_description(col, connection_uuid=None): type_entry = col.typeDesc.types[0] if type_entry.primitiveEntry: @@ -684,7 +706,8 @@ def _col_to_description(col): cleaned_type = (name[:-5] if name.endswith("_TYPE") else name).lower() else: raise OperationalError( - "Thrift protocol error: t_type_entry not a primitiveEntry" + "Thrift protocol error: t_type_entry not a primitiveEntry", + connection_uuid=connection_uuid, ) if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE: @@ -697,7 +720,8 @@ def _col_to_description(col): else: raise OperationalError( "Decimal type did not provide typeQualifier precision, scale in " - "primitiveEntry {}".format(type_entry.primitiveEntry) + "primitiveEntry {}".format(type_entry.primitiveEntry), + connection_uuid=connection_uuid, ) else: precision, scale = None, None @@ -705,9 +729,10 @@ def _col_to_description(col): return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod - def _hive_schema_to_description(t_table_schema): + def _hive_schema_to_description(t_table_schema, connection_uuid=None): return [ - ThriftBackend._col_to_description(col) for col in t_table_schema.columns + ThriftBackend._col_to_description(col, connection_uuid) + for col in t_table_schema.columns ] def _results_message_to_execute_response(self, resp, operation_state): @@ -727,7 +752,8 @@ def _results_message_to_execute_response(self, resp, operation_state): ttypes.TSparkRowSetType._VALUES_TO_NAMES[ t_result_set_metadata_resp.resultFormat ] - ) + ), + connection_uuid=self._connection_uuid, ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation @@ -737,13 +763,15 @@ def _results_message_to_execute_response(self, resp, operation_state): or direct_results.resultSet.hasMoreRows ) description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema + t_result_set_metadata_resp.schema, self._connection_uuid ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + or self._hive_schema_to_arrow_schema( + t_result_set_metadata_resp.schema, self._connection_uuid + ) .serialize() .to_pybytes() ) @@ -804,13 +832,15 @@ def get_execution_result(self, op_handle, cursor): is_staging_operation = t_result_set_metadata_resp.isStagingOperation has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema + t_result_set_metadata_resp.schema, self._connection_uuid ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + or self._hive_schema_to_arrow_schema( + t_result_set_metadata_resp.schema, self._connection_uuid + ) .serialize() .to_pybytes() ) @@ -864,23 +894,23 @@ def get_query_state(self, op_handle) -> "TOperationState": return operation_state @staticmethod - def _check_direct_results_for_error(t_spark_direct_results): + def _check_direct_results_for_error(t_spark_direct_results, connection_uuid=None): if t_spark_direct_results: if t_spark_direct_results.operationStatus: ThriftBackend._check_response_for_error( - t_spark_direct_results.operationStatus + t_spark_direct_results.operationStatus, connection_uuid ) if t_spark_direct_results.resultSetMetadata: ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSetMetadata + t_spark_direct_results.resultSetMetadata, connection_uuid ) if t_spark_direct_results.resultSet: ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSet + t_spark_direct_results.resultSet, connection_uuid ) if t_spark_direct_results.closeOperation: ThriftBackend._check_response_for_error( - t_spark_direct_results.closeOperation + t_spark_direct_results.closeOperation, connection_uuid ) def execute_command( @@ -1029,7 +1059,7 @@ def get_columns( def _handle_execute_response(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults) + self._check_direct_results_for_error(resp.directResults, self._connection_uuid) final_operation_state = self._wait_until_command_done( resp.operationHandle, @@ -1040,7 +1070,7 @@ def _handle_execute_response(self, resp, cursor): def _handle_execute_response_async(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults) + self._check_direct_results_for_error(resp.directResults, self._connection_uuid) def fetch_results( self, @@ -1074,7 +1104,8 @@ def fetch_results( raise DataError( "fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format( expected_row_start_offset, resp.results.startRowOffset - ) + ), + connection_uuid=self._connection_uuid, ) queue = ResultSetQueueFactory.build_queue( From 5305308994e3ef3f1a2f76c0a8f147638a83a91c Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 11 Jun 2025 09:36:41 +0530 Subject: [PATCH 02/48] changed logger.error to logger.debug in exc.py Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/exc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 92577d548..61d3a6234 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -29,7 +29,7 @@ def __init__( ) telemetry_client.export_failure_log(error_name, self.message) except Exception as telemetry_error: - logger.error(f"Failed to send error to telemetry: {telemetry_error}") + logger.debug(f"Failed to send error to telemetry: {telemetry_error}") def __str__(self): return self.message From ba83c33561f6f7e86b55bec3443be26fc8fc1c63 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 11 Jun 2025 11:27:53 +0530 Subject: [PATCH 03/48] Fix telemetry loss during Python shutdown Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index b2caa6c1f..eb0dee82c 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -308,6 +308,8 @@ class TelemetryClientFactory: _executor: Optional[ThreadPoolExecutor] = None _initialized: bool = False _lock = threading.Lock() # Thread safety for factory operations + _original_excepthook = None + _excepthook_installed = False @classmethod def _initialize(cls): @@ -318,11 +320,58 @@ def _initialize(cls): cls._executor = ThreadPoolExecutor( max_workers=10 ) # Thread pool for async operations TODO: Decide on max workers + cls._install_exception_hook() cls._initialized = True logger.debug( "TelemetryClientFactory initialized with thread pool (max_workers=10)" ) + @classmethod + def _install_exception_hook(cls): + """Install global exception handler for unhandled exceptions""" + if not cls._excepthook_installed: + import sys + + cls._original_excepthook = sys.excepthook + sys.excepthook = cls._handle_unhandled_exception + cls._excepthook_installed = True + logger.debug("Global exception handler installed for telemetry") + + @classmethod + def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): + """Handle unhandled exceptions by sending telemetry and flushing thread pool""" + logger.debug("Handling unhandled exception: %s", exc_type.__name__) + + try: + # Flush existing thread pool work and wait for completion + logger.debug( + "Flushing pending telemetry and waiting for thread pool completion..." + ) + for uuid, client in cls._clients.items(): + if hasattr(client, "flush"): + try: + client.flush() # Submit any pending events + except Exception as e: + logger.debug( + "Failed to flush telemetry for connection %s: %s", uuid, e + ) + + if cls._executor: + try: + cls._executor.shutdown( + wait=True + ) # This waits for all submitted work to complete + logger.debug("Thread pool shutdown completed successfully") + except Exception as e: + logger.debug("Thread pool shutdown failed: %s", e) + + except Exception as e: + logger.debug("Exception in excepthook telemetry handler: %s", e) + + # Call the original exception handler to maintain normal behavior + if cls._original_excepthook: + cls._original_excepthook(exc_type, exc_value, exc_traceback) + @staticmethod def initialize_telemetry_client( telemetry_enabled, From 131db92293771bda983e3305371c8de460281704 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 12 Jun 2025 10:21:47 +0530 Subject: [PATCH 04/48] unit tests for export_failure_log Signed-off-by: Sai Shree Pradhan --- tests/unit/test_telemetry.py | 47 ++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 478205b18..b210d61b8 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -83,6 +83,12 @@ def test_export_initial_telemetry_log(self, noop_telemetry_client): driver_connection_params=MagicMock(), user_agent="test" ) + def test_export_failure_log(self, noop_telemetry_client): + """Test that export_failure_log does nothing.""" + noop_telemetry_client.export_failure_log( + error_name="TestError", error_message="Test error message" + ) + def test_close(self, noop_telemetry_client): """Test that close does nothing.""" noop_telemetry_client.close() @@ -127,6 +133,47 @@ def test_export_initial_telemetry_log( mock_frontend_log.assert_called_once() client.export_event.assert_called_once_with(mock_frontend_log.return_value) + @patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog") + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.getDriverSystemConfiguration") + @patch("databricks.sql.telemetry.telemetry_client.DriverErrorInfo") + @patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4") + @patch("databricks.sql.telemetry.telemetry_client.time.time") + def test_export_failure_log( + self, + mock_time, + mock_uuid4, + mock_driver_error_info, + mock_get_driver_config, + mock_frontend_log, + telemetry_client_setup + ): + """Test exporting failure telemetry log.""" + mock_time.return_value = 2000 + mock_uuid4.return_value = "test-error-uuid" + mock_get_driver_config.return_value = "test-driver-config" + mock_driver_error_info.return_value = MagicMock() + mock_frontend_log.return_value = MagicMock() + + client = telemetry_client_setup["client"] + client.export_event = MagicMock() + + client._driver_connection_params = "test-connection-params" + client._user_agent = "test-user-agent" + + error_name = "TestError" + error_message = "This is a test error message" + + client.export_failure_log(error_name, error_message) + + mock_driver_error_info.assert_called_once_with( + error_name=error_name, + stack_trace=error_message + ) + + mock_frontend_log.assert_called_once() + + client.export_event.assert_called_once_with(mock_frontend_log.return_value) + def test_export_event(self, telemetry_client_setup): """Test exporting an event.""" client = telemetry_client_setup["client"] From 3abc40dcaa39e6ebfb527a0019f355d95a53164f Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 12 Jun 2025 10:56:34 +0530 Subject: [PATCH 05/48] try-catch blocks to make telemetry failures non-blocking for connector operations Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 98 +++++++++++-------- 1 file changed, 57 insertions(+), 41 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index eb0dee82c..eb4edec18 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -244,56 +244,72 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): "Exporting initial telemetry log for connection %s", self._connection_uuid ) - self._driver_connection_params = driver_connection_params - self._user_agent = user_agent - - telemetry_frontend_log = TelemetryFrontendLog( - frontend_log_event_id=str(uuid.uuid4()), - context=FrontendLogContext( - client_context=TelemetryClientContext( - timestamp_millis=int(time.time() * 1000), - user_agent=self._user_agent, - ) - ), - entry=FrontendLogEntry( - sql_driver_log=TelemetryEvent( - session_id=self._connection_uuid, - system_configuration=TelemetryHelper.getDriverSystemConfiguration(), - driver_connection_params=self._driver_connection_params, - ) - ), - ) + try: + self._driver_connection_params = driver_connection_params + self._user_agent = user_agent + + telemetry_frontend_log = TelemetryFrontendLog( + frontend_log_event_id=str(uuid.uuid4()), + context=FrontendLogContext( + client_context=TelemetryClientContext( + timestamp_millis=int(time.time() * 1000), + user_agent=self._user_agent, + ) + ), + entry=FrontendLogEntry( + sql_driver_log=TelemetryEvent( + session_id=self._connection_uuid, + system_configuration=TelemetryHelper.getDriverSystemConfiguration(), + driver_connection_params=self._driver_connection_params, + ) + ), + ) - self.export_event(telemetry_frontend_log) + self.export_event(telemetry_frontend_log) + except Exception as e: + logger.debug("Failed to export initial telemetry log: %s", e) def export_failure_log(self, error_name, error_message): logger.debug("Exporting failure log for connection %s", self._connection_uuid) - error_info = DriverErrorInfo(error_name=error_name, stack_trace=error_message) - telemetry_frontend_log = TelemetryFrontendLog( - frontend_log_event_id=str(uuid.uuid4()), - context=FrontendLogContext( - client_context=TelemetryClientContext( - timestamp_millis=int(time.time() * 1000), - user_agent=self._user_agent, - ) - ), - entry=FrontendLogEntry( - sql_driver_log=TelemetryEvent( - session_id=self._connection_uuid, - system_configuration=TelemetryHelper.getDriverSystemConfiguration(), - driver_connection_params=self._driver_connection_params, - error_info=error_info, - ) - ), - ) - self.export_event(telemetry_frontend_log) + try: + error_info = DriverErrorInfo( + error_name=error_name, stack_trace=error_message + ) + telemetry_frontend_log = TelemetryFrontendLog( + frontend_log_event_id=str(uuid.uuid4()), + context=FrontendLogContext( + client_context=TelemetryClientContext( + timestamp_millis=int(time.time() * 1000), + user_agent=self._user_agent, + ) + ), + entry=FrontendLogEntry( + sql_driver_log=TelemetryEvent( + session_id=self._connection_uuid, + system_configuration=TelemetryHelper.getDriverSystemConfiguration(), + driver_connection_params=self._driver_connection_params, + error_info=error_info, + ) + ), + ) + self.export_event(telemetry_frontend_log) + except Exception as e: + logger.debug("Failed to export failure log: %s", e) def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._connection_uuid) - self.flush() + try: + self.flush() + except Exception as e: + logger.debug("Failed to flush telemetry during close: %s", e) - TelemetryClientFactory.close(self._connection_uuid) + try: + TelemetryClientFactory.close(self._connection_uuid) + except Exception as e: + logger.debug( + "Failed to remove telemetry client from telemetry clientfactory: %s", e + ) class TelemetryClientFactory: From ffa47872a063ba8bdc93da126b67a3baecaa07a7 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 12 Jun 2025 11:55:38 +0530 Subject: [PATCH 06/48] removed redundant try/catch blocks, added try/catch block to initialize and get telemetry client Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/exc.py | 11 +- .../sql/telemetry/telemetry_client.py | 107 +++++++++--------- 2 files changed, 59 insertions(+), 59 deletions(-) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 61d3a6234..d7bcd5c61 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -23,13 +23,10 @@ def __init__( error_name = self.__class__.__name__ if self.connection_uuid: - try: - telemetry_client = TelemetryClientFactory.get_telemetry_client( - self.connection_uuid - ) - telemetry_client.export_failure_log(error_name, self.message) - except Exception as telemetry_error: - logger.debug(f"Failed to send error to telemetry: {telemetry_error}") + telemetry_client = TelemetryClientFactory.get_telemetry_client( + self.connection_uuid + ) + telemetry_client.export_failure_log(error_name, self.message) def __str__(self): return self.message diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index eb4edec18..216262b31 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -301,15 +301,9 @@ def close(self): logger.debug("Closing TelemetryClient for connection %s", self._connection_uuid) try: self.flush() - except Exception as e: - logger.debug("Failed to flush telemetry during close: %s", e) - - try: TelemetryClientFactory.close(self._connection_uuid) except Exception as e: - logger.debug( - "Failed to remove telemetry client from telemetry clientfactory: %s", e - ) + logger.debug("Failed to close telemetry client: %s", e) class TelemetryClientFactory: @@ -358,31 +352,27 @@ def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): """Handle unhandled exceptions by sending telemetry and flushing thread pool""" logger.debug("Handling unhandled exception: %s", exc_type.__name__) - try: - # Flush existing thread pool work and wait for completion - logger.debug( - "Flushing pending telemetry and waiting for thread pool completion..." - ) - for uuid, client in cls._clients.items(): - if hasattr(client, "flush"): - try: - client.flush() # Submit any pending events - except Exception as e: - logger.debug( - "Failed to flush telemetry for connection %s: %s", uuid, e - ) - - if cls._executor: + # Flush existing thread pool work and wait for completion + logger.debug( + "Flushing pending telemetry and waiting for thread pool completion..." + ) + for uuid, client in cls._clients.items(): + if hasattr(client, "flush"): try: - cls._executor.shutdown( - wait=True - ) # This waits for all submitted work to complete - logger.debug("Thread pool shutdown completed successfully") + client.flush() # Submit any pending events except Exception as e: - logger.debug("Thread pool shutdown failed: %s", e) + logger.debug( + "Failed to flush telemetry for connection %s: %s", uuid, e + ) - except Exception as e: - logger.debug("Exception in excepthook telemetry handler: %s", e) + if cls._executor: + try: + cls._executor.shutdown( + wait=True + ) # This waits for all submitted work to complete + logger.debug("Thread pool shutdown completed successfully") + except Exception as e: + logger.debug("Thread pool shutdown failed: %s", e) # Call the original exception handler to maintain normal behavior if cls._original_excepthook: @@ -396,35 +386,48 @@ def initialize_telemetry_client( host_url, ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" - TelemetryClientFactory._initialize() + try: + TelemetryClientFactory._initialize() - with TelemetryClientFactory._lock: - if connection_uuid not in TelemetryClientFactory._clients: - logger.debug( - "Creating new TelemetryClient for connection %s", connection_uuid - ) - if telemetry_enabled: - TelemetryClientFactory._clients[connection_uuid] = TelemetryClient( - telemetry_enabled=telemetry_enabled, - connection_uuid=connection_uuid, - auth_provider=auth_provider, - host_url=host_url, - executor=TelemetryClientFactory._executor, + with TelemetryClientFactory._lock: + if connection_uuid not in TelemetryClientFactory._clients: + logger.debug( + "Creating new TelemetryClient for connection %s", + connection_uuid, ) - else: - TelemetryClientFactory._clients[ - connection_uuid - ] = NoopTelemetryClient() + if telemetry_enabled: + TelemetryClientFactory._clients[ + connection_uuid + ] = TelemetryClient( + telemetry_enabled=telemetry_enabled, + connection_uuid=connection_uuid, + auth_provider=auth_provider, + host_url=host_url, + executor=TelemetryClientFactory._executor, + ) + else: + TelemetryClientFactory._clients[ + connection_uuid + ] = NoopTelemetryClient() + except Exception as e: + logger.debug("Failed to initialize telemetry client: %s", e) + # Fallback to NoopTelemetryClient to ensure connection doesn't fail + TelemetryClientFactory._clients[connection_uuid] = NoopTelemetryClient() @staticmethod def get_telemetry_client(connection_uuid): """Get the telemetry client for a specific connection""" - if connection_uuid in TelemetryClientFactory._clients: - return TelemetryClientFactory._clients[connection_uuid] - else: - logger.error( - "Telemetry client not initialized for connection %s", connection_uuid - ) + try: + if connection_uuid in TelemetryClientFactory._clients: + return TelemetryClientFactory._clients[connection_uuid] + else: + logger.error( + "Telemetry client not initialized for connection %s", + connection_uuid, + ) + return NoopTelemetryClient() + except Exception as e: + logger.debug("Failed to get telemetry client: %s", e) return NoopTelemetryClient() @staticmethod From cc077f3b6032bee52e99dfe948fd26b3dc9911be Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 12 Jun 2025 13:54:05 +0530 Subject: [PATCH 07/48] skip null fields in telemetry request Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/models/event.py | 16 +++++++------- .../sql/telemetry/models/frontend_logs.py | 10 ++++----- src/databricks/sql/telemetry/utils.py | 21 +++++++++++++++++++ 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/src/databricks/sql/telemetry/models/event.py b/src/databricks/sql/telemetry/models/event.py index 4429a7626..c00738810 100644 --- a/src/databricks/sql/telemetry/models/event.py +++ b/src/databricks/sql/telemetry/models/event.py @@ -9,7 +9,7 @@ ExecutionResultFormat, ) from typing import Optional -from databricks.sql.telemetry.utils import EnumEncoder +from databricks.sql.telemetry.utils import to_json_compact @dataclass @@ -26,7 +26,7 @@ class HostDetails: port: int def to_json(self): - return json.dumps(asdict(self)) + return to_json_compact(self) @dataclass @@ -52,7 +52,7 @@ class DriverConnectionParameters: socket_timeout: Optional[int] = None def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) @dataclass @@ -88,7 +88,7 @@ class DriverSystemConfiguration: locale_name: Optional[str] = None def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) @dataclass @@ -106,7 +106,7 @@ class DriverVolumeOperation: volume_path: str def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) @dataclass @@ -124,7 +124,7 @@ class DriverErrorInfo: stack_trace: str def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) @dataclass @@ -146,7 +146,7 @@ class SqlExecutionEvent: retry_count: int def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) @dataclass @@ -179,4 +179,4 @@ class TelemetryEvent: operation_latency_ms: Optional[int] = None def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) diff --git a/src/databricks/sql/telemetry/models/frontend_logs.py b/src/databricks/sql/telemetry/models/frontend_logs.py index 36086a7cc..f5d58a4be 100644 --- a/src/databricks/sql/telemetry/models/frontend_logs.py +++ b/src/databricks/sql/telemetry/models/frontend_logs.py @@ -1,7 +1,7 @@ import json from dataclasses import dataclass, asdict from databricks.sql.telemetry.models.event import TelemetryEvent -from databricks.sql.telemetry.utils import EnumEncoder +from databricks.sql.telemetry.utils import to_json_compact from typing import Optional @@ -20,7 +20,7 @@ class TelemetryClientContext: user_agent: str def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) @dataclass @@ -36,7 +36,7 @@ class FrontendLogContext: client_context: TelemetryClientContext def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) @dataclass @@ -52,7 +52,7 @@ class FrontendLogEntry: sql_driver_log: TelemetryEvent def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) @dataclass @@ -75,4 +75,4 @@ class TelemetryFrontendLog: workspace_id: Optional[int] = None def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) diff --git a/src/databricks/sql/telemetry/utils.py b/src/databricks/sql/telemetry/utils.py index 6a4d64eba..8be2c9873 100644 --- a/src/databricks/sql/telemetry/utils.py +++ b/src/databricks/sql/telemetry/utils.py @@ -1,5 +1,6 @@ import json from enum import Enum +from dataclasses import asdict class EnumEncoder(json.JSONEncoder): @@ -13,3 +14,23 @@ def default(self, obj): if isinstance(obj, Enum): return obj.value return super().default(obj) + + +def filter_none_values(data): + """ + Recursively remove None values from dictionaries. + This reduces telemetry payload size by excluding null fields. + """ + if isinstance(data, dict): + return {k: filter_none_values(v) for k, v in data.items() if v is not None} + else: + return data + + +def to_json_compact(dataclass_obj): + """ + Convert a dataclass to JSON string, excluding None values. + """ + data_dict = asdict(dataclass_obj) + filtered_dict = filter_none_values(data_dict) + return json.dumps(filtered_dict, cls=EnumEncoder) From 2c6fd44cb18b9c8b8f910d34bc7e02a24077e58b Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 12 Jun 2025 23:12:39 +0530 Subject: [PATCH 08/48] removed dup import, renamed func, changed a filter_null_values to lamda Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 8 +++---- src/databricks/sql/telemetry/utils.py | 21 +++++++------------ tests/unit/test_telemetry.py | 4 ++-- 3 files changed, 12 insertions(+), 21 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 216262b31..1099c81cd 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -37,7 +37,7 @@ class TelemetryHelper: _DRIVER_SYSTEM_CONFIGURATION = None @classmethod - def getDriverSystemConfiguration(cls) -> DriverSystemConfiguration: + def get_driver_system_configuration(cls) -> DriverSystemConfiguration: if cls._DRIVER_SYSTEM_CONFIGURATION is None: from databricks.sql import __version__ @@ -259,7 +259,7 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): entry=FrontendLogEntry( sql_driver_log=TelemetryEvent( session_id=self._connection_uuid, - system_configuration=TelemetryHelper.getDriverSystemConfiguration(), + system_configuration=TelemetryHelper.get_driver_system_configuration(), driver_connection_params=self._driver_connection_params, ) ), @@ -286,7 +286,7 @@ def export_failure_log(self, error_name, error_message): entry=FrontendLogEntry( sql_driver_log=TelemetryEvent( session_id=self._connection_uuid, - system_configuration=TelemetryHelper.getDriverSystemConfiguration(), + system_configuration=TelemetryHelper.get_driver_system_configuration(), driver_connection_params=self._driver_connection_params, error_info=error_info, ) @@ -340,8 +340,6 @@ def _initialize(cls): def _install_exception_hook(cls): """Install global exception handler for unhandled exceptions""" if not cls._excepthook_installed: - import sys - cls._original_excepthook = sys.excepthook sys.excepthook = cls._handle_unhandled_exception cls._excepthook_installed = True diff --git a/src/databricks/sql/telemetry/utils.py b/src/databricks/sql/telemetry/utils.py index 8be2c9873..2ae87b96e 100644 --- a/src/databricks/sql/telemetry/utils.py +++ b/src/databricks/sql/telemetry/utils.py @@ -16,21 +16,14 @@ def default(self, obj): return super().default(obj) -def filter_none_values(data): - """ - Recursively remove None values from dictionaries. - This reduces telemetry payload size by excluding null fields. - """ - if isinstance(data, dict): - return {k: filter_none_values(v) for k, v in data.items() if v is not None} - else: - return data - - def to_json_compact(dataclass_obj): """ Convert a dataclass to JSON string, excluding None values. """ - data_dict = asdict(dataclass_obj) - filtered_dict = filter_none_values(data_dict) - return json.dumps(filtered_dict, cls=EnumEncoder) + return json.dumps( + asdict( + dataclass_obj, + dict_factory=lambda data: {k: v for k, v in data if v is not None}, + ), + cls=EnumEncoder, + ) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index b210d61b8..a3e0239db 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -98,7 +98,7 @@ class TestTelemetryClient: """Tests for the TelemetryClient class.""" @patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog") - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.getDriverSystemConfiguration") + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.get_driver_system_configuration") @patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4") @patch("databricks.sql.telemetry.telemetry_client.time.time") def test_export_initial_telemetry_log( @@ -134,7 +134,7 @@ def test_export_initial_telemetry_log( client.export_event.assert_called_once_with(mock_frontend_log.return_value) @patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog") - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.getDriverSystemConfiguration") + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.get_driver_system_configuration") @patch("databricks.sql.telemetry.telemetry_client.DriverErrorInfo") @patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4") @patch("databricks.sql.telemetry.telemetry_client.time.time") From 89540a169e101c368334f8065e3bf1b1573cf2b7 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 02:26:24 +0530 Subject: [PATCH 09/48] removed unnecassary class variable and a redundant try/except block Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/exc.py | 5 ++--- src/databricks/sql/telemetry/telemetry_client.py | 7 +------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index d7bcd5c61..cc7a47cb4 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -19,12 +19,11 @@ def __init__( super().__init__(message, *args, **kwargs) self.message = message self.context = context or {} - self.connection_uuid = connection_uuid error_name = self.__class__.__name__ - if self.connection_uuid: + if connection_uuid: telemetry_client = TelemetryClientFactory.get_telemetry_client( - self.connection_uuid + connection_uuid ) telemetry_client.export_failure_log(error_name, self.message) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 1099c81cd..3402cfd70 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -356,12 +356,7 @@ def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): ) for uuid, client in cls._clients.items(): if hasattr(client, "flush"): - try: - client.flush() # Submit any pending events - except Exception as e: - logger.debug( - "Failed to flush telemetry for connection %s: %s", uuid, e - ) + client.flush() # Submit any pending events if cls._executor: try: From 52a1152b33a39be7a7b155ffcacb6b482829ddd2 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 03:16:04 +0530 Subject: [PATCH 10/48] public functions defined at interface level Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 3402cfd70..f2212d7a4 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -102,6 +102,14 @@ class BaseTelemetryClient(ABC): It is used to define the interface for telemetry clients. """ + @abstractmethod + def export_event(self, event): + pass + + @abstractmethod + def flush(self): + pass + @abstractmethod def export_initial_telemetry_log(self, driver_connection_params, user_agent): pass @@ -128,6 +136,12 @@ def __new__(cls): cls._instance = super(NoopTelemetryClient, cls).__new__(cls) return cls._instance + def export_event(self, event): + pass + + def flush(self): + pass + def export_initial_telemetry_log(self, driver_connection_params, user_agent): pass @@ -354,9 +368,8 @@ def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): logger.debug( "Flushing pending telemetry and waiting for thread pool completion..." ) - for uuid, client in cls._clients.items(): - if hasattr(client, "flush"): - client.flush() # Submit any pending events + for client in cls._clients.items(): + client.flush() # Submit any pending events if cls._executor: try: From 3dcdcfa70ff4562da24fa41e006392f2825602fb Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 11:06:58 +0530 Subject: [PATCH 11/48] changed export_event and flush to private functions Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 2 +- .../sql/telemetry/telemetry_client.py | 44 ++++--------------- tests/unit/test_telemetry.py | 18 ++++---- 3 files changed, 17 insertions(+), 47 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index f0f1f5b8a..a4f56c4f9 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -467,7 +467,7 @@ def _close(self, close_cursors=True) -> None: self.open = False - self._telemetry_client.close() + TelemetryClientFactory.close(self.get_session_id_hex()) def commit(self): """No-op because Databricks does not support transactions""" diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index f2212d7a4..f6e3daad1 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -102,14 +102,6 @@ class BaseTelemetryClient(ABC): It is used to define the interface for telemetry clients. """ - @abstractmethod - def export_event(self, event): - pass - - @abstractmethod - def flush(self): - pass - @abstractmethod def export_initial_telemetry_log(self, driver_connection_params, user_agent): pass @@ -136,12 +128,6 @@ def __new__(cls): cls._instance = super(NoopTelemetryClient, cls).__new__(cls) return cls._instance - def export_event(self, event): - pass - - def flush(self): - pass - def export_initial_telemetry_log(self, driver_connection_params, user_agent): pass @@ -182,7 +168,7 @@ def __init__( self._host_url = host_url self._executor = executor - def export_event(self, event): + def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" logger.debug("Exporting event for connection %s", self._connection_uuid) with self._lock: @@ -191,9 +177,9 @@ def export_event(self, event): logger.debug( "Batch size limit reached (%s), flushing events", self._batch_size ) - self.flush() + self._flush() - def flush(self): + def _flush(self): """Flush the current batch of events to the server""" with self._lock: events_to_flush = self._events_batch.copy() @@ -313,11 +299,7 @@ def export_failure_log(self, error_name, error_message): def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._connection_uuid) - try: - self.flush() - TelemetryClientFactory.close(self._connection_uuid) - except Exception as e: - logger.debug("Failed to close telemetry client: %s", e) + self._flush() class TelemetryClientFactory: @@ -365,20 +347,8 @@ def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): logger.debug("Handling unhandled exception: %s", exc_type.__name__) # Flush existing thread pool work and wait for completion - logger.debug( - "Flushing pending telemetry and waiting for thread pool completion..." - ) - for client in cls._clients.items(): - client.flush() # Submit any pending events - - if cls._executor: - try: - cls._executor.shutdown( - wait=True - ) # This waits for all submitted work to complete - logger.debug("Thread pool shutdown completed successfully") - except Exception as e: - logger.debug("Thread pool shutdown failed: %s", e) + for uuid, _ in cls._clients.items(): + cls.close(uuid) # Call the original exception handler to maintain normal behavior if cls._original_excepthook: @@ -445,6 +415,7 @@ def close(connection_uuid): logger.debug( "Removing telemetry client for connection %s", connection_uuid ) + TelemetryClientFactory.get_telemetry_client(connection_uuid).close() TelemetryClientFactory._clients.pop(connection_uuid, None) # Shutdown executor if no more clients @@ -455,3 +426,4 @@ def close(connection_uuid): TelemetryClientFactory._executor.shutdown(wait=True) TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False + diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index a3e0239db..97b8f276b 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -177,18 +177,18 @@ def test_export_failure_log( def test_export_event(self, telemetry_client_setup): """Test exporting an event.""" client = telemetry_client_setup["client"] - client.flush = MagicMock() + client._flush = MagicMock() for i in range(5): - client.export_event(f"event-{i}") + client._export_event(f"event-{i}") - client.flush.assert_not_called() + client._flush.assert_not_called() assert len(client._events_batch) == 5 for i in range(5, 10): - client.export_event(f"event-{i}") + client._export_event(f"event-{i}") - client.flush.assert_called_once() + client._flush.assert_called_once() assert len(client._events_batch) == 10 @patch("requests.post") @@ -244,7 +244,7 @@ def test_flush(self, telemetry_client_setup): client._events_batch = ["event1", "event2"] client._send_telemetry = MagicMock() - client.flush() + client._flush() client._send_telemetry.assert_called_once_with(["event1", "event2"]) assert client._events_batch == [] @@ -253,13 +253,11 @@ def test_flush(self, telemetry_client_setup): def test_close(self, mock_factory_class, telemetry_client_setup): """Test closing the client.""" client = telemetry_client_setup["client"] - connection_uuid = telemetry_client_setup["connection_uuid"] - client.flush = MagicMock() + client._flush = MagicMock() client.close() - client.flush.assert_called_once() - mock_factory_class.close.assert_called_once_with(connection_uuid) + client._flush.assert_called_once() class TestTelemetryClientFactory: From b2714c9738439d7cb2c947fd2654c5ed7444fe99 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 11:10:10 +0530 Subject: [PATCH 12/48] formatting Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/telemetry_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index f6e3daad1..fe1c0e191 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -426,4 +426,3 @@ def close(connection_uuid): TelemetryClientFactory._executor.shutdown(wait=True) TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False - From 377a87bb2f493b6442c8a921b7c3e51e3c72b44d Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 11:28:34 +0530 Subject: [PATCH 13/48] changed connection_uuid to thread local in thrift backend Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/thrift_backend.py | 77 ++++++++++++++-------------- 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 233b4f55c..79fc7f1b0 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -72,6 +72,9 @@ "_retry_delay_default": (float, 5, 1, 60), } +# Add thread local storage +_connection_uuid = threading.local() + class ThriftBackend: CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE @@ -223,7 +226,7 @@ def __init__( raise self._request_lock = threading.RLock() - self._connection_uuid = None + _connection_uuid.value = None # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918) def _initialize_retry_args(self, kwargs): @@ -256,13 +259,14 @@ def _initialize_retry_args(self, kwargs): ) @staticmethod - def _check_response_for_error(response, connection_uuid=None): + def _check_response_for_error(response): if response.status and response.status.statusCode in [ ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ]: raise DatabaseError( - response.status.errorMessage, connection_uuid=connection_uuid + response.status.errorMessage, + connection_uuid=getattr(_connection_uuid, "value", None), ) @staticmethod @@ -316,7 +320,7 @@ def _handle_request_error(self, error_info, attempt, elapsed): network_request_error = RequestError( user_friendly_error_message, full_error_info_context, - self._connection_uuid, + getattr(_connection_uuid, "value", None), error_info.error, ) logger.info(network_request_error.message_with_context()) @@ -489,7 +493,7 @@ def attempt_request(attempt): if not isinstance(response_or_error_info, RequestErrorInfo): # log nothing here, presume that main request logging covers response = response_or_error_info - ThriftBackend._check_response_for_error(response, self._connection_uuid) + ThriftBackend._check_response_for_error(response) return response error_info = response_or_error_info @@ -504,7 +508,7 @@ def _check_protocol_version(self, t_open_session_resp): "Error: expected server to use a protocol version >= " "SPARK_CLI_SERVICE_PROTOCOL_V2, " "instead got: {}".format(protocol_version), - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) def _check_initial_namespace(self, catalog, schema, response): @@ -518,7 +522,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Setting initial namespace not supported by the DBR version, " "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.", - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) if catalog: @@ -526,7 +530,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Unexpected response from server: Trying to set initial catalog to {}, " + "but server does not support multiple catalogs.".format(catalog), # type: ignore - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) def _check_session_configuration(self, session_configuration): @@ -541,7 +545,7 @@ def _check_session_configuration(self, session_configuration): TIMESTAMP_AS_STRING_CONFIG, session_configuration[TIMESTAMP_AS_STRING_CONFIG], ), - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) def open_session(self, session_configuration, catalog, schema): @@ -572,7 +576,7 @@ def open_session(self, session_configuration, catalog, schema): response = self.make_request(self._client.OpenSession, open_session_req) self._check_initial_namespace(catalog, schema, response) self._check_protocol_version(response) - self._connection_uuid = ( + _connection_uuid.value = ( self.handle_to_hex_id(response.sessionHandle) if response.sessionHandle else None @@ -601,7 +605,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": get_operations_resp.diagnosticInfo, }, - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) else: raise ServerOperationError( @@ -611,7 +615,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": None, }, - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE: raise DatabaseError( @@ -622,7 +626,7 @@ def _check_command_not_in_error_or_closed_state( "operation-id": op_handle and self.guid_to_hex_id(op_handle.operationId.guid) }, - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) def _poll_for_status(self, op_handle): @@ -645,7 +649,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti else: raise OperationalError( "Unsupported TRowSet instance {}".format(t_row_set), - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) return convert_decimals_in_arrow_table(arrow_table, description), num_rows @@ -654,7 +658,7 @@ def _get_metadata_resp(self, op_handle): return self.make_request(self._client.GetResultSetMetadata, req) @staticmethod - def _hive_schema_to_arrow_schema(t_table_schema, connection_uuid=None): + def _hive_schema_to_arrow_schema(t_table_schema): def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -686,7 +690,7 @@ def map_type(t_type_entry): # even for complex types raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - connection_uuid=connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) def convert_col(t_column_desc): @@ -697,7 +701,7 @@ def convert_col(t_column_desc): return pyarrow.schema([convert_col(col) for col in t_table_schema.columns]) @staticmethod - def _col_to_description(col, connection_uuid=None): + def _col_to_description(col): type_entry = col.typeDesc.types[0] if type_entry.primitiveEntry: @@ -707,7 +711,7 @@ def _col_to_description(col, connection_uuid=None): else: raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - connection_uuid=connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE: @@ -721,7 +725,7 @@ def _col_to_description(col, connection_uuid=None): raise OperationalError( "Decimal type did not provide typeQualifier precision, scale in " "primitiveEntry {}".format(type_entry.primitiveEntry), - connection_uuid=connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) else: precision, scale = None, None @@ -729,10 +733,9 @@ def _col_to_description(col, connection_uuid=None): return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod - def _hive_schema_to_description(t_table_schema, connection_uuid=None): + def _hive_schema_to_description(t_table_schema): return [ - ThriftBackend._col_to_description(col, connection_uuid) - for col in t_table_schema.columns + ThriftBackend._col_to_description(col) for col in t_table_schema.columns ] def _results_message_to_execute_response(self, resp, operation_state): @@ -753,7 +756,7 @@ def _results_message_to_execute_response(self, resp, operation_state): t_result_set_metadata_resp.resultFormat ] ), - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation @@ -763,15 +766,13 @@ def _results_message_to_execute_response(self, resp, operation_state): or direct_results.resultSet.hasMoreRows ) description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema, self._connection_uuid + t_result_set_metadata_resp.schema ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema( - t_result_set_metadata_resp.schema, self._connection_uuid - ) + or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) .serialize() .to_pybytes() ) @@ -832,15 +833,13 @@ def get_execution_result(self, op_handle, cursor): is_staging_operation = t_result_set_metadata_resp.isStagingOperation has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema, self._connection_uuid + t_result_set_metadata_resp.schema ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema( - t_result_set_metadata_resp.schema, self._connection_uuid - ) + or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) .serialize() .to_pybytes() ) @@ -894,23 +893,23 @@ def get_query_state(self, op_handle) -> "TOperationState": return operation_state @staticmethod - def _check_direct_results_for_error(t_spark_direct_results, connection_uuid=None): + def _check_direct_results_for_error(t_spark_direct_results): if t_spark_direct_results: if t_spark_direct_results.operationStatus: ThriftBackend._check_response_for_error( - t_spark_direct_results.operationStatus, connection_uuid + t_spark_direct_results.operationStatus ) if t_spark_direct_results.resultSetMetadata: ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSetMetadata, connection_uuid + t_spark_direct_results.resultSetMetadata ) if t_spark_direct_results.resultSet: ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSet, connection_uuid + t_spark_direct_results.resultSet ) if t_spark_direct_results.closeOperation: ThriftBackend._check_response_for_error( - t_spark_direct_results.closeOperation, connection_uuid + t_spark_direct_results.closeOperation ) def execute_command( @@ -1059,7 +1058,7 @@ def get_columns( def _handle_execute_response(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults, self._connection_uuid) + self._check_direct_results_for_error(resp.directResults) final_operation_state = self._wait_until_command_done( resp.operationHandle, @@ -1070,7 +1069,7 @@ def _handle_execute_response(self, resp, cursor): def _handle_execute_response_async(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults, self._connection_uuid) + self._check_direct_results_for_error(resp.directResults) def fetch_results( self, @@ -1105,7 +1104,7 @@ def fetch_results( "fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format( expected_row_start_offset, resp.results.startRowOffset ), - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) queue = ResultSetQueueFactory.build_queue( From c9376b8b8ff36f9f4f3809cf4f19bf13d0b52c4b Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 12:04:42 +0530 Subject: [PATCH 14/48] made errors more specific Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 37 +++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index a4f56c4f9..f9a011b11 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -18,6 +18,9 @@ OperationalError, SessionAlreadyClosedError, CursorAlreadyClosedError, + InterfaceError, + NotSupportedError, + ProgrammingError, ) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.thrift_backend import ThriftBackend @@ -421,7 +424,7 @@ def cursor( Will throw an Error if the connection has been closed. """ if not self.open: - raise Error( + raise InterfaceError( "Cannot create cursor from closed connection", connection_uuid=self.get_session_id_hex(), ) @@ -529,7 +532,7 @@ def __iter__(self): for row in self.active_result_set: yield row else: - raise Error( + raise ProgrammingError( "There is no active result set", connection_uuid=self.connection.get_session_id_hex(), ) @@ -669,7 +672,7 @@ def _close_and_clear_active_result_set(self): def _check_not_closed(self): if not self.open: - raise Error( + raise InterfaceError( "Attempting operation on closed cursor", connection_uuid=self.connection.get_session_id_hex(), ) @@ -689,7 +692,7 @@ def _handle_staging_operation( elif isinstance(staging_allowed_local_path, type(list())): _staging_allowed_local_paths = staging_allowed_local_path else: - raise Error( + raise ProgrammingError( "You must provide at least one staging_allowed_local_path when initialising a connection to perform ingestion commands", connection_uuid=self.connection.get_session_id_hex(), ) @@ -719,7 +722,7 @@ def _handle_staging_operation( else: continue if not allow_operation: - raise Error( + raise ProgrammingError( "Local file operations are restricted to paths within the configured staging_allowed_local_path", connection_uuid=self.connection.get_session_id_hex(), ) @@ -749,7 +752,7 @@ def _handle_staging_operation( handler_args.pop("local_file") return self._handle_staging_remove(**handler_args) else: - raise Error( + raise ProgrammingError( f"Operation {row.operation} is not supported. " + "Supported operations are GET, PUT, and REMOVE", connection_uuid=self.connection.get_session_id_hex(), @@ -764,7 +767,7 @@ def _handle_staging_put( """ if local_file is None: - raise Error( + raise ProgrammingError( "Cannot perform PUT without specifying a local_file", connection_uuid=self.connection.get_session_id_hex(), ) @@ -783,7 +786,7 @@ def _handle_staging_put( # fmt: on if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: - raise Error( + raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", connection_uuid=self.connection.get_session_id_hex(), ) @@ -803,7 +806,7 @@ def _handle_staging_get( """ if local_file is None: - raise Error( + raise ProgrammingError( "Cannot perform GET without specifying a local_file", connection_uuid=self.connection.get_session_id_hex(), ) @@ -813,7 +816,7 @@ def _handle_staging_get( # response.ok verifies the status code is not between 400-600. # Any 2xx or 3xx will evaluate r.ok == True if not r.ok: - raise Error( + raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", connection_uuid=self.connection.get_session_id_hex(), ) @@ -829,7 +832,7 @@ def _handle_staging_remove( r = requests.delete(url=presigned_url, headers=headers) if not r.ok: - raise Error( + raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", connection_uuid=self.connection.get_session_id_hex(), ) @@ -1029,7 +1032,7 @@ def get_async_execution_result(self): return self else: - raise Error( + raise OperationalError( f"get_execution_result failed with Operation status {operation_state}", connection_uuid=self.connection.get_session_id_hex(), ) @@ -1181,7 +1184,7 @@ def fetchall(self) -> List[Row]: if self.active_result_set: return self.active_result_set.fetchall() else: - raise Error( + raise ProgrammingError( "There is no active result set", connection_uuid=self.connection.get_session_id_hex(), ) @@ -1198,7 +1201,7 @@ def fetchone(self) -> Optional[Row]: if self.active_result_set: return self.active_result_set.fetchone() else: - raise Error( + raise ProgrammingError( "There is no active result set", connection_uuid=self.connection.get_session_id_hex(), ) @@ -1223,7 +1226,7 @@ def fetchmany(self, size: int) -> List[Row]: if self.active_result_set: return self.active_result_set.fetchmany(size) else: - raise Error( + raise ProgrammingError( "There is no active result set", connection_uuid=self.connection.get_session_id_hex(), ) @@ -1233,7 +1236,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": if self.active_result_set: return self.active_result_set.fetchall_arrow() else: - raise Error( + raise ProgrammingError( "There is no active result set", connection_uuid=self.connection.get_session_id_hex(), ) @@ -1243,7 +1246,7 @@ def fetchmany_arrow(self, size) -> "pyarrow.Table": if self.active_result_set: return self.active_result_set.fetchmany_arrow(size) else: - raise Error( + raise ProgrammingError( "There is no active result set", connection_uuid=self.connection.get_session_id_hex(), ) From bbfadf2b16f53d317e0ca4ba36b5b20f30eea533 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 14:14:27 +0530 Subject: [PATCH 15/48] revert change to connection_uuid Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/thrift_backend.py | 82 +++++++++++++++------------- 1 file changed, 45 insertions(+), 37 deletions(-) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 79fc7f1b0..7c47da2b1 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -72,9 +72,6 @@ "_retry_delay_default": (float, 5, 1, 60), } -# Add thread local storage -_connection_uuid = threading.local() - class ThriftBackend: CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE @@ -226,7 +223,7 @@ def __init__( raise self._request_lock = threading.RLock() - _connection_uuid.value = None + self._connection_uuid = None # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918) def _initialize_retry_args(self, kwargs): @@ -259,14 +256,14 @@ def _initialize_retry_args(self, kwargs): ) @staticmethod - def _check_response_for_error(response): + def _check_response_for_error(response, connection_uuid=None): if response.status and response.status.statusCode in [ ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ]: raise DatabaseError( response.status.errorMessage, - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=connection_uuid, ) @staticmethod @@ -320,7 +317,7 @@ def _handle_request_error(self, error_info, attempt, elapsed): network_request_error = RequestError( user_friendly_error_message, full_error_info_context, - getattr(_connection_uuid, "value", None), + self._connection_uuid, error_info.error, ) logger.info(network_request_error.message_with_context()) @@ -493,7 +490,7 @@ def attempt_request(attempt): if not isinstance(response_or_error_info, RequestErrorInfo): # log nothing here, presume that main request logging covers response = response_or_error_info - ThriftBackend._check_response_for_error(response) + ThriftBackend._check_response_for_error(response, self._connection_uuid) return response error_info = response_or_error_info @@ -508,7 +505,7 @@ def _check_protocol_version(self, t_open_session_resp): "Error: expected server to use a protocol version >= " "SPARK_CLI_SERVICE_PROTOCOL_V2, " "instead got: {}".format(protocol_version), - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) def _check_initial_namespace(self, catalog, schema, response): @@ -522,7 +519,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Setting initial namespace not supported by the DBR version, " "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.", - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) if catalog: @@ -530,7 +527,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Unexpected response from server: Trying to set initial catalog to {}, " + "but server does not support multiple catalogs.".format(catalog), # type: ignore - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) def _check_session_configuration(self, session_configuration): @@ -545,7 +542,7 @@ def _check_session_configuration(self, session_configuration): TIMESTAMP_AS_STRING_CONFIG, session_configuration[TIMESTAMP_AS_STRING_CONFIG], ), - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) def open_session(self, session_configuration, catalog, schema): @@ -576,7 +573,7 @@ def open_session(self, session_configuration, catalog, schema): response = self.make_request(self._client.OpenSession, open_session_req) self._check_initial_namespace(catalog, schema, response) self._check_protocol_version(response) - _connection_uuid.value = ( + self._connection_uuid = ( self.handle_to_hex_id(response.sessionHandle) if response.sessionHandle else None @@ -605,7 +602,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": get_operations_resp.diagnosticInfo, }, - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) else: raise ServerOperationError( @@ -615,7 +612,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": None, }, - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE: raise DatabaseError( @@ -626,7 +623,7 @@ def _check_command_not_in_error_or_closed_state( "operation-id": op_handle and self.guid_to_hex_id(op_handle.operationId.guid) }, - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) def _poll_for_status(self, op_handle): @@ -649,7 +646,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti else: raise OperationalError( "Unsupported TRowSet instance {}".format(t_row_set), - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) return convert_decimals_in_arrow_table(arrow_table, description), num_rows @@ -658,7 +655,7 @@ def _get_metadata_resp(self, op_handle): return self.make_request(self._client.GetResultSetMetadata, req) @staticmethod - def _hive_schema_to_arrow_schema(t_table_schema): + def _hive_schema_to_arrow_schema(t_table_schema, connection_uuid=None): def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -690,7 +687,7 @@ def map_type(t_type_entry): # even for complex types raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=connection_uuid, ) def convert_col(t_column_desc): @@ -701,7 +698,7 @@ def convert_col(t_column_desc): return pyarrow.schema([convert_col(col) for col in t_table_schema.columns]) @staticmethod - def _col_to_description(col): + def _col_to_description(col, connection_uuid=None): type_entry = col.typeDesc.types[0] if type_entry.primitiveEntry: @@ -711,7 +708,7 @@ def _col_to_description(col): else: raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=connection_uuid, ) if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE: @@ -725,7 +722,7 @@ def _col_to_description(col): raise OperationalError( "Decimal type did not provide typeQualifier precision, scale in " "primitiveEntry {}".format(type_entry.primitiveEntry), - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=connection_uuid, ) else: precision, scale = None, None @@ -733,9 +730,10 @@ def _col_to_description(col): return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod - def _hive_schema_to_description(t_table_schema): + def _hive_schema_to_description(t_table_schema, connection_uuid=None): return [ - ThriftBackend._col_to_description(col) for col in t_table_schema.columns + ThriftBackend._col_to_description(col, connection_uuid) + for col in t_table_schema.columns ] def _results_message_to_execute_response(self, resp, operation_state): @@ -756,7 +754,7 @@ def _results_message_to_execute_response(self, resp, operation_state): t_result_set_metadata_resp.resultFormat ] ), - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation @@ -766,13 +764,16 @@ def _results_message_to_execute_response(self, resp, operation_state): or direct_results.resultSet.hasMoreRows ) description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema + t_result_set_metadata_resp.schema, + self._connection_uuid, ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + or self._hive_schema_to_arrow_schema( + t_result_set_metadata_resp.schema, self._connection_uuid + ) .serialize() .to_pybytes() ) @@ -833,13 +834,16 @@ def get_execution_result(self, op_handle, cursor): is_staging_operation = t_result_set_metadata_resp.isStagingOperation has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema + t_result_set_metadata_resp.schema, + self._connection_uuid, ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + or self._hive_schema_to_arrow_schema( + t_result_set_metadata_resp.schema, self._connection_uuid + ) .serialize() .to_pybytes() ) @@ -893,23 +897,27 @@ def get_query_state(self, op_handle) -> "TOperationState": return operation_state @staticmethod - def _check_direct_results_for_error(t_spark_direct_results): + def _check_direct_results_for_error(t_spark_direct_results, connection_uuid=None): if t_spark_direct_results: if t_spark_direct_results.operationStatus: ThriftBackend._check_response_for_error( - t_spark_direct_results.operationStatus + t_spark_direct_results.operationStatus, + connection_uuid, ) if t_spark_direct_results.resultSetMetadata: ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSetMetadata + t_spark_direct_results.resultSetMetadata, + connection_uuid, ) if t_spark_direct_results.resultSet: ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSet + t_spark_direct_results.resultSet, + connection_uuid, ) if t_spark_direct_results.closeOperation: ThriftBackend._check_response_for_error( - t_spark_direct_results.closeOperation + t_spark_direct_results.closeOperation, + connection_uuid, ) def execute_command( @@ -1058,7 +1066,7 @@ def get_columns( def _handle_execute_response(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults) + self._check_direct_results_for_error(resp.directResults, self._connection_uuid) final_operation_state = self._wait_until_command_done( resp.operationHandle, @@ -1069,7 +1077,7 @@ def _handle_execute_response(self, resp, cursor): def _handle_execute_response_async(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults) + self._check_direct_results_for_error(resp.directResults, self._connection_uuid) def fetch_results( self, @@ -1104,7 +1112,7 @@ def fetch_results( "fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format( expected_row_start_offset, resp.results.startRowOffset ), - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) queue = ResultSetQueueFactory.build_queue( From 9bce26b3eac736bf896850ca5448d4c831febde9 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 14:24:21 +0530 Subject: [PATCH 16/48] reverting change in close in telemetry client Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 2 +- src/databricks/sql/telemetry/telemetry_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index f9a011b11..04ef4584f 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -470,7 +470,7 @@ def _close(self, close_cursors=True) -> None: self.open = False - TelemetryClientFactory.close(self.get_session_id_hex()) + self._telemetry_client.close() def commit(self): """No-op because Databricks does not support transactions""" diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index fe1c0e191..ddb0a3974 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -300,6 +300,7 @@ def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._connection_uuid) self._flush() + TelemetryClientFactory.close(self._connection_uuid) class TelemetryClientFactory: @@ -415,7 +416,6 @@ def close(connection_uuid): logger.debug( "Removing telemetry client for connection %s", connection_uuid ) - TelemetryClientFactory.get_telemetry_client(connection_uuid).close() TelemetryClientFactory._clients.pop(connection_uuid, None) # Shutdown executor if no more clients From ef4514d4020e164fbc02025c9ff7bc8c831fb360 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 16:23:35 +0530 Subject: [PATCH 17/48] JsonSerializableMixin Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/models/event.py | 40 +++++-------------- .../sql/telemetry/models/frontend_logs.py | 25 +++--------- .../sql/telemetry/telemetry_client.py | 11 ++--- src/databricks/sql/telemetry/utils.py | 32 +++++++++------ tests/unit/test_telemetry.py | 8 ++-- 5 files changed, 44 insertions(+), 72 deletions(-) diff --git a/src/databricks/sql/telemetry/models/event.py b/src/databricks/sql/telemetry/models/event.py index c00738810..f5496deec 100644 --- a/src/databricks/sql/telemetry/models/event.py +++ b/src/databricks/sql/telemetry/models/event.py @@ -1,5 +1,4 @@ -import json -from dataclasses import dataclass, asdict +from dataclasses import dataclass from databricks.sql.telemetry.models.enums import ( AuthMech, AuthFlow, @@ -9,11 +8,11 @@ ExecutionResultFormat, ) from typing import Optional -from databricks.sql.telemetry.utils import to_json_compact +from databricks.sql.telemetry.utils import JsonSerializableMixin @dataclass -class HostDetails: +class HostDetails(JsonSerializableMixin): """ Represents the host connection details for a Databricks workspace. @@ -25,12 +24,9 @@ class HostDetails: host_url: str port: int - def to_json(self): - return to_json_compact(self) - @dataclass -class DriverConnectionParameters: +class DriverConnectionParameters(JsonSerializableMixin): """ Contains all connection parameters used to establish a connection to Databricks SQL. This includes authentication details, host information, and connection settings. @@ -51,12 +47,9 @@ class DriverConnectionParameters: auth_flow: Optional[AuthFlow] = None socket_timeout: Optional[int] = None - def to_json(self): - return to_json_compact(self) - @dataclass -class DriverSystemConfiguration: +class DriverSystemConfiguration(JsonSerializableMixin): """ Contains system-level configuration information about the client environment. This includes details about the operating system, runtime, and driver version. @@ -87,12 +80,9 @@ class DriverSystemConfiguration: client_app_name: Optional[str] = None locale_name: Optional[str] = None - def to_json(self): - return to_json_compact(self) - @dataclass -class DriverVolumeOperation: +class DriverVolumeOperation(JsonSerializableMixin): """ Represents a volume operation performed by the driver. Used for tracking volume-related operations in telemetry. @@ -105,12 +95,9 @@ class DriverVolumeOperation: volume_operation_type: DriverVolumeOperationType volume_path: str - def to_json(self): - return to_json_compact(self) - @dataclass -class DriverErrorInfo: +class DriverErrorInfo(JsonSerializableMixin): """ Contains detailed information about errors that occur during driver operations. Used for error tracking and debugging in telemetry. @@ -123,12 +110,9 @@ class DriverErrorInfo: error_name: str stack_trace: str - def to_json(self): - return to_json_compact(self) - @dataclass -class SqlExecutionEvent: +class SqlExecutionEvent(JsonSerializableMixin): """ Represents a SQL query execution event. Contains details about the query execution, including type, compression, and result format. @@ -145,12 +129,9 @@ class SqlExecutionEvent: execution_result: ExecutionResultFormat retry_count: int - def to_json(self): - return to_json_compact(self) - @dataclass -class TelemetryEvent: +class TelemetryEvent(JsonSerializableMixin): """ Main telemetry event class that aggregates all telemetry data. Contains information about the session, system configuration, connection parameters, @@ -177,6 +158,3 @@ class TelemetryEvent: sql_operation: Optional[SqlExecutionEvent] = None error_info: Optional[DriverErrorInfo] = None operation_latency_ms: Optional[int] = None - - def to_json(self): - return to_json_compact(self) diff --git a/src/databricks/sql/telemetry/models/frontend_logs.py b/src/databricks/sql/telemetry/models/frontend_logs.py index f5d58a4be..4cc314ec3 100644 --- a/src/databricks/sql/telemetry/models/frontend_logs.py +++ b/src/databricks/sql/telemetry/models/frontend_logs.py @@ -1,12 +1,11 @@ -import json -from dataclasses import dataclass, asdict +from dataclasses import dataclass from databricks.sql.telemetry.models.event import TelemetryEvent -from databricks.sql.telemetry.utils import to_json_compact +from databricks.sql.telemetry.utils import JsonSerializableMixin from typing import Optional @dataclass -class TelemetryClientContext: +class TelemetryClientContext(JsonSerializableMixin): """ Contains client-side context information for telemetry events. This includes timestamp and user agent information for tracking when and how the client is being used. @@ -19,12 +18,9 @@ class TelemetryClientContext: timestamp_millis: int user_agent: str - def to_json(self): - return to_json_compact(self) - @dataclass -class FrontendLogContext: +class FrontendLogContext(JsonSerializableMixin): """ Wrapper for client context information in frontend logs. Provides additional context about the client environment for telemetry events. @@ -35,12 +31,9 @@ class FrontendLogContext: client_context: TelemetryClientContext - def to_json(self): - return to_json_compact(self) - @dataclass -class FrontendLogEntry: +class FrontendLogEntry(JsonSerializableMixin): """ Contains the actual telemetry event data in a frontend log. Wraps the SQL driver log information for frontend processing. @@ -51,12 +44,9 @@ class FrontendLogEntry: sql_driver_log: TelemetryEvent - def to_json(self): - return to_json_compact(self) - @dataclass -class TelemetryFrontendLog: +class TelemetryFrontendLog(JsonSerializableMixin): """ Main container for frontend telemetry data. Aggregates workspace information, event ID, context, and the actual log entry. @@ -73,6 +63,3 @@ class TelemetryFrontendLog: context: FrontendLogContext entry: FrontendLogEntry workspace_id: Optional[int] = None - - def to_json(self): - return to_json_compact(self) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index ddb0a3974..403379aff 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -265,7 +265,8 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): ), ) - self.export_event(telemetry_frontend_log) + self._export_event(telemetry_frontend_log) + except Exception as e: logger.debug("Failed to export initial telemetry log: %s", e) @@ -292,7 +293,7 @@ def export_failure_log(self, error_name, error_message): ) ), ) - self.export_event(telemetry_frontend_log) + self._export_event(telemetry_frontend_log) except Exception as e: logger.debug("Failed to export failure log: %s", e) @@ -347,9 +348,9 @@ def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): """Handle unhandled exceptions by sending telemetry and flushing thread pool""" logger.debug("Handling unhandled exception: %s", exc_type.__name__) - # Flush existing thread pool work and wait for completion - for uuid, _ in cls._clients.items(): - cls.close(uuid) + clients_to_close = list(cls._clients.values()) + for client in clients_to_close: + client.close() # Call the original exception handler to maintain normal behavior if cls._original_excepthook: diff --git a/src/databricks/sql/telemetry/utils.py b/src/databricks/sql/telemetry/utils.py index 2ae87b96e..d14e2fcd4 100644 --- a/src/databricks/sql/telemetry/utils.py +++ b/src/databricks/sql/telemetry/utils.py @@ -1,6 +1,25 @@ import json from enum import Enum from dataclasses import asdict +from abc import ABC +from typing import Any + + +class JsonSerializableMixin(ABC): + """Mixin class to provide JSON serialization capabilities to dataclasses.""" + + def to_json(self) -> str: + """ + Convert the object to a JSON string, excluding None values. + Handles Enum serialization and filters out None values from the output. + """ + return json.dumps( + asdict( + self, + dict_factory=lambda data: {k: v for k, v in data if v is not None}, + ), + cls=EnumEncoder, + ) class EnumEncoder(json.JSONEncoder): @@ -14,16 +33,3 @@ def default(self, obj): if isinstance(obj, Enum): return obj.value return super().default(obj) - - -def to_json_compact(dataclass_obj): - """ - Convert a dataclass to JSON string, excluding None values. - """ - return json.dumps( - asdict( - dataclass_obj, - dict_factory=lambda data: {k: v for k, v in data if v is not None}, - ), - cls=EnumEncoder, - ) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 97b8f276b..35eba8157 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -117,7 +117,7 @@ def test_export_initial_telemetry_log( client = telemetry_client_setup["client"] host_url = telemetry_client_setup["host_url"] - client.export_event = MagicMock() + client._export_event = MagicMock() driver_connection_params = DriverConnectionParameters( http_path="test-path", @@ -131,7 +131,7 @@ def test_export_initial_telemetry_log( client.export_initial_telemetry_log(driver_connection_params, user_agent) mock_frontend_log.assert_called_once() - client.export_event.assert_called_once_with(mock_frontend_log.return_value) + client._export_event.assert_called_once_with(mock_frontend_log.return_value) @patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog") @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.get_driver_system_configuration") @@ -155,7 +155,7 @@ def test_export_failure_log( mock_frontend_log.return_value = MagicMock() client = telemetry_client_setup["client"] - client.export_event = MagicMock() + client._export_event = MagicMock() client._driver_connection_params = "test-connection-params" client._user_agent = "test-user-agent" @@ -172,7 +172,7 @@ def test_export_failure_log( mock_frontend_log.assert_called_once() - client.export_event.assert_called_once_with(mock_frontend_log.return_value) + client._export_event.assert_called_once_with(mock_frontend_log.return_value) def test_export_event(self, telemetry_client_setup): """Test exporting an event.""" From 8924835e59aed8df7922903fe187485d4d976aee Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 16:48:57 +0530 Subject: [PATCH 18/48] isdataclass check in JsonSerializableMixin Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/telemetry/utils.py b/src/databricks/sql/telemetry/utils.py index d14e2fcd4..6d95526b8 100644 --- a/src/databricks/sql/telemetry/utils.py +++ b/src/databricks/sql/telemetry/utils.py @@ -1,8 +1,7 @@ import json from enum import Enum -from dataclasses import asdict +from dataclasses import asdict, is_dataclass from abc import ABC -from typing import Any class JsonSerializableMixin(ABC): @@ -13,6 +12,11 @@ def to_json(self) -> str: Convert the object to a JSON string, excluding None values. Handles Enum serialization and filters out None values from the output. """ + if not is_dataclass(self): + raise TypeError( + f"{self.__class__.__name__} must be a dataclass to use JsonSerializableMixin" + ) + return json.dumps( asdict( self, From 65361e76f32e3199c94e809798169b6b9fe29c72 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 16 Jun 2025 11:10:20 +0530 Subject: [PATCH 19/48] convert TelemetryClientFactory to module-level functions, replace NoopTelemetryClient class with NOOP_TELEMETRY_CLIENT singleton, updated tests accordingly Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 7 +- src/databricks/sql/exc.py | 6 +- .../sql/telemetry/telemetry_client.py | 243 ++++++++---------- tests/unit/test_telemetry.py | 162 ++++-------- 4 files changed, 167 insertions(+), 251 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 04ef4584f..bee60d317 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -53,8 +53,9 @@ TOperationState, ) from databricks.sql.telemetry.telemetry_client import ( - TelemetryClientFactory, TelemetryHelper, + initialize_telemetry_client, + get_telemetry_client, ) from databricks.sql.telemetry.models.enums import DatabricksClientType from databricks.sql.telemetry.models.event import ( @@ -306,14 +307,14 @@ def read(self) -> Optional[OAuthToken]: kwargs.get("use_inline_params", False) ) - TelemetryClientFactory.initialize_telemetry_client( + initialize_telemetry_client( telemetry_enabled=self.telemetry_enabled, connection_uuid=self.get_session_id_hex(), auth_provider=auth_provider, host_url=self.host, ) - self._telemetry_client = TelemetryClientFactory.get_telemetry_client( + self._telemetry_client = get_telemetry_client( connection_uuid=self.get_session_id_hex() ) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index cc7a47cb4..443d5605f 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -2,7 +2,7 @@ import logging import traceback -from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory +from databricks.sql.telemetry.telemetry_client import get_telemetry_client logger = logging.getLogger(__name__) @@ -22,9 +22,7 @@ def __init__( error_name = self.__class__.__name__ if connection_uuid: - telemetry_client = TelemetryClientFactory.get_telemetry_client( - connection_uuid - ) + telemetry_client = get_telemetry_client(connection_uuid) telemetry_client.export_failure_log(error_name, self.message) def __str__(self): diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 403379aff..728220789 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -115,27 +115,16 @@ def close(self): pass -class NoopTelemetryClient(BaseTelemetryClient): - """ - NoopTelemetryClient is a telemetry client that does not send any events to the server. - It is used when telemetry is disabled. - """ - - _instance = None - - def __new__(cls): - if cls._instance is None: - cls._instance = super(NoopTelemetryClient, cls).__new__(cls) - return cls._instance - - def export_initial_telemetry_log(self, driver_connection_params, user_agent): - pass - - def export_failure_log(self, error_name, error_message): - pass - - def close(self): - pass +# A single instance of the no-op client that can be reused +NOOP_TELEMETRY_CLIENT = type( + "NoopTelemetryClient", + (BaseTelemetryClient,), + { + "export_initial_telemetry_log": lambda self, *args, **kwargs: None, + "export_failure_log": lambda self, *args, **kwargs: None, + "close": lambda self: None, + }, +)() class TelemetryClient(BaseTelemetryClient): @@ -301,129 +290,111 @@ def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._connection_uuid) self._flush() - TelemetryClientFactory.close(self._connection_uuid) - - -class TelemetryClientFactory: - """ - Static factory class for creating and managing telemetry clients. - It uses a thread pool to handle asynchronous operations. - """ - - _clients: Dict[ - str, BaseTelemetryClient - ] = {} # Map of connection_uuid -> BaseTelemetryClient - _executor: Optional[ThreadPoolExecutor] = None - _initialized: bool = False - _lock = threading.Lock() # Thread safety for factory operations - _original_excepthook = None - _excepthook_installed = False - - @classmethod - def _initialize(cls): - """Initialize the factory if not already initialized""" - with cls._lock: - if not cls._initialized: - cls._clients = {} - cls._executor = ThreadPoolExecutor( - max_workers=10 - ) # Thread pool for async operations TODO: Decide on max workers - cls._install_exception_hook() - cls._initialized = True - logger.debug( - "TelemetryClientFactory initialized with thread pool (max_workers=10)" - ) - - @classmethod - def _install_exception_hook(cls): - """Install global exception handler for unhandled exceptions""" - if not cls._excepthook_installed: - cls._original_excepthook = sys.excepthook - sys.excepthook = cls._handle_unhandled_exception - cls._excepthook_installed = True - logger.debug("Global exception handler installed for telemetry") + _remove_telemetry_client(self._connection_uuid) + + +# Module-level state +_clients: Dict[str, BaseTelemetryClient] = {} +_executor: Optional[ThreadPoolExecutor] = None +_initialized: bool = False +_lock = threading.Lock() +_original_excepthook = None +_excepthook_installed = False + + +def _initialize(): + """Initialize the telemetry system if not already initialized""" + global _initialized, _executor + with _lock: + if not _initialized: + _clients.clear() + _executor = ThreadPoolExecutor(max_workers=10) + _install_exception_hook() + _initialized = True + logger.debug( + "Telemetry system initialized with thread pool (max_workers=10)" + ) - @classmethod - def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): - """Handle unhandled exceptions by sending telemetry and flushing thread pool""" - logger.debug("Handling unhandled exception: %s", exc_type.__name__) - clients_to_close = list(cls._clients.values()) - for client in clients_to_close: - client.close() +def _install_exception_hook(): + """Install global exception handler for unhandled exceptions""" + global _excepthook_installed, _original_excepthook + if not _excepthook_installed: + _original_excepthook = sys.excepthook + sys.excepthook = _handle_unhandled_exception + _excepthook_installed = True + logger.debug("Global exception handler installed for telemetry") - # Call the original exception handler to maintain normal behavior - if cls._original_excepthook: - cls._original_excepthook(exc_type, exc_value, exc_traceback) - @staticmethod - def initialize_telemetry_client( - telemetry_enabled, - connection_uuid, - auth_provider, - host_url, - ): - """Initialize a telemetry client for a specific connection if telemetry is enabled""" - try: - TelemetryClientFactory._initialize() +def _handle_unhandled_exception(exc_type, exc_value, exc_traceback): + """Handle unhandled exceptions by sending telemetry and flushing thread pool""" + logger.debug("Handling unhandled exception: %s", exc_type.__name__) - with TelemetryClientFactory._lock: - if connection_uuid not in TelemetryClientFactory._clients: - logger.debug( - "Creating new TelemetryClient for connection %s", - connection_uuid, - ) - if telemetry_enabled: - TelemetryClientFactory._clients[ - connection_uuid - ] = TelemetryClient( - telemetry_enabled=telemetry_enabled, - connection_uuid=connection_uuid, - auth_provider=auth_provider, - host_url=host_url, - executor=TelemetryClientFactory._executor, - ) - else: - TelemetryClientFactory._clients[ - connection_uuid - ] = NoopTelemetryClient() - except Exception as e: - logger.debug("Failed to initialize telemetry client: %s", e) - # Fallback to NoopTelemetryClient to ensure connection doesn't fail - TelemetryClientFactory._clients[connection_uuid] = NoopTelemetryClient() + clients_to_close = list(_clients.values()) + for client in clients_to_close: + client.close() - @staticmethod - def get_telemetry_client(connection_uuid): - """Get the telemetry client for a specific connection""" - try: - if connection_uuid in TelemetryClientFactory._clients: - return TelemetryClientFactory._clients[connection_uuid] - else: - logger.error( - "Telemetry client not initialized for connection %s", - connection_uuid, - ) - return NoopTelemetryClient() - except Exception as e: - logger.debug("Failed to get telemetry client: %s", e) - return NoopTelemetryClient() + # Call the original exception handler to maintain normal behavior + if _original_excepthook: + _original_excepthook(exc_type, exc_value, exc_traceback) - @staticmethod - def close(connection_uuid): - """Close and remove the telemetry client for a specific connection""" - with TelemetryClientFactory._lock: - if connection_uuid in TelemetryClientFactory._clients: - logger.debug( - "Removing telemetry client for connection %s", connection_uuid - ) - TelemetryClientFactory._clients.pop(connection_uuid, None) +def initialize_telemetry_client( + telemetry_enabled, connection_uuid, auth_provider, host_url +): + """Initialize a telemetry client for a specific connection if telemetry is enabled""" + try: + _initialize() - # Shutdown executor if no more clients - if not TelemetryClientFactory._clients and TelemetryClientFactory._executor: + with _lock: + if connection_uuid not in _clients: logger.debug( - "No more telemetry clients, shutting down thread pool executor" + "Creating new TelemetryClient for connection %s", connection_uuid ) - TelemetryClientFactory._executor.shutdown(wait=True) - TelemetryClientFactory._executor = None - TelemetryClientFactory._initialized = False + if telemetry_enabled: + _clients[connection_uuid] = TelemetryClient( + telemetry_enabled=telemetry_enabled, + connection_uuid=connection_uuid, + auth_provider=auth_provider, + host_url=host_url, + executor=_executor, + ) + else: + _clients[connection_uuid] = NOOP_TELEMETRY_CLIENT + except Exception as e: + logger.debug("Failed to initialize telemetry client: %s", e) + # Fallback to NoopTelemetryClient to ensure connection doesn't fail + _clients[connection_uuid] = NOOP_TELEMETRY_CLIENT + + +def get_telemetry_client(connection_uuid): + """Get the telemetry client for a specific connection""" + try: + if connection_uuid in _clients: + return _clients[connection_uuid] + else: + logger.error( + "Telemetry client not initialized for connection %s", connection_uuid + ) + return NOOP_TELEMETRY_CLIENT + except Exception as e: + logger.debug("Failed to get telemetry client: %s", e) + return NOOP_TELEMETRY_CLIENT + + +def _remove_telemetry_client(connection_uuid): + """Remove the telemetry client for a specific connection""" + global _initialized, _executor + with _lock: + if connection_uuid in _clients: + logger.debug("Removing telemetry client for connection %s", connection_uuid) + _clients.pop(connection_uuid, None) + + # Shutdown executor if no more clients + if not _clients and _executor: + logger.debug( + "No more telemetry clients, shutting down thread pool executor" + ) + _executor.shutdown(wait=True) + _executor = None + _initialized = False diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 35eba8157..975febd20 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -5,8 +5,10 @@ from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, - NoopTelemetryClient, - TelemetryClientFactory, + NOOP_TELEMETRY_CLIENT, + initialize_telemetry_client, + get_telemetry_client, + _remove_telemetry_client, ) from databricks.sql.telemetry.models.enums import ( AuthMech, @@ -23,8 +25,8 @@ @pytest.fixture def noop_telemetry_client(): - """Fixture for NoopTelemetryClient.""" - return NoopTelemetryClient() + """Fixture for NOOP_TELEMETRY_CLIENT.""" + return NOOP_TELEMETRY_CLIENT @pytest.fixture @@ -53,30 +55,27 @@ def telemetry_client_setup(): @pytest.fixture -def telemetry_factory_reset(): - """Fixture to reset TelemetryClientFactory state before each test.""" - # Reset the static class state before each test - TelemetryClientFactory._clients = {} - TelemetryClientFactory._executor = None - TelemetryClientFactory._initialized = False +def telemetry_system_reset(): + """Fixture to reset telemetry system state before each test.""" + # Reset the static state before each test + from databricks.sql.telemetry.telemetry_client import _clients, _executor, _initialized + _clients.clear() + if _executor: + _executor.shutdown(wait=True) + _executor = None + _initialized = False yield # Cleanup after test if needed - TelemetryClientFactory._clients = {} - if TelemetryClientFactory._executor: - TelemetryClientFactory._executor.shutdown(wait=True) - TelemetryClientFactory._executor = None - TelemetryClientFactory._initialized = False + _clients.clear() + if _executor: + _executor.shutdown(wait=True) + _executor = None + _initialized = False class TestNoopTelemetryClient: - """Tests for the NoopTelemetryClient class.""" - - def test_singleton(self): - """Test that NoopTelemetryClient is a singleton.""" - client1 = NoopTelemetryClient() - client2 = NoopTelemetryClient() - assert client1 is client2 - + """Tests for the NOOP_TELEMETRY_CLIENT.""" + def test_export_initial_telemetry_log(self, noop_telemetry_client): """Test that export_initial_telemetry_log does nothing.""" noop_telemetry_client.export_initial_telemetry_log( @@ -249,8 +248,7 @@ def test_flush(self, telemetry_client_setup): client._send_telemetry.assert_called_once_with(["event1", "event2"]) assert client._events_batch == [] - @patch("databricks.sql.telemetry.telemetry_client.TelemetryClientFactory") - def test_close(self, mock_factory_class, telemetry_client_setup): + def test_close(self, telemetry_client_setup): """Test closing the client.""" client = telemetry_client_setup["client"] client._flush = MagicMock() @@ -260,115 +258,63 @@ def test_close(self, mock_factory_class, telemetry_client_setup): client._flush.assert_called_once() -class TestTelemetryClientFactory: - """Tests for the TelemetryClientFactory static class.""" +class TestTelemetrySystem: + """Tests for the telemetry system functions.""" - @patch("databricks.sql.telemetry.telemetry_client.TelemetryClient") - def test_initialize_telemetry_client_enabled(self, mock_client_class, telemetry_factory_reset): + def test_initialize_telemetry_client_enabled(self, telemetry_system_reset): """Test initializing a telemetry client when telemetry is enabled.""" connection_uuid = "test-uuid" auth_provider = MagicMock() host_url = "test-host" - mock_client = MagicMock() - mock_client_class.return_value = mock_client - - TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=True, - connection_uuid=connection_uuid, - auth_provider=auth_provider, - host_url=host_url, - ) - # Verify a new client was created and stored - mock_client_class.assert_called_once_with( + initialize_telemetry_client( telemetry_enabled=True, connection_uuid=connection_uuid, auth_provider=auth_provider, host_url=host_url, - executor=TelemetryClientFactory._executor, ) - assert TelemetryClientFactory._clients[connection_uuid] == mock_client - - # Call again with the same connection_uuid - client2 = TelemetryClientFactory.get_telemetry_client(connection_uuid=connection_uuid) - # Verify the same client was returned and no new client was created - assert client2 == mock_client - mock_client_class.assert_called_once() # Still only called once + client = get_telemetry_client(connection_uuid) + assert isinstance(client, TelemetryClient) + assert client._connection_uuid == connection_uuid + assert client._auth_provider == auth_provider + assert client._host_url == host_url - def test_initialize_telemetry_client_disabled(self, telemetry_factory_reset): + def test_initialize_telemetry_client_disabled(self, telemetry_system_reset): """Test initializing a telemetry client when telemetry is disabled.""" connection_uuid = "test-uuid" - TelemetryClientFactory.initialize_telemetry_client( + initialize_telemetry_client( telemetry_enabled=False, connection_uuid=connection_uuid, auth_provider=MagicMock(), host_url="test-host", ) - # Verify a NoopTelemetryClient was stored - assert isinstance(TelemetryClientFactory._clients[connection_uuid], NoopTelemetryClient) - - client2 = TelemetryClientFactory.get_telemetry_client(connection_uuid) - assert isinstance(client2, NoopTelemetryClient) - - def test_get_telemetry_client_existing(self, telemetry_factory_reset): - """Test getting an existing telemetry client.""" - connection_uuid = "test-uuid" - mock_client = MagicMock() - TelemetryClientFactory._clients[connection_uuid] = mock_client - - client = TelemetryClientFactory.get_telemetry_client(connection_uuid) - - assert client == mock_client + client = get_telemetry_client(connection_uuid) + assert client is NOOP_TELEMETRY_CLIENT - def test_get_telemetry_client_nonexistent(self, telemetry_factory_reset): + def test_get_telemetry_client_nonexistent(self, telemetry_system_reset): """Test getting a non-existent telemetry client.""" - client = TelemetryClientFactory.get_telemetry_client("nonexistent-uuid") - - assert isinstance(client, NoopTelemetryClient) + client = get_telemetry_client("nonexistent-uuid") + assert client is NOOP_TELEMETRY_CLIENT - @patch("databricks.sql.telemetry.telemetry_client.ThreadPoolExecutor") - @patch("databricks.sql.telemetry.telemetry_client.TelemetryClient") - def test_close(self, mock_client_class, mock_executor_class, telemetry_factory_reset): - """Test that factory reinitializes properly after complete shutdown.""" - connection_uuid1 = "test-uuid1" - mock_executor1 = MagicMock() - mock_client1 = MagicMock() - mock_executor_class.return_value = mock_executor1 - mock_client_class.return_value = mock_client1 - - TelemetryClientFactory._clients[connection_uuid1] = mock_client1 - TelemetryClientFactory._executor = mock_executor1 - TelemetryClientFactory._initialized = True - - TelemetryClientFactory.close(connection_uuid1) - - assert TelemetryClientFactory._clients == {} - assert TelemetryClientFactory._executor is None - assert TelemetryClientFactory._initialized is False - mock_executor1.shutdown.assert_called_once_with(wait=True) - - # Now create a new client - this should reinitialize the factory - connection_uuid2 = "test-uuid2" - mock_executor2 = MagicMock() - mock_client2 = MagicMock() - mock_executor_class.return_value = mock_executor2 - mock_client_class.return_value = mock_client2 + def test_close_telemetry_client(self, telemetry_system_reset): + """Test closing a telemetry client.""" + connection_uuid = "test-uuid" + auth_provider = MagicMock() + host_url = "test-host" - TelemetryClientFactory.initialize_telemetry_client( + initialize_telemetry_client( telemetry_enabled=True, - connection_uuid=connection_uuid2, - auth_provider=MagicMock(), - host_url="test-host", + connection_uuid=connection_uuid, + auth_provider=auth_provider, + host_url=host_url, ) - # Verify factory was reinitialized - assert TelemetryClientFactory._initialized is True - assert TelemetryClientFactory._executor is not None - assert TelemetryClientFactory._executor == mock_executor2 - assert connection_uuid2 in TelemetryClientFactory._clients - assert TelemetryClientFactory._clients[connection_uuid2] == mock_client2 + client = get_telemetry_client(connection_uuid) + assert isinstance(client, TelemetryClient) + + _remove_telemetry_client(connection_uuid) - # Verify new ThreadPoolExecutor was created - assert mock_executor_class.call_count == 1 \ No newline at end of file + client = get_telemetry_client(connection_uuid) + assert client is NOOP_TELEMETRY_CLIENT \ No newline at end of file From 1722a7799ed98b7dacc1cbc31b054a201fb6106b Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 16 Jun 2025 12:33:47 +0530 Subject: [PATCH 20/48] renamed connection_uuid as session_id_hex Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 40 +++++------ src/databricks/sql/exc.py | 6 +- .../sql/telemetry/telemetry_client.py | 50 +++++++------- src/databricks/sql/thrift_backend.py | 68 +++++++++---------- tests/unit/test_telemetry.py | 32 ++++----- 5 files changed, 98 insertions(+), 98 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index bee60d317..23e4e38b1 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -309,13 +309,13 @@ def read(self) -> Optional[OAuthToken]: initialize_telemetry_client( telemetry_enabled=self.telemetry_enabled, - connection_uuid=self.get_session_id_hex(), + session_id_hex=self.get_session_id_hex(), auth_provider=auth_provider, host_url=self.host, ) self._telemetry_client = get_telemetry_client( - connection_uuid=self.get_session_id_hex() + session_id_hex=self.get_session_id_hex() ) driver_connection_params = DriverConnectionParameters( @@ -427,7 +427,7 @@ def cursor( if not self.open: raise InterfaceError( "Cannot create cursor from closed connection", - connection_uuid=self.get_session_id_hex(), + session_id_hex=self.get_session_id_hex(), ) cursor = Cursor( @@ -480,7 +480,7 @@ def commit(self): def rollback(self): raise NotSupportedError( "Transactions are not supported on Databricks", - connection_uuid=self.get_session_id_hex(), + session_id_hex=self.get_session_id_hex(), ) @@ -535,7 +535,7 @@ def __iter__(self): else: raise ProgrammingError( "There is no active result set", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def _determine_parameter_approach( @@ -675,7 +675,7 @@ def _check_not_closed(self): if not self.open: raise InterfaceError( "Attempting operation on closed cursor", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def _handle_staging_operation( @@ -695,7 +695,7 @@ def _handle_staging_operation( else: raise ProgrammingError( "You must provide at least one staging_allowed_local_path when initialising a connection to perform ingestion commands", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) abs_staging_allowed_local_paths = [ @@ -725,7 +725,7 @@ def _handle_staging_operation( if not allow_operation: raise ProgrammingError( "Local file operations are restricted to paths within the configured staging_allowed_local_path", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) # May be real headers, or could be json string @@ -756,7 +756,7 @@ def _handle_staging_operation( raise ProgrammingError( f"Operation {row.operation} is not supported. " + "Supported operations are GET, PUT, and REMOVE", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def _handle_staging_put( @@ -770,7 +770,7 @@ def _handle_staging_put( if local_file is None: raise ProgrammingError( "Cannot perform PUT without specifying a local_file", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) with open(local_file, "rb") as fh: @@ -789,7 +789,7 @@ def _handle_staging_put( if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) if r.status_code == ACCEPTED: @@ -809,7 +809,7 @@ def _handle_staging_get( if local_file is None: raise ProgrammingError( "Cannot perform GET without specifying a local_file", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) r = requests.get(url=presigned_url, headers=headers) @@ -819,7 +819,7 @@ def _handle_staging_get( if not r.ok: raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) with open(local_file, "wb") as fp: @@ -835,7 +835,7 @@ def _handle_staging_remove( if not r.ok: raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def execute( @@ -1035,7 +1035,7 @@ def get_async_execution_result(self): else: raise OperationalError( f"get_execution_result failed with Operation status {operation_state}", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def executemany(self, operation, seq_of_parameters): @@ -1187,7 +1187,7 @@ def fetchall(self) -> List[Row]: else: raise ProgrammingError( "There is no active result set", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def fetchone(self) -> Optional[Row]: @@ -1204,7 +1204,7 @@ def fetchone(self) -> Optional[Row]: else: raise ProgrammingError( "There is no active result set", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def fetchmany(self, size: int) -> List[Row]: @@ -1229,7 +1229,7 @@ def fetchmany(self, size: int) -> List[Row]: else: raise ProgrammingError( "There is no active result set", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def fetchall_arrow(self) -> "pyarrow.Table": @@ -1239,7 +1239,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": else: raise ProgrammingError( "There is no active result set", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def fetchmany_arrow(self, size) -> "pyarrow.Table": @@ -1249,7 +1249,7 @@ def fetchmany_arrow(self, size) -> "pyarrow.Table": else: raise ProgrammingError( "There is no active result set", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def cancel(self) -> None: diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 443d5605f..e7b2dad23 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -14,15 +14,15 @@ class Error(Exception): """ def __init__( - self, message=None, context=None, connection_uuid=None, *args, **kwargs + self, message=None, context=None, session_id_hex=None, *args, **kwargs ): super().__init__(message, *args, **kwargs) self.message = message self.context = context or {} error_name = self.__class__.__name__ - if connection_uuid: - telemetry_client = get_telemetry_client(connection_uuid) + if session_id_hex: + telemetry_client = get_telemetry_client(session_id_hex) telemetry_client.export_failure_log(error_name, self.message) def __str__(self): diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 728220789..a918beb09 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -140,15 +140,15 @@ class TelemetryClient(BaseTelemetryClient): def __init__( self, telemetry_enabled, - connection_uuid, + session_id_hex, auth_provider, host_url, executor, ): - logger.debug("Initializing TelemetryClient for connection: %s", connection_uuid) + logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled self._batch_size = 10 # TODO: Decide on batch size - self._connection_uuid = connection_uuid + self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None self._events_batch = [] @@ -159,7 +159,7 @@ def __init__( def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" - logger.debug("Exporting event for connection %s", self._connection_uuid) + logger.debug("Exporting event for connection %s", self._session_id_hex) with self._lock: self._events_batch.append(event) if len(self._events_batch) >= self._batch_size: @@ -230,7 +230,7 @@ def _telemetry_request_callback(self, future): def export_initial_telemetry_log(self, driver_connection_params, user_agent): logger.debug( - "Exporting initial telemetry log for connection %s", self._connection_uuid + "Exporting initial telemetry log for connection %s", self._session_id_hex ) try: @@ -247,7 +247,7 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): ), entry=FrontendLogEntry( sql_driver_log=TelemetryEvent( - session_id=self._connection_uuid, + session_id=self._session_id_hex, system_configuration=TelemetryHelper.get_driver_system_configuration(), driver_connection_params=self._driver_connection_params, ) @@ -260,7 +260,7 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): logger.debug("Failed to export initial telemetry log: %s", e) def export_failure_log(self, error_name, error_message): - logger.debug("Exporting failure log for connection %s", self._connection_uuid) + logger.debug("Exporting failure log for connection %s", self._session_id_hex) try: error_info = DriverErrorInfo( error_name=error_name, stack_trace=error_message @@ -275,7 +275,7 @@ def export_failure_log(self, error_name, error_message): ), entry=FrontendLogEntry( sql_driver_log=TelemetryEvent( - session_id=self._connection_uuid, + session_id=self._session_id_hex, system_configuration=TelemetryHelper.get_driver_system_configuration(), driver_connection_params=self._driver_connection_params, error_info=error_info, @@ -288,9 +288,9 @@ def export_failure_log(self, error_name, error_message): def close(self): """Flush remaining events before closing""" - logger.debug("Closing TelemetryClient for connection %s", self._connection_uuid) + logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) self._flush() - _remove_telemetry_client(self._connection_uuid) + _remove_telemetry_client(self._session_id_hex) # Module-level state @@ -340,41 +340,41 @@ def _handle_unhandled_exception(exc_type, exc_value, exc_traceback): def initialize_telemetry_client( - telemetry_enabled, connection_uuid, auth_provider, host_url + telemetry_enabled, session_id_hex, auth_provider, host_url ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: _initialize() with _lock: - if connection_uuid not in _clients: + if session_id_hex not in _clients: logger.debug( - "Creating new TelemetryClient for connection %s", connection_uuid + "Creating new TelemetryClient for connection %s", session_id_hex ) if telemetry_enabled: - _clients[connection_uuid] = TelemetryClient( + _clients[session_id_hex] = TelemetryClient( telemetry_enabled=telemetry_enabled, - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, auth_provider=auth_provider, host_url=host_url, executor=_executor, ) else: - _clients[connection_uuid] = NOOP_TELEMETRY_CLIENT + _clients[session_id_hex] = NOOP_TELEMETRY_CLIENT except Exception as e: logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail - _clients[connection_uuid] = NOOP_TELEMETRY_CLIENT + _clients[session_id_hex] = NOOP_TELEMETRY_CLIENT -def get_telemetry_client(connection_uuid): +def get_telemetry_client(session_id_hex): """Get the telemetry client for a specific connection""" try: - if connection_uuid in _clients: - return _clients[connection_uuid] + if session_id_hex in _clients: + return _clients[session_id_hex] else: logger.error( - "Telemetry client not initialized for connection %s", connection_uuid + "Telemetry client not initialized for connection %s", session_id_hex ) return NOOP_TELEMETRY_CLIENT except Exception as e: @@ -382,13 +382,13 @@ def get_telemetry_client(connection_uuid): return NOOP_TELEMETRY_CLIENT -def _remove_telemetry_client(connection_uuid): +def _remove_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor with _lock: - if connection_uuid in _clients: - logger.debug("Removing telemetry client for connection %s", connection_uuid) - _clients.pop(connection_uuid, None) + if session_id_hex in _clients: + logger.debug("Removing telemetry client for connection %s", session_id_hex) + _clients.pop(session_id_hex, None) # Shutdown executor if no more clients if not _clients and _executor: diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 7c47da2b1..78683ac31 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -223,7 +223,7 @@ def __init__( raise self._request_lock = threading.RLock() - self._connection_uuid = None + self._session_id_hex = None # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918) def _initialize_retry_args(self, kwargs): @@ -256,14 +256,14 @@ def _initialize_retry_args(self, kwargs): ) @staticmethod - def _check_response_for_error(response, connection_uuid=None): + def _check_response_for_error(response, session_id_hex=None): if response.status and response.status.statusCode in [ ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ]: raise DatabaseError( response.status.errorMessage, - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, ) @staticmethod @@ -317,7 +317,7 @@ def _handle_request_error(self, error_info, attempt, elapsed): network_request_error = RequestError( user_friendly_error_message, full_error_info_context, - self._connection_uuid, + self._session_id_hex, error_info.error, ) logger.info(network_request_error.message_with_context()) @@ -490,7 +490,7 @@ def attempt_request(attempt): if not isinstance(response_or_error_info, RequestErrorInfo): # log nothing here, presume that main request logging covers response = response_or_error_info - ThriftBackend._check_response_for_error(response, self._connection_uuid) + ThriftBackend._check_response_for_error(response, self._session_id_hex) return response error_info = response_or_error_info @@ -505,7 +505,7 @@ def _check_protocol_version(self, t_open_session_resp): "Error: expected server to use a protocol version >= " "SPARK_CLI_SERVICE_PROTOCOL_V2, " "instead got: {}".format(protocol_version), - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) def _check_initial_namespace(self, catalog, schema, response): @@ -519,7 +519,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Setting initial namespace not supported by the DBR version, " "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.", - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) if catalog: @@ -527,7 +527,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Unexpected response from server: Trying to set initial catalog to {}, " + "but server does not support multiple catalogs.".format(catalog), # type: ignore - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) def _check_session_configuration(self, session_configuration): @@ -542,7 +542,7 @@ def _check_session_configuration(self, session_configuration): TIMESTAMP_AS_STRING_CONFIG, session_configuration[TIMESTAMP_AS_STRING_CONFIG], ), - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) def open_session(self, session_configuration, catalog, schema): @@ -573,7 +573,7 @@ def open_session(self, session_configuration, catalog, schema): response = self.make_request(self._client.OpenSession, open_session_req) self._check_initial_namespace(catalog, schema, response) self._check_protocol_version(response) - self._connection_uuid = ( + self._session_id_hex = ( self.handle_to_hex_id(response.sessionHandle) if response.sessionHandle else None @@ -602,7 +602,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": get_operations_resp.diagnosticInfo, }, - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) else: raise ServerOperationError( @@ -612,7 +612,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": None, }, - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE: raise DatabaseError( @@ -623,7 +623,7 @@ def _check_command_not_in_error_or_closed_state( "operation-id": op_handle and self.guid_to_hex_id(op_handle.operationId.guid) }, - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) def _poll_for_status(self, op_handle): @@ -646,7 +646,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti else: raise OperationalError( "Unsupported TRowSet instance {}".format(t_row_set), - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) return convert_decimals_in_arrow_table(arrow_table, description), num_rows @@ -655,7 +655,7 @@ def _get_metadata_resp(self, op_handle): return self.make_request(self._client.GetResultSetMetadata, req) @staticmethod - def _hive_schema_to_arrow_schema(t_table_schema, connection_uuid=None): + def _hive_schema_to_arrow_schema(t_table_schema, session_id_hex=None): def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -687,7 +687,7 @@ def map_type(t_type_entry): # even for complex types raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, ) def convert_col(t_column_desc): @@ -698,7 +698,7 @@ def convert_col(t_column_desc): return pyarrow.schema([convert_col(col) for col in t_table_schema.columns]) @staticmethod - def _col_to_description(col, connection_uuid=None): + def _col_to_description(col, session_id_hex=None): type_entry = col.typeDesc.types[0] if type_entry.primitiveEntry: @@ -708,7 +708,7 @@ def _col_to_description(col, connection_uuid=None): else: raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, ) if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE: @@ -722,7 +722,7 @@ def _col_to_description(col, connection_uuid=None): raise OperationalError( "Decimal type did not provide typeQualifier precision, scale in " "primitiveEntry {}".format(type_entry.primitiveEntry), - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, ) else: precision, scale = None, None @@ -730,9 +730,9 @@ def _col_to_description(col, connection_uuid=None): return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod - def _hive_schema_to_description(t_table_schema, connection_uuid=None): + def _hive_schema_to_description(t_table_schema, session_id_hex=None): return [ - ThriftBackend._col_to_description(col, connection_uuid) + ThriftBackend._col_to_description(col, session_id_hex) for col in t_table_schema.columns ] @@ -754,7 +754,7 @@ def _results_message_to_execute_response(self, resp, operation_state): t_result_set_metadata_resp.resultFormat ] ), - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation @@ -765,14 +765,14 @@ def _results_message_to_execute_response(self, resp, operation_state): ) description = self._hive_schema_to_description( t_result_set_metadata_resp.schema, - self._connection_uuid, + self._session_id_hex, ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema or self._hive_schema_to_arrow_schema( - t_result_set_metadata_resp.schema, self._connection_uuid + t_result_set_metadata_resp.schema, self._session_id_hex ) .serialize() .to_pybytes() @@ -835,14 +835,14 @@ def get_execution_result(self, op_handle, cursor): has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( t_result_set_metadata_resp.schema, - self._connection_uuid, + self._session_id_hex, ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema or self._hive_schema_to_arrow_schema( - t_result_set_metadata_resp.schema, self._connection_uuid + t_result_set_metadata_resp.schema, self._session_id_hex ) .serialize() .to_pybytes() @@ -897,27 +897,27 @@ def get_query_state(self, op_handle) -> "TOperationState": return operation_state @staticmethod - def _check_direct_results_for_error(t_spark_direct_results, connection_uuid=None): + def _check_direct_results_for_error(t_spark_direct_results, session_id_hex=None): if t_spark_direct_results: if t_spark_direct_results.operationStatus: ThriftBackend._check_response_for_error( t_spark_direct_results.operationStatus, - connection_uuid, + session_id_hex, ) if t_spark_direct_results.resultSetMetadata: ThriftBackend._check_response_for_error( t_spark_direct_results.resultSetMetadata, - connection_uuid, + session_id_hex, ) if t_spark_direct_results.resultSet: ThriftBackend._check_response_for_error( t_spark_direct_results.resultSet, - connection_uuid, + session_id_hex, ) if t_spark_direct_results.closeOperation: ThriftBackend._check_response_for_error( t_spark_direct_results.closeOperation, - connection_uuid, + session_id_hex, ) def execute_command( @@ -1066,7 +1066,7 @@ def get_columns( def _handle_execute_response(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults, self._connection_uuid) + self._check_direct_results_for_error(resp.directResults, self._session_id_hex) final_operation_state = self._wait_until_command_done( resp.operationHandle, @@ -1077,7 +1077,7 @@ def _handle_execute_response(self, resp, cursor): def _handle_execute_response_async(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults, self._connection_uuid) + self._check_direct_results_for_error(resp.directResults, self._session_id_hex) def fetch_results( self, @@ -1112,7 +1112,7 @@ def fetch_results( "fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format( expected_row_start_offset, resp.results.startRowOffset ), - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) queue = ResultSetQueueFactory.build_queue( diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 975febd20..f89191ca4 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -32,14 +32,14 @@ def noop_telemetry_client(): @pytest.fixture def telemetry_client_setup(): """Fixture for TelemetryClient setup data.""" - connection_uuid = str(uuid.uuid4()) + session_id_hex = str(uuid.uuid4()) auth_provider = AccessTokenAuthProvider("test-token") host_url = "test-host" executor = MagicMock() client = TelemetryClient( telemetry_enabled=True, - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, auth_provider=auth_provider, host_url=host_url, executor=executor, @@ -47,7 +47,7 @@ def telemetry_client_setup(): return { "client": client, - "connection_uuid": connection_uuid, + "session_id_hex": session_id_hex, "auth_provider": auth_provider, "host_url": host_url, "executor": executor, @@ -217,7 +217,7 @@ def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup) unauthenticated_client = TelemetryClient( telemetry_enabled=True, - connection_uuid=str(uuid.uuid4()), + session_id_hex=str(uuid.uuid4()), auth_provider=None, # No auth provider host_url=host_url, executor=executor, @@ -263,34 +263,34 @@ class TestTelemetrySystem: def test_initialize_telemetry_client_enabled(self, telemetry_system_reset): """Test initializing a telemetry client when telemetry is enabled.""" - connection_uuid = "test-uuid" + session_id_hex = "test-uuid" auth_provider = MagicMock() host_url = "test-host" initialize_telemetry_client( telemetry_enabled=True, - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, auth_provider=auth_provider, host_url=host_url, ) - client = get_telemetry_client(connection_uuid) + client = get_telemetry_client(session_id_hex) assert isinstance(client, TelemetryClient) - assert client._connection_uuid == connection_uuid + assert client._session_id_hex == session_id_hex assert client._auth_provider == auth_provider assert client._host_url == host_url def test_initialize_telemetry_client_disabled(self, telemetry_system_reset): """Test initializing a telemetry client when telemetry is disabled.""" - connection_uuid = "test-uuid" + session_id_hex = "test-uuid" initialize_telemetry_client( telemetry_enabled=False, - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, auth_provider=MagicMock(), host_url="test-host", ) - client = get_telemetry_client(connection_uuid) + client = get_telemetry_client(session_id_hex) assert client is NOOP_TELEMETRY_CLIENT def test_get_telemetry_client_nonexistent(self, telemetry_system_reset): @@ -300,21 +300,21 @@ def test_get_telemetry_client_nonexistent(self, telemetry_system_reset): def test_close_telemetry_client(self, telemetry_system_reset): """Test closing a telemetry client.""" - connection_uuid = "test-uuid" + session_id_hex = "test-uuid" auth_provider = MagicMock() host_url = "test-host" initialize_telemetry_client( telemetry_enabled=True, - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, auth_provider=auth_provider, host_url=host_url, ) - client = get_telemetry_client(connection_uuid) + client = get_telemetry_client(session_id_hex) assert isinstance(client, TelemetryClient) - _remove_telemetry_client(connection_uuid) + _remove_telemetry_client(session_id_hex) - client = get_telemetry_client(connection_uuid) + client = get_telemetry_client(session_id_hex) assert client is NOOP_TELEMETRY_CLIENT \ No newline at end of file From e84143419a4ab67e88fcf317145fb74d0ed12a46 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 16 Jun 2025 15:43:52 +0530 Subject: [PATCH 21/48] added NotImplementedError to abstract class, added unit tests Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 6 +- tests/unit/test_telemetry.py | 156 +++++++++++++++++- 2 files changed, 158 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index a918beb09..585945cc6 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -104,15 +104,15 @@ class BaseTelemetryClient(ABC): @abstractmethod def export_initial_telemetry_log(self, driver_connection_params, user_agent): - pass + raise NotImplementedError("Subclasses must implement export_initial_telemetry_log") @abstractmethod def export_failure_log(self, error_name, error_message): - pass + raise NotImplementedError("Subclasses must implement export_failure_log") @abstractmethod def close(self): - pass + raise NotImplementedError("Subclasses must implement close") # A single instance of the no-op client that can be reused diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index f89191ca4..b82780f4b 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -9,10 +9,13 @@ initialize_telemetry_client, get_telemetry_client, _remove_telemetry_client, + TelemetryHelper, + BaseTelemetryClient ) from databricks.sql.telemetry.models.enums import ( AuthMech, DatabricksClientType, + AuthFlow, ) from databricks.sql.telemetry.models.event import ( DriverConnectionParameters, @@ -20,6 +23,8 @@ ) from databricks.sql.auth.authenticators import ( AccessTokenAuthProvider, + DatabricksOAuthProvider, + ExternalAuthProvider, ) @@ -257,6 +262,72 @@ def test_close(self, telemetry_client_setup): client._flush.assert_called_once() + @patch("requests.post") + def test_telemetry_request_callback_success(self, mock_post, telemetry_client_setup): + """Test successful telemetry request callback.""" + client = telemetry_client_setup["client"] + + mock_response = MagicMock() + mock_response.status_code = 200 + + mock_future = MagicMock() + mock_future.result.return_value = mock_response + + client._telemetry_request_callback(mock_future) + + mock_future.result.assert_called_once() + + @patch("requests.post") + def test_telemetry_request_callback_failure(self, mock_post, telemetry_client_setup): + """Test telemetry request callback with failure""" + client = telemetry_client_setup["client"] + + # Test with non-200 status code + mock_response = MagicMock() + mock_response.status_code = 500 + future = MagicMock() + future.result.return_value = mock_response + client._telemetry_request_callback(future) + + # Test with exception + future = MagicMock() + future.result.side_effect = Exception("Test error") + client._telemetry_request_callback(future) + + def test_telemetry_client_exception_handling(self, telemetry_client_setup): + """Test exception handling in telemetry client methods.""" + client = telemetry_client_setup["client"] + + # Test export_initial_telemetry_log with exception + with patch.object(client, '_export_event', side_effect=Exception("Test error")): + # Should not raise exception + client.export_initial_telemetry_log(MagicMock(), "test-agent") + + # Test export_failure_log with exception + with patch.object(client, '_export_event', side_effect=Exception("Test error")): + # Should not raise exception + client.export_failure_log("TestError", "Test error message") + + # Test _send_telemetry with exception + with patch.object(client._executor, 'submit', side_effect=Exception("Test error")): + # Should not raise exception + client._send_telemetry([MagicMock()]) + + def test_send_telemetry_thread_pool_failure(self, telemetry_client_setup): + """Test handling of thread pool submission failure""" + client = telemetry_client_setup["client"] + client._executor.submit.side_effect = Exception("Thread pool error") + event = MagicMock() + client._send_telemetry([event]) + + def test_base_telemetry_client_abstract_methods(self): + """Test that BaseTelemetryClient cannot be instantiated without implementing all abstract methods""" + class TestBaseClient(BaseTelemetryClient): + pass + + with pytest.raises(TypeError): + TestBaseClient() # Can't instantiate abstract class + class TestTelemetrySystem: """Tests for the telemetry system functions.""" @@ -317,4 +388,87 @@ def test_close_telemetry_client(self, telemetry_system_reset): _remove_telemetry_client(session_id_hex) client = get_telemetry_client(session_id_hex) - assert client is NOOP_TELEMETRY_CLIENT \ No newline at end of file + assert client is NOOP_TELEMETRY_CLIENT + + @patch("databricks.sql.telemetry.telemetry_client._handle_unhandled_exception") + def test_global_exception_hook(self, mock_handle_exception, telemetry_system_reset): + """Test that global exception hook is installed and handles exceptions.""" + from databricks.sql.telemetry.telemetry_client import _install_exception_hook, _handle_unhandled_exception + + _install_exception_hook() + + test_exception = ValueError("Test exception") + _handle_unhandled_exception(type(test_exception), test_exception, None) + + mock_handle_exception.assert_called_once_with(type(test_exception), test_exception, None) + + +class TestTelemetryHelper: + """Tests for the TelemetryHelper class.""" + + def test_get_driver_system_configuration(self): + """Test getting driver system configuration.""" + config = TelemetryHelper.get_driver_system_configuration() + + assert isinstance(config.driver_name, str) + assert isinstance(config.driver_version, str) + assert isinstance(config.runtime_name, str) + assert isinstance(config.runtime_vendor, str) + assert isinstance(config.runtime_version, str) + assert isinstance(config.os_name, str) + assert isinstance(config.os_version, str) + assert isinstance(config.os_arch, str) + assert isinstance(config.locale_name, str) + assert isinstance(config.char_set_encoding, str) + + assert config.driver_name == "Databricks SQL Python Connector" + assert "Python" in config.runtime_name + assert config.runtime_vendor in ["CPython", "PyPy", "Jython", "IronPython"] + assert config.os_name in ["Darwin", "Linux", "Windows"] + + # Verify caching behavior + config2 = TelemetryHelper.get_driver_system_configuration() + assert config is config2 # Should return same instance + + def test_get_auth_mechanism(self): + """Test getting auth mechanism for different auth providers.""" + # Test PAT auth + pat_auth = AccessTokenAuthProvider("test-token") + assert TelemetryHelper.get_auth_mechanism(pat_auth) == AuthMech.PAT + + # Test OAuth auth + oauth_auth = MagicMock(spec=DatabricksOAuthProvider) + assert TelemetryHelper.get_auth_mechanism(oauth_auth) == AuthMech.DATABRICKS_OAUTH + + # Test External auth + external_auth = MagicMock(spec=ExternalAuthProvider) + assert TelemetryHelper.get_auth_mechanism(external_auth) == AuthMech.EXTERNAL_AUTH + + # Test None auth provider + assert TelemetryHelper.get_auth_mechanism(None) is None + + # Test unknown auth provider + unknown_auth = MagicMock() + assert TelemetryHelper.get_auth_mechanism(unknown_auth) == AuthMech.CLIENT_CERT + + def test_get_auth_flow(self): + """Test getting auth flow for different OAuth providers.""" + # Test OAuth with existing tokens + oauth_with_tokens = MagicMock(spec=DatabricksOAuthProvider) + oauth_with_tokens._access_token = "test-access-token" + oauth_with_tokens._refresh_token = "test-refresh-token" + assert TelemetryHelper.get_auth_flow(oauth_with_tokens) == AuthFlow.TOKEN_PASSTHROUGH + + # Test OAuth with browser-based auth + oauth_with_browser = MagicMock(spec=DatabricksOAuthProvider) + oauth_with_browser._access_token = None + oauth_with_browser._refresh_token = None + oauth_with_browser.oauth_manager = MagicMock() + assert TelemetryHelper.get_auth_flow(oauth_with_browser) == AuthFlow.BROWSER_BASED_AUTHENTICATION + + # Test non-OAuth provider + pat_auth = AccessTokenAuthProvider("test-token") + assert TelemetryHelper.get_auth_flow(pat_auth) is None + + # Test None auth provider + assert TelemetryHelper.get_auth_flow(None) is None \ No newline at end of file From 2f89266cd2d44745fcf4a528aa07b8b2050299a8 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 16 Jun 2025 15:45:42 +0530 Subject: [PATCH 22/48] formatting Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/telemetry_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 585945cc6..a918beb09 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -104,15 +104,15 @@ class BaseTelemetryClient(ABC): @abstractmethod def export_initial_telemetry_log(self, driver_connection_params, user_agent): - raise NotImplementedError("Subclasses must implement export_initial_telemetry_log") + pass @abstractmethod def export_failure_log(self, error_name, error_message): - raise NotImplementedError("Subclasses must implement export_failure_log") + pass @abstractmethod def close(self): - raise NotImplementedError("Subclasses must implement close") + pass # A single instance of the no-op client that can be reused From 5564bbb9a0c40123caf9c13058fbc7b030bea44a Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 10:31:00 +0530 Subject: [PATCH 23/48] added PEP-249 link, changed NoopTelemetryClient implementation Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/exc.py | 1 + .../sql/telemetry/telemetry_client.py | 32 ++++++++++++------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index e7b2dad23..20a898999 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -7,6 +7,7 @@ logger = logging.getLogger(__name__) ### PEP-249 Mandated ### +# https://peps.python.org/pep-0249/#exceptions class Error(Exception): """Base class for DB-API2.0 exceptions. `message`: An optional user-friendly error message. It should be short, actionable and stable diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index a918beb09..06a6813c1 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -104,27 +104,37 @@ class BaseTelemetryClient(ABC): @abstractmethod def export_initial_telemetry_log(self, driver_connection_params, user_agent): - pass + raise NotImplementedError( + "Subclasses must implement export_initial_telemetry_log" + ) @abstractmethod def export_failure_log(self, error_name, error_message): - pass + raise NotImplementedError("Subclasses must implement export_failure_log") @abstractmethod + def close(self): + raise NotImplementedError("Subclasses must implement close") + + +class NoopTelemetryClient(BaseTelemetryClient): + """ + NoopTelemetryClient is a telemetry client that does not send any events to the server. + It is used when telemetry is disabled. + """ + + def export_initial_telemetry_log(self, driver_connection_params, user_agent): + pass + + def export_failure_log(self, error_name, error_message): + pass + def close(self): pass # A single instance of the no-op client that can be reused -NOOP_TELEMETRY_CLIENT = type( - "NoopTelemetryClient", - (BaseTelemetryClient,), - { - "export_initial_telemetry_log": lambda self, *args, **kwargs: None, - "export_failure_log": lambda self, *args, **kwargs: None, - "close": lambda self: None, - }, -)() +NOOP_TELEMETRY_CLIENT = NoopTelemetryClient() class TelemetryClient(BaseTelemetryClient): From 1e4e8cfb07dc2ecef5b73ba36bde487b4aaec965 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 12:10:28 +0530 Subject: [PATCH 24/48] removed unused import Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/exc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 20a898999..9ca662126 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -1,6 +1,5 @@ import json import logging -import traceback from databricks.sql.telemetry.telemetry_client import get_telemetry_client From 55b29bceafcc01541e77f9d69264d214656c60e2 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 14:45:43 +0530 Subject: [PATCH 25/48] made telemetry client close a module-level function Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 3 +- .../sql/telemetry/telemetry_client.py | 6 ++-- tests/unit/test_telemetry.py | 30 +++++++++++++++++-- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 23e4e38b1..9359c4272 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -56,6 +56,7 @@ TelemetryHelper, initialize_telemetry_client, get_telemetry_client, + close_telemetry_client, ) from databricks.sql.telemetry.models.enums import DatabricksClientType from databricks.sql.telemetry.models.event import ( @@ -471,7 +472,7 @@ def _close(self, close_cursors=True) -> None: self.open = False - self._telemetry_client.close() + close_telemetry_client(self.get_session_id_hex()) def commit(self): """No-op because Databricks does not support transactions""" diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 06a6813c1..1fd850dbc 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -300,7 +300,6 @@ def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) self._flush() - _remove_telemetry_client(self._session_id_hex) # Module-level state @@ -392,13 +391,14 @@ def get_telemetry_client(session_id_hex): return NOOP_TELEMETRY_CLIENT -def _remove_telemetry_client(session_id_hex): +def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor with _lock: if session_id_hex in _clients: logger.debug("Removing telemetry client for connection %s", session_id_hex) - _clients.pop(session_id_hex, None) + telemetry_client = _clients.pop(session_id_hex, None) + telemetry_client.close() # Shutdown executor if no more clients if not _clients and _executor: diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index b82780f4b..217f4cfaa 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -8,7 +8,7 @@ NOOP_TELEMETRY_CLIENT, initialize_telemetry_client, get_telemetry_client, - _remove_telemetry_client, + close_telemetry_client, TelemetryHelper, BaseTelemetryClient ) @@ -385,7 +385,33 @@ def test_close_telemetry_client(self, telemetry_system_reset): client = get_telemetry_client(session_id_hex) assert isinstance(client, TelemetryClient) - _remove_telemetry_client(session_id_hex) + client.close = MagicMock() + + close_telemetry_client(session_id_hex) + + client.close.assert_called_once() + + client = get_telemetry_client(session_id_hex) + assert client is NOOP_TELEMETRY_CLIENT + + def test_close_telemetry_client_noop(self, telemetry_system_reset): + """Test closing a no-op telemetry client.""" + session_id_hex = "test-uuid" + initialize_telemetry_client( + telemetry_enabled=False, + session_id_hex=session_id_hex, + auth_provider=MagicMock(), + host_url="test-host", + ) + + client = get_telemetry_client(session_id_hex) + assert client is NOOP_TELEMETRY_CLIENT + + client.close = MagicMock() + + close_telemetry_client(session_id_hex) + + client.close.assert_called_once() client = get_telemetry_client(session_id_hex) assert client is NOOP_TELEMETRY_CLIENT From 93bf170f46cda84f8658ef4149ed84328b14778e Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 15:15:47 +0530 Subject: [PATCH 26/48] unit tests verbose Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 462d22369..265f8a829 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -61,7 +61,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit + run: poetry run python -m pytest tests/unit -v run-unit-tests-with-arrow: runs-on: ubuntu-latest strategy: @@ -112,7 +112,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit + run: poetry run python -m pytest tests/unit -v check-linting: runs-on: ubuntu-latest strategy: From 45f5ccf0e97196b1232397756e1faecf67e98a9c Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 15:34:55 +0530 Subject: [PATCH 27/48] debug logs in unit tests Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 265f8a829..d78854671 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -61,7 +61,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit -v + run: LOG_LEVEL=DEBUG poetry run python -m pytest tests/unit -v run-unit-tests-with-arrow: runs-on: ubuntu-latest strategy: @@ -112,7 +112,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit -v + run: LOG_LEVEL=DEBUG poetry run python -m pytest tests/unit -v check-linting: runs-on: ubuntu-latest strategy: From 8ff1c1fa26dc3b0d62419e21faf2ffa03d136f9e Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 15:57:42 +0530 Subject: [PATCH 28/48] debug logs in unit tests Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index d78854671..7a221c53a 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -61,7 +61,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: LOG_LEVEL=DEBUG poetry run python -m pytest tests/unit -v + run: poetry run python -m pytest tests/unit -v -s --log-cli-level=DEBUG run-unit-tests-with-arrow: runs-on: ubuntu-latest strategy: @@ -112,7 +112,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: LOG_LEVEL=DEBUG poetry run python -m pytest tests/unit -v + run: poetry run python -m pytest tests/unit -v -s --log-cli-level=DEBUG check-linting: runs-on: ubuntu-latest strategy: From 8bdd3243bb4394a7aa9a1ea930f27142b30b6d1e Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 18:40:34 +0530 Subject: [PATCH 29/48] removed ABC from mixin, added try/catch block around executor shutdown Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 4 +- .../sql/telemetry/telemetry_client.py | 36 +++++++++------- src/databricks/sql/telemetry/utils.py | 3 +- tests/unit/test_client.py | 41 ++++++++++--------- tests/unit/test_telemetry.py | 24 +++++++---- 5 files changed, 61 insertions(+), 47 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 7a221c53a..e23072d3e 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -61,7 +61,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit -v -s --log-cli-level=DEBUG + run: poetry run python -m pytest tests/unit run-unit-tests-with-arrow: runs-on: ubuntu-latest strategy: @@ -112,7 +112,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit -v -s --log-cli-level=DEBUG + run: poetry run python -m pytest tests/unit -v -s check-linting: runs-on: ubuntu-latest strategy: diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 1fd850dbc..7d6f7b404 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -123,6 +123,13 @@ class NoopTelemetryClient(BaseTelemetryClient): It is used when telemetry is disabled. """ + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(NoopTelemetryClient, cls).__new__(cls) + return cls._instance + def export_initial_telemetry_log(self, driver_connection_params, user_agent): pass @@ -133,10 +140,6 @@ def close(self): pass -# A single instance of the no-op client that can be reused -NOOP_TELEMETRY_CLIENT = NoopTelemetryClient() - - class TelemetryClient(BaseTelemetryClient): """ Telemetry client class that handles sending telemetry events in batches to the server. @@ -369,11 +372,11 @@ def initialize_telemetry_client( executor=_executor, ) else: - _clients[session_id_hex] = NOOP_TELEMETRY_CLIENT + _clients[session_id_hex] = NoopTelemetryClient() except Exception as e: logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail - _clients[session_id_hex] = NOOP_TELEMETRY_CLIENT + _clients[session_id_hex] = NoopTelemetryClient() def get_telemetry_client(session_id_hex): @@ -385,10 +388,10 @@ def get_telemetry_client(session_id_hex): logger.error( "Telemetry client not initialized for connection %s", session_id_hex ) - return NOOP_TELEMETRY_CLIENT + return NoopTelemetryClient() except Exception as e: logger.debug("Failed to get telemetry client: %s", e) - return NOOP_TELEMETRY_CLIENT + return NoopTelemetryClient() def close_telemetry_client(session_id_hex): @@ -401,10 +404,13 @@ def close_telemetry_client(session_id_hex): telemetry_client.close() # Shutdown executor if no more clients - if not _clients and _executor: - logger.debug( - "No more telemetry clients, shutting down thread pool executor" - ) - _executor.shutdown(wait=True) - _executor = None - _initialized = False + try: + if not _clients and _executor: + logger.debug( + "No more telemetry clients, shutting down thread pool executor" + ) + _executor.shutdown(wait=True) + _executor = None + _initialized = False + except Exception as e: + logger.debug("Failed to shutdown thread pool executor: %s", e) diff --git a/src/databricks/sql/telemetry/utils.py b/src/databricks/sql/telemetry/utils.py index 6d95526b8..df7acf28c 100644 --- a/src/databricks/sql/telemetry/utils.py +++ b/src/databricks/sql/telemetry/utils.py @@ -1,10 +1,9 @@ import json from enum import Enum from dataclasses import asdict, is_dataclass -from abc import ABC -class JsonSerializableMixin(ABC): +class JsonSerializableMixin: """Mixin class to provide JSON serialization capabilities to dataclasses.""" def to_json(self) -> str: diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 588b0d70e..427a7d7bd 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -337,6 +337,7 @@ def test_negative_fetch_throws_exception(self): result_set.fetchmany(-1) def test_context_manager_closes_cursor(self): + print("hellow") mock_close = Mock() with client.Cursor(Mock(), Mock()) as cursor: cursor.close = mock_close @@ -351,29 +352,30 @@ def test_context_manager_closes_cursor(self): finally: cursor.close.assert_called() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_context_manager_closes_connection(self, mock_client_class): - instance = mock_client_class.return_value + # @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + # def test_context_manager_closes_connection(self, mock_client_class): + # print("hellow1") + # instance = mock_client_class.return_value - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp + # mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + # mock_open_session_resp.sessionHandle.sessionId = b"\x22" + # instance.open_session.return_value = mock_open_session_resp - with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: - pass + # with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: + # pass - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") + # # Check the close session request has an id of x22 + # close_session_id = instance.close_session.call_args[0][0].sessionId + # self.assertEqual(close_session_id, b"\x22") - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - connection.close = Mock() - try: - with self.assertRaises(KeyboardInterrupt): - with connection: - raise KeyboardInterrupt("Simulated interrupt") - finally: - connection.close.assert_called() + # connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + # connection.close = Mock() + # try: + # with self.assertRaises(KeyboardInterrupt): + # with connection: + # raise KeyboardInterrupt("Simulated interrupt") + # finally: + # connection.close.assert_called() def dict_product(self, dicts): """ @@ -791,6 +793,7 @@ def test_cursor_context_manager_handles_exit_exception(self): def test_connection_close_handles_cursor_close_exception(self): """Test that _close handles exceptions from cursor.close() properly.""" + print("banana") cursors_closed = [] def mock_close_with_exception(): diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 217f4cfaa..84833dafd 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -5,7 +5,7 @@ from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, - NOOP_TELEMETRY_CLIENT, + NoopTelemetryClient, initialize_telemetry_client, get_telemetry_client, close_telemetry_client, @@ -30,8 +30,8 @@ @pytest.fixture def noop_telemetry_client(): - """Fixture for NOOP_TELEMETRY_CLIENT.""" - return NOOP_TELEMETRY_CLIENT + """Fixture for NoopTelemetryClient.""" + return NoopTelemetryClient() @pytest.fixture @@ -79,7 +79,13 @@ def telemetry_system_reset(): class TestNoopTelemetryClient: - """Tests for the NOOP_TELEMETRY_CLIENT.""" + """Tests for the NoopTelemetryClient.""" + + def test_singleton(self): + """Test that NoopTelemetryClient is a singleton.""" + client1 = NoopTelemetryClient() + client2 = NoopTelemetryClient() + assert client1 is client2 def test_export_initial_telemetry_log(self, noop_telemetry_client): """Test that export_initial_telemetry_log does nothing.""" @@ -362,12 +368,12 @@ def test_initialize_telemetry_client_disabled(self, telemetry_system_reset): ) client = get_telemetry_client(session_id_hex) - assert client is NOOP_TELEMETRY_CLIENT + assert isinstance(client, NoopTelemetryClient) def test_get_telemetry_client_nonexistent(self, telemetry_system_reset): """Test getting a non-existent telemetry client.""" client = get_telemetry_client("nonexistent-uuid") - assert client is NOOP_TELEMETRY_CLIENT + assert isinstance(client, NoopTelemetryClient) def test_close_telemetry_client(self, telemetry_system_reset): """Test closing a telemetry client.""" @@ -392,7 +398,7 @@ def test_close_telemetry_client(self, telemetry_system_reset): client.close.assert_called_once() client = get_telemetry_client(session_id_hex) - assert client is NOOP_TELEMETRY_CLIENT + assert isinstance(client, NoopTelemetryClient) def test_close_telemetry_client_noop(self, telemetry_system_reset): """Test closing a no-op telemetry client.""" @@ -405,7 +411,7 @@ def test_close_telemetry_client_noop(self, telemetry_system_reset): ) client = get_telemetry_client(session_id_hex) - assert client is NOOP_TELEMETRY_CLIENT + assert isinstance(client, NoopTelemetryClient) client.close = MagicMock() @@ -414,7 +420,7 @@ def test_close_telemetry_client_noop(self, telemetry_system_reset): client.close.assert_called_once() client = get_telemetry_client(session_id_hex) - assert client is NOOP_TELEMETRY_CLIENT + assert isinstance(client, NoopTelemetryClient) @patch("databricks.sql.telemetry.telemetry_client._handle_unhandled_exception") def test_global_exception_hook(self, mock_handle_exception, telemetry_system_reset): From f99f7ea98f1385c07855f0e139e023da049a1cc8 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 20:30:53 +0530 Subject: [PATCH 30/48] checking stuff Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 2 +- tests/unit/test_client.py | 49 +++++++++++------------ 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index e23072d3e..f5b23dd28 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -75,7 +75,7 @@ jobs: uses: actions/checkout@v2 - name: Set up python ${{ matrix.python-version }} id: setup-python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} #---------------------------------------------- diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 427a7d7bd..d64b06b5f 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -52,7 +52,6 @@ def new(cls): ) ThriftBackendMock.execute_command.return_value = MockTExecuteStatementResp - return ThriftBackendMock @classmethod @@ -352,30 +351,30 @@ def test_context_manager_closes_cursor(self): finally: cursor.close.assert_called() - # @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - # def test_context_manager_closes_connection(self, mock_client_class): - # print("hellow1") - # instance = mock_client_class.return_value - - # mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - # mock_open_session_resp.sessionHandle.sessionId = b"\x22" - # instance.open_session.return_value = mock_open_session_resp - - # with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: - # pass - - # # Check the close session request has an id of x22 - # close_session_id = instance.close_session.call_args[0][0].sessionId - # self.assertEqual(close_session_id, b"\x22") - - # connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - # connection.close = Mock() - # try: - # with self.assertRaises(KeyboardInterrupt): - # with connection: - # raise KeyboardInterrupt("Simulated interrupt") - # finally: - # connection.close.assert_called() + @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + def test_context_manager_closes_connection(self, mock_client_class): + print("hellow1") + instance = mock_client_class.return_value + + mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + mock_open_session_resp.sessionHandle.sessionId = b"\x22" + instance.open_session.return_value = mock_open_session_resp + + with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: + pass + + # Check the close session request has an id of x22 + close_session_id = instance.close_session.call_args[0][0].sessionId + self.assertEqual(close_session_id, b"\x22") + + connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection.close = Mock() + try: + with self.assertRaises(KeyboardInterrupt): + with connection: + raise KeyboardInterrupt("Simulated interrupt") + finally: + connection.close.assert_called() def dict_product(self, dicts): """ From b972c8a36e29414d832ee5806d4ccedf9e98dce3 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 20:35:25 +0530 Subject: [PATCH 31/48] finding out --- .github/workflows/code-quality-checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index f5b23dd28..8f8a2278a 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -117,7 +117,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13t", "3.14"] steps: #---------------------------------------------- # check-out repo and set-up python From 7ca36363e3d5c324287def891f9a261189d4931a Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 20:37:20 +0530 Subject: [PATCH 32/48] finding out more --- .github/workflows/code-quality-checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 8f8a2278a..158ac64a6 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -117,7 +117,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13t", "3.14"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13t"] steps: #---------------------------------------------- # check-out repo and set-up python From 0ac8ed2d7ad3904c5a1312a4489a7b913ab35a74 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 20:49:31 +0530 Subject: [PATCH 33/48] more more finding out more nice --- src/databricks/sql/telemetry/telemetry_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 7d6f7b404..9d1a38909 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -183,9 +183,9 @@ def _export_event(self, event): def _flush(self): """Flush the current batch of events to the server""" - with self._lock: - events_to_flush = self._events_batch.copy() - self._events_batch = [] + # with self._lock: + events_to_flush = self._events_batch.copy() + self._events_batch = [] if events_to_flush: logger.debug("Flushing %s telemetry events to server", len(events_to_flush)) From c457a0970d99604347d40478a7f1ef3f3e53674d Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 20:53:19 +0530 Subject: [PATCH 34/48] locks are useless anyways --- .../sql/telemetry/telemetry_client.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 9d1a38909..9176cdcaf 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -183,9 +183,9 @@ def _export_event(self, event): def _flush(self): """Flush the current batch of events to the server""" - # with self._lock: - events_to_flush = self._events_batch.copy() - self._events_batch = [] + with self._lock: + events_to_flush = self._events_batch.copy() + self._events_batch = [] if events_to_flush: logger.debug("Flushing %s telemetry events to server", len(events_to_flush)) @@ -397,20 +397,20 @@ def get_telemetry_client(session_id_hex): def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor - with _lock: - if session_id_hex in _clients: - logger.debug("Removing telemetry client for connection %s", session_id_hex) - telemetry_client = _clients.pop(session_id_hex, None) - telemetry_client.close() - # Shutdown executor if no more clients - try: - if not _clients and _executor: - logger.debug( - "No more telemetry clients, shutting down thread pool executor" - ) - _executor.shutdown(wait=True) - _executor = None - _initialized = False - except Exception as e: - logger.debug("Failed to shutdown thread pool executor: %s", e) + if session_id_hex in _clients: + logger.debug("Removing telemetry client for connection %s", session_id_hex) + telemetry_client = _clients.pop(session_id_hex, None) + telemetry_client.close() + + # Shutdown executor if no more clients + try: + if not _clients and _executor: + logger.debug( + "No more telemetry clients, shutting down thread pool executor" + ) + _executor.shutdown(wait=True) + _executor = None + _initialized = False + except Exception as e: + logger.debug("Failed to shutdown thread pool executor: %s", e) From 5f07a84bcc54aa28c117a8fe5771a5302a019ea5 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 21:00:51 +0530 Subject: [PATCH 35/48] haha --- .../sql/telemetry/telemetry_client.py | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 9176cdcaf..754772235 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -317,15 +317,14 @@ def close(self): def _initialize(): """Initialize the telemetry system if not already initialized""" global _initialized, _executor - with _lock: - if not _initialized: - _clients.clear() - _executor = ThreadPoolExecutor(max_workers=10) - _install_exception_hook() - _initialized = True - logger.debug( - "Telemetry system initialized with thread pool (max_workers=10)" - ) + if not _initialized: + _clients.clear() + _executor = ThreadPoolExecutor(max_workers=10) + _install_exception_hook() + _initialized = True + logger.debug( + "Telemetry system initialized with thread pool (max_workers=10)" + ) def _install_exception_hook(): @@ -356,9 +355,8 @@ def initialize_telemetry_client( ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: - _initialize() - with _lock: + _initialize() if session_id_hex not in _clients: logger.debug( "Creating new TelemetryClient for connection %s", session_id_hex @@ -371,8 +369,10 @@ def initialize_telemetry_client( host_url=host_url, executor=_executor, ) + print("i have initialized the telemetry client yes") else: _clients[session_id_hex] = NoopTelemetryClient() + print("i have initialized the noop client yes") except Exception as e: logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail @@ -397,20 +397,20 @@ def get_telemetry_client(session_id_hex): def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor + with _lock: + if session_id_hex in _clients: + logger.debug("Removing telemetry client for connection %s", session_id_hex) + telemetry_client = _clients.pop(session_id_hex, None) + telemetry_client.close() - if session_id_hex in _clients: - logger.debug("Removing telemetry client for connection %s", session_id_hex) - telemetry_client = _clients.pop(session_id_hex, None) - telemetry_client.close() - - # Shutdown executor if no more clients - try: - if not _clients and _executor: - logger.debug( - "No more telemetry clients, shutting down thread pool executor" - ) - _executor.shutdown(wait=True) - _executor = None - _initialized = False - except Exception as e: - logger.debug("Failed to shutdown thread pool executor: %s", e) + # Shutdown executor if no more clients + try: + if not _clients and _executor: + logger.debug( + "No more telemetry clients, shutting down thread pool executor" + ) + _executor.shutdown(wait=True) + _executor = None + _initialized = False + except Exception as e: + logger.debug("Failed to shutdown thread pool executor: %s", e) From 1115e2523f1419016aa41b2a710867eb488bacf7 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 21:07:03 +0530 Subject: [PATCH 36/48] normal --- tests/unit/test_client.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index d64b06b5f..981543552 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -497,17 +497,17 @@ def test_version_is_canonical(self): ) self.assertIsNotNone(re.match(canonical_version_re, version)) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_configuration_passthrough(self, mock_client_class): - mock_session_config = Mock() - databricks.sql.connect( - session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS - ) - - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][0], - mock_session_config, - ) + # @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + # def test_configuration_passthrough(self, mock_client_class): + # mock_session_config = Mock() + # databricks.sql.connect( + # session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS + # ) + + # self.assertEqual( + # mock_client_class.return_value.open_session.call_args[0][0], + # mock_session_config, + # ) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_initial_namespace_passthrough(self, mock_client_class): From de1ed87b6b4b17e9cadf0ac9b961bb4c2f730704 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 21:17:32 +0530 Subject: [PATCH 37/48] := looks like walrus horizontally --- src/databricks/sql/telemetry/telemetry_client.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 754772235..c194f9f90 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -398,9 +398,8 @@ def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor with _lock: - if session_id_hex in _clients: + if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: logger.debug("Removing telemetry client for connection %s", session_id_hex) - telemetry_client = _clients.pop(session_id_hex, None) telemetry_client.close() # Shutdown executor if no more clients From 554aeaf02ef36628797b99370464b15c31c56144 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 21:35:16 +0530 Subject: [PATCH 38/48] one more --- .../sql/telemetry/telemetry_client.py | 4 +++- tests/unit/test_client.py | 22 +++++++++---------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index c194f9f90..e36f05ea7 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -398,7 +398,9 @@ def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor with _lock: - if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: + # if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: + if session_id_hex in _clients: + telemetry_client = _clients.pop(session_id_hex) logger.debug("Removing telemetry client for connection %s", session_id_hex) telemetry_client.close() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 981543552..d64b06b5f 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -497,17 +497,17 @@ def test_version_is_canonical(self): ) self.assertIsNotNone(re.match(canonical_version_re, version)) - # @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - # def test_configuration_passthrough(self, mock_client_class): - # mock_session_config = Mock() - # databricks.sql.connect( - # session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS - # ) - - # self.assertEqual( - # mock_client_class.return_value.open_session.call_args[0][0], - # mock_session_config, - # ) + @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + def test_configuration_passthrough(self, mock_client_class): + mock_session_config = Mock() + databricks.sql.connect( + session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS + ) + + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][0], + mock_session_config, + ) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_initial_namespace_passthrough(self, mock_client_class): From fffac5f70ebfecfcb55fc330cc6ebeab67a2c5e0 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 21:38:48 +0530 Subject: [PATCH 39/48] walrus again --- src/databricks/sql/telemetry/telemetry_client.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index e36f05ea7..c194f9f90 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -398,9 +398,7 @@ def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor with _lock: - # if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: - if session_id_hex in _clients: - telemetry_client = _clients.pop(session_id_hex) + if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: logger.debug("Removing telemetry client for connection %s", session_id_hex) telemetry_client.close() From b77208a8231fccec5b03590dc1d5cb6fbcfb5418 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 21:42:46 +0530 Subject: [PATCH 40/48] old stuff without walrus seems to fail --- src/databricks/sql/telemetry/telemetry_client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index c194f9f90..e36f05ea7 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -398,7 +398,9 @@ def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor with _lock: - if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: + # if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: + if session_id_hex in _clients: + telemetry_client = _clients.pop(session_id_hex) logger.debug("Removing telemetry client for connection %s", session_id_hex) telemetry_client.close() From 733c288e36854398a03baa57060493ac1cb6474a Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 21:44:45 +0530 Subject: [PATCH 41/48] manually do the walrussing --- src/databricks/sql/telemetry/telemetry_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index e36f05ea7..af32de489 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -399,8 +399,8 @@ def close_telemetry_client(session_id_hex): global _initialized, _executor with _lock: # if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: - if session_id_hex in _clients: - telemetry_client = _clients.pop(session_id_hex) + telemetry_client = _clients.pop(session_id_hex, None) + if telemetry_client is not None: logger.debug("Removing telemetry client for connection %s", session_id_hex) telemetry_client.close() From ca8b9586e360e6a70de88657a76f952f0d3cdd71 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 23:36:57 +0530 Subject: [PATCH 42/48] change 3.13t, v2 Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 6 +++--- src/databricks/sql/telemetry/telemetry_client.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 158ac64a6..df6a0e169 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -61,7 +61,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit + run: poetry run python -m pytest tests/unit run-unit-tests-with-arrow: runs-on: ubuntu-latest strategy: @@ -75,7 +75,7 @@ jobs: uses: actions/checkout@v2 - name: Set up python ${{ matrix.python-version }} id: setup-python - uses: actions/setup-python@v5 + uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} #---------------------------------------------- @@ -117,7 +117,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13t"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: #---------------------------------------------- # check-out repo and set-up python diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index af32de489..e36f05ea7 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -399,8 +399,8 @@ def close_telemetry_client(session_id_hex): global _initialized, _executor with _lock: # if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: - telemetry_client = _clients.pop(session_id_hex, None) - if telemetry_client is not None: + if session_id_hex in _clients: + telemetry_client = _clients.pop(session_id_hex) logger.debug("Removing telemetry client for connection %s", session_id_hex) telemetry_client.close() From 3eabac9641b893927ca417e2ad2d8007914c4455 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 23:45:24 +0530 Subject: [PATCH 43/48] formatting, added walrus Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/telemetry_client.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index e36f05ea7..5ff1a63d9 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -322,9 +322,7 @@ def _initialize(): _executor = ThreadPoolExecutor(max_workers=10) _install_exception_hook() _initialized = True - logger.debug( - "Telemetry system initialized with thread pool (max_workers=10)" - ) + logger.debug("Telemetry system initialized with thread pool (max_workers=10)") def _install_exception_hook(): @@ -398,9 +396,9 @@ def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor with _lock: - # if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: - if session_id_hex in _clients: - telemetry_client = _clients.pop(session_id_hex) + if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: + # if session_id_hex in _clients: + # telemetry_client = _clients.pop(session_id_hex) logger.debug("Removing telemetry client for connection %s", session_id_hex) telemetry_client.close() From fb9ef43b3857625d73977ff98c70992cbc9dcc12 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 23:55:49 +0530 Subject: [PATCH 44/48] formatting Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/telemetry_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 5ff1a63d9..c10dd4083 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -397,8 +397,8 @@ def close_telemetry_client(session_id_hex): global _initialized, _executor with _lock: if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: - # if session_id_hex in _clients: - # telemetry_client = _clients.pop(session_id_hex) + # if session_id_hex in _clients: + # telemetry_client = _clients.pop(session_id_hex) logger.debug("Removing telemetry client for connection %s", session_id_hex) telemetry_client.close() From 1e795aa4f7a5bb2e12e79748a1f76fc368bc8645 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 23:59:53 +0530 Subject: [PATCH 45/48] removed walrus, removed test before stalling test Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 6 +-- tests/unit/test_client.py | 54 +++++++++---------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index c10dd4083..c7daff14f 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -396,9 +396,9 @@ def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor with _lock: - if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: - # if session_id_hex in _clients: - # telemetry_client = _clients.pop(session_id_hex) + # if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: + if session_id_hex in _clients: + telemetry_client = _clients.pop(session_id_hex) logger.debug("Removing telemetry client for connection %s", session_id_hex) telemetry_client.close() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index d64b06b5f..cc41e6c87 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -790,41 +790,41 @@ def test_cursor_context_manager_handles_exit_exception(self): cursor.close.assert_called_once() - def test_connection_close_handles_cursor_close_exception(self): - """Test that _close handles exceptions from cursor.close() properly.""" - print("banana") - cursors_closed = [] + # def test_connection_close_handles_cursor_close_exception(self): + # """Test that _close handles exceptions from cursor.close() properly.""" + # print("banana") + # cursors_closed = [] - def mock_close_with_exception(): - cursors_closed.append(1) - raise Exception("Test error during close") + # def mock_close_with_exception(): + # cursors_closed.append(1) + # raise Exception("Test error during close") - cursor1 = Mock() - cursor1.close = mock_close_with_exception + # cursor1 = Mock() + # cursor1.close = mock_close_with_exception - def mock_close_normal(): - cursors_closed.append(2) + # def mock_close_normal(): + # cursors_closed.append(2) - cursor2 = Mock() - cursor2.close = mock_close_normal + # cursor2 = Mock() + # cursor2.close = mock_close_normal - mock_backend = Mock() - mock_session_handle = Mock() + # mock_backend = Mock() + # mock_session_handle = Mock() - try: - for cursor in [cursor1, cursor2]: - try: - cursor.close() - except Exception: - pass + # try: + # for cursor in [cursor1, cursor2]: + # try: + # cursor.close() + # except Exception: + # pass - mock_backend.close_session(mock_session_handle) - except Exception as e: - self.fail(f"Connection close should handle exceptions: {e}") + # mock_backend.close_session(mock_session_handle) + # except Exception as e: + # self.fail(f"Connection close should handle exceptions: {e}") - self.assertEqual( - cursors_closed, [1, 2], "Both cursors should have close called" - ) + # self.assertEqual( + # cursors_closed, [1, 2], "Both cursors should have close called" + # ) def test_resultset_close_handles_cursor_already_closed_error(self): """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" From 2c293a5f66294bd0c5ded64825ec774eba000763 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 18 Jun 2025 09:27:03 +0530 Subject: [PATCH 46/48] changed order of stalling test Signed-off-by: Sai Shree Pradhan --- tests/unit/test_client.py | 56 +++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index cc41e6c87..d3a0c9866 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -352,7 +352,7 @@ def test_context_manager_closes_cursor(self): cursor.close.assert_called() @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_context_manager_closes_connection(self, mock_client_class): + def test_a_context_manager_closes_connection(self, mock_client_class): print("hellow1") instance = mock_client_class.return_value @@ -790,41 +790,41 @@ def test_cursor_context_manager_handles_exit_exception(self): cursor.close.assert_called_once() - # def test_connection_close_handles_cursor_close_exception(self): - # """Test that _close handles exceptions from cursor.close() properly.""" - # print("banana") - # cursors_closed = [] + def test_connection_close_handles_cursor_close_exception(self): + """Test that _close handles exceptions from cursor.close() properly.""" + print("banana") + cursors_closed = [] - # def mock_close_with_exception(): - # cursors_closed.append(1) - # raise Exception("Test error during close") + def mock_close_with_exception(): + cursors_closed.append(1) + raise Exception("Test error during close") - # cursor1 = Mock() - # cursor1.close = mock_close_with_exception + cursor1 = Mock() + cursor1.close = mock_close_with_exception - # def mock_close_normal(): - # cursors_closed.append(2) + def mock_close_normal(): + cursors_closed.append(2) - # cursor2 = Mock() - # cursor2.close = mock_close_normal + cursor2 = Mock() + cursor2.close = mock_close_normal - # mock_backend = Mock() - # mock_session_handle = Mock() + mock_backend = Mock() + mock_session_handle = Mock() - # try: - # for cursor in [cursor1, cursor2]: - # try: - # cursor.close() - # except Exception: - # pass + try: + for cursor in [cursor1, cursor2]: + try: + cursor.close() + except Exception: + pass - # mock_backend.close_session(mock_session_handle) - # except Exception as e: - # self.fail(f"Connection close should handle exceptions: {e}") + mock_backend.close_session(mock_session_handle) + except Exception as e: + self.fail(f"Connection close should handle exceptions: {e}") - # self.assertEqual( - # cursors_closed, [1, 2], "Both cursors should have close called" - # ) + self.assertEqual( + cursors_closed, [1, 2], "Both cursors should have close called" + ) def test_resultset_close_handles_cursor_already_closed_error(self): """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" From d237255818ef9246e417430d86be07475cf0a786 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 18 Jun 2025 10:40:40 +0530 Subject: [PATCH 47/48] removed debugging, added TelemetryClientFactory Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 10 +- src/databricks/sql/exc.py | 6 +- .../sql/telemetry/telemetry_client.py | 226 ++++++++++-------- tests/unit/test_client.py | 5 +- tests/unit/test_telemetry.py | 203 ++++++++-------- 5 files changed, 231 insertions(+), 219 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 9359c4272..26705f3f8 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -54,9 +54,7 @@ ) from databricks.sql.telemetry.telemetry_client import ( TelemetryHelper, - initialize_telemetry_client, - get_telemetry_client, - close_telemetry_client, + TelemetryClientFactory, ) from databricks.sql.telemetry.models.enums import DatabricksClientType from databricks.sql.telemetry.models.event import ( @@ -308,14 +306,14 @@ def read(self) -> Optional[OAuthToken]: kwargs.get("use_inline_params", False) ) - initialize_telemetry_client( + TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=self.telemetry_enabled, session_id_hex=self.get_session_id_hex(), auth_provider=auth_provider, host_url=self.host, ) - self._telemetry_client = get_telemetry_client( + self._telemetry_client = TelemetryClientFactory.get_telemetry_client( session_id_hex=self.get_session_id_hex() ) @@ -472,7 +470,7 @@ def _close(self, close_cursors=True) -> None: self.open = False - close_telemetry_client(self.get_session_id_hex()) + TelemetryClientFactory.close(self.get_session_id_hex()) def commit(self): """No-op because Databricks does not support transactions""" diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 9ca662126..30fd6c26d 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -1,7 +1,7 @@ import json import logging -from databricks.sql.telemetry.telemetry_client import get_telemetry_client +from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory logger = logging.getLogger(__name__) @@ -22,7 +22,9 @@ def __init__( error_name = self.__class__.__name__ if session_id_hex: - telemetry_client = get_telemetry_client(session_id_hex) + telemetry_client = TelemetryClientFactory.get_telemetry_client( + session_id_hex + ) telemetry_client.export_failure_log(error_name, self.message) def __str__(self): diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index c7daff14f..f7fccf47a 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -305,111 +305,131 @@ def close(self): self._flush() -# Module-level state -_clients: Dict[str, BaseTelemetryClient] = {} -_executor: Optional[ThreadPoolExecutor] = None -_initialized: bool = False -_lock = threading.Lock() -_original_excepthook = None -_excepthook_installed = False - - -def _initialize(): - """Initialize the telemetry system if not already initialized""" - global _initialized, _executor - if not _initialized: - _clients.clear() - _executor = ThreadPoolExecutor(max_workers=10) - _install_exception_hook() - _initialized = True - logger.debug("Telemetry system initialized with thread pool (max_workers=10)") - - -def _install_exception_hook(): - """Install global exception handler for unhandled exceptions""" - global _excepthook_installed, _original_excepthook - if not _excepthook_installed: - _original_excepthook = sys.excepthook - sys.excepthook = _handle_unhandled_exception - _excepthook_installed = True - logger.debug("Global exception handler installed for telemetry") - - -def _handle_unhandled_exception(exc_type, exc_value, exc_traceback): - """Handle unhandled exceptions by sending telemetry and flushing thread pool""" - logger.debug("Handling unhandled exception: %s", exc_type.__name__) - - clients_to_close = list(_clients.values()) - for client in clients_to_close: - client.close() - - # Call the original exception handler to maintain normal behavior - if _original_excepthook: - _original_excepthook(exc_type, exc_value, exc_traceback) - - -def initialize_telemetry_client( - telemetry_enabled, session_id_hex, auth_provider, host_url -): - """Initialize a telemetry client for a specific connection if telemetry is enabled""" - try: - with _lock: - _initialize() - if session_id_hex not in _clients: - logger.debug( - "Creating new TelemetryClient for connection %s", session_id_hex - ) - if telemetry_enabled: - _clients[session_id_hex] = TelemetryClient( - telemetry_enabled=telemetry_enabled, - session_id_hex=session_id_hex, - auth_provider=auth_provider, - host_url=host_url, - executor=_executor, - ) - print("i have initialized the telemetry client yes") - else: - _clients[session_id_hex] = NoopTelemetryClient() - print("i have initialized the noop client yes") - except Exception as e: - logger.debug("Failed to initialize telemetry client: %s", e) - # Fallback to NoopTelemetryClient to ensure connection doesn't fail - _clients[session_id_hex] = NoopTelemetryClient() - - -def get_telemetry_client(session_id_hex): - """Get the telemetry client for a specific connection""" - try: - if session_id_hex in _clients: - return _clients[session_id_hex] - else: - logger.error( - "Telemetry client not initialized for connection %s", session_id_hex +class TelemetryClientFactory: + """ + Static factory class for creating and managing telemetry clients. + It uses a thread pool to handle asynchronous operations. + """ + + _clients: Dict[ + str, BaseTelemetryClient + ] = {} # Map of session_id_hex -> BaseTelemetryClient + _executor: Optional[ThreadPoolExecutor] = None + _initialized: bool = False + _lock = threading.Lock() # Thread safety for factory operations + _original_excepthook = None + _excepthook_installed = False + + @classmethod + def _initialize(cls): + """Initialize the factory if not already initialized""" + + if not cls._initialized: + cls._clients = {} + cls._executor = ThreadPoolExecutor( + max_workers=10 + ) # Thread pool for async operations TODO: Decide on max workers + cls._install_exception_hook() + cls._initialized = True + logger.debug( + "TelemetryClientFactory initialized with thread pool (max_workers=10)" ) - return NoopTelemetryClient() - except Exception as e: - logger.debug("Failed to get telemetry client: %s", e) - return NoopTelemetryClient() - - -def close_telemetry_client(session_id_hex): - """Remove the telemetry client for a specific connection""" - global _initialized, _executor - with _lock: - # if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: - if session_id_hex in _clients: - telemetry_client = _clients.pop(session_id_hex) - logger.debug("Removing telemetry client for connection %s", session_id_hex) - telemetry_client.close() - - # Shutdown executor if no more clients + + @classmethod + def _install_exception_hook(cls): + """Install global exception handler for unhandled exceptions""" + if not cls._excepthook_installed: + cls._original_excepthook = sys.excepthook + sys.excepthook = cls._handle_unhandled_exception + cls._excepthook_installed = True + logger.debug("Global exception handler installed for telemetry") + + @classmethod + def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): + """Handle unhandled exceptions by sending telemetry and flushing thread pool""" + logger.debug("Handling unhandled exception: %s", exc_type.__name__) + + clients_to_close = list(cls._clients.values()) + for client in clients_to_close: + client.close() + + # Call the original exception handler to maintain normal behavior + if cls._original_excepthook: + cls._original_excepthook(exc_type, exc_value, exc_traceback) + + @staticmethod + def initialize_telemetry_client( + telemetry_enabled, + session_id_hex, + auth_provider, + host_url, + ): + """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: - if not _clients and _executor: + + with TelemetryClientFactory._lock: + TelemetryClientFactory._initialize() + + if session_id_hex not in TelemetryClientFactory._clients: + logger.debug( + "Creating new TelemetryClient for connection %s", + session_id_hex, + ) + if telemetry_enabled: + TelemetryClientFactory._clients[ + session_id_hex + ] = TelemetryClient( + telemetry_enabled=telemetry_enabled, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + executor=TelemetryClientFactory._executor, + ) + else: + TelemetryClientFactory._clients[ + session_id_hex + ] = NoopTelemetryClient() + except Exception as e: + logger.debug("Failed to initialize telemetry client: %s", e) + # Fallback to NoopTelemetryClient to ensure connection doesn't fail + TelemetryClientFactory._clients[session_id_hex] = NoopTelemetryClient() + + @staticmethod + def get_telemetry_client(session_id_hex): + """Get the telemetry client for a specific connection""" + try: + if session_id_hex in TelemetryClientFactory._clients: + return TelemetryClientFactory._clients[session_id_hex] + else: + logger.error( + "Telemetry client not initialized for connection %s", + session_id_hex, + ) + return NoopTelemetryClient() + except Exception as e: + logger.debug("Failed to get telemetry client: %s", e) + return NoopTelemetryClient() + + @staticmethod + def close(session_id_hex): + """Close and remove the telemetry client for a specific connection""" + + with TelemetryClientFactory._lock: + if ( + telemetry_client := TelemetryClientFactory._clients.pop( + session_id_hex, None + ) + ) is not None: + logger.debug( + "Removing telemetry client for connection %s", session_id_hex + ) + telemetry_client.close() + + # Shutdown executor if no more clients + if not TelemetryClientFactory._clients and TelemetryClientFactory._executor: logger.debug( "No more telemetry clients, shutting down thread pool executor" ) - _executor.shutdown(wait=True) - _executor = None - _initialized = False - except Exception as e: - logger.debug("Failed to shutdown thread pool executor: %s", e) + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index d3a0c9866..f9206dc27 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -336,7 +336,6 @@ def test_negative_fetch_throws_exception(self): result_set.fetchmany(-1) def test_context_manager_closes_cursor(self): - print("hellow") mock_close = Mock() with client.Cursor(Mock(), Mock()) as cursor: cursor.close = mock_close @@ -352,8 +351,7 @@ def test_context_manager_closes_cursor(self): cursor.close.assert_called() @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_a_context_manager_closes_connection(self, mock_client_class): - print("hellow1") + def test_context_manager_closes_connection(self, mock_client_class): instance = mock_client_class.return_value mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() @@ -792,7 +790,6 @@ def test_cursor_context_manager_handles_exit_exception(self): def test_connection_close_handles_cursor_close_exception(self): """Test that _close handles exceptions from cursor.close() properly.""" - print("banana") cursors_closed = [] def mock_close_with_exception(): diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 84833dafd..699480bbe 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -6,9 +6,7 @@ from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, NoopTelemetryClient, - initialize_telemetry_client, - get_telemetry_client, - close_telemetry_client, + TelemetryClientFactory, TelemetryHelper, BaseTelemetryClient ) @@ -63,19 +61,18 @@ def telemetry_client_setup(): def telemetry_system_reset(): """Fixture to reset telemetry system state before each test.""" # Reset the static state before each test - from databricks.sql.telemetry.telemetry_client import _clients, _executor, _initialized - _clients.clear() - if _executor: - _executor.shutdown(wait=True) - _executor = None - _initialized = False + TelemetryClientFactory._clients.clear() + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False yield # Cleanup after test if needed - _clients.clear() - if _executor: - _executor.shutdown(wait=True) - _executor = None - _initialized = False + TelemetryClientFactory._clients.clear() + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False class TestNoopTelemetryClient: @@ -335,6 +332,77 @@ class TestBaseClient(BaseTelemetryClient): TestBaseClient() # Can't instantiate abstract class +class TestTelemetryHelper: + """Tests for the TelemetryHelper class.""" + + def test_get_driver_system_configuration(self): + """Test getting driver system configuration.""" + config = TelemetryHelper.get_driver_system_configuration() + + assert isinstance(config.driver_name, str) + assert isinstance(config.driver_version, str) + assert isinstance(config.runtime_name, str) + assert isinstance(config.runtime_vendor, str) + assert isinstance(config.runtime_version, str) + assert isinstance(config.os_name, str) + assert isinstance(config.os_version, str) + assert isinstance(config.os_arch, str) + assert isinstance(config.locale_name, str) + assert isinstance(config.char_set_encoding, str) + + assert config.driver_name == "Databricks SQL Python Connector" + assert "Python" in config.runtime_name + assert config.runtime_vendor in ["CPython", "PyPy", "Jython", "IronPython"] + assert config.os_name in ["Darwin", "Linux", "Windows"] + + # Verify caching behavior + config2 = TelemetryHelper.get_driver_system_configuration() + assert config is config2 # Should return same instance + + def test_get_auth_mechanism(self): + """Test getting auth mechanism for different auth providers.""" + # Test PAT auth + pat_auth = AccessTokenAuthProvider("test-token") + assert TelemetryHelper.get_auth_mechanism(pat_auth) == AuthMech.PAT + + # Test OAuth auth + oauth_auth = MagicMock(spec=DatabricksOAuthProvider) + assert TelemetryHelper.get_auth_mechanism(oauth_auth) == AuthMech.DATABRICKS_OAUTH + + # Test External auth + external_auth = MagicMock(spec=ExternalAuthProvider) + assert TelemetryHelper.get_auth_mechanism(external_auth) == AuthMech.EXTERNAL_AUTH + + # Test None auth provider + assert TelemetryHelper.get_auth_mechanism(None) is None + + # Test unknown auth provider + unknown_auth = MagicMock() + assert TelemetryHelper.get_auth_mechanism(unknown_auth) == AuthMech.CLIENT_CERT + + def test_get_auth_flow(self): + """Test getting auth flow for different OAuth providers.""" + # Test OAuth with existing tokens + oauth_with_tokens = MagicMock(spec=DatabricksOAuthProvider) + oauth_with_tokens._access_token = "test-access-token" + oauth_with_tokens._refresh_token = "test-refresh-token" + assert TelemetryHelper.get_auth_flow(oauth_with_tokens) == AuthFlow.TOKEN_PASSTHROUGH + + # Test OAuth with browser-based auth + oauth_with_browser = MagicMock(spec=DatabricksOAuthProvider) + oauth_with_browser._access_token = None + oauth_with_browser._refresh_token = None + oauth_with_browser.oauth_manager = MagicMock() + assert TelemetryHelper.get_auth_flow(oauth_with_browser) == AuthFlow.BROWSER_BASED_AUTHENTICATION + + # Test non-OAuth provider + pat_auth = AccessTokenAuthProvider("test-token") + assert TelemetryHelper.get_auth_flow(pat_auth) is None + + # Test None auth provider + assert TelemetryHelper.get_auth_flow(None) is None + + class TestTelemetrySystem: """Tests for the telemetry system functions.""" @@ -344,14 +412,14 @@ def test_initialize_telemetry_client_enabled(self, telemetry_system_reset): auth_provider = MagicMock() host_url = "test-host" - initialize_telemetry_client( + TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, auth_provider=auth_provider, host_url=host_url, ) - client = get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, TelemetryClient) assert client._session_id_hex == session_id_hex assert client._auth_provider == auth_provider @@ -360,19 +428,19 @@ def test_initialize_telemetry_client_enabled(self, telemetry_system_reset): def test_initialize_telemetry_client_disabled(self, telemetry_system_reset): """Test initializing a telemetry client when telemetry is disabled.""" session_id_hex = "test-uuid" - initialize_telemetry_client( + TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, session_id_hex=session_id_hex, auth_provider=MagicMock(), host_url="test-host", ) - client = get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) def test_get_telemetry_client_nonexistent(self, telemetry_system_reset): """Test getting a non-existent telemetry client.""" - client = get_telemetry_client("nonexistent-uuid") + client = TelemetryClientFactory.get_telemetry_client("nonexistent-uuid") assert isinstance(client, NoopTelemetryClient) def test_close_telemetry_client(self, telemetry_system_reset): @@ -381,126 +449,53 @@ def test_close_telemetry_client(self, telemetry_system_reset): auth_provider = MagicMock() host_url = "test-host" - initialize_telemetry_client( + TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, auth_provider=auth_provider, host_url=host_url, ) - client = get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, TelemetryClient) client.close = MagicMock() - close_telemetry_client(session_id_hex) + TelemetryClientFactory.close(session_id_hex) client.close.assert_called_once() - client = get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) def test_close_telemetry_client_noop(self, telemetry_system_reset): """Test closing a no-op telemetry client.""" session_id_hex = "test-uuid" - initialize_telemetry_client( + TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, session_id_hex=session_id_hex, auth_provider=MagicMock(), host_url="test-host", ) - client = get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) client.close = MagicMock() - close_telemetry_client(session_id_hex) + TelemetryClientFactory.close(session_id_hex) client.close.assert_called_once() - client = get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) - @patch("databricks.sql.telemetry.telemetry_client._handle_unhandled_exception") + @patch("databricks.sql.telemetry.telemetry_client.TelemetryClientFactory._handle_unhandled_exception") def test_global_exception_hook(self, mock_handle_exception, telemetry_system_reset): """Test that global exception hook is installed and handles exceptions.""" - from databricks.sql.telemetry.telemetry_client import _install_exception_hook, _handle_unhandled_exception - - _install_exception_hook() + TelemetryClientFactory._install_exception_hook() test_exception = ValueError("Test exception") - _handle_unhandled_exception(type(test_exception), test_exception, None) + TelemetryClientFactory._handle_unhandled_exception(type(test_exception), test_exception, None) - mock_handle_exception.assert_called_once_with(type(test_exception), test_exception, None) - - -class TestTelemetryHelper: - """Tests for the TelemetryHelper class.""" - - def test_get_driver_system_configuration(self): - """Test getting driver system configuration.""" - config = TelemetryHelper.get_driver_system_configuration() - - assert isinstance(config.driver_name, str) - assert isinstance(config.driver_version, str) - assert isinstance(config.runtime_name, str) - assert isinstance(config.runtime_vendor, str) - assert isinstance(config.runtime_version, str) - assert isinstance(config.os_name, str) - assert isinstance(config.os_version, str) - assert isinstance(config.os_arch, str) - assert isinstance(config.locale_name, str) - assert isinstance(config.char_set_encoding, str) - - assert config.driver_name == "Databricks SQL Python Connector" - assert "Python" in config.runtime_name - assert config.runtime_vendor in ["CPython", "PyPy", "Jython", "IronPython"] - assert config.os_name in ["Darwin", "Linux", "Windows"] - - # Verify caching behavior - config2 = TelemetryHelper.get_driver_system_configuration() - assert config is config2 # Should return same instance - - def test_get_auth_mechanism(self): - """Test getting auth mechanism for different auth providers.""" - # Test PAT auth - pat_auth = AccessTokenAuthProvider("test-token") - assert TelemetryHelper.get_auth_mechanism(pat_auth) == AuthMech.PAT - - # Test OAuth auth - oauth_auth = MagicMock(spec=DatabricksOAuthProvider) - assert TelemetryHelper.get_auth_mechanism(oauth_auth) == AuthMech.DATABRICKS_OAUTH - - # Test External auth - external_auth = MagicMock(spec=ExternalAuthProvider) - assert TelemetryHelper.get_auth_mechanism(external_auth) == AuthMech.EXTERNAL_AUTH - - # Test None auth provider - assert TelemetryHelper.get_auth_mechanism(None) is None - - # Test unknown auth provider - unknown_auth = MagicMock() - assert TelemetryHelper.get_auth_mechanism(unknown_auth) == AuthMech.CLIENT_CERT - - def test_get_auth_flow(self): - """Test getting auth flow for different OAuth providers.""" - # Test OAuth with existing tokens - oauth_with_tokens = MagicMock(spec=DatabricksOAuthProvider) - oauth_with_tokens._access_token = "test-access-token" - oauth_with_tokens._refresh_token = "test-refresh-token" - assert TelemetryHelper.get_auth_flow(oauth_with_tokens) == AuthFlow.TOKEN_PASSTHROUGH - - # Test OAuth with browser-based auth - oauth_with_browser = MagicMock(spec=DatabricksOAuthProvider) - oauth_with_browser._access_token = None - oauth_with_browser._refresh_token = None - oauth_with_browser.oauth_manager = MagicMock() - assert TelemetryHelper.get_auth_flow(oauth_with_browser) == AuthFlow.BROWSER_BASED_AUTHENTICATION - - # Test non-OAuth provider - pat_auth = AccessTokenAuthProvider("test-token") - assert TelemetryHelper.get_auth_flow(pat_auth) is None - - # Test None auth provider - assert TelemetryHelper.get_auth_flow(None) is None \ No newline at end of file + mock_handle_exception.assert_called_once_with(type(test_exception), test_exception, None) \ No newline at end of file From f101b198d7e23b51f6758fbd0869c658910e9a65 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 18 Jun 2025 10:44:09 +0530 Subject: [PATCH 48/48] remove more debugging Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 2 +- tests/unit/test_client.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index df6a0e169..462d22369 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -112,7 +112,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit -v -s + run: poetry run python -m pytest tests/unit check-linting: runs-on: ubuntu-latest strategy: diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index f9206dc27..588b0d70e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -52,6 +52,7 @@ def new(cls): ) ThriftBackendMock.execute_command.return_value = MockTExecuteStatementResp + return ThriftBackendMock @classmethod