diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 5539142bf..4d66a7489 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -4,6 +4,11 @@ ### New Features and Improvements +- Add support for OIDC ID token authentication from an environment variable + ([PR #977](https://github.com/databricks/databricks-sdk-py/pull/977)). +- Add support for OIDC ID token authentication from a file + ([PR #977](https://github.com/databricks/databricks-sdk-py/pull/977)). + ### Bug Fixes ### Documentation diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 1e674806f..487527be7 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -60,10 +60,21 @@ def with_user_agent_extra(key: str, value: str): class Config: host: str = ConfigAttribute(env="DATABRICKS_HOST") account_id: str = ConfigAttribute(env="DATABRICKS_ACCOUNT_ID") + + # PAT token. token: str = ConfigAttribute(env="DATABRICKS_TOKEN", auth="pat", sensitive=True) + + # Audience for OIDC ID token source accepting an audience as a parameter. + # For example, the GitHub action ID token source. token_audience: str = ConfigAttribute(env="DATABRICKS_TOKEN_AUDIENCE", auth="github-oidc") + + # Environment variable for OIDC token. + oidc_token_env: str = ConfigAttribute(env="DATABRICKS_OIDC_TOKEN_ENV", auth="env-oidc") + oidc_token_filepath: str = ConfigAttribute(env="DATABRICKS_OIDC_TOKEN_FILE", auth="file-oidc") + username: str = ConfigAttribute(env="DATABRICKS_USERNAME", auth="basic") password: str = ConfigAttribute(env="DATABRICKS_PASSWORD", auth="basic", sensitive=True) + client_id: str = ConfigAttribute(env="DATABRICKS_CLIENT_ID", auth="oauth") client_secret: str = ConfigAttribute(env="DATABRICKS_CLIENT_SECRET", auth="oauth", sensitive=True) profile: str = ConfigAttribute(env="DATABRICKS_CONFIG_PROFILE") @@ -194,7 +205,7 @@ def oauth_token(self) -> Token: def wrap_debug_info(self, message: str) -> str: debug_string = self.debug_string() if debug_string: - message = f'{message.rstrip(".")}. {debug_string}' + message = f"{message.rstrip('.')}. {debug_string}" return message @staticmethod @@ -337,9 +348,9 @@ def debug_string(self) -> str: safe = "***" if attr.sensitive else f"{value}" attrs_used.append(f"{attr.name}={safe}") if attrs_used: - buf.append(f'Config: {", ".join(attrs_used)}') + buf.append(f"Config: {', '.join(attrs_used)}") if envs_used: - buf.append(f'Env: {", ".join(envs_used)}') + buf.append(f"Env: {', '.join(envs_used)}") return ". ".join(buf) def to_dict(self) -> Dict[str, any]: @@ -481,7 +492,7 @@ def _known_file_config_loader(self): if profile not in profiles: raise ValueError(f"resolve: {config_path} has no {profile} profile configured") raw_config = profiles[profile] - logger.info(f'loading {profile} profile from {config_file}: {", ".join(raw_config.keys())}') + logger.info(f"loading {profile} profile from {config_file}: {', '.join(raw_config.keys())}") for k, v in raw_config.items(): if k in self._inner: # don't overwrite a value previously set diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 86bd5c4d2..2f5121180 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -20,10 +20,7 @@ from google.auth.transport.requests import Request # type: ignore from google.oauth2 import service_account # type: ignore -from .azure import add_sp_management_token, add_workspace_id_header -from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token, - TokenCache, TokenSource) -from .oidc_token_supplier import GitHubOIDCTokenSupplier +from . import azure, oauth, oidc, oidc_token_supplier CredentialsProvider = Callable[[], Dict[str, str]] @@ -36,7 +33,7 @@ class OAuthCredentialsProvider: def __init__( self, credentials_provider: CredentialsProvider, - token_provider: Callable[[], Token], + token_provider: Callable[[], oauth.Token], ): self._credentials_provider = credentials_provider self._token_provider = token_provider @@ -44,7 +41,7 @@ def __init__( def __call__(self) -> Dict[str, str]: return self._credentials_provider() - def oauth_token(self) -> Token: + def oauth_token(self) -> oauth.Token: return self._token_provider() @@ -77,7 +74,7 @@ def auth_type(self) -> str: def __call__(self, cfg: "Config") -> OAuthCredentialsProvider: return self._headers_provider(cfg) - def oauth_token(self, cfg: "Config") -> Token: + def oauth_token(self, cfg: "Config") -> oauth.Token: return self._headers_provider(cfg).oauth_token() @@ -89,7 +86,6 @@ def credentials_strategy(name: str, require: List[str]): def inner( func: Callable[["Config"], CredentialsProvider], ) -> CredentialsStrategy: - @functools.wraps(func) def wrapper(cfg: "Config") -> Optional[CredentialsProvider]: for attr in require: @@ -112,7 +108,6 @@ def oauth_credentials_strategy(name: str, require: List[str]): def inner( func: Callable[["Config"], OAuthCredentialsProvider], ) -> OauthCredentialsStrategy: - @functools.wraps(func) def wrapper(cfg: "Config") -> Optional[OAuthCredentialsProvider]: for attr in require: @@ -186,7 +181,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]: if oidc is None: return None - token_source = ClientCredentials( + token_source = oauth.ClientCredentials( client_id=cfg.client_id, client_secret=cfg.client_secret, token_url=oidc.token_endpoint, @@ -199,7 +194,7 @@ def inner() -> Dict[str, str]: token = token_source.token() return {"Authorization": f"{token.token_type} {token.access_token}"} - def token() -> Token: + def token() -> oauth.Token: return token_source.token() return OAuthCredentialsProvider(inner, token) @@ -224,7 +219,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]: # local to the Python SDK and not reused by other SDKs. oidc_endpoints = cfg.oidc_endpoints redirect_url = "http://localhost:8020" - token_cache = TokenCache( + token_cache = oauth.TokenCache( host=cfg.host, oidc_endpoints=oidc_endpoints, client_id=client_id, @@ -243,7 +238,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]: except Exception as e: logger.warning(f"Failed to refresh cached token: {e}. Initiating new OAuth login flow") - oauth_client = OAuthClient( + oauth_client = oauth.OAuthClient( oidc_endpoints=oidc_endpoints, client_id=client_id, redirect_url=redirect_url, @@ -258,7 +253,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]: return credentials(cfg) -def _ensure_host_present(cfg: "Config", token_source_for: Callable[[str], TokenSource]): +def _ensure_host_present(cfg: "Config", token_source_for: Callable[[str], oauth.TokenSource]): """Resolves Azure Databricks workspace URL from ARM Resource ID""" if cfg.host: return @@ -284,9 +279,9 @@ def azure_service_principal(cfg: "Config") -> CredentialsProvider: to every request, while automatically resolving different Azure environment endpoints. """ - def token_source_for(resource: str) -> TokenSource: + def token_source_for(resource: str) -> oauth.TokenSource: aad_endpoint = cfg.arm_environment.active_directory_endpoint - return ClientCredentials( + return oauth.ClientCredentials( client_id=cfg.azure_client_id, client_secret=cfg.azure_client_secret, token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token", @@ -305,18 +300,63 @@ def refreshed_headers() -> Dict[str, str]: headers = { "Authorization": f"Bearer {inner.token().access_token}", } - add_workspace_id_header(cfg, headers) - add_sp_management_token(cloud, headers) + azure.add_workspace_id_header(cfg, headers) + azure.add_sp_management_token(cloud, headers) return headers - def token() -> Token: + def token() -> oauth.Token: return inner.token() return OAuthCredentialsProvider(refreshed_headers, token) +@credentials_strategy("env-oidc", ["host"]) +def env_oidc(cfg) -> Optional[CredentialsProvider]: + # Search for an OIDC ID token in DATABRICKS_OIDC_TOKEN environment variable + # by default. This can be overridden by setting DATABRICKS_OIDC_TOKEN_ENV + # to the name of an environment variable that contains the OIDC ID token. + env_var = "DATABRICKS_OIDC_TOKEN" + if cfg.oidc_token_env: + env_var = cfg.oidc_token_env + + return _oidc_credentials_provider(cfg, oidc.EnvIdTokenSource(env_var)) + + +@credentials_strategy("file-oidc", ["host", "oidc_token_filepath"]) +def file_oidc(cfg) -> Optional[CredentialsProvider]: + return _oidc_credentials_provider(cfg, oidc.FileIdTokenSource(cfg.oidc_token_filepath)) + + +# This function is a helper function to create an OIDC CredentialsProvider +# that provides a Databricks token from an IdTokenSource. +def _oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Optional[CredentialsProvider]: + try: + id_token = id_token_source.id_token() + except Exception as e: + logger.debug(f"Failed to get OIDC token: {e}") + return None + + token_source = oidc.DatabricksOidcTokenSource( + host=cfg.host, + token_endpoint=cfg.oidc_endpoints.token_endpoint, + client_id=cfg.client_id, + account_id=cfg.account_id, + id_token=id_token, + disable_async=cfg.disable_async_token_refresh, + ) + + def refreshed_headers() -> Dict[str, str]: + token = token_source.token() + return {"Authorization": f"{token.token_type} {token.access_token}"} + + def token() -> oauth.Token: + return token_source.token() + + return OAuthCredentialsProvider(refreshed_headers, token) + + @oauth_credentials_strategy("github-oidc", ["host", "client_id"]) -def databricks_wif(cfg: "Config") -> Optional[CredentialsProvider]: +def github_oidc(cfg: "Config") -> Optional[CredentialsProvider]: """ DatabricksWIFCredentials uses a Token Supplier to get a JWT Token and exchanges it for a Databricks Token. @@ -324,7 +364,7 @@ def databricks_wif(cfg: "Config") -> Optional[CredentialsProvider]: Supported suppliers: - GitHub OIDC """ - supplier = GitHubOIDCTokenSupplier() + supplier = oidc_token_supplier.GitHubOIDCTokenSupplier() audience = cfg.token_audience if audience is None and cfg.is_account_client: @@ -337,21 +377,21 @@ def databricks_wif(cfg: "Config") -> Optional[CredentialsProvider]: if not id_token: return None - def token_source_for(audience: str) -> TokenSource: + def token_source_for(audience: str) -> oauth.TokenSource: id_token = supplier.get_oidc_token(audience) if not id_token: # Should not happen, since we checked it above. raise Exception("Cannot get OIDC token") - params = { - "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", - "subject_token": id_token, - "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", - } - return ClientCredentials( + + return oauth.ClientCredentials( client_id=cfg.client_id, client_secret="", # we have no (rotatable) secrets in OIDC flow token_url=cfg.oidc_endpoints.token_endpoint, - endpoint_params=params, + endpoint_params={ + "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", + "subject_token": id_token, + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + }, scopes=["all-apis"], use_params=True, disable_async=cfg.disable_async_token_refresh, @@ -361,7 +401,7 @@ def refreshed_headers() -> Dict[str, str]: token = token_source_for(audience).token() return {"Authorization": f"{token.token_type} {token.access_token}"} - def token() -> Token: + def token() -> oauth.Token: return token_source_for(audience).token() return OAuthCredentialsProvider(refreshed_headers, token) @@ -378,7 +418,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: if not cfg.is_azure: return None - token = GitHubOIDCTokenSupplier().get_oidc_token("api://AzureADTokenExchange") + token = oidc_token_supplier.GitHubOIDCTokenSupplier().get_oidc_token("api://AzureADTokenExchange") if not token: return None @@ -386,21 +426,22 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: "Configured AAD token for GitHub Actions OIDC (%s)", cfg.azure_client_id, ) - params = { - "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", - "resource": cfg.effective_azure_login_app_id, - "client_assertion": token, - } + aad_endpoint = cfg.arm_environment.active_directory_endpoint if not cfg.azure_tenant_id: # detect Azure AD Tenant ID if it's not specified directly token_endpoint = cfg.oidc_endpoints.token_endpoint cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, "").split("/")[0] - inner = ClientCredentials( + + inner = oauth.ClientCredentials( client_id=cfg.azure_client_id, client_secret="", # we have no (rotatable) secrets in OIDC flow token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token", - endpoint_params=params, + endpoint_params={ + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "resource": cfg.effective_azure_login_app_id, + "client_assertion": token, + }, use_params=True, disable_async=cfg.disable_async_token_refresh, ) @@ -409,7 +450,7 @@ def refreshed_headers() -> Dict[str, str]: token = inner.token() return {"Authorization": f"{token.token_type} {token.access_token}"} - def token() -> Token: + def token() -> oauth.Token: return inner.token() return OAuthCredentialsProvider(refreshed_headers, token) @@ -442,7 +483,7 @@ def google_credentials(cfg: "Config") -> Optional[CredentialsProvider]: gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info, scopes=GcpScopes) - def token() -> Token: + def token() -> oauth.Token: credentials.refresh(request) return credentials.token @@ -483,7 +524,7 @@ def google_id(cfg: "Config") -> Optional[CredentialsProvider]: request = Request() - def token() -> Token: + def token() -> oauth.Token: id_creds.refresh(request) return id_creds.token @@ -498,8 +539,7 @@ def refreshed_headers() -> Dict[str, str]: return OAuthCredentialsProvider(refreshed_headers, token) -class CliTokenSource(Refreshable): - +class CliTokenSource(oauth.Refreshable): def __init__( self, cmd: List[str], @@ -525,12 +565,12 @@ def _parse_expiry(expiry: str) -> datetime: if last_e: raise last_e - def refresh(self) -> Token: + def refresh(self) -> oauth.Token: try: out = _run_subprocess(self._cmd, capture_output=True, check=True) it = json.loads(out.stdout.decode()) expires_on = self._parse_expiry(it[self._expiry_field]) - return Token( + return oauth.Token( access_token=it[self._access_token_field], token_type=it[self._token_type_field], expiry=expires_on, @@ -558,7 +598,7 @@ def _run_subprocess( kwargs["shell"] = sys.platform.startswith("win") # windows requires shell=True to be able to execute 'az login' or other commands # cannot use shell=True all the time, as it breaks macOS - logging.debug(f'Running command: {" ".join(popenargs)}') + logging.debug(f"Running command: {' '.join(popenargs)}") return subprocess.run( popenargs, input=input, @@ -689,7 +729,7 @@ def azure_cli(cfg: "Config") -> Optional[CredentialsProvider]: mgmt_token_source = AzureCliTokenSource.for_resource(cfg, management_endpoint) except Exception as e: logger.debug( - f"Not including service management token in headers", + "Not including service management token in headers", exc_info=e, ) mgmt_token_source = None @@ -700,9 +740,9 @@ def azure_cli(cfg: "Config") -> Optional[CredentialsProvider]: def inner() -> Dict[str, str]: token = token_source.token() headers = {"Authorization": f"{token.token_type} {token.access_token}"} - add_workspace_id_header(cfg, headers) + azure.add_workspace_id_header(cfg, headers) if mgmt_token_source: - add_sp_management_token(mgmt_token_source, headers) + azure.add_sp_management_token(mgmt_token_source, headers) return headers return inner @@ -784,13 +824,13 @@ def inner() -> Dict[str, str]: token = token_source.token() return {"Authorization": f"{token.token_type} {token.access_token}"} - def token() -> Token: + def token() -> oauth.Token: return token_source.token() return OAuthCredentialsProvider(inner, token) -class MetadataServiceTokenSource(Refreshable): +class MetadataServiceTokenSource(oauth.Refreshable): """Obtain the token granted by Databricks Metadata Service""" METADATA_SERVICE_VERSION = "1" @@ -803,7 +843,7 @@ def __init__(self, cfg: "Config"): self.url = cfg.metadata_service_url self.host = cfg.host - def refresh(self) -> Token: + def refresh(self) -> oauth.Token: resp = requests.get( self.url, timeout=self._metadata_service_timeout, @@ -831,7 +871,7 @@ def refresh(self) -> Token: except: raise ValueError("Metadata Service returned invalid expiry") - return Token(access_token=access_token, token_type=token_type, expiry=expiry) + return oauth.Token(access_token=access_token, token_type=token_type, expiry=expiry) @credentials_strategy("metadata-service", ["host", "metadata_service_url"]) @@ -972,7 +1012,9 @@ def __init__(self) -> None: basic_auth, metadata_service, oauth_service_principal, - databricks_wif, + env_oidc, + file_oidc, + github_oidc, azure_service_principal, github_oidc_azure, azure_cli, @@ -987,7 +1029,7 @@ def __init__(self) -> None: def auth_type(self) -> str: return self._auth_type - def oauth_token(self, cfg: "Config") -> Token: + def oauth_token(self, cfg: "Config") -> oauth.Token: for provider in self._auth_providers: auth_type = provider.auth_type() if auth_type != self._auth_type: @@ -1004,9 +1046,13 @@ def __call__(self, cfg: "Config") -> CredentialsProvider: continue logger.debug(f"Attempting to configure auth: {auth_type}") try: + # The header factory might be None if the provider cannot be + # configured for the current environment. For example, if the + # provider requires some missing environment variables. header_factory = provider(cfg) if not header_factory: continue + self._auth_type = auth_type return header_factory except Exception as e: diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index e099dbf07..f18f0cd51 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -156,7 +156,6 @@ def jwt_claims(self) -> Dict[str, str]: class TokenSource: - @abstractmethod def token(self) -> Token: pass @@ -341,7 +340,6 @@ def refresh(self) -> Token: class _OAuthCallback(BaseHTTPRequestHandler): - def __init__(self, feedback: list, *args): self._feedback = feedback super().__init__(*args) @@ -418,7 +416,6 @@ def get_azure_entra_id_workspace_endpoints( class SessionCredentials(Refreshable): - def __init__( self, token: Token, @@ -493,7 +490,6 @@ def refresh(self) -> Token: class Consent: - def __init__( self, state: str, @@ -627,7 +623,6 @@ def __init__( scopes: List[str] = None, client_secret: str = None, ): - if not scopes: # all-apis ensures that the returned OAuth token can be used with all APIs, aside # from direct-to-dataplane APIs. diff --git a/databricks/sdk/oidc.py b/databricks/sdk/oidc.py new file mode 100644 index 000000000..5c0af2949 --- /dev/null +++ b/databricks/sdk/oidc.py @@ -0,0 +1,206 @@ +""" +Package oidc provides utilities for working with OIDC ID tokens. + +This package is experimental and subject to change. +""" + +import logging +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional + +from . import oauth + +logger = logging.getLogger(__name__) + + +@dataclass +class IdToken: + """Represents an OIDC ID token that can be exchanged for a Databricks access token. + + Parameters + ---------- + jwt : str + The signed JWT token string. + """ + + jwt: str + + +class IdTokenSource(ABC): + """Abstract base class representing anything that returns an IDToken. + + This class defines the interface for token sources that can provide OIDC ID tokens. + """ + + @abstractmethod + def id_token(self) -> IdToken: + """Get an ID token. + + Returns + ------- + IdToken + An ID token. + + Raises + ------ + Exception + Implementation specific exceptions. + """ + + +class EnvIdTokenSource(IdTokenSource): + """IDTokenSource that reads the ID token from an environment variable. + + Parameters + ---------- + env_var : str + The name of the environment variable containing the ID token. + """ + + def __init__(self, env_var: str): + self.env_var = env_var + + def id_token(self) -> IdToken: + """Get an ID token from an environment variable. + + Returns + ------- + IdToken + An ID token. + + Raises + ------ + ValueError + If the environment variable is not set. + """ + token = os.getenv(self.env_var) + if not token: + raise ValueError(f"Missing env var {self.env_var!r}") + return IdToken(jwt=token) + + +class FileIdTokenSource(IdTokenSource): + """IDTokenSource that reads the ID token from a file. + + Parameters + ---------- + path : str + The path to the file containing the ID token. + """ + + def __init__(self, path: str): + self.path = path + + def id_token(self) -> IdToken: + """Get an ID token from a file. + + Returns + ------- + IdToken + An ID token. + + Raises + ------ + ValueError + If the file is empty, does not exist, or cannot be read. + """ + if not self.path: + raise ValueError("Missing path") + + token = None + try: + with open(self.path, "r") as f: + token = f.read().strip() + except FileNotFoundError: + raise ValueError(f"File {self.path!r} does not exist") + except Exception as e: + raise ValueError(f"Error reading token file: {str(e)}") + + if not token: + raise ValueError(f"File {self.path!r} is empty") + return IdToken(jwt=token) + + +class DatabricksOidcTokenSource(oauth.TokenSource): + """A TokenSource which exchanges a token using Workload Identity Federation. + + Parameters + ---------- + host : str + The host of the Databricks account or workspace. + id_token_source : IdTokenSource + IDTokenSource that returns the IDToken to be used for the token exchange. + token_endpoint_provider : Callable[[], dict] + Returns the token endpoint for the Databricks OIDC application. + client_id : Optional[str], optional + ClientID of the Databricks OIDC application. It corresponds to the + Application ID of the Databricks Service Principal. Only required for + Workload Identity Federation and should be empty for Account-wide token + federation. + account_id : Optional[str], optional + The account ID of the Databricks Account. Only required for + Account-wide token federation. + audience : Optional[str], optional + The audience of the Databricks OIDC application. Only used for + Workspace level tokens. + """ + + def __init__( + self, + host: str, + token_endpoint: str, + id_token_source: IdTokenSource, + client_id: Optional[str] = None, + account_id: Optional[str] = None, + audience: Optional[str] = None, + disable_async: bool = False, + ): + self._host = host + self._id_token_source = id_token_source + self._token_endpoint = token_endpoint + self._client_id = client_id + self._account_id = account_id + self._audience = audience + self._disable_async = disable_async + + def token(self) -> oauth.Token: + """Get a token by exchanging the ID token. + + Returns + ------- + dict + The exchanged token. + + Raises + ------ + ValueError + If the host is missing or other configuration errors occur. + """ + if not self._host: + logger.debug("Missing Host") + raise ValueError("missing Host") + + if not self._client_id: + logger.debug("No ClientID provided, authenticating with Account-wide token federation") + else: + logger.debug("Client ID provided, authenticating with Workload Identity Federation") + + id_token = self._id_token_source.id_token() + + client = oauth.ClientCredentials( + client_id=self._client_id, + client_secret="", # we have no (rotatable) secrets in OIDC flow + token_url=self._token_endpoint, + endpoint_params={ + "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", + "subject_token": id_token, + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + }, + scopes=["all-apis"], + use_params=True, + disable_async=self._disable_async, + ) + + return client.token() diff --git a/tests/test_config.py b/tests/test_config.py index dc9d8e410..b023123af 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,9 +7,8 @@ import pytest -from databricks.sdk import useragent +from databricks.sdk import oauth, useragent from databricks.sdk.config import Config, with_product, with_user_agent_extra -from databricks.sdk.credentials_provider import Token from databricks.sdk.version import __version__ from .conftest import noop_credentials, set_az_path @@ -114,7 +113,7 @@ def test_config_copy_deep_copies_user_agent_other_info(config): def test_config_deep_copy(monkeypatch, mocker, tmp_path): mocker.patch( "databricks.sdk.credentials_provider.CliTokenSource.refresh", - return_value=Token( + return_value=oauth.Token( access_token="token", token_type="Bearer", expiry=datetime(2023, 5, 22, 0, 0, 0), diff --git a/tests/test_credentials_provider.py b/tests/test_credentials_provider.py index fb24d9dc4..b23044d7c 100644 --- a/tests/test_credentials_provider.py +++ b/tests/test_credentials_provider.py @@ -26,7 +26,7 @@ def test_external_browser_refresh_success(mocker): # Inject the mock implementations. mocker.patch( - "databricks.sdk.credentials_provider.TokenCache", + "databricks.sdk.oauth.TokenCache", return_value=mock_token_cache, ) @@ -66,11 +66,11 @@ def test_external_browser_refresh_failure_new_oauth_flow(mocker): # Inject the mock implementations. mocker.patch( - "databricks.sdk.credentials_provider.TokenCache", + "databricks.sdk.oauth.TokenCache", return_value=mock_token_cache, ) mocker.patch( - "databricks.sdk.credentials_provider.OAuthClient", + "databricks.sdk.oauth.OAuthClient", return_value=mock_oauth_client, ) @@ -112,11 +112,11 @@ def test_external_browser_no_cached_credentials(mocker): # Inject the mock implementations. mocker.patch( - "databricks.sdk.credentials_provider.TokenCache", + "databricks.sdk.oauth.TokenCache", return_value=mock_token_cache, ) mocker.patch( - "databricks.sdk.credentials_provider.OAuthClient", + "databricks.sdk.oauth.OAuthClient", return_value=mock_oauth_client, ) @@ -150,11 +150,11 @@ def test_external_browser_consent_fails(mocker): # Inject the mock implementations. mocker.patch( - "databricks.sdk.credentials_provider.TokenCache", + "databricks.sdk.oauth.TokenCache", return_value=mock_token_cache, ) mocker.patch( - "databricks.sdk.credentials_provider.OAuthClient", + "databricks.sdk.oauth.OAuthClient", return_value=mock_oauth_client, ) diff --git a/tests/test_oidc.py b/tests/test_oidc.py new file mode 100644 index 000000000..4bed32e96 --- /dev/null +++ b/tests/test_oidc.py @@ -0,0 +1,109 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +import pytest + +from databricks.sdk import oidc + + +@dataclass +class EnvTestCase: + name: str + env_name: str = "" + env_value: str = "" + want: oidc.IdToken = None + wantException: Exception = None + + +_env_id_test_cases = [ + EnvTestCase( + name="success", + env_name="OIDC_TEST_TOKEN_SUCCESS", + env_value="test-token-123", + want=oidc.IdToken(jwt="test-token-123"), + ), + EnvTestCase( + name="missing_env_var", + env_name="OIDC_TEST_TOKEN_MISSING", + env_value="", + wantException=ValueError, + ), + EnvTestCase( + name="empty_env_var", + env_name="OIDC_TEST_TOKEN_EMPTY", + env_value="", + wantException=ValueError, + ), + EnvTestCase( + name="different_variable_name", + env_name="ANOTHER_OIDC_TOKEN", + env_value="another-token-456", + want=oidc.IdToken(jwt="another-token-456"), + ), +] + + +@pytest.mark.parametrize("test_case", _env_id_test_cases) +def test_env_id_token_source(test_case: EnvTestCase, monkeypatch): + monkeypatch.setenv(test_case.env_name, test_case.env_value) + + source = oidc.EnvIdTokenSource(test_case.env_name) + if test_case.wantException: + with pytest.raises(test_case.wantException): + source.id_token() + else: + assert source.id_token() == test_case.want + + +@dataclass +class FileTestCase: + name: str + file: Optional[Tuple[str, str]] = None # (name, content) + filepath: str = "" + want: oidc.IdToken = None + wantException: Exception = None + + +_file_id_test_cases = [ + FileTestCase( + name="missing_filepath", + file=("token", "content"), + filepath="", + wantException=ValueError, + ), + FileTestCase( + name="empty_file", + file=("token", ""), + filepath="token", + wantException=ValueError, + ), + FileTestCase( + name="file_does_not_exist", + filepath="nonexistent-file", + wantException=ValueError, + ), + FileTestCase( + name="file_exists", + file=("token", "content"), + filepath="token", + want=oidc.IdToken(jwt="content"), + ), +] + + +@pytest.mark.parametrize("test_case", _file_id_test_cases) +def test_file_id_token_source(test_case: FileTestCase, tmp_path): + if test_case.file: + token_file = tmp_path / test_case.file[0] + token_file.write_text(test_case.file[1]) + + fp = "" + if test_case.filepath: + fp = tmp_path / test_case.filepath + + source = oidc.FileIdTokenSource(fp) + if test_case.wantException: + with pytest.raises(test_case.wantException): + source.id_token() + else: + assert source.id_token() == test_case.want