Skip to content

Add support for OIDC ID token authentication using an environment variable. #1215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

### New Features and Improvements

- Add support for OIDC ID token authentication using an environment variable
([PR #1215](https://github.com/databricks/databricks-sdk-go/pull/1215)).

### Bug Fixes

### Documentation
Expand Down
3 changes: 2 additions & 1 deletion config/auth_databricks_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/url"

"github.com/databricks/databricks-sdk-go/config/experimental/auth"
"github.com/databricks/databricks-sdk-go/config/experimental/auth/oidc"
"github.com/databricks/databricks-sdk-go/credentials/u2m"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2"
Expand Down Expand Up @@ -35,7 +36,7 @@ type DatabricksOIDCTokenSourceConfig struct {
// This is only used for Workspace level tokens.
Audience string
// IdTokenSource returns the IDToken to be used for the token exchange.
IdTokenSource IDTokenSource
IdTokenSource oidc.IDTokenSource
}

// databricksOIDCTokenSource is a auth.TokenSource which exchanges a token using
Expand Down
69 changes: 28 additions & 41 deletions config/auth_databricks_oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,13 @@ import (
"net/url"
"testing"

"github.com/databricks/databricks-sdk-go/config/experimental/auth/oidc"
"github.com/databricks/databricks-sdk-go/credentials/u2m"
"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
"github.com/google/go-cmp/cmp"
"golang.org/x/oauth2"
)

type mockIdTokenProvider struct {
// input
audience string
// output
idToken string
err error
}

func (m *mockIdTokenProvider) IDToken(ctx context.Context, audience string) (*IDToken, error) {
m.audience = audience
return &IDToken{Value: m.idToken}, m.err
}

func TestDatabricksOidcTokenSource(t *testing.T) {
testCases := []struct {
desc string
Expand All @@ -36,7 +24,7 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
httpTransport http.RoundTripper
oidcEndpointProvider func(context.Context) (*u2m.OAuthAuthorizationServer, error)
idToken string
expectedAudience string
wantAudience string
tokenProviderError error
wantToken string
wantErrPrefix *string
Expand Down Expand Up @@ -64,7 +52,7 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
TokenEndpoint: "https://host.com/oidc/v1/token",
}, nil
},
expectedAudience: "token-audience",
wantAudience: "token-audience",
tokenProviderError: errors.New("error getting id token"),
wantErrPrefix: errPrefix("error getting id token"),
},
Expand All @@ -86,9 +74,9 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
},
},
},
expectedAudience: "token-audience",
idToken: "id-token-42",
wantErrPrefix: errPrefix("oauth2: cannot fetch token: Internal Server Error"),
wantAudience: "token-audience",
idToken: "id-token-42",
wantErrPrefix: errPrefix("oauth2: cannot fetch token: Internal Server Error"),
},
{
desc: "invalid auth token",
Expand All @@ -111,9 +99,9 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
},
},
},
expectedAudience: "token-audience",
idToken: "id-token-42",
wantErrPrefix: errPrefix("oauth2: server response missing access_token"),
wantAudience: "token-audience",
idToken: "id-token-42",
wantErrPrefix: errPrefix("oauth2: server response missing access_token"),
},
{
desc: "success workspace",
Expand Down Expand Up @@ -147,9 +135,9 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
},
},
},
expectedAudience: "token-audience",
idToken: "id-token-42",
wantToken: "test-auth-token",
wantAudience: "token-audience",
idToken: "id-token-42",
wantToken: "test-auth-token",
},
{
desc: "success account",
Expand Down Expand Up @@ -183,9 +171,9 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
},
},
},
expectedAudience: "token-audience",
idToken: "id-token-42",
wantToken: "test-auth-token",
wantAudience: "token-audience",
idToken: "id-token-42",
wantToken: "test-auth-token",
},
{
desc: "default token audience account",
Expand All @@ -211,9 +199,9 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
},
},
},
expectedAudience: "ac123",
idToken: "id-token-42",
wantToken: "test-auth-token",
wantAudience: "ac123",
idToken: "id-token-42",
wantToken: "test-auth-token",
},
{
desc: "default token audience workspace",
Expand All @@ -238,26 +226,25 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
},
},
},
expectedAudience: "https://host.com/oidc/v1/token",
idToken: "id-token-42",
wantToken: "test-auth-token",
wantAudience: "https://host.com/oidc/v1/token",
idToken: "id-token-42",
wantToken: "test-auth-token",
},
}

for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
p := &mockIdTokenProvider{
idToken: tc.idToken,
err: tc.tokenProviderError,
}

var gotAudience string // set when IDTokenSource is called
cfg := DatabricksOIDCTokenSourceConfig{
ClientID: tc.clientID,
AccountID: tc.accountID,
Host: tc.host,
TokenEndpointProvider: tc.oidcEndpointProvider,
Audience: tc.tokenAudience,
IdTokenSource: p,
IdTokenSource: oidc.IDTokenSourceFn(func(ctx context.Context, aud string) (*oidc.IDToken, error) {
gotAudience = aud
return &oidc.IDToken{Value: tc.idToken}, tc.tokenProviderError
}),
}

ts := NewDatabricksOIDCTokenSource(cfg)
Expand All @@ -283,8 +270,8 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
if tc.wantErrPrefix != nil && !hasPrefix(err, *tc.wantErrPrefix) {
t.Errorf("Token(ctx): got error %q, want error with prefix %q", err, *tc.wantErrPrefix)
}
if tc.expectedAudience != p.audience {
t.Errorf("mockTokenProvider: got audience %s, want %s", p.audience, tc.expectedAudience)
if tc.wantAudience != gotAudience {
t.Errorf("mockTokenProvider: got audience %s, want %s", gotAudience, tc.wantAudience)
}
tokenValue := ""
if token != nil {
Expand Down
16 changes: 15 additions & 1 deletion config/auth_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,29 @@ import (
"fmt"

"github.com/databricks/databricks-sdk-go/config/credentials"
"github.com/databricks/databricks-sdk-go/config/experimental/auth/oidc"
"github.com/databricks/databricks-sdk-go/logger"
)

// Constructs all Databricks OIDC Credentials Strategies
func buildOidcTokenCredentialStrategies(cfg *Config) []CredentialsStrategy {
type namedIdTokenSource struct {
name string
tokenSource IDTokenSource
tokenSource oidc.IDTokenSource
}
idTokenSources := []namedIdTokenSource{
{
name: "env-oidc",
// If the OIDCTokenEnv is not set, use DATABRICKS_OIDC_TOKEN as
// default value.
tokenSource: func() oidc.IDTokenSource {
v := cfg.OIDCTokenEnv
if v == "" {
v = "DATABRICKS_OIDC_TOKEN"
}
return oidc.NewEnvIDTokenSource(v)
}(),
},
{
name: "github-oidc",
tokenSource: &githubIDTokenSource{
Expand All @@ -26,6 +39,7 @@ func buildOidcTokenCredentialStrategies(cfg *Config) []CredentialsStrategy {
},
// Add new providers at the end of the list
}

strategies := []CredentialsStrategy{}
for _, idTokenSource := range idTokenSources {
oidcConfig := DatabricksOIDCTokenSourceConfig{
Expand Down
3 changes: 3 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ type Config struct {
// specified by this argument. This argument also holds currently selected auth.
AuthType string `name:"auth_type" env:"DATABRICKS_AUTH_TYPE" auth:"-"`

// Environment variable name that contains an OIDC ID token.
OIDCTokenEnv string `name:"oidc_token_env" env:"DATABRICKS_OIDC_TOKEN_ENV" auth:"-"`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to read this dynamically from the environment to support "refreshes"?
Note that neither AWS, GCP or Azure seem to support the refresh.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual environment variable referred to by this environment variable is read each time the IDTokenSource is called. I'm not sure it's worth also reading this one dynamically.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. I missread this as if DATABRICKS_OIDC_TOKEN_ENV itself was the token.


// Skip SSL certificate verification for HTTP calls.
// Use at your own risk or for unit testing purposes.
InsecureSkipVerify bool `name:"skip_verify" auth:"-"`
Expand Down
50 changes: 50 additions & 0 deletions config/experimental/auth/oidc/oidc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Package oidc provides utilities for working with OIDC ID tokens.
//
// This package is experimental and subject to change.
package oidc

import (
"context"
"fmt"
"os"
)

// IDToken represents an OIDC ID token that can be exchanged for a Databricks
// access token.
type IDToken struct {
Value string
}

// IDTokenSource is anything that returns an IDToken given an audience.
type IDTokenSource interface {
IDToken(ctx context.Context, audience string) (*IDToken, error)
}

// IDTokenSourceFn is an adapter to allow the use of ordinary functions as
// IDTokenSource.
//
// Example:
//
// ts := IDTokenSourceFn(func(ctx context.Context, aud string) (*IDToken, error) {
// return &IDToken{}, nil
// })
type IDTokenSourceFn func(ctx context.Context, audience string) (*IDToken, error)

func (fn IDTokenSourceFn) IDToken(ctx context.Context, audience string) (*IDToken, error) {
return fn(ctx, audience)
}

// NewEnvIDTokenSource returns an IDTokenSource that reads the token from
// environment variable env.
//
// Note that the IDTokenSource does not cache the token and will read the token
// from environment variable env each time.
func NewEnvIDTokenSource(env string) IDTokenSource {
return IDTokenSourceFn(func(ctx context.Context, _ string) (*IDToken, error) {
t := os.Getenv(env)
if t == "" {
return nil, fmt.Errorf("missing env var %q", env)
}
return &IDToken{Value: t}, nil
})
}
106 changes: 106 additions & 0 deletions config/experimental/auth/oidc/oidc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package oidc

import (
"context"
"fmt"
"testing"

"github.com/google/go-cmp/cmp"
)

func TestIDTokenSourceFn(t *testing.T) {
wantToken := &IDToken{Value: "from-func"}
wantErr := fmt.Errorf("test error")
wantAud := "func-audience"
wantCtx := context.Background()

ts := IDTokenSourceFn(func(gotCtx context.Context, gotAud string) (*IDToken, error) {
if gotCtx != wantCtx {
t.Errorf("unexpected context: got %v, want %v", gotCtx, wantCtx)
}
if gotAud != wantAud {
t.Errorf("unexpected audience: got %q, want %q", gotAud, wantAud)
}
return wantToken, wantErr
})

gotToken, gotErr := ts.IDToken(wantCtx, wantAud)

if gotErr != wantErr {
t.Errorf("IDToken() want error: %v, got error: %v", wantErr, gotErr)
}
if !cmp.Equal(gotToken, wantToken) {
t.Errorf("IDToken() token = %v, want %v", gotToken, wantToken)
}
}

func TestNewEnvIDTokenSource(t *testing.T) {
testCases := []struct {
desc string
envName string
envValue string
audience string
want *IDToken
wantErr bool
}{
{
desc: "Success - variable set",
envName: "OIDC_TEST_TOKEN_SUCCESS",
envValue: "test-token-123",
audience: "test-audience-1",
want: &IDToken{Value: "test-token-123"},
wantErr: false,
},
{
desc: "Failure - variable not set",
envName: "OIDC_TEST_TOKEN_MISSING",
envValue: "",
audience: "test-audience-2",
want: nil,
wantErr: true,
},
{
desc: "Failure - variable set to empty string",
envName: "OIDC_TEST_TOKEN_EMPTY",
envValue: "",
audience: "test-audience-3",
want: nil,
wantErr: true,
},
{
desc: "Success - different variable name",
envName: "ANOTHER_OIDC_TOKEN",
envValue: "another-token-456",
audience: "test-audience-4",
want: &IDToken{Value: "another-token-456"},
wantErr: false,
},
{
desc: "Success - empty audience string",
envName: "OIDC_TEST_TOKEN_NO_AUDIENCE",
envValue: "token-no-audience",
audience: "",
want: &IDToken{Value: "token-no-audience"},
wantErr: false,
},
}

for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Setenv(tc.envName, tc.envValue)

ts := NewEnvIDTokenSource(tc.envName)
got, gotErr := ts.IDToken(context.Background(), tc.audience)

if tc.wantErr && gotErr == nil {
t.Fatalf("IDToken() want error, got none")
}
if !tc.wantErr && gotErr != nil {
t.Fatalf("IDToken() want no error, got error: %v", gotErr)
}
if !cmp.Equal(got, tc.want) {
t.Errorf("IDToken() token = %v, want %v", got, tc.want)
}
})
}
}
Loading
Loading