From 61ce3913adb109dc33aab0932bb7a2e5b222abe5 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 13 May 2025 13:24:42 +0000 Subject: [PATCH 01/10] draft implementation --- databricks/sdk/config.py | 32 +++- databricks/sdk/credentials_provider.py | 154 ++++++++++------- databricks/sdk/oauth.py | 5 - databricks/sdk/oidc.py | 230 +++++++++++++++++++++++++ 4 files changed, 349 insertions(+), 72 deletions(-) create mode 100644 databricks/sdk/oidc.py diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 1e674806f..ca6eb21bc 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -14,11 +14,14 @@ from ._base_client import _fix_host_if_needed from .clock import Clock, RealClock from .credentials_provider import CredentialsStrategy, DefaultCredentials -from .environments import (ALL_ENVS, AzureEnvironment, Cloud, - DatabricksEnvironment, get_environment_for_hostname) -from .oauth import (OidcEndpoints, Token, get_account_endpoints, - get_azure_entra_id_workspace_endpoints, - get_workspace_endpoints) +from .environments import ALL_ENVS, AzureEnvironment, Cloud, DatabricksEnvironment, get_environment_for_hostname +from .oauth import ( + OidcEndpoints, + Token, + get_account_endpoints, + get_azure_entra_id_workspace_endpoints, + get_workspace_endpoints, +) logger = logging.getLogger("databricks.sdk") @@ -60,10 +63,21 @@ def with_user_agent_extra(key: str, value: str): class Config: host: str = ConfigAttribute(env="DATABRICKS_HOST") account_id: str = ConfigAttribute(env="DATABRICKS_ACCOUNT_ID") + + # PAT token. token: str = ConfigAttribute(env="DATABRICKS_TOKEN", auth="pat", sensitive=True) + + # Audience for OIDC ID token source accepting an audience as a parameter. + # For example, the GitHub action ID token source. token_audience: str = ConfigAttribute(env="DATABRICKS_TOKEN_AUDIENCE", auth="github-oidc") + + # Environment variable for OIDC token. + oidc_token_env: str = ConfigAttribute(env="DATABRICKS_OIDC_TOKEN_ENV", auth="env-oidc") + oidc_token_filepath: str = ConfigAttribute(env="DATABRICKS_OIDC_TOKEN_FILE", auth="file-oidc") + username: str = ConfigAttribute(env="DATABRICKS_USERNAME", auth="basic") password: str = ConfigAttribute(env="DATABRICKS_PASSWORD", auth="basic", sensitive=True) + client_id: str = ConfigAttribute(env="DATABRICKS_CLIENT_ID", auth="oauth") client_secret: str = ConfigAttribute(env="DATABRICKS_CLIENT_SECRET", auth="oauth", sensitive=True) profile: str = ConfigAttribute(env="DATABRICKS_CONFIG_PROFILE") @@ -194,7 +208,7 @@ def oauth_token(self) -> Token: def wrap_debug_info(self, message: str) -> str: debug_string = self.debug_string() if debug_string: - message = f'{message.rstrip(".")}. {debug_string}' + message = f"{message.rstrip('.')}. {debug_string}" return message @staticmethod @@ -337,9 +351,9 @@ def debug_string(self) -> str: safe = "***" if attr.sensitive else f"{value}" attrs_used.append(f"{attr.name}={safe}") if attrs_used: - buf.append(f'Config: {", ".join(attrs_used)}') + buf.append(f"Config: {', '.join(attrs_used)}") if envs_used: - buf.append(f'Env: {", ".join(envs_used)}') + buf.append(f"Env: {', '.join(envs_used)}") return ". ".join(buf) def to_dict(self) -> Dict[str, any]: @@ -481,7 +495,7 @@ def _known_file_config_loader(self): if profile not in profiles: raise ValueError(f"resolve: {config_path} has no {profile} profile configured") raw_config = profiles[profile] - logger.info(f'loading {profile} profile from {config_file}: {", ".join(raw_config.keys())}') + logger.info(f"loading {profile} profile from {config_file}: {', '.join(raw_config.keys())}") for k, v in raw_config.items(): if k in self._inner: # don't overwrite a value previously set diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 86bd5c4d2..c6f3e8121 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -20,10 +20,7 @@ from google.auth.transport.requests import Request # type: ignore from google.oauth2 import service_account # type: ignore -from .azure import add_sp_management_token, add_workspace_id_header -from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token, - TokenCache, TokenSource) -from .oidc_token_supplier import GitHubOIDCTokenSupplier +from . import azure, oauth, oidc, oidc_token_supplier CredentialsProvider = Callable[[], Dict[str, str]] @@ -36,7 +33,7 @@ class OAuthCredentialsProvider: def __init__( self, credentials_provider: CredentialsProvider, - token_provider: Callable[[], Token], + token_provider: Callable[[], oauth.Token], ): self._credentials_provider = credentials_provider self._token_provider = token_provider @@ -44,7 +41,7 @@ def __init__( def __call__(self) -> Dict[str, str]: return self._credentials_provider() - def oauth_token(self) -> Token: + def oauth_token(self) -> oauth.Token: return self._token_provider() @@ -77,7 +74,7 @@ def auth_type(self) -> str: def __call__(self, cfg: "Config") -> OAuthCredentialsProvider: return self._headers_provider(cfg) - def oauth_token(self, cfg: "Config") -> Token: + def oauth_token(self, cfg: "Config") -> oauth.Token: return self._headers_provider(cfg).oauth_token() @@ -89,7 +86,6 @@ def credentials_strategy(name: str, require: List[str]): def inner( func: Callable[["Config"], CredentialsProvider], ) -> CredentialsStrategy: - @functools.wraps(func) def wrapper(cfg: "Config") -> Optional[CredentialsProvider]: for attr in require: @@ -112,7 +108,6 @@ def oauth_credentials_strategy(name: str, require: List[str]): def inner( func: Callable[["Config"], OAuthCredentialsProvider], ) -> OauthCredentialsStrategy: - @functools.wraps(func) def wrapper(cfg: "Config") -> Optional[OAuthCredentialsProvider]: for attr in require: @@ -156,9 +151,7 @@ def runtime_native_auth(cfg: "Config") -> Optional[CredentialsProvider]: # This import MUST be after the "DATABRICKS_RUNTIME_VERSION" check # above, so that we are not throwing import errors when not in # runtime and no config variables are set. - from databricks.sdk.runtime import (init_runtime_legacy_auth, - init_runtime_native_auth, - init_runtime_repl_auth) + from databricks.sdk.runtime import init_runtime_legacy_auth, init_runtime_native_auth, init_runtime_repl_auth for init in [ init_runtime_native_auth, @@ -186,7 +179,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]: if oidc is None: return None - token_source = ClientCredentials( + token_source = oauth.ClientCredentials( client_id=cfg.client_id, client_secret=cfg.client_secret, token_url=oidc.token_endpoint, @@ -199,7 +192,7 @@ def inner() -> Dict[str, str]: token = token_source.token() return {"Authorization": f"{token.token_type} {token.access_token}"} - def token() -> Token: + def token() -> oauth.Token: return token_source.token() return OAuthCredentialsProvider(inner, token) @@ -224,7 +217,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]: # local to the Python SDK and not reused by other SDKs. oidc_endpoints = cfg.oidc_endpoints redirect_url = "http://localhost:8020" - token_cache = TokenCache( + token_cache = oauth.TokenCache( host=cfg.host, oidc_endpoints=oidc_endpoints, client_id=client_id, @@ -243,7 +236,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]: except Exception as e: logger.warning(f"Failed to refresh cached token: {e}. Initiating new OAuth login flow") - oauth_client = OAuthClient( + oauth_client = oauth.OAuthClient( oidc_endpoints=oidc_endpoints, client_id=client_id, redirect_url=redirect_url, @@ -258,7 +251,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]: return credentials(cfg) -def _ensure_host_present(cfg: "Config", token_source_for: Callable[[str], TokenSource]): +def _ensure_host_present(cfg: "Config", token_source_for: Callable[[str], oauth.TokenSource]): """Resolves Azure Databricks workspace URL from ARM Resource ID""" if cfg.host: return @@ -284,9 +277,9 @@ def azure_service_principal(cfg: "Config") -> CredentialsProvider: to every request, while automatically resolving different Azure environment endpoints. """ - def token_source_for(resource: str) -> TokenSource: + def token_source_for(resource: str) -> oauth.TokenSource: aad_endpoint = cfg.arm_environment.active_directory_endpoint - return ClientCredentials( + return oauth.ClientCredentials( client_id=cfg.azure_client_id, client_secret=cfg.azure_client_secret, token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token", @@ -305,18 +298,58 @@ def refreshed_headers() -> Dict[str, str]: headers = { "Authorization": f"Bearer {inner.token().access_token}", } - add_workspace_id_header(cfg, headers) - add_sp_management_token(cloud, headers) + azure.add_workspace_id_header(cfg, headers) + azure.add_sp_management_token(cloud, headers) return headers - def token() -> Token: + def token() -> oauth.Token: return inner.token() return OAuthCredentialsProvider(refreshed_headers, token) +@credentials_strategy("env-oidc", ["host"]) +def env_oidc(cfg) -> Optional[CredentialsProvider]: + # Search for an OIDC ID token in DATABRICKS_OIDC_TOKEN environment variable + # by default. This can be overridden by setting DATABRICKS_OIDC_TOKEN_ENV + # to the name of an environment variable that contains the OIDC ID token. + env_var = "DATABRICKS_OIDC_TOKEN" + if cfg.oidc_token_env: + env_var = cfg.oidc_token_env + + return _oidc_credentials_provider(cfg, oidc.EnvIdTokenSource(env_var)) + + +# This function is a helper function to create an OIDC CredentialsProvider +# that provides a Databricks token from an IdTokenSource. +def _oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Optional[CredentialsProvider]: + try: + id_token = id_token_source.id_token() + except Exception as e: + logger.debug(f"Failed to get OIDC token: {e}") + return None + + token_source = oidc.DatabricksOidcTokenSource( + host=cfg.host, + token_endpoint=cfg.oidc_endpoints.token_endpoint, + client_id=cfg.client_id, + account_id=cfg.account_id, + id_token=id_token, + disable_async=cfg.disable_async_token_refresh, + ) + + def refreshed_headers() -> Dict[str, str]: + token = token_source.token() + return {"Authorization": f"{token.token_type} {token.access_token}"} + + def token() -> oauth.Token: + return token_source.token() + + return OAuthCredentialsProvider(refreshed_headers, token) + + @oauth_credentials_strategy("github-oidc", ["host", "client_id"]) -def databricks_wif(cfg: "Config") -> Optional[CredentialsProvider]: +def github_oidc(cfg: "Config") -> Optional[CredentialsProvider]: """ DatabricksWIFCredentials uses a Token Supplier to get a JWT Token and exchanges it for a Databricks Token. @@ -324,7 +357,7 @@ def databricks_wif(cfg: "Config") -> Optional[CredentialsProvider]: Supported suppliers: - GitHub OIDC """ - supplier = GitHubOIDCTokenSupplier() + supplier = oidc_token_supplier.GitHubOIDCTokenSupplier() audience = cfg.token_audience if audience is None and cfg.is_account_client: @@ -337,21 +370,21 @@ def databricks_wif(cfg: "Config") -> Optional[CredentialsProvider]: if not id_token: return None - def token_source_for(audience: str) -> TokenSource: + def token_source_for(audience: str) -> oauth.TokenSource: id_token = supplier.get_oidc_token(audience) if not id_token: # Should not happen, since we checked it above. raise Exception("Cannot get OIDC token") - params = { - "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", - "subject_token": id_token, - "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", - } - return ClientCredentials( + + return oauth.ClientCredentials( client_id=cfg.client_id, client_secret="", # we have no (rotatable) secrets in OIDC flow token_url=cfg.oidc_endpoints.token_endpoint, - endpoint_params=params, + endpoint_params={ + "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", + "subject_token": id_token, + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + }, scopes=["all-apis"], use_params=True, disable_async=cfg.disable_async_token_refresh, @@ -361,7 +394,7 @@ def refreshed_headers() -> Dict[str, str]: token = token_source_for(audience).token() return {"Authorization": f"{token.token_type} {token.access_token}"} - def token() -> Token: + def token() -> oauth.Token: return token_source_for(audience).token() return OAuthCredentialsProvider(refreshed_headers, token) @@ -378,7 +411,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: if not cfg.is_azure: return None - token = GitHubOIDCTokenSupplier().get_oidc_token("api://AzureADTokenExchange") + token = oidc_token_supplier.GitHubOIDCTokenSupplier().get_oidc_token("api://AzureADTokenExchange") if not token: return None @@ -386,21 +419,22 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: "Configured AAD token for GitHub Actions OIDC (%s)", cfg.azure_client_id, ) - params = { - "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", - "resource": cfg.effective_azure_login_app_id, - "client_assertion": token, - } + aad_endpoint = cfg.arm_environment.active_directory_endpoint if not cfg.azure_tenant_id: # detect Azure AD Tenant ID if it's not specified directly token_endpoint = cfg.oidc_endpoints.token_endpoint cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, "").split("/")[0] - inner = ClientCredentials( + + inner = oauth.ClientCredentials( client_id=cfg.azure_client_id, client_secret="", # we have no (rotatable) secrets in OIDC flow token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token", - endpoint_params=params, + endpoint_params={ + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "resource": cfg.effective_azure_login_app_id, + "client_assertion": token, + }, use_params=True, disable_async=cfg.disable_async_token_refresh, ) @@ -409,7 +443,7 @@ def refreshed_headers() -> Dict[str, str]: token = inner.token() return {"Authorization": f"{token.token_type} {token.access_token}"} - def token() -> Token: + def token() -> oauth.Token: return inner.token() return OAuthCredentialsProvider(refreshed_headers, token) @@ -442,7 +476,7 @@ def google_credentials(cfg: "Config") -> Optional[CredentialsProvider]: gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info, scopes=GcpScopes) - def token() -> Token: + def token() -> oauth.Token: credentials.refresh(request) return credentials.token @@ -483,7 +517,7 @@ def google_id(cfg: "Config") -> Optional[CredentialsProvider]: request = Request() - def token() -> Token: + def token() -> oauth.Token: id_creds.refresh(request) return id_creds.token @@ -498,8 +532,7 @@ def refreshed_headers() -> Dict[str, str]: return OAuthCredentialsProvider(refreshed_headers, token) -class CliTokenSource(Refreshable): - +class CliTokenSource(oauth.Refreshable): def __init__( self, cmd: List[str], @@ -525,12 +558,12 @@ def _parse_expiry(expiry: str) -> datetime: if last_e: raise last_e - def refresh(self) -> Token: + def refresh(self) -> oauth.Token: try: out = _run_subprocess(self._cmd, capture_output=True, check=True) it = json.loads(out.stdout.decode()) expires_on = self._parse_expiry(it[self._expiry_field]) - return Token( + return oauth.Token( access_token=it[self._access_token_field], token_type=it[self._token_type_field], expiry=expires_on, @@ -558,7 +591,7 @@ def _run_subprocess( kwargs["shell"] = sys.platform.startswith("win") # windows requires shell=True to be able to execute 'az login' or other commands # cannot use shell=True all the time, as it breaks macOS - logging.debug(f'Running command: {" ".join(popenargs)}') + logging.debug(f"Running command: {' '.join(popenargs)}") return subprocess.run( popenargs, input=input, @@ -689,7 +722,7 @@ def azure_cli(cfg: "Config") -> Optional[CredentialsProvider]: mgmt_token_source = AzureCliTokenSource.for_resource(cfg, management_endpoint) except Exception as e: logger.debug( - f"Not including service management token in headers", + "Not including service management token in headers", exc_info=e, ) mgmt_token_source = None @@ -700,9 +733,9 @@ def azure_cli(cfg: "Config") -> Optional[CredentialsProvider]: def inner() -> Dict[str, str]: token = token_source.token() headers = {"Authorization": f"{token.token_type} {token.access_token}"} - add_workspace_id_header(cfg, headers) + azure.add_workspace_id_header(cfg, headers) if mgmt_token_source: - add_sp_management_token(mgmt_token_source, headers) + azure.add_sp_management_token(mgmt_token_source, headers) return headers return inner @@ -784,13 +817,13 @@ def inner() -> Dict[str, str]: token = token_source.token() return {"Authorization": f"{token.token_type} {token.access_token}"} - def token() -> Token: + def token() -> oauth.Token: return token_source.token() return OAuthCredentialsProvider(inner, token) -class MetadataServiceTokenSource(Refreshable): +class MetadataServiceTokenSource(oauth.Refreshable): """Obtain the token granted by Databricks Metadata Service""" METADATA_SERVICE_VERSION = "1" @@ -803,7 +836,7 @@ def __init__(self, cfg: "Config"): self.url = cfg.metadata_service_url self.host = cfg.host - def refresh(self) -> Token: + def refresh(self) -> oauth.Token: resp = requests.get( self.url, timeout=self._metadata_service_timeout, @@ -831,7 +864,7 @@ def refresh(self) -> Token: except: raise ValueError("Metadata Service returned invalid expiry") - return Token(access_token=access_token, token_type=token_type, expiry=expiry) + return oauth.Token(access_token=access_token, token_type=token_type, expiry=expiry) @credentials_strategy("metadata-service", ["host", "metadata_service_url"]) @@ -972,7 +1005,8 @@ def __init__(self) -> None: basic_auth, metadata_service, oauth_service_principal, - databricks_wif, + env_oidc, + github_oidc, azure_service_principal, github_oidc_azure, azure_cli, @@ -987,7 +1021,7 @@ def __init__(self) -> None: def auth_type(self) -> str: return self._auth_type - def oauth_token(self, cfg: "Config") -> Token: + def oauth_token(self, cfg: "Config") -> oauth.Token: for provider in self._auth_providers: auth_type = provider.auth_type() if auth_type != self._auth_type: @@ -1004,9 +1038,13 @@ def __call__(self, cfg: "Config") -> CredentialsProvider: continue logger.debug(f"Attempting to configure auth: {auth_type}") try: + # The header factory might be None if the provider cannot be + # configured for the current environment. For example, if the + # provider requires some missing environment variables. header_factory = provider(cfg) if not header_factory: continue + self._auth_type = auth_type return header_factory except Exception as e: diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index e099dbf07..f18f0cd51 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -156,7 +156,6 @@ def jwt_claims(self) -> Dict[str, str]: class TokenSource: - @abstractmethod def token(self) -> Token: pass @@ -341,7 +340,6 @@ def refresh(self) -> Token: class _OAuthCallback(BaseHTTPRequestHandler): - def __init__(self, feedback: list, *args): self._feedback = feedback super().__init__(*args) @@ -418,7 +416,6 @@ def get_azure_entra_id_workspace_endpoints( class SessionCredentials(Refreshable): - def __init__( self, token: Token, @@ -493,7 +490,6 @@ def refresh(self) -> Token: class Consent: - def __init__( self, state: str, @@ -627,7 +623,6 @@ def __init__( scopes: List[str] = None, client_secret: str = None, ): - if not scopes: # all-apis ensures that the returned OAuth token can be used with all APIs, aside # from direct-to-dataplane APIs. diff --git a/databricks/sdk/oidc.py b/databricks/sdk/oidc.py new file mode 100644 index 000000000..7fc8748d9 --- /dev/null +++ b/databricks/sdk/oidc.py @@ -0,0 +1,230 @@ +""" +Package oidc provides utilities for working with OIDC ID tokens. + +This package is experimental and subject to change. +""" + +import logging +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional + +import requests + +from . import oauth + +logger = logging.getLogger(__name__) + + +@dataclass +class IdToken: + """Represents an OIDC ID token that can be exchanged for a Databricks access token. + + Parameters + ---------- + jwt : str + The signed JWT token string. + """ + + jwt: str + + +class IdTokenSource(ABC): + """Abstract base class representing anything that returns an IDToken. + + This class defines the interface for token sources that can provide OIDC ID tokens. + """ + + @abstractmethod + def id_token(self) -> IdToken: + """Get an ID token. + + Returns + ------- + IdToken + An ID token. + + Raises + ------ + Exception + Implementation specific exceptions. + """ + + +class EnvIdTokenSource(IdTokenSource): + """IDTokenSource that reads the ID token from an environment variable. + + Parameters + ---------- + env_var : str + The name of the environment variable containing the ID token. + """ + + def __init__(self, env_var: str): + self.env_var = env_var + + def id_token(self) -> IdToken: + """Get an ID token from an environment variable. + + Returns + ------- + IdToken + An ID token. + + Raises + ------ + ValueError + If the environment variable is not set. + """ + token = os.getenv(self.env_var) + if not token: + raise ValueError(f"Missing env var {self.env_var!r}") + return IdToken(jwt=token) + + +class FileIdTokenSource(IdTokenSource): + """IDTokenSource that reads the ID token from a file. + + Parameters + ---------- + path : str + The path to the file containing the ID token. + """ + + def __init__(self, path: str): + self.path = path + + def id_token(self) -> IdToken: + """Get an ID token from a file. + + Returns + ------- + IdToken + An ID token. + + Raises + ------ + ValueError + If the file is empty, does not exist, or cannot be read. + """ + if not self.path: + raise ValueError("Missing path") + try: + with open(self.path, "r") as f: + token = f.read().strip() + if not token: + raise ValueError(f"File {self.path!r} is empty") + return IdToken(jwt=token) + except FileNotFoundError: + raise ValueError(f"File {self.path!r} does not exist") + except Exception as e: + raise ValueError(f"Error reading token file: {str(e)}") + + +class GitHubIdTokenSource(IdTokenSource): + """ + Supplies OIDC tokens from GitHub Actions. + """ + + def __init__(self, request_token: str, request_url: str): + self._request_token = request_token + self._request_url = request_url + + def id_token(self, audience: str) -> IdToken: + # See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers + headers = {"Authorization": f"Bearer {self._request_token}"} + endpoint = f"{self._request_url}&audience={audience}" + response = requests.get(endpoint, headers=headers) + if not response.ok: + raise ValueError(f"Failed to get ID token: {response.status_code} {response.text}") + + # get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name + response_json = response.json() + if "value" not in response_json: + raise ValueError("Missing value in response") + + return IdToken(jwt=response_json["value"]) + + +class DatabricksOidcTokenSource(oauth.TokenSource): + """A TokenSource which exchanges a token using Workload Identity Federation. + + Parameters + ---------- + host : str + The host of the Databricks account or workspace. + id_token_source : IdTokenSource + IDTokenSource that returns the IDToken to be used for the token exchange. + token_endpoint_provider : Callable[[], dict] + Returns the token endpoint for the Databricks OIDC application. + client_id : Optional[str], optional + ClientID of the Databricks OIDC application. It corresponds to the + Application ID of the Databricks Service Principal. Only required for + Workload Identity Federation and should be empty for Account-wide token + federation. + account_id : Optional[str], optional + The account ID of the Databricks Account. Only required for + Account-wide token federation. + audience : Optional[str], optional + The audience of the Databricks OIDC application. Only used for + Workspace level tokens. + """ + + def __init__( + self, + host: str, + token_endpoint: str, + id_token_source: IdTokenSource, + client_id: Optional[str] = None, + account_id: Optional[str] = None, + audience: Optional[str] = None, + disable_async: bool = False, + ): + self._host = host + self._id_token_source = id_token_source + self._token_endpoint = token_endpoint + self._client_id = client_id + self._account_id = account_id + self._audience = audience + self._disable_async = disable_async + + def token(self) -> oauth.Token: + """Get a token by exchanging the ID token. + + Returns + ------- + dict + The exchanged token. + + Raises + ------ + ValueError + If the host is missing or other configuration errors occur. + """ + if not self._host: + logger.debug("Missing Host") + raise ValueError("missing Host") + + if not self._client_id: + logger.debug("No ClientID provided, authenticating with Account-wide token federation") + else: + logger.debug("Client ID provided, authenticating with Workload Identity Federation") + + id_token = self._id_token_source.id_token() + + client = oauth.ClientCredentials( + client_id=self._client_id, + client_secret="", # we have no (rotatable) secrets in OIDC flow + token_url=self._token_endpoint, + endpoint_params={ + "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", + "subject_token": id_token, + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + }, + scopes=["all-apis"], + use_params=True, + disable_async=self._disable_async, + ) + + return client.token() From a2061735694ac56961cbcc907ef3c6fe86b83a65 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 23 May 2025 11:31:58 +0000 Subject: [PATCH 02/10] Add file_oidc --- databricks/sdk/config.py | 13 +++++-------- databricks/sdk/credentials_provider.py | 8 +++++++- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index ca6eb21bc..487527be7 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -14,14 +14,11 @@ from ._base_client import _fix_host_if_needed from .clock import Clock, RealClock from .credentials_provider import CredentialsStrategy, DefaultCredentials -from .environments import ALL_ENVS, AzureEnvironment, Cloud, DatabricksEnvironment, get_environment_for_hostname -from .oauth import ( - OidcEndpoints, - Token, - get_account_endpoints, - get_azure_entra_id_workspace_endpoints, - get_workspace_endpoints, -) +from .environments import (ALL_ENVS, AzureEnvironment, Cloud, + DatabricksEnvironment, get_environment_for_hostname) +from .oauth import (OidcEndpoints, Token, get_account_endpoints, + get_azure_entra_id_workspace_endpoints, + get_workspace_endpoints) logger = logging.getLogger("databricks.sdk") diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index c6f3e8121..fecb9f910 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -151,7 +151,9 @@ def runtime_native_auth(cfg: "Config") -> Optional[CredentialsProvider]: # This import MUST be after the "DATABRICKS_RUNTIME_VERSION" check # above, so that we are not throwing import errors when not in # runtime and no config variables are set. - from databricks.sdk.runtime import init_runtime_legacy_auth, init_runtime_native_auth, init_runtime_repl_auth + from databricks.sdk.runtime import (init_runtime_legacy_auth, + init_runtime_native_auth, + init_runtime_repl_auth) for init in [ init_runtime_native_auth, @@ -319,6 +321,9 @@ def env_oidc(cfg) -> Optional[CredentialsProvider]: return _oidc_credentials_provider(cfg, oidc.EnvIdTokenSource(env_var)) +@credentials_strategy("file-oidc", ["host", "oidc_token_filepath"]) +def file_oidc(cfg) -> Optional[CredentialsProvider]: + return _oidc_credentials_provider(cfg, oidc.FileIdTokenSource(cfg.oidc_token_filepath)) # This function is a helper function to create an OIDC CredentialsProvider # that provides a Databricks token from an IdTokenSource. @@ -1006,6 +1011,7 @@ def __init__(self) -> None: metadata_service, oauth_service_principal, env_oidc, + file_oidc, github_oidc, azure_service_principal, github_oidc_azure, From 89fafd471b57ecdf24b2648d16c80e86ffba4b8e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 23 May 2025 12:59:51 +0000 Subject: [PATCH 03/10] Unit tests --- databricks/sdk/credentials_provider.py | 2 + tests/test_config.py | 5 +- tests/test_oidc.py | 114 +++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 3 deletions(-) create mode 100644 tests/test_oidc.py diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index fecb9f910..2f5121180 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -321,10 +321,12 @@ def env_oidc(cfg) -> Optional[CredentialsProvider]: return _oidc_credentials_provider(cfg, oidc.EnvIdTokenSource(env_var)) + @credentials_strategy("file-oidc", ["host", "oidc_token_filepath"]) def file_oidc(cfg) -> Optional[CredentialsProvider]: return _oidc_credentials_provider(cfg, oidc.FileIdTokenSource(cfg.oidc_token_filepath)) + # This function is a helper function to create an OIDC CredentialsProvider # that provides a Databricks token from an IdTokenSource. def _oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Optional[CredentialsProvider]: diff --git a/tests/test_config.py b/tests/test_config.py index dc9d8e410..b023123af 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,9 +7,8 @@ import pytest -from databricks.sdk import useragent +from databricks.sdk import oauth, useragent from databricks.sdk.config import Config, with_product, with_user_agent_extra -from databricks.sdk.credentials_provider import Token from databricks.sdk.version import __version__ from .conftest import noop_credentials, set_az_path @@ -114,7 +113,7 @@ def test_config_copy_deep_copies_user_agent_other_info(config): def test_config_deep_copy(monkeypatch, mocker, tmp_path): mocker.patch( "databricks.sdk.credentials_provider.CliTokenSource.refresh", - return_value=Token( + return_value=oauth.Token( access_token="token", token_type="Bearer", expiry=datetime(2023, 5, 22, 0, 0, 0), diff --git a/tests/test_oidc.py b/tests/test_oidc.py new file mode 100644 index 000000000..6fe0ea4a0 --- /dev/null +++ b/tests/test_oidc.py @@ -0,0 +1,114 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +import pytest + +from databricks.sdk import oidc + + +class MockIdTokenSource(oidc.IdTokenSource): + def __init__(self, id_token: str, exception: Exception = None): + self.id_token = id_token + self.exception = exception + + def id_token(self) -> oidc.IdToken: + if self.exception: + raise self.exception + return oidc.IdToken(jwt=self.id_token) + + +@dataclass +class EnvTestCase: + name: str + env_name: str = "" + env_value: str = "" + want: oidc.IdToken = None + wantException: Exception = None + + +_env_id_test_cases = [ + EnvTestCase( + name="success", + env_name="OIDC_TEST_TOKEN_SUCCESS", + env_value="test-token-123", + want=oidc.IdToken(jwt="test-token-123"), + ), + EnvTestCase( + name="missing_env_var", + env_name="OIDC_TEST_TOKEN_MISSING", + env_value="", + wantException=ValueError, + ), + EnvTestCase( + name="empty_env_var", + env_name="OIDC_TEST_TOKEN_EMPTY", + env_value="", + wantException=ValueError, + ), + EnvTestCase( + name="different_variable_name", + env_name="ANOTHER_OIDC_TOKEN", + env_value="another-token-456", + want=oidc.IdToken(jwt="another-token-456"), + ), +] + + +@pytest.mark.parametrize("test_case", _env_id_test_cases) +def test_env_id_token_source(test_case: EnvIdTestCase, monkeypatch): + monkeypatch.setenv(test_case.env_name, test_case.env_value) + + source = oidc.EnvIdTokenSource(test_case.env_name) + if test_case.wantException: + with pytest.raises(test_case.wantException): + source.id_token() + else: + assert source.id_token() == test_case.want + + +@dataclass +class FileTestCase: + name: str + file: Optional[Tuple[str, str]] = None # (name, content) + filepath: str = "" + want: oidc.IdToken = None + wantException: Exception = None + + +_file_id_test_cases = [ + FileTestCase( + name="missing_filepath", + file=("token", "content"), + filepath="", + wantException=ValueError, + ), + FileTestCase( + name="empty_file", + file=("token", ""), + filepath="token", + wantException=ValueError, + ), + FileTestCase( + name="file_does_not_exist", + ), + FileTestCase( + name="file_exists", + file=("token", "content"), + filepath="token", + want=oidc.IdToken(jwt="content"), + ), +] + + +@pytest.mark.parametrize("test_case", _file_id_test_cases) +def test_file_id_token_source(test_case: FileTestCase, tmp_path): + if test_case.file: + token_file = tmp_path / test_case.file[0] + token_file.write_text(test_case.file[1]) + + source = oidc.FileIdTokenSource(test_case.filepath) + if test_case.wantException: + with pytest.raises(test_case.wantException): + source.id_token() + else: + assert source.id_token() == test_case.want From 2ddc6274bfb3824c7a8b6b7570213607347b0581 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 23 May 2025 13:05:59 +0000 Subject: [PATCH 04/10] Add changelogs --- NEXT_CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 34a0d9f1c..59d8ad3ce 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -4,6 +4,11 @@ ### New Features and Improvements +- Add support for OIDC ID token authentication from an environment variable + ([PR #977](https://github.com/databricks/databricks-sdk-py/pull/977)). +- Add support for OIDC ID token authentication from a file + ([PR #977](https://github.com/databricks/databricks-sdk-py/pull/977)). + ### Bug Fixes ### Documentation From 47062c3698c6dd5a4bfdcc2127c481793c04fa5c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 23 May 2025 13:06:38 +0000 Subject: [PATCH 05/10] Fix test --- tests/test_oidc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_oidc.py b/tests/test_oidc.py index 6fe0ea4a0..1ee6fe0fc 100644 --- a/tests/test_oidc.py +++ b/tests/test_oidc.py @@ -55,7 +55,7 @@ class EnvTestCase: @pytest.mark.parametrize("test_case", _env_id_test_cases) -def test_env_id_token_source(test_case: EnvIdTestCase, monkeypatch): +def test_env_id_token_source(test_case: EnvTestCase, monkeypatch): monkeypatch.setenv(test_case.env_name, test_case.env_value) source = oidc.EnvIdTokenSource(test_case.env_name) From 0ecaf3a88291dca42355c312953ee811d31302bd Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 23 May 2025 13:25:01 +0000 Subject: [PATCH 06/10] Fix test + exceptions --- databricks/sdk/oidc.py | 9 ++++++--- tests/test_credentials_provider.py | 14 +++++++------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/databricks/sdk/oidc.py b/databricks/sdk/oidc.py index 7fc8748d9..28b2b5349 100644 --- a/databricks/sdk/oidc.py +++ b/databricks/sdk/oidc.py @@ -110,16 +110,19 @@ def id_token(self) -> IdToken: """ if not self.path: raise ValueError("Missing path") + + token = None try: with open(self.path, "r") as f: token = f.read().strip() - if not token: - raise ValueError(f"File {self.path!r} is empty") - return IdToken(jwt=token) except FileNotFoundError: raise ValueError(f"File {self.path!r} does not exist") except Exception as e: raise ValueError(f"Error reading token file: {str(e)}") + + if not token: + raise ValueError(f"File {self.path!r} is empty") + return IdToken(jwt=token) class GitHubIdTokenSource(IdTokenSource): diff --git a/tests/test_credentials_provider.py b/tests/test_credentials_provider.py index fb24d9dc4..b23044d7c 100644 --- a/tests/test_credentials_provider.py +++ b/tests/test_credentials_provider.py @@ -26,7 +26,7 @@ def test_external_browser_refresh_success(mocker): # Inject the mock implementations. mocker.patch( - "databricks.sdk.credentials_provider.TokenCache", + "databricks.sdk.oauth.TokenCache", return_value=mock_token_cache, ) @@ -66,11 +66,11 @@ def test_external_browser_refresh_failure_new_oauth_flow(mocker): # Inject the mock implementations. mocker.patch( - "databricks.sdk.credentials_provider.TokenCache", + "databricks.sdk.oauth.TokenCache", return_value=mock_token_cache, ) mocker.patch( - "databricks.sdk.credentials_provider.OAuthClient", + "databricks.sdk.oauth.OAuthClient", return_value=mock_oauth_client, ) @@ -112,11 +112,11 @@ def test_external_browser_no_cached_credentials(mocker): # Inject the mock implementations. mocker.patch( - "databricks.sdk.credentials_provider.TokenCache", + "databricks.sdk.oauth.TokenCache", return_value=mock_token_cache, ) mocker.patch( - "databricks.sdk.credentials_provider.OAuthClient", + "databricks.sdk.oauth.OAuthClient", return_value=mock_oauth_client, ) @@ -150,11 +150,11 @@ def test_external_browser_consent_fails(mocker): # Inject the mock implementations. mocker.patch( - "databricks.sdk.credentials_provider.TokenCache", + "databricks.sdk.oauth.TokenCache", return_value=mock_token_cache, ) mocker.patch( - "databricks.sdk.credentials_provider.OAuthClient", + "databricks.sdk.oauth.OAuthClient", return_value=mock_oauth_client, ) From f19ed8ad4471fe9691dc49d7a067654de231633b Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 23 May 2025 13:43:58 +0000 Subject: [PATCH 07/10] Format --- databricks/sdk/oidc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/databricks/sdk/oidc.py b/databricks/sdk/oidc.py index 28b2b5349..d8a7b1cc0 100644 --- a/databricks/sdk/oidc.py +++ b/databricks/sdk/oidc.py @@ -110,7 +110,7 @@ def id_token(self) -> IdToken: """ if not self.path: raise ValueError("Missing path") - + token = None try: with open(self.path, "r") as f: @@ -119,7 +119,7 @@ def id_token(self) -> IdToken: raise ValueError(f"File {self.path!r} does not exist") except Exception as e: raise ValueError(f"Error reading token file: {str(e)}") - + if not token: raise ValueError(f"File {self.path!r} is empty") return IdToken(jwt=token) From 286efee5aae70758da46c9f69f137571c4c9feb3 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 23 May 2025 14:03:55 +0000 Subject: [PATCH 08/10] Update tests --- tests/test_oidc.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/tests/test_oidc.py b/tests/test_oidc.py index 1ee6fe0fc..302ef4c83 100644 --- a/tests/test_oidc.py +++ b/tests/test_oidc.py @@ -6,17 +6,6 @@ from databricks.sdk import oidc -class MockIdTokenSource(oidc.IdTokenSource): - def __init__(self, id_token: str, exception: Exception = None): - self.id_token = id_token - self.exception = exception - - def id_token(self) -> oidc.IdToken: - if self.exception: - raise self.exception - return oidc.IdToken(jwt=self.id_token) - - @dataclass class EnvTestCase: name: str @@ -90,6 +79,8 @@ class FileTestCase: ), FileTestCase( name="file_does_not_exist", + filepath="nonexistent-file", + wantException=ValueError, ), FileTestCase( name="file_exists", @@ -106,9 +97,24 @@ def test_file_id_token_source(test_case: FileTestCase, tmp_path): token_file = tmp_path / test_case.file[0] token_file.write_text(test_case.file[1]) - source = oidc.FileIdTokenSource(test_case.filepath) + fp = "" + if test_case.filepath: + fp = tmp_path / test_case.filepath + + source = oidc.FileIdTokenSource(fp) if test_case.wantException: with pytest.raises(test_case.wantException): source.id_token() else: assert source.id_token() == test_case.want + + +# class MockIdTokenSource(oidc.IdTokenSource): +# def __init__(self, id_token: str, exception: Exception = None): +# self.id_token = id_token +# self.exception = exception + +# def id_token(self) -> oidc.IdToken: +# if self.exception: +# raise self.exception +# return oidc.IdToken(jwt=self.id_token) From 9f7ea31300e63bee7ae3430d86b0697163b21987 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 23 May 2025 17:16:53 +0000 Subject: [PATCH 09/10] Remove Github id token source --- databricks/sdk/oidc.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/databricks/sdk/oidc.py b/databricks/sdk/oidc.py index d8a7b1cc0..5c0af2949 100644 --- a/databricks/sdk/oidc.py +++ b/databricks/sdk/oidc.py @@ -10,8 +10,6 @@ from dataclasses import dataclass from typing import Optional -import requests - from . import oauth logger = logging.getLogger(__name__) @@ -125,31 +123,6 @@ def id_token(self) -> IdToken: return IdToken(jwt=token) -class GitHubIdTokenSource(IdTokenSource): - """ - Supplies OIDC tokens from GitHub Actions. - """ - - def __init__(self, request_token: str, request_url: str): - self._request_token = request_token - self._request_url = request_url - - def id_token(self, audience: str) -> IdToken: - # See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers - headers = {"Authorization": f"Bearer {self._request_token}"} - endpoint = f"{self._request_url}&audience={audience}" - response = requests.get(endpoint, headers=headers) - if not response.ok: - raise ValueError(f"Failed to get ID token: {response.status_code} {response.text}") - - # get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name - response_json = response.json() - if "value" not in response_json: - raise ValueError("Missing value in response") - - return IdToken(jwt=response_json["value"]) - - class DatabricksOidcTokenSource(oauth.TokenSource): """A TokenSource which exchanges a token using Workload Identity Federation. From 9c5090e5b1913d86c7836d447d1c073d9e8bfe8e Mon Sep 17 00:00:00 2001 From: Renaud Hartert Date: Mon, 26 May 2025 11:42:05 +0200 Subject: [PATCH 10/10] Update test_oidc.py Signed-off-by: Renaud Hartert --- tests/test_oidc.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/test_oidc.py b/tests/test_oidc.py index 302ef4c83..4bed32e96 100644 --- a/tests/test_oidc.py +++ b/tests/test_oidc.py @@ -107,14 +107,3 @@ def test_file_id_token_source(test_case: FileTestCase, tmp_path): source.id_token() else: assert source.id_token() == test_case.want - - -# class MockIdTokenSource(oidc.IdTokenSource): -# def __init__(self, id_token: str, exception: Exception = None): -# self.id_token = id_token -# self.exception = exception - -# def id_token(self) -> oidc.IdToken: -# if self.exception: -# raise self.exception -# return oidc.IdToken(jwt=self.id_token)