diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index dbf4fa0a2..31df063f3 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -21,7 +21,10 @@ InterfaceError, NotSupportedError, ProgrammingError, + AuthenticationError, + ConnectionError, ) +from urllib3.exceptions import MaxRetryError from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.thrift_backend import ThriftBackend from databricks.sql.utils import ( @@ -242,9 +245,18 @@ def read(self) -> Optional[OAuthToken]: self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) - auth_provider = get_python_sql_connector_auth_provider( - server_hostname, **kwargs - ) + try: + auth_provider = get_python_sql_connector_auth_provider( + server_hostname, **kwargs + ) + except Exception as e: + raise AuthenticationError( + message=f"Failed to create authentication provider: {str(e)}", + host_url=server_hostname, + http_path=http_path, + port=self.port, + original_exception=e, + ) from e self.server_telemetry_enabled = True self.client_telemetry_enabled = kwargs.get("enable_telemetry", False) @@ -282,20 +294,33 @@ def read(self) -> Optional[OAuthToken]: tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) - self.thrift_backend = ThriftBackend( - self.host, - self.port, - http_path, - (http_headers or []) + base_headers, - auth_provider, - ssl_options=self._ssl_options, - _use_arrow_native_complex_types=_use_arrow_native_complex_types, - **kwargs, - ) + try: + self.thrift_backend = ThriftBackend( + self.host, + self.port, + http_path, + (http_headers or []) + base_headers, + auth_provider, + ssl_options=self._ssl_options, + _use_arrow_native_complex_types=_use_arrow_native_complex_types, + **kwargs, + ) + + self._open_session_resp = self.thrift_backend.open_session( + session_configuration, catalog, schema + ) + except (RequestError, MaxRetryError, MaxRetryDurationError) as e: + raise + except Exception as e: + raise ConnectionError( + message=f"Failed to establish connection: {str(e)}", + host_url=self.host, + http_path=http_path, + port=self.port, + user_agent=useragent_header, + original_exception=e, + ) from e - self._open_session_resp = self.thrift_backend.open_session( - session_configuration, catalog, schema - ) self._session_handle = self._open_session_resp.sessionHandle self.protocol_version = self.get_protocol_version(self._open_session_resp) self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 30fd6c26d..d3a175e26 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -22,10 +22,29 @@ def __init__( error_name = self.__class__.__name__ if session_id_hex: + # Normal case: we have a session, send to regular telemetry client telemetry_client = TelemetryClientFactory.get_telemetry_client( session_id_hex ) telemetry_client.export_failure_log(error_name, self.message) + elif ( + isinstance(self, (ConnectionError, AuthenticationError)) + and "host_url" in self.context + ): + # Connection error case: no session but we should still send telemetry + self._send_connection_error_telemetry(error_name) + + def _send_connection_error_telemetry(self, error_name): + """Send connection error telemetry to unauthenticated endpoint""" + + TelemetryClientFactory.send_connection_error_telemetry( + error_name=error_name, + error_message=self.message or str(self), + host_url=self.context["host_url"], + http_path=self.context.get("http_path", ""), + port=self.context.get("port", 443), + user_agent=self.context.get("user_agent"), + ) def __str__(self): return self.message @@ -126,3 +145,46 @@ class SessionAlreadyClosedError(RequestError): class CursorAlreadyClosedError(RequestError): """Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected.""" + + +class ConnectionError(OperationalError): + """Thrown when connection to Databricks fails during initial setup""" + + def __init__( + self, + message=None, + host_url=None, + http_path=None, + port=443, + user_agent=None, + original_exception=None, + **kwargs + ): + # Set up context for connection error telemetry + context = kwargs.get("context", {}) + if host_url: + context.update( + { + "host_url": host_url, + "http_path": http_path or "", + "port": port, + "user_agent": user_agent, + "original_exception": str(original_exception) + if original_exception + else None, + } + ) + + super().__init__(message=message, context=context, **kwargs) + + +class AuthenticationError(ConnectionError): + """Thrown when authentication to Databricks fails""" + + def __init__(self, message=None, auth_type=None, **kwargs): + context = kwargs.get("context", {}) + if auth_type: + context["auth_type"] = auth_type + kwargs["context"] = context + + super().__init__(message=message, **kwargs) diff --git a/src/databricks/sql/telemetry/models/event.py b/src/databricks/sql/telemetry/models/event.py index f5496deec..a155c7597 100644 --- a/src/databricks/sql/telemetry/models/event.py +++ b/src/databricks/sql/telemetry/models/event.py @@ -149,9 +149,9 @@ class TelemetryEvent(JsonSerializableMixin): operation_latency_ms (Optional[int]): Operation latency in milliseconds """ - session_id: str system_configuration: DriverSystemConfiguration driver_connection_params: DriverConnectionParameters + session_id: Optional[str] = None sql_statement_id: Optional[str] = None auth_type: Optional[str] = None vol_operation: Optional[DriverVolumeOperation] = None diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 936a07683..bbfdbe0d9 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -9,6 +9,8 @@ TelemetryEvent, DriverSystemConfiguration, DriverErrorInfo, + DriverConnectionParameters, + HostDetails, ) from databricks.sql.telemetry.models.frontend_logs import ( TelemetryFrontendLog, @@ -16,7 +18,11 @@ FrontendLogContext, FrontendLogEntry, ) -from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow +from databricks.sql.telemetry.models.enums import ( + AuthMech, + AuthFlow, + DatabricksClientType, +) from databricks.sql.auth.authenticators import ( AccessTokenAuthProvider, DatabricksOAuthProvider, @@ -437,3 +443,76 @@ def close(session_id_hex): logger.debug("Failed to shutdown thread pool executor: %s", e) TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False + + @staticmethod + def send_connection_error_telemetry( + error_name: str, + error_message: str, + host_url: str, + http_path: str, + port: int = 443, + user_agent: Optional[str] = None, + ): + """Send error telemetry when connection creation fails, without requiring a session""" + try: + logger.debug("Sending connection error telemetry for host: %s", host_url) + + # Initialize factory if needed (with proper locking) + with TelemetryClientFactory._lock: + TelemetryClientFactory._initialize() + + # Create driver connection params for the failed connection + driver_connection_params = DriverConnectionParameters( + http_path=http_path, + mode=DatabricksClientType.THRIFT, + host_info=HostDetails(host_url=host_url, port=port), + ) + + error_info = DriverErrorInfo( + error_name=error_name, stack_trace=error_message + ) + + telemetry_frontend_log = TelemetryFrontendLog( + frontend_log_event_id=str(uuid.uuid4()), + context=FrontendLogContext( + client_context=TelemetryClientContext( + timestamp_millis=int(time.time() * 1000), + user_agent=user_agent or "PyDatabricksSqlConnector", + ) + ), + entry=FrontendLogEntry( + sql_driver_log=TelemetryEvent( + system_configuration=TelemetryHelper.get_driver_system_configuration(), + driver_connection_params=driver_connection_params, + error_info=error_info, + ) + ), + ) + + # Send to unauthenticated endpoint since we don't have working auth + request = { + "uploadTime": int(time.time() * 1000), + "items": [], + "protoLogs": [telemetry_frontend_log.to_json()], + } + + url = f"https://{host_url}/telemetry-unauth" + headers = {"Accept": "application/json", "Content-Type": "application/json"} + + # Send synchronously for connection errors since we're probably about to exit + response = requests.post( + url, + data=json.dumps(request), + headers=headers, + timeout=5, + ) + if response.status_code == 200: + logger.debug("Connection error telemetry sent successfully") + else: + logger.debug( + "Connection error telemetry failed with status: %s", + response.status_code, + ) + + except Exception as e: + logger.debug("Failed to send connection error telemetry: %s", e)