Skip to content

Support Databricks Workload Identity Federation for GitHub tokens #933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
* Enabled asynchronous token refreshes by default. A new `disable_async_token_refresh` configuration option has been added to allow disabling this feature if necessary ([#952](https://github.com/databricks/databricks-sdk-py/pull/952)).
To disable asynchronous token refresh, set the environment variable `DATABRICKS_DISABLE_ASYNC_TOKEN_REFRESH=true` or configure it within your configuration object.
The previous `enable_experimental_async_token_refresh` option has been removed as asynchronous refresh is now the default behavior.
* Introduce support for Databricks Workload Identity Federation in GitHub workflows ([933](https://github.com/databricks/databricks-sdk-py/pull/933)).
See README.md for instructions.
* [Breaking] Users running their workflows in GitHub Actions, which use Cloud native authentication and also have a `DATABRICKS_CLIENT_ID` and `DATABRICKS_HOST`
environment variables set may see their authentication start failing due to the order in which the SDK tries different authentication methods.

### Bug Fixes

Expand Down
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,18 @@ Depending on the Databricks authentication method, the SDK uses the following in

### Databricks native authentication

By default, the Databricks SDK for Python initially tries [Databricks token authentication](https://docs.databricks.com/dev-tools/api/latest/authentication.html) (`auth_type='pat'` argument). If the SDK is unsuccessful, it then tries Databricks basic (username/password) authentication (`auth_type="basic"` argument).
By default, the Databricks SDK for Python initially tries [Databricks token authentication](https://docs.databricks.com/dev-tools/api/latest/authentication.html) (`auth_type='pat'` argument). If the SDK is unsuccessful, it then tries Databricks Workload Identity Federation (WIF) authentication using OIDC (`auth_type="github-oidc"` argument).

- For Databricks token authentication, you must provide `host` and `token`; or their environment variable or `.databrickscfg` file field equivalents.
- For Databricks basic authentication, you must provide `host`, `username`, and `password` _(for AWS workspace-level operations)_; or `host`, `account_id`, `username`, and `password` _(for AWS, Azure, or GCP account-level operations)_; or their environment variable or `.databrickscfg` file field equivalents.

| Argument | Description | Environment variable |
|--------------|-------------|-------------------|
| `host` | _(String)_ The Databricks host URL for either the Databricks workspace endpoint or the Databricks accounts endpoint. | `DATABRICKS_HOST` |
| `account_id` | _(String)_ The Databricks account ID for the Databricks accounts endpoint. Only has effect when `Host` is either `https://accounts.cloud.databricks.com/` _(AWS)_, `https://accounts.azuredatabricks.net/` _(Azure)_, or `https://accounts.gcp.databricks.com/` _(GCP)_. | `DATABRICKS_ACCOUNT_ID` |
| `token` | _(String)_ The Databricks personal access token (PAT) _(AWS, Azure, and GCP)_ or Azure Active Directory (Azure AD) token _(Azure)_. | `DATABRICKS_TOKEN` |
| `username` | _(String)_ The Databricks username part of basic authentication. Only possible when `Host` is `*.cloud.databricks.com` _(AWS)_. | `DATABRICKS_USERNAME` |
| `password` | _(String)_ The Databricks password part of basic authentication. Only possible when `Host` is `*.cloud.databricks.com` _(AWS)_. | `DATABRICKS_PASSWORD` |
- For Databricks OIDC authentication, you must provide the `host`, `client_id` and `token_audience` _(optional)_ either directly, through the corresponding environment variables, or in your `.databrickscfg` configuration file.

| Argument | Description | Environment variable |
|------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------|
| `host` | _(String)_ The Databricks host URL for either the Databricks workspace endpoint or the Databricks accounts endpoint. | `DATABRICKS_HOST` |
| `account_id` | _(String)_ The Databricks account ID for the Databricks accounts endpoint. Only has effect when `Host` is either `https://accounts.cloud.databricks.com/` _(AWS)_, `https://accounts.azuredatabricks.net/` _(Azure)_, or `https://accounts.gcp.databricks.com/` _(GCP)_. | `DATABRICKS_ACCOUNT_ID` |
| `token` | _(String)_ The Databricks personal access token (PAT) _(AWS, Azure, and GCP)_ or Azure Active Directory (Azure AD) token _(Azure)_. | `DATABRICKS_TOKEN` |
| `client_id` | _(String)_ The Databricks Service Principal Application ID. | `DATABRICKS_CLIENT_ID` |
| `token_audience` | _(String)_ When using Workload Identity Federation, the audience to specify when fetching an ID token from the ID token supplier. | `TOKEN_AUDIENCE` |

For example, to use Databricks token authentication:

Expand Down
4 changes: 4 additions & 0 deletions databricks/sdk/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Config:
host: str = ConfigAttribute(env="DATABRICKS_HOST")
account_id: str = ConfigAttribute(env="DATABRICKS_ACCOUNT_ID")
token: str = ConfigAttribute(env="DATABRICKS_TOKEN", auth="pat", sensitive=True)
token_audience: str = ConfigAttribute(env="DATABRICKS_TOKEN_AUDIENCE", auth="github-oidc")
username: str = ConfigAttribute(env="DATABRICKS_USERNAME", auth="basic")
password: str = ConfigAttribute(env="DATABRICKS_PASSWORD", auth="basic", sensitive=True)
client_id: str = ConfigAttribute(env="DATABRICKS_CLIENT_ID", auth="oauth")
Expand Down
68 changes: 57 additions & 11 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .azure import add_sp_management_token, add_workspace_id_header
from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token,
TokenCache, TokenSource)
from .oidc_token_supplier import GitHubOIDCTokenSupplier

CredentialsProvider = Callable[[], Dict[str, str]]

Expand Down Expand Up @@ -314,6 +315,58 @@ def token() -> Token:
return OAuthCredentialsProvider(refreshed_headers, token)


@oauth_credentials_strategy("github-oidc", ["host", "client_id"])
def databricks_wif(cfg: "Config") -> Optional[CredentialsProvider]:
"""
DatabricksWIFCredentials uses a Token Supplier to get a JWT Token and exchanges
it for a Databricks Token.

Supported suppliers:
- GitHub OIDC
"""
supplier = GitHubOIDCTokenSupplier()

audience = cfg.token_audience
if audience is None and cfg.is_account_client:
audience = cfg.account_id
if audience is None and not cfg.is_account_client:
audience = cfg.oidc_endpoints.token_endpoint

# Try to get an idToken. If no supplier returns a token, we cannot use this authentication mode.
id_token = supplier.get_oidc_token(audience)
if not id_token:
return None

def token_source_for(audience: str) -> TokenSource:
id_token = supplier.get_oidc_token(audience)
if not id_token:
# Should not happen, since we checked it above.
raise Exception("Cannot get OIDC token")
params = {
"subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
"subject_token": id_token,
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
}
return ClientCredentials(
client_id=cfg.client_id,
client_secret="", # we have no (rotatable) secrets in OIDC flow
token_url=cfg.oidc_endpoints.token_endpoint,
endpoint_params=params,
scopes=["all-apis"],
use_params=True,
disable_async=cfg.disable_async_token_refresh,
)

def refreshed_headers() -> Dict[str, str]:
token = token_source_for(audience).token()
return {"Authorization": f"{token.token_type} {token.access_token}"}

def token() -> Token:
return token_source_for(audience).token()

return OAuthCredentialsProvider(refreshed_headers, token)


@oauth_credentials_strategy("github-oidc-azure", ["host", "azure_client_id"])
def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ:
Expand All @@ -325,16 +378,8 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
if not cfg.is_azure:
return None

# See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers
headers = {"Authorization": f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"}
endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience=api://AzureADTokenExchange"
response = requests.get(endpoint, headers=headers)
if not response.ok:
return None

# get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name
response_json = response.json()
if "value" not in response_json:
token = GitHubOIDCTokenSupplier().get_oidc_token("api://AzureADTokenExchange")
if not token:
return None

logger.info(
Expand All @@ -344,7 +389,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
params = {
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"resource": cfg.effective_azure_login_app_id,
"client_assertion": response_json["value"],
"client_assertion": token,
}
aad_endpoint = cfg.arm_environment.active_directory_endpoint
if not cfg.azure_tenant_id:
Expand Down Expand Up @@ -927,6 +972,7 @@ def __init__(self) -> None:
basic_auth,
metadata_service,
oauth_service_principal,
databricks_wif,
azure_service_principal,
github_oidc_azure,
azure_cli,
Expand Down
28 changes: 28 additions & 0 deletions databricks/sdk/oidc_token_supplier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
from typing import Optional

import requests


class GitHubOIDCTokenSupplier:
"""
Supplies OIDC tokens from GitHub Actions.
"""

def get_oidc_token(self, audience: str) -> Optional[str]:
if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ or "ACTIONS_ID_TOKEN_REQUEST_URL" not in os.environ:
# not in GitHub actions
return None
# See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers
headers = {"Authorization": f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"}
endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience={audience}"
response = requests.get(endpoint, headers=headers)
if not response.ok:
return None

# get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name
response_json = response.json()
if "value" not in response_json:
return None

return response_json["value"]
71 changes: 71 additions & 0 deletions tests/integration/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import pytest

from databricks.sdk import AccountClient, WorkspaceClient
from databricks.sdk.service import iam, oauth2
from databricks.sdk.service.compute import (ClusterSpec, DataSecurityMode,
Library, ResultType, SparkVersion)
from databricks.sdk.service.jobs import NotebookTask, Task, ViewType
Expand Down Expand Up @@ -198,3 +200,72 @@ def _task_outputs(w, run):
output += data["data"]
task_outputs[task_run.task_key] = output
return task_outputs


def test_wif_account(ucacct, env_or_skip, random):

sp = ucacct.service_principals.create(
active=True,
display_name="py-sdk-test-" + random(),
roles=[iam.ComplexValue(value="account_admin")],
)

ucacct.service_principal_federation_policy.create(
policy=oauth2.FederationPolicy(
oidc_policy=oauth2.OidcFederationPolicy(
issuer="https://token.actions.githubusercontent.com",
audiences=["https://github.com/databricks-eng"],
subject="repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests",
)
),
service_principal_id=sp.id,
)

ac = AccountClient(
host=ucacct.config.host,
account_id=ucacct.config.account_id,
client_id=sp.application_id,
auth_type="github-oidc",
token_audience="https://github.com/databricks-eng",
)

groups = ac.groups.list()

next(groups)


def test_wif_workspace(ucacct, env_or_skip, random):

workspace_id = env_or_skip("TEST_WORKSPACE_ID")
workspace_url = env_or_skip("TEST_WORKSPACE_URL")

sp = ucacct.service_principals.create(
active=True,
display_name="py-sdk-test-" + random(),
)

ucacct.service_principal_federation_policy.create(
policy=oauth2.FederationPolicy(
oidc_policy=oauth2.OidcFederationPolicy(
issuer="https://token.actions.githubusercontent.com",
audiences=["https://github.com/databricks-eng"],
subject="repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests",
)
),
service_principal_id=sp.id,
)

ucacct.workspace_assignment.update(
workspace_id=workspace_id,
principal_id=sp.id,
permissions=[iam.WorkspacePermission.ADMIN],
)

ws = WorkspaceClient(
host=workspace_url,
client_id=sp.application_id,
auth_type="github-oidc",
token_audience="https://github.com/databricks-eng",
)

ws.current_user.me()
Loading