From 5063be36a5e4a69f0c77f004b301b842e744271a Mon Sep 17 00:00:00 2001 From: Hector Castejon Diaz Date: Fri, 21 Mar 2025 09:50:05 +0100 Subject: [PATCH] Tests --- Makefile | 2 +- NEXT_CHANGELOG.md | 4 ++ README.md | 20 ++++---- databricks/sdk/__init__.py | 4 ++ databricks/sdk/config.py | 1 + databricks/sdk/credentials_provider.py | 61 ++++++++++++++++++---- databricks/sdk/oidc_token_supplier.py | 33 ++++++++++++ tests/integration/test_auth.py | 71 ++++++++++++++++++++++++++ 8 files changed, 174 insertions(+), 22 deletions(-) create mode 100644 databricks/sdk/oidc_token_supplier.py diff --git a/Makefile b/Makefile index c147f4074..5d5b5db2e 100644 --- a/Makefile +++ b/Makefile @@ -28,7 +28,7 @@ test: pytest -m 'not integration and not benchmark' --cov=databricks --cov-report html tests integration: - pytest -n auto -m 'integration and not benchmark' --reruns 2 --dist loadgroup --cov=databricks --cov-report html tests + pytest -n auto --dist loadgroup --cov=databricks --cov-report html tests/integration/test_auth.py benchmark: pytest -m 'benchmark' tests diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 395e1b0f7..726132f33 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -3,6 +3,10 @@ ## Release v0.47.0 ### New Features and Improvements +* Introduce support for Databricks Workload Identity Federation in GitHub workflows ([933](https://github.com/databricks/databricks-sdk-py/pull/933)). + See README.md for instructions. +* [Breaking] Users running their workflows in GitHub Actions, which use Cloud native authentication and also have a `DATABRICKS_CLIENT_ID` and `DATABRICKS_HOST` + environment variables set may see their authentication start failing due to the order in which the SDK tries different authentication methods. ### Bug Fixes diff --git a/README.md b/README.md index 9991c9cd0..2913bc02d 100644 --- a/README.md +++ b/README.md @@ -126,18 +126,18 @@ Depending on the Databricks authentication method, the SDK uses the following in ### Databricks native authentication -By default, the Databricks SDK for Python initially tries [Databricks token authentication](https://docs.databricks.com/dev-tools/api/latest/authentication.html) (`auth_type='pat'` argument). If the SDK is unsuccessful, it then tries Databricks basic (username/password) authentication (`auth_type="basic"` argument). +By default, the Databricks SDK for Python initially tries [Databricks token authentication](https://docs.databricks.com/dev-tools/api/latest/authentication.html) (`auth_type='pat'` argument). If the SDK is unsuccessful, it then tries Databricks Workload Identity Federation (WIF) authentication (`auth_type="databricks-wif"` argument). - For Databricks token authentication, you must provide `host` and `token`; or their environment variable or `.databrickscfg` file field equivalents. -- For Databricks basic authentication, you must provide `host`, `username`, and `password` _(for AWS workspace-level operations)_; or `host`, `account_id`, `username`, and `password` _(for AWS, Azure, or GCP account-level operations)_; or their environment variable or `.databrickscfg` file field equivalents. - -| Argument | Description | Environment variable | -|--------------|-------------|-------------------| -| `host` | _(String)_ The Databricks host URL for either the Databricks workspace endpoint or the Databricks accounts endpoint. | `DATABRICKS_HOST` | -| `account_id` | _(String)_ The Databricks account ID for the Databricks accounts endpoint. Only has effect when `Host` is either `https://accounts.cloud.databricks.com/` _(AWS)_, `https://accounts.azuredatabricks.net/` _(Azure)_, or `https://accounts.gcp.databricks.com/` _(GCP)_. | `DATABRICKS_ACCOUNT_ID` | -| `token` | _(String)_ The Databricks personal access token (PAT) _(AWS, Azure, and GCP)_ or Azure Active Directory (Azure AD) token _(Azure)_. | `DATABRICKS_TOKEN` | -| `username` | _(String)_ The Databricks username part of basic authentication. Only possible when `Host` is `*.cloud.databricks.com` _(AWS)_. | `DATABRICKS_USERNAME` | -| `password` | _(String)_ The Databricks password part of basic authentication. Only possible when `Host` is `*.cloud.databricks.com` _(AWS)_. | `DATABRICKS_PASSWORD` | +- For Databricks wif authentication, you must provide `host`, `client_id` and `token_audience` _(optional)_; or their environment variable or `.databrickscfg` file field equivalents. + +| Argument | Description | Environment variable | +|------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------| +| `host` | _(String)_ The Databricks host URL for either the Databricks workspace endpoint or the Databricks accounts endpoint. | `DATABRICKS_HOST` | +| `account_id` | _(String)_ The Databricks account ID for the Databricks accounts endpoint. Only has effect when `Host` is either `https://accounts.cloud.databricks.com/` _(AWS)_, `https://accounts.azuredatabricks.net/` _(Azure)_, or `https://accounts.gcp.databricks.com/` _(GCP)_. | `DATABRICKS_ACCOUNT_ID` | +| `token` | _(String)_ The Databricks personal access token (PAT) _(AWS, Azure, and GCP)_ or Azure Active Directory (Azure AD) token _(Azure)_. | `DATABRICKS_TOKEN` | +| `client_id` | _(String)_ The Databricks Service Principal Application ID. | `DATABRICKS_CLIENT_ID` | +| `token_audience` | _(String)_ When using Workload Identity Federation, the audience to specify when fetching an ID token from the ID token supplier. | `TOKEN_AUDIENCE` | For example, to use Databricks token authentication: diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index 806d8c584..fa8306045 100755 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -168,6 +168,7 @@ def __init__( product_version="0.0.0", credentials_strategy: Optional[CredentialsStrategy] = None, credentials_provider: Optional[CredentialsStrategy] = None, + token_audience: Optional[str] = None, config: Optional[client.Config] = None, ): if not config: @@ -196,6 +197,7 @@ def __init__( debug_headers=debug_headers, product=product, product_version=product_version, + token_audience=token_audience, ) self._config = config.copy() self._dbutils = _make_dbutils(self._config) @@ -860,6 +862,7 @@ def __init__( product_version="0.0.0", credentials_strategy: Optional[CredentialsStrategy] = None, credentials_provider: Optional[CredentialsStrategy] = None, + token_audience: Optional[str] = None, config: Optional[client.Config] = None, ): if not config: @@ -888,6 +891,7 @@ def __init__( debug_headers=debug_headers, product=product, product_version=product_version, + token_audience=token_audience, ) self._config = config.copy() self._api_client = client.ApiClient(self._config) diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 2a05cf6ba..f6e760a4f 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -61,6 +61,7 @@ class Config: host: str = ConfigAttribute(env="DATABRICKS_HOST") account_id: str = ConfigAttribute(env="DATABRICKS_ACCOUNT_ID") token: str = ConfigAttribute(env="DATABRICKS_TOKEN", auth="pat", sensitive=True) + token_audience: str = ConfigAttribute(env="DATABRICKS_TOKEN_AUDIENCE", auth="databricks-wif") username: str = ConfigAttribute(env="DATABRICKS_USERNAME", auth="basic") password: str = ConfigAttribute(env="DATABRICKS_PASSWORD", auth="basic", sensitive=True) client_id: str = ConfigAttribute(env="DATABRICKS_CLIENT_ID", auth="oauth") diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index eac7c9697..c780a4cc0 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -23,6 +23,7 @@ from .azure import add_sp_management_token, add_workspace_id_header from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token, TokenCache, TokenSource) +from .oidc_token_supplier import GitHubOIDCTokenSupplier CredentialsProvider = Callable[[], Dict[str, str]] @@ -314,6 +315,51 @@ def token() -> Token: return OAuthCredentialsProvider(refreshed_headers, token) +@oauth_credentials_strategy("databricks-wif", ["host", "client_id"]) +def databricks_wif(cfg: "Config") -> Optional[CredentialsProvider]: + """ + DatabricksWIFCredentials uses a Token Supplier to get a JWT Token and exchanges + it for a Databricks Token. + + Supported suppliers: + - GitHub OIDC + """ + supplier = GitHubOIDCTokenSupplier() + # Try to get an idToken. If no supplier returns a token, we cannot use this authentication mode. + idToken = supplier.get_oidc_token(cfg.token_audience) + if not idToken: + return None + + def token_source_for(audience: str) -> TokenSource: + idToken = supplier.get_oidc_token(audience) + if not idToken: + # Should not happen, since we checked it above. + raise Exception("Cannot get OIDC token") + params = { + "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", + "subject_token": idToken, + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + } + return ClientCredentials( + client_id=cfg.client_id, + client_secret="", # we have no (rotatable) secrets in OIDC flow + token_url=cfg.oidc_endpoints.token_endpoint, + endpoint_params=params, + scopes=["all-apis"], + use_params=True, + disable_async=not cfg.enable_experimental_async_token_refresh, + ) + + def refreshed_headers() -> Dict[str, str]: + token = token_source_for(cfg.token_audience).token() + return {"Authorization": f"{token.token_type} {token.access_token}"} + + def token() -> Token: + return token_source_for(cfg.token_audience).token() + + return OAuthCredentialsProvider(refreshed_headers, token) + + @oauth_credentials_strategy("github-oidc-azure", ["host", "azure_client_id"]) def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ: @@ -325,16 +371,8 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: if not cfg.is_azure: return None - # See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers - headers = {"Authorization": f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"} - endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience=api://AzureADTokenExchange" - response = requests.get(endpoint, headers=headers) - if not response.ok: - return None - - # get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name - response_json = response.json() - if "value" not in response_json: + token = GitHubOIDCTokenSupplier().get_oidc_token("api://AzureADTokenExchange") + if not token: return None logger.info( @@ -344,7 +382,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: params = { "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", "resource": cfg.effective_azure_login_app_id, - "client_assertion": response_json["value"], + "client_assertion": token, } aad_endpoint = cfg.arm_environment.active_directory_endpoint if not cfg.azure_tenant_id: @@ -927,6 +965,7 @@ def __init__(self) -> None: basic_auth, metadata_service, oauth_service_principal, + databricks_wif, azure_service_principal, github_oidc_azure, azure_cli, diff --git a/databricks/sdk/oidc_token_supplier.py b/databricks/sdk/oidc_token_supplier.py new file mode 100644 index 000000000..174dbf128 --- /dev/null +++ b/databricks/sdk/oidc_token_supplier.py @@ -0,0 +1,33 @@ +import os +from time import sleep +from typing import Optional + +import requests + + +class GitHubOIDCTokenSupplier: + """ + Supplies OIDC tokens from GitHub Actions. + """ + + def get_oidc_token(self, audience: str) -> Optional[str]: + if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ or "ACTIONS_ID_TOKEN_REQUEST_URL" not in os.environ: + # not in GitHub actions + return None + # See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers + headers = {"Authorization": f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"} + endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience={audience}" + response = requests.get(endpoint, headers=headers) + if not response.ok: + return None + + # get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name + response_json = response.json() + if "value" not in response_json: + return None + + # GitHub issued time is not allways in sync, and can give tokens which are not yet valid. + # TODO: Remove this after Databricks API is updated to handle such cases. + sleep(2) + + return response_json["value"] diff --git a/tests/integration/test_auth.py b/tests/integration/test_auth.py index b50c54f1b..d15f4dc68 100644 --- a/tests/integration/test_auth.py +++ b/tests/integration/test_auth.py @@ -12,6 +12,8 @@ import pytest +from databricks.sdk import AccountClient, WorkspaceClient +from databricks.sdk.service import iam, oauth2 from databricks.sdk.service.compute import (ClusterSpec, DataSecurityMode, Library, ResultType, SparkVersion) from databricks.sdk.service.jobs import NotebookTask, Task, ViewType @@ -198,3 +200,72 @@ def _task_outputs(w, run): output += data["data"] task_outputs[task_run.task_key] = output return task_outputs + + +def test_wif_account(ucacct, env_or_skip, random): + + sp = ucacct.service_principals.create( + active=True, + display_name="py-sdk-test-" + random(), + roles=[iam.ComplexValue(value="account_admin")], + ) + + ucacct.service_principal_federation_policy.create( + policy=oauth2.FederationPolicy( + oidc_policy=oauth2.OidcFederationPolicy( + issuer="https://token.actions.githubusercontent.com", + audiences=["https://github.com/databricks-eng"], + subject="repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests", + ) + ), + service_principal_id=sp.id, + ) + + ac = AccountClient( + host=ucacct.config.host, + account_id=ucacct.config.account_id, + client_id=sp.application_id, + auth_type="databricks-wif", + token_audience="https://github.com/databricks-eng", + ) + + groups = ac.groups.list() + + next(groups) + + +def test_wif_workspace(ucacct, env_or_skip, random): + + workspace_id = env_or_skip("TEST_WORKSPACE_ID") + workspace_url = env_or_skip("TEST_WORKSPACE_URL") + + sp = ucacct.service_principals.create( + active=True, + display_name="py-sdk-test-" + random(), + ) + + ucacct.service_principal_federation_policy.create( + policy=oauth2.FederationPolicy( + oidc_policy=oauth2.OidcFederationPolicy( + issuer="https://token.actions.githubusercontent.com", + audiences=["https://github.com/databricks-eng"], + subject="repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests", + ) + ), + service_principal_id=sp.id, + ) + + ucacct.workspace_assignment.update( + workspace_id=workspace_id, + principal_id=sp.id, + permissions=[iam.WorkspacePermission.ADMIN], + ) + + ws = WorkspaceClient( + host=workspace_url, + client_id=sp.application_id, + auth_type="databricks-wif", + token_audience="https://github.com/databricks-eng", + ) + + ws.current_user.me()