Skip to content

Add support to authenticate with Account-wide token federation #1219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

### New Features and Improvements

- Add support to authenticate with Account-wide token federation from the
following auth methods: `env-oidc`, `file-oidc`, and `github-oidc`.
Comment on lines +7 to +8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we document what is env-oidc and file-oidc somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not outside of the code as far as I can tell. Though, there is an ongoing effort to document these centrally.


### Bug Fixes

### Documentation
Expand Down
10 changes: 6 additions & 4 deletions config/auth_azure_github_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/databricks/databricks-sdk-go/config/credentials"
"github.com/databricks/databricks-sdk-go/config/experimental/auth/oidc"
"github.com/databricks/databricks-sdk-go/httpclient"
"golang.org/x/oauth2"
)
Expand All @@ -26,10 +27,11 @@ func (c AzureGithubOIDCCredentials) Configure(ctx context.Context, cfg *Config)
if !cfg.IsAzure() || cfg.AzureClientID == "" || cfg.Host == "" || cfg.AzureTenantID == "" || cfg.ActionsIDTokenRequestURL == "" || cfg.ActionsIDTokenRequestToken == "" {
return nil, nil
}
supplier := githubIDTokenSource{actionsIDTokenRequestURL: cfg.ActionsIDTokenRequestURL,
actionsIDTokenRequestToken: cfg.ActionsIDTokenRequestToken,
refreshClient: cfg.refreshClient,
}
supplier := oidc.NewGithubIDTokenSource(
cfg.refreshClient,
cfg.ActionsIDTokenRequestURL,
cfg.ActionsIDTokenRequestToken,
)

idToken, err := supplier.IDToken(ctx, "api://AzureADTokenExchange")
if err != nil {
Expand Down
16 changes: 8 additions & 8 deletions config/auth_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,28 @@ func buildOidcTokenCredentialStrategies(cfg *Config) []CredentialsStrategy {
},
{
name: "github-oidc",
tokenSource: &githubIDTokenSource{
actionsIDTokenRequestURL: cfg.ActionsIDTokenRequestURL,
actionsIDTokenRequestToken: cfg.ActionsIDTokenRequestToken,
refreshClient: cfg.refreshClient,
},
tokenSource: oidc.NewGithubIDTokenSource(
cfg.refreshClient,
cfg.ActionsIDTokenRequestURL,
cfg.ActionsIDTokenRequestToken,
),
},
// Add new providers at the end of the list
}

strategies := []CredentialsStrategy{}
for _, idTokenSource := range idTokenSources {
oidcConfig := DatabricksOIDCTokenSourceConfig{
oidcConfig := oidc.DatabricksOIDCTokenSourceConfig{
ClientID: cfg.ClientID,
Host: cfg.CanonicalHostName(),
TokenEndpointProvider: cfg.getOidcEndpoints,
Audience: cfg.TokenAudience,
IdTokenSource: idTokenSource.tokenSource,
IDTokenSource: idTokenSource.tokenSource,
}
if cfg.IsAccountClient() {
oidcConfig.AccountID = cfg.AccountID
}
tokenSource := NewDatabricksOIDCTokenSource(oidcConfig)
tokenSource := oidc.NewDatabricksOIDCTokenSource(oidcConfig)
strategies = append(strategies, NewTokenSourceStrategy(idTokenSource.name, tokenSource))
}
return strategies
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
package config
package oidc

import (
"context"
"errors"
"fmt"

"github.com/databricks/databricks-sdk-go/config/experimental/auth/oidc"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
)

// NewGithubIDTokenSource returns a new IDTokenSource that retrieves an IDToken
// from the Github Actions environment. This IDTokenSource is only valid when
// running in Github Actions with OIDC enabled.
func NewGithubIDTokenSource(client *httpclient.ApiClient, actionsIDTokenRequestURL, actionsIDTokenRequestToken string) IDTokenSource {
return &githubIDTokenSource{
actionsIDTokenRequestURL: actionsIDTokenRequestURL,
actionsIDTokenRequestToken: actionsIDTokenRequestToken,
refreshClient: client,
}
}

// githubIDTokenSource retrieves JWT Tokens from Github Actions.
type githubIDTokenSource struct {
actionsIDTokenRequestURL string
Expand All @@ -19,7 +29,7 @@ type githubIDTokenSource struct {

// IDToken returns a JWT Token for the specified audience. It will return
// an error if not running in GitHub Actions.
func (g *githubIDTokenSource) IDToken(ctx context.Context, audience string) (*oidc.IDToken, error) {
func (g *githubIDTokenSource) IDToken(ctx context.Context, audience string) (*IDToken, error) {
if g.actionsIDTokenRequestURL == "" {
logger.Debugf(ctx, "Missing ActionsIDTokenRequestURL, likely not calling from a Github action")
return nil, errors.New("missing ActionsIDTokenRequestURL")
Expand All @@ -29,7 +39,7 @@ func (g *githubIDTokenSource) IDToken(ctx context.Context, audience string) (*oi
return nil, errors.New("missing ActionsIDTokenRequestToken")
}

resp := &oidc.IDToken{}
resp := &IDToken{}
requestUrl := g.actionsIDTokenRequestURL
if audience != "" {
requestUrl = fmt.Sprintf("%s&audience=%s", requestUrl, audience)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package config
package oidc

import (
"context"
"net/http"
"testing"

"github.com/databricks/databricks-sdk-go/config/experimental/auth/oidc"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
"github.com/google/go-cmp/cmp"
Expand All @@ -18,7 +17,7 @@ func TestGithubIDTokenSource(t *testing.T) {
tokenRequestToken string
audience string
httpTransport http.RoundTripper
wantToken *oidc.IDToken
wantToken *IDToken
wantErrPrefix *string
}{
{
Expand Down Expand Up @@ -60,7 +59,7 @@ func TestGithubIDTokenSource(t *testing.T) {
Response: `{"value": "id-token-42"}`,
},
},
wantToken: &oidc.IDToken{
wantToken: &IDToken{
Value: "id-token-42",
},
},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,42 +1,49 @@
package config
package oidc

import (
"context"
"errors"
"net/url"

"github.com/databricks/databricks-sdk-go/config/experimental/auth"
"github.com/databricks/databricks-sdk-go/config/experimental/auth/oidc"
"github.com/databricks/databricks-sdk-go/credentials/u2m"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)

// Creates a new Databricks OIDC TokenSource.
func NewDatabricksOIDCTokenSource(cfg DatabricksOIDCTokenSourceConfig) auth.TokenSource {
return &databricksOIDCTokenSource{
cfg: cfg,
}
}

// Config for Databricks OIDC TokenSource.
// DatabricksOIDCTokenSourceConfig is the configuration for a Databricks OIDC
// TokenSource.
type DatabricksOIDCTokenSourceConfig struct {
// ClientID is the client ID of the Databricks OIDC application. For
// Databricks Service Principal, this is the Application ID of the Service Principal.
// ClientID of the Databricks OIDC application. It corresponds to the
// Application ID of the Databricks Service Principal.
//
// This field is only required for Workload Identity Federation and should
// be empty for Account-wide token federation.
ClientID string
// [Optional] AccountID is the account ID of the Databricks Account.
// This is only used for Account level tokens.

// AccountID is the account ID of the Databricks Account. This field is
// only required for Account-wide token federation.
AccountID string

// Host is the host of the Databricks account or workspace.
Host string
// TokenEndpointProvider returns the token endpoint for the Databricks OIDC application.

// TokenEndpointProvider returns the token endpoint for the Databricks OIDC
// application.
TokenEndpointProvider func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error)

// Audience is the audience of the Databricks OIDC application.
// This is only used for Workspace level tokens.
Audience string
// IdTokenSource returns the IDToken to be used for the token exchange.
IdTokenSource oidc.IDTokenSource

// IDTokenSource returns the IDToken to be used for the token exchange.
IDTokenSource IDTokenSource
}

// NewDatabricksOIDCTokenSource returns a new Databricks OIDC TokenSource.
func NewDatabricksOIDCTokenSource(cfg DatabricksOIDCTokenSourceConfig) auth.TokenSource {
return &databricksOIDCTokenSource{cfg: cfg}
}

// databricksOIDCTokenSource is a auth.TokenSource which exchanges a token using
Expand All @@ -47,10 +54,6 @@ type databricksOIDCTokenSource struct {

// Token implements [TokenSource.Token]
func (w *databricksOIDCTokenSource) Token(ctx context.Context) (*oauth2.Token, error) {
if w.cfg.ClientID == "" {
logger.Debugf(ctx, "Missing ClientID")
return nil, errors.New("missing ClientID")
}
if w.cfg.Host == "" {
logger.Debugf(ctx, "Missing Host")
return nil, errors.New("missing Host")
Expand All @@ -59,8 +62,17 @@ func (w *databricksOIDCTokenSource) Token(ctx context.Context) (*oauth2.Token, e
if err != nil {
return nil, err
}

if w.cfg.ClientID == "" {
logger.Debugf(ctx, "No ClientID provided, authenticating with Account-wide token federation")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more of a question than a comment: Should account-wide token federation also be added to the Java SDK? Currently, ClientID is not an optional field in the Java SDK.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! Yes, we will have to add that in the Java SDK too.

} else {
logger.Debugf(ctx, "Client ID provided, authenticating with Workload Identity Federation")
}

// TODO: The audience is a concept of the IDToken that should likely be
// configured when the IDTokenSource is created.
audience := w.determineAudience(endpoints)
idToken, err := w.cfg.IdTokenSource.IDToken(ctx, audience)
idToken, err := w.cfg.IDTokenSource.IDToken(ctx, audience)
if err != nil {
return nil, err
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
package config
package oidc

import (
"context"
"errors"
"net/http"
"net/url"
"strings"
"testing"

"github.com/databricks/databricks-sdk-go/config/experimental/auth/oidc"
"github.com/databricks/databricks-sdk-go/credentials/u2m"
"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
"github.com/google/go-cmp/cmp"
"golang.org/x/oauth2"
)

func errPrefix(s string) *string {
return &s
}

func hasPrefix(err error, prefix string) bool {
return strings.HasPrefix(err.Error(), prefix)
}

func TestDatabricksOidcTokenSource(t *testing.T) {
testCases := []struct {
desc string
Expand All @@ -35,12 +43,6 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
tokenAudience: "token-audience",
wantErrPrefix: errPrefix("missing Host"),
},
{
desc: "missing client ID",
host: "http://host.com",
tokenAudience: "token-audience",
wantErrPrefix: errPrefix("missing ClientID"),
},
{
desc: "token provider error",

Expand Down Expand Up @@ -104,7 +106,7 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
wantErrPrefix: errPrefix("oauth2: server response missing access_token"),
},
{
desc: "success workspace",
desc: "success WIF workspace",
clientID: "client-id",
host: "http://host.com",
tokenAudience: "token-audience",
Expand Down Expand Up @@ -140,7 +142,7 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
wantToken: "test-auth-token",
},
{
desc: "success account",
desc: "success WIF account",
clientID: "client-id",
accountID: "ac123",
host: "https://accounts.databricks.com",
Expand Down Expand Up @@ -230,6 +232,40 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
idToken: "id-token-42",
wantToken: "test-auth-token",
},
{
desc: "success account-wide",
host: "http://host.com",
tokenAudience: "token-audience",
oidcEndpointProvider: func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error) {
return &u2m.OAuthAuthorizationServer{
TokenEndpoint: "https://host.com/oidc/v1/token",
}, nil
},
httpTransport: fixtures.MappingTransport{
"POST /oidc/v1/token": {

Status: http.StatusOK,
ExpectedHeaders: map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
},
ExpectedRequest: url.Values{
"scope": {"all-apis"},
"subject_token_type": {"urn:ietf:params:oauth:token-type:jwt"},
"subject_token": {"id-token-42"},
"grant_type": {"urn:ietf:params:oauth:grant-type:token-exchange"},
},
Response: map[string]string{
"token_type": "access-token",
"access_token": "test-auth-token",
"refresh_token": "refresh",
"expires_on": "0",
},
},
},
wantAudience: "token-audience",
idToken: "id-token-42",
wantToken: "test-auth-token",
},
}

for _, tc := range testCases {
Expand All @@ -241,9 +277,9 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
Host: tc.host,
TokenEndpointProvider: tc.oidcEndpointProvider,
Audience: tc.tokenAudience,
IdTokenSource: oidc.IDTokenSourceFn(func(ctx context.Context, aud string) (*oidc.IDToken, error) {
IDTokenSource: IDTokenSourceFn(func(ctx context.Context, aud string) (*IDToken, error) {
gotAudience = aud
return &oidc.IDToken{Value: tc.idToken}, tc.tokenProviderError
return &IDToken{Value: tc.idToken}, tc.tokenProviderError
}),
}

Expand Down
Loading