From dd7ec28723dbff32cdf70c63ff1e5b855adccefe Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 20 May 2025 12:45:27 +0000 Subject: [PATCH 01/10] Surface errors on specified auth type --- config/auth_default.go | 184 +++++++++++++-------------- config/auth_default_test.go | 53 ++++++-- config/token_source_strategy.go | 24 ++-- config/token_source_strategy_test.go | 2 +- 4 files changed, 142 insertions(+), 121 deletions(-) diff --git a/config/auth_default.go b/config/auth_default.go index e2be578fe..67da63446 100644 --- a/config/auth_default.go +++ b/config/auth_default.go @@ -2,7 +2,6 @@ package config import ( "context" - "errors" "fmt" "github.com/databricks/databricks-sdk-go/config/credentials" @@ -10,78 +9,7 @@ import ( "github.com/databricks/databricks-sdk-go/logger" ) -// Constructs all Databricks OIDC Credentials Strategies -func buildOidcTokenCredentialStrategies(cfg *Config) []CredentialsStrategy { - type namedIdTokenSource struct { - name string - tokenSource oidc.IDTokenSource - } - idTokenSources := []namedIdTokenSource{ - { - name: "env-oidc", - // If the OIDCTokenEnv is not set, use DATABRICKS_OIDC_TOKEN as - // default value. - tokenSource: func() oidc.IDTokenSource { - v := cfg.OIDCTokenEnv - if v == "" { - v = "DATABRICKS_OIDC_TOKEN" - } - return oidc.NewEnvIDTokenSource(v) - }(), - }, - { - name: "file-oidc", - tokenSource: oidc.NewFileTokenSource(cfg.OIDCTokenFilepath), - }, - { - name: "github-oidc", - tokenSource: oidc.NewGithubIDTokenSource( - cfg.refreshClient, - cfg.ActionsIDTokenRequestURL, - cfg.ActionsIDTokenRequestToken, - ), - }, - // Add new providers at the end of the list - } - - strategies := []CredentialsStrategy{} - for _, idTokenSource := range idTokenSources { - oidcConfig := oidc.DatabricksOIDCTokenSourceConfig{ - ClientID: cfg.ClientID, - Host: cfg.CanonicalHostName(), - TokenEndpointProvider: cfg.getOidcEndpoints, - Audience: cfg.TokenAudience, - IDTokenSource: idTokenSource.tokenSource, - } - if cfg.IsAccountClient() { - oidcConfig.AccountID = cfg.AccountID - } - tokenSource := oidc.NewDatabricksOIDCTokenSource(oidcConfig) - strategies = append(strategies, NewTokenSourceStrategy(idTokenSource.name, tokenSource)) - } - return strategies -} - -func buildDefaultStrategies(cfg *Config) []CredentialsStrategy { - strategies := []CredentialsStrategy{} - strategies = append(strategies, - PatCredentials{}, - BasicCredentials{}, - M2mCredentials{}, - DatabricksCliCredentials, - MetadataServiceCredentials{}) - strategies = append(strategies, buildOidcTokenCredentialStrategies(cfg)...) - strategies = append(strategies, - // Attempt to configure auth from most specific to most generic (the Azure CLI). - AzureGithubOIDCCredentials{}, - AzureMsiCredentials{}, - AzureClientSecretCredentials{}, - AzureCliCredentials{}, - // Attempt to configure auth from most specific to most generic (Google Application Default Credentials). - GoogleCredentials{}, - GoogleDefaultCredentials{}) - return strategies -} +const authDocURL = "https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication" type DefaultCredentials struct { name string @@ -94,33 +22,103 @@ func (c *DefaultCredentials) Name() string { return c.name } -var authFlowUrl = "https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication" -var errorMessage = fmt.Sprintf("cannot configure default credentials, please check %s to configure credentials for your preferred authentication method", authFlowUrl) - -// ErrCannotConfigureAuth (experimental) is returned when no auth is configured -var ErrCannotConfigureAuth = errors.New(errorMessage) - func (c *DefaultCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { err := cfg.EnsureResolved() if err != nil { return nil, err } - for _, p := range buildDefaultStrategies(cfg) { - if cfg.AuthType != "" && p.Name() != cfg.AuthType { - // ignore other auth types if one is explicitly enforced - logger.Infof(ctx, "Ignoring %s auth, because %s is preferred", p.Name(), cfg.AuthType) - continue - } - logger.Tracef(ctx, "Attempting to configure auth: %s", p.Name()) - credentialsProvider, err := p.Configure(ctx, cfg) - if err != nil { - return nil, fmt.Errorf("%s: %w", p.Name(), err) + + // Order in which strategies are tested. Iteration proceeds from most + // specific to most generic, and the first strategy to return a non-nil + // credentials provider is selected. + // + // Modifying this order could break authentication for users whose + // environments are compatible with multiple strategies and who rely on the + // current priority for tie-breaking. While arguably an anti-pattern, this + // order is maintained for backward compatibility. + strategies := []CredentialsStrategy{ + PatCredentials{}, + BasicCredentials{}, + M2mCredentials{}, + u2mCredentials{}, + MetadataServiceCredentials{}, + // OIDC Strategies from most specific to most generic. + oidcStrategy(cfg, "github-oidc", githubOIDCTokenSource(cfg)), + oidcStrategy(cfg, "env-oidc", envOIDCTokenSource(cfg)), + oidcStrategy(cfg, "file-oidc", fileOIDCTokenSource(cfg)), + // Azure strategies from most specific to most generic. + AzureGithubOIDCCredentials{}, + AzureMsiCredentials{}, + AzureClientSecretCredentials{}, + AzureCliCredentials{}, + // Google strategies from most specific to most generic. + GoogleCredentials{}, + GoogleDefaultCredentials{}, + } + + // If an auth type is specified, try to configure the credentials for that + // specific auth type. If an error is encountered, return it. + if cfg.AuthType != "" { + for _, s := range strategies { + if s.Name() == cfg.AuthType { + logger.Tracef(ctx, "Attempting to configure auth: %q", s.Name()) + c.name = s.Name() + return s.Configure(ctx, cfg) + } } - if credentialsProvider == nil { + return nil, fmt.Errorf("auth type %q not found, please check %s for a list of supported auth types", cfg.AuthType, authDocURL) + } + + // If no auth type is specified, try the strategies in order. If a strategy + // succeeds, returns the credentials provider. If a strategy fails, swallow + // the error and try the next strategy. + for _, s := range strategies { + logger.Tracef(ctx, "Attempting to configure auth: %q", s.Name()) + cp, err := s.Configure(ctx, cfg) + if err != nil || cp == nil { + logger.Tracef(ctx, "Failed to configure auth: %q", s.Name()) continue } - c.name = p.Name() - return credentialsProvider, nil + c.name = s.Name() + return cp, nil } - return nil, ErrCannotConfigureAuth + + return nil, fmt.Errorf("cannot configure default credentials, please check %s to configure credentials for your preferred authentication method", authDocURL) +} + +// oidcStrategy returns a new CredentialsStrategy to authenticate with +// Databricks using the given OIDC IDTokenSource. +func oidcStrategy(cfg *Config, name string, ts oidc.IDTokenSource) CredentialsStrategy { + oidcConfig := oidc.DatabricksOIDCTokenSourceConfig{ + ClientID: cfg.ClientID, + Host: cfg.CanonicalHostName(), + TokenEndpointProvider: cfg.getOidcEndpoints, + Audience: cfg.TokenAudience, + IDTokenSource: ts, + } + if cfg.IsAccountClient() { + oidcConfig.AccountID = cfg.AccountID + } + tokenSource := oidc.NewDatabricksOIDCTokenSource(oidcConfig) + return NewTokenSourceStrategy(name, tokenSource) +} + +func githubOIDCTokenSource(cfg *Config) oidc.IDTokenSource { + return oidc.NewGithubIDTokenSource( + cfg.refreshClient, + cfg.ActionsIDTokenRequestURL, + cfg.ActionsIDTokenRequestToken, + ) +} + +func envOIDCTokenSource(cfg *Config) oidc.IDTokenSource { + v := cfg.OIDCTokenEnv + if v == "" { + v = "DATABRICKS_OIDC_TOKEN" + } + return oidc.NewEnvIDTokenSource(v) +} + +func fileOIDCTokenSource(cfg *Config) oidc.IDTokenSource { + return oidc.NewFileTokenSource(cfg.OIDCTokenFilepath) } diff --git a/config/auth_default_test.go b/config/auth_default_test.go index 78642c200..21fb5968f 100644 --- a/config/auth_default_test.go +++ b/config/auth_default_test.go @@ -1,19 +1,48 @@ -package config_test +package config import ( "context" - "errors" + "strings" "testing" - - "github.com/databricks/databricks-sdk-go" - "github.com/databricks/databricks-sdk-go/config" - "github.com/databricks/databricks-sdk-go/internal/env" - "github.com/stretchr/testify/assert" ) -func TestErrCannotConfigureAuth(t *testing.T) { - env.CleanupEnvironment(t) - w := databricks.Must(databricks.NewWorkspaceClient()) - _, err := w.CurrentUser.Me(context.Background()) - assert.True(t, errors.Is(err, config.ErrCannotConfigureAuth)) +func TestDefaultCredentials_Configure_unknownAuthType(t *testing.T) { + ctx := context.Background() + cfg := &Config{ + AuthType: "unknown-mode-for-test", + resolved: true, // avoid calling EnsureResolved + } + + dc := DefaultCredentials{} + got, gotErr := dc.Configure(ctx, cfg) + + if got != nil { + t.Errorf("DefaultCredentials.Configure: got %v, want nil", got) + } + if gotErr == nil { + t.Errorf("DefaultCredentials.Configure: got error %v, want non-nil", gotErr) + } + if !strings.Contains(gotErr.Error(), "auth type \"unknown-mode-for-test\" not found") { + t.Errorf("DefaultCredentials.Configure: got error %v, want error containing \"auth type \"unknown-mode-for-test\" not found\"", gotErr) + } +} + +func TestDefaultCredentials_Configure_noValidAuth(t *testing.T) { + ctx := context.Background() + cfg := &Config{ + resolved: true, // avoid calling EnsureResolved + } + + dc := DefaultCredentials{} + got, gotErr := dc.Configure(ctx, cfg) + + if got != nil { + t.Errorf("DefaultCredentials.Configure: got %v, want nil", got) + } + if gotErr == nil { + t.Errorf("DefaultCredentials.Configure: got error %v, want non-nil", gotErr) + } + if !strings.Contains(gotErr.Error(), "cannot configure default credentials") { + t.Errorf("DefaultCredentials.Configure: got error %v, want error containing \"cannot configure default credentials\"", gotErr) + } } diff --git a/config/token_source_strategy.go b/config/token_source_strategy.go index 45393ca31..5765bc0d6 100644 --- a/config/token_source_strategy.go +++ b/config/token_source_strategy.go @@ -11,32 +11,26 @@ import ( ) // Creates a CredentialsStrategy from a TokenSource. -func NewTokenSourceStrategy( - name string, - tokenSource auth.TokenSource, -) CredentialsStrategy { - return &tokenSourceStrategy{ - name: name, - tokenSource: tokenSource, - } +func NewTokenSourceStrategy(name string, ts auth.TokenSource) CredentialsStrategy { + return &tokenSourceStrategy{name: name, ts: ts} } -// tokenSourceStrategy is wrapper on a auth.TokenSource which converts it into a CredentialsStrategy +// tokenSourceStrategy is wrapper on a auth.TokenSource which converts it into +// a CredentialsStrategy. type tokenSourceStrategy struct { - tokenSource auth.TokenSource - name string + name string + ts auth.TokenSource } // Configure implements [CredentialsStrategy.Configure]. -func (t *tokenSourceStrategy) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { - +func (tss *tokenSourceStrategy) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { // If we cannot get a token, skip this CredentialsStrategy. // We don't want to fail here because it's possible that the supplier is enabled // without the user action. For instance, jobs running in GitHub will have // OIDC environment variables added automatically - cached := auth.NewCachedTokenSource(t.tokenSource) + cached := auth.NewCachedTokenSource(tss.ts) if _, err := cached.Token(ctx); err != nil { - logger.Debugf(ctx, fmt.Sprintf("Skipping %s due to error: %v", t.name, err)) + logger.Debugf(ctx, fmt.Sprintf("Skipping %s due to error: %v", tss.name, err)) return nil, nil } diff --git a/config/token_source_strategy_test.go b/config/token_source_strategy_test.go index 48ecc38c7..54315b133 100644 --- a/config/token_source_strategy_test.go +++ b/config/token_source_strategy_test.go @@ -35,7 +35,7 @@ func TestDatabricksTokenSourceStrategy(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { strat := &tokenSourceStrategy{ name: "github-oidc", - tokenSource: auth.TokenSourceFn(func(_ context.Context) (*oauth2.Token, error) { + ts: auth.TokenSourceFn(func(_ context.Context) (*oauth2.Token, error) { return tc.token, tc.tokenSourceError }), } From 879125efb0280d6e6f3473a6ed12be038c554fdf Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 20 May 2025 13:57:25 +0000 Subject: [PATCH 02/10] Clean-up --- config/credentials/credentials.go | 45 ++++++++++++++--- config/token_source_strategy.go | 19 +++----- config/token_source_strategy_test.go | 72 ++++++++++++++-------------- 3 files changed, 79 insertions(+), 57 deletions(-) diff --git a/config/credentials/credentials.go b/config/credentials/credentials.go index 0fe494d5c..580040fe2 100644 --- a/config/credentials/credentials.go +++ b/config/credentials/credentials.go @@ -1,8 +1,11 @@ package credentials import ( + "context" + "fmt" "net/http" + "github.com/databricks/databricks-sdk-go/config/experimental/auth" "golang.org/x/oauth2" ) @@ -32,22 +35,50 @@ type OAuthCredentialsProvider interface { Token() (*oauth2.Token, error) } -type oauthCredentialsProvider struct { - setHeaders func(r *http.Request) error - token func() (*oauth2.Token, error) +// NewOAuthCredentialsProviderFromTokenSource returns a new OAuthCredentialsProvider +// that uses the given TokenSource to get the token. +// +// The returned OAuthCredentialsProvider does not alter the behavior of the token +// source. For example, it does not implement any caching or retrying logic. It +// is the responsibility of the TokenSource to implement these behaviors. +func NewOAuthCredentialsProviderFromTokenSource(ts auth.TokenSource) OAuthCredentialsProvider { + return &tsOAuthCredentialsProvider{ts} } -func (c *oauthCredentialsProvider) SetHeaders(r *http.Request) error { - return c.setHeaders(r) +type tsOAuthCredentialsProvider struct { + ts auth.TokenSource } -func (c *oauthCredentialsProvider) Token() (*oauth2.Token, error) { - return c.token() +func (cp tsOAuthCredentialsProvider) SetHeaders(r *http.Request) error { + token, err := cp.ts.Token(context.Background()) + if err != nil { + return fmt.Errorf("error getting token: %w", err) + } + token.SetAuthHeader(r) + return nil +} + +func (cp tsOAuthCredentialsProvider) Token() (*oauth2.Token, error) { + return cp.ts.Token(context.Background()) } +// DEPRECATED: Use NewOAuthCredentialsProviderFromTokenSource instead. func NewOAuthCredentialsProvider(visitor func(r *http.Request) error, tokenProvider func() (*oauth2.Token, error)) OAuthCredentialsProvider { return &oauthCredentialsProvider{ setHeaders: visitor, token: tokenProvider, } } + +type oauthCredentialsProvider struct { + setHeaders func(r *http.Request) error + token func() (*oauth2.Token, error) +} + +func (c *oauthCredentialsProvider) SetHeaders(r *http.Request) error { + return c.setHeaders(r) +} + +func (c *oauthCredentialsProvider) Token() (*oauth2.Token, error) { + return c.token() +} diff --git a/config/token_source_strategy.go b/config/token_source_strategy.go index 5765bc0d6..e5e5180df 100644 --- a/config/token_source_strategy.go +++ b/config/token_source_strategy.go @@ -2,12 +2,9 @@ package config import ( "context" - "fmt" "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/config/experimental/auth" - "github.com/databricks/databricks-sdk-go/config/experimental/auth/authconv" - "github.com/databricks/databricks-sdk-go/logger" ) // Creates a CredentialsStrategy from a TokenSource. @@ -24,18 +21,14 @@ type tokenSourceStrategy struct { // Configure implements [CredentialsStrategy.Configure]. func (tss *tokenSourceStrategy) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { - // If we cannot get a token, skip this CredentialsStrategy. - // We don't want to fail here because it's possible that the supplier is enabled - // without the user action. For instance, jobs running in GitHub will have - // OIDC environment variables added automatically - cached := auth.NewCachedTokenSource(tss.ts) - if _, err := cached.Token(ctx); err != nil { - logger.Debugf(ctx, fmt.Sprintf("Skipping %s due to error: %v", tss.name, err)) - return nil, nil + cp := credentials.NewOAuthCredentialsProviderFromTokenSource(auth.NewCachedTokenSource(tss.ts)) + + // Sanity check that a token can be obtained. + if _, err := cp.Token(); err != nil { + return nil, err } - visitor := refreshableAuthVisitor(cached) - return credentials.NewOAuthCredentialsProvider(visitor, authconv.OAuth2TokenSource(cached).Token), nil + return cp, nil } // Name implements [CredentialsStrategy.Name]. diff --git a/config/token_source_strategy_test.go b/config/token_source_strategy_test.go index 54315b133..ea9265029 100644 --- a/config/token_source_strategy_test.go +++ b/config/token_source_strategy_test.go @@ -7,63 +7,61 @@ import ( "testing" "github.com/databricks/databricks-sdk-go/config/experimental/auth" - "github.com/google/go-cmp/cmp" "golang.org/x/oauth2" ) -func TestDatabricksTokenSourceStrategy(t *testing.T) { +func TestTokenSourceStrategy_Configure(t *testing.T) { testCases := []struct { - desc string - token *oauth2.Token - tokenSourceError error - wantHeaders http.Header + desc string + ts auth.TokenSource + wantHeader string + wantError bool }{ { - desc: "token source error skips", - tokenSourceError: errors.New("random error"), + desc: "token source return error", + ts: auth.TokenSourceFn(func(_ context.Context) (*oauth2.Token, error) { + return nil, errors.New("test error") + }), + wantError: true, }, { - desc: "token source error skips", - token: &oauth2.Token{ - AccessToken: "token-123", - }, - wantHeaders: http.Header{"Authorization": {"Bearer token-123"}}, + desc: "token source return token", + ts: auth.TokenSourceFn(func(_ context.Context) (*oauth2.Token, error) { + return &oauth2.Token{ + AccessToken: "token-123", + }, nil + }), + wantHeader: "Bearer token-123", }, } for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - strat := &tokenSourceStrategy{ - name: "github-oidc", - ts: auth.TokenSourceFn(func(_ context.Context) (*oauth2.Token, error) { - return tc.token, tc.tokenSourceError - }), + start := &tokenSourceStrategy{ + name: "test-strategy", + ts: tc.ts, } - provider, err := strat.Configure(context.Background(), &Config{}) - if tc.tokenSourceError == nil && provider == nil { - t.Error("Provider expected to not be nil, but it is") - } - if tc.tokenSourceError != nil && provider != nil { - t.Error("A failure in the TokenSource should cause the provider to be nil, but it's not") - } - if err != nil { - t.Errorf("Configure() got error %q, want none", err) - } - - if provider != nil { - req, _ := http.NewRequest("GET", "http://localhost", nil) - gotErr := provider.SetHeaders(req) + provider, err := start.Configure(context.Background(), &Config{}) - if gotErr != nil { - t.Errorf("SetHeaders(): got error %q, want none", gotErr) + if tc.wantError { + if err == nil { + t.Errorf("Expected error, but got none") } - if diff := cmp.Diff(tc.wantHeaders, req.Header); diff != "" { - t.Errorf("Authenticate(): mismatch (-want +got):\n%s", diff) + if provider != nil { + t.Errorf("Expected provider to be nil, but it's not") } - + return } + if err != nil { + t.Errorf("Expected no error, but got %q", err) + } + req := &http.Request{Header: make(http.Header)} + provider.SetHeaders(req) + if req.Header.Get("Authorization") != tc.wantHeader { + t.Errorf("Expected header %q, but got %q", tc.wantHeader, req.Header.Get("Authorization")) + } }) } } From 0be50b034f00b2ef0ca96ceb4b691dc732204ce3 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 20 May 2025 13:59:54 +0000 Subject: [PATCH 03/10] Changelogs --- NEXT_CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 9683a9fb6..36a2565bf 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -6,6 +6,9 @@ ### Bug Fixes +- Stop swallowing authentication errors when a non-default auth type is + explicitly set ([#1223](https://github.com/databricks/databricks-sdk-go/pull/1223)) + ### Documentation ### Internal Changes From 44bb109c5890c1593fc7e78e030e6cb09f1d96ea Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 20 May 2025 14:04:26 +0000 Subject: [PATCH 04/10] Fix test --- config/auth_permutations_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/auth_permutations_test.go b/config/auth_permutations_test.go index 1f0c4ed02..ad43b0eb8 100644 --- a/config/auth_permutations_test.go +++ b/config/auth_permutations_test.go @@ -385,7 +385,7 @@ func TestConfig_AzureCliHost_Fail(t *testing.T) { "HOME": "testdata/azure", "FAIL": "yes", }, - AssertError: "default auth: azure-cli: cannot get access token: This is just a failing script.\n. Config: azure_workspace_resource_id=/sub/rg/ws", + AssertError: fmt.Sprintf("%s. Config: azure_workspace_resource_id=/sub/rg/ws", defaultAuthBaseErrorMessage), }.apply(t) } From 3546674599ece6f12c357cb8f24a3341227b08cc Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 20 May 2025 14:07:48 +0000 Subject: [PATCH 05/10] Simplify --- config/auth_default.go | 52 +++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/config/auth_default.go b/config/auth_default.go index 67da63446..db026d618 100644 --- a/config/auth_default.go +++ b/config/auth_default.go @@ -42,16 +42,16 @@ func (c *DefaultCredentials) Configure(ctx context.Context, cfg *Config) (creden M2mCredentials{}, u2mCredentials{}, MetadataServiceCredentials{}, - // OIDC Strategies from most specific to most generic. - oidcStrategy(cfg, "github-oidc", githubOIDCTokenSource(cfg)), - oidcStrategy(cfg, "env-oidc", envOIDCTokenSource(cfg)), - oidcStrategy(cfg, "file-oidc", fileOIDCTokenSource(cfg)), - // Azure strategies from most specific to most generic. + // OIDC Strategies. + githubOIDC(cfg), + envOIDC(cfg), + fileOIDC(cfg), + // Azure strategies. AzureGithubOIDCCredentials{}, AzureMsiCredentials{}, AzureClientSecretCredentials{}, AzureCliCredentials{}, - // Google strategies from most specific to most generic. + // Google strategies. GoogleCredentials{}, GoogleDefaultCredentials{}, } @@ -86,6 +86,26 @@ func (c *DefaultCredentials) Configure(ctx context.Context, cfg *Config) (creden return nil, fmt.Errorf("cannot configure default credentials, please check %s to configure credentials for your preferred authentication method", authDocURL) } +func githubOIDC(cfg *Config) CredentialsStrategy { + return oidcStrategy(cfg, "github-oidc", oidc.NewGithubIDTokenSource( + cfg.refreshClient, + cfg.ActionsIDTokenRequestURL, + cfg.ActionsIDTokenRequestToken, + )) +} + +func envOIDC(cfg *Config) CredentialsStrategy { + v := cfg.OIDCTokenEnv + if v == "" { + v = "DATABRICKS_OIDC_TOKEN" + } + return oidcStrategy(cfg, "env-oidc", oidc.NewEnvIDTokenSource(v)) +} + +func fileOIDC(cfg *Config) CredentialsStrategy { + return oidcStrategy(cfg, "file-oidc", oidc.NewFileTokenSource(cfg.OIDCTokenFilepath)) +} + // oidcStrategy returns a new CredentialsStrategy to authenticate with // Databricks using the given OIDC IDTokenSource. func oidcStrategy(cfg *Config, name string, ts oidc.IDTokenSource) CredentialsStrategy { @@ -102,23 +122,3 @@ func oidcStrategy(cfg *Config, name string, ts oidc.IDTokenSource) CredentialsSt tokenSource := oidc.NewDatabricksOIDCTokenSource(oidcConfig) return NewTokenSourceStrategy(name, tokenSource) } - -func githubOIDCTokenSource(cfg *Config) oidc.IDTokenSource { - return oidc.NewGithubIDTokenSource( - cfg.refreshClient, - cfg.ActionsIDTokenRequestURL, - cfg.ActionsIDTokenRequestToken, - ) -} - -func envOIDCTokenSource(cfg *Config) oidc.IDTokenSource { - v := cfg.OIDCTokenEnv - if v == "" { - v = "DATABRICKS_OIDC_TOKEN" - } - return oidc.NewEnvIDTokenSource(v) -} - -func fileOIDCTokenSource(cfg *Config) oidc.IDTokenSource { - return oidc.NewFileTokenSource(cfg.OIDCTokenFilepath) -} From 24910514230b22ba76055995e25e6b262ddfe8a4 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 20 May 2025 19:05:36 +0000 Subject: [PATCH 06/10] Update tests --- config/auth_azure_msi_test.go | 4 ++++ config/auth_metadata_service.go | 16 ++++++++-------- config/auth_metadata_service_test.go | 6 ++++++ 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/config/auth_azure_msi_test.go b/config/auth_azure_msi_test.go index ce2328153..f86a48266 100644 --- a/config/auth_azure_msi_test.go +++ b/config/auth_azure_msi_test.go @@ -47,6 +47,7 @@ func TestMsiHappyFlow(t *testing.T) { assertHeaders(t, &Config{ AzureUseMSI: true, AzureResourceID: "/a/b/c", + AuthType: "azure-msi", HTTPTransport: fixtures.MappingTransport{ "GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": { ExpectedHeaders: map[string]string{ @@ -86,6 +87,7 @@ func TestMsiFailsOnResolveWorkspace(t *testing.T) { _, err := authenticateRequest(&Config{ AzureUseMSI: true, AzureResourceID: "/a/b/c", + AuthType: "azure-msi", HTTPTransport: fixtures.MappingTransport{ "GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": { Response: someValidToken("bcd"), @@ -108,6 +110,7 @@ func TestMsiTokenNotFound(t *testing.T) { AzureUseMSI: true, AzureClientID: "abc", AzureResourceID: "/a/b/c", + AuthType: "azure-msi", HTTPTransport: fixtures.MappingTransport{ "GET /metadata/identity/oauth2/token?api-version=2018-02-01&client_id=abc&resource=https%3A%2F%2Fmanagement.azure.com%2F": { Status: 404, @@ -122,6 +125,7 @@ func TestMsiInvalidTokenExpiry(t *testing.T) { _, err := authenticateRequest(&Config{ AzureUseMSI: true, AzureResourceID: "/a/b/c", + AuthType: "azure-msi", HTTPTransport: fixtures.MappingTransport{ "GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": { Response: map[string]any{ diff --git a/config/auth_metadata_service.go b/config/auth_metadata_service.go index beae500c4..3718b4a93 100644 --- a/config/auth_metadata_service.go +++ b/config/auth_metadata_service.go @@ -66,15 +66,15 @@ func (c MetadataServiceCredentials) Configure(ctx context.Context, cfg *Config) metadataServiceURL: parsedMetadataServiceURL, config: cfg, } - response, err := ms.Get() + response, err := ms.Get(ctx) if err != nil { return nil, err } if response == nil { return nil, nil } - visitor := refreshableVisitor(&ms) - return credentials.NewCredentialsProvider(visitor), nil + + return credentials.NewOAuthCredentialsProviderFromTokenSource(ms), nil } type metadataService struct { @@ -83,8 +83,8 @@ type metadataService struct { } // performs a request to the metadata service and returns the token -func (s metadataService) Get() (*oauth2.Token, error) { - ctx, cancel := context.WithTimeout(context.Background(), metadataServiceTimeout) +func (s metadataService) Get(ctx context.Context) (*oauth2.Token, error) { + ctx, cancel := context.WithTimeout(ctx, metadataServiceTimeout) defer cancel() var inner msiToken err := s.config.refreshClient.Do(ctx, http.MethodGet, @@ -99,15 +99,15 @@ func (s metadataService) Get() (*oauth2.Token, error) { return inner.Token() } -func (t metadataService) Token() (*oauth2.Token, error) { - token, err := t.Get() +func (t metadataService) Token(ctx context.Context) (*oauth2.Token, error) { + token, err := t.Get(ctx) if err != nil { return nil, err } if token == nil { return nil, fmt.Errorf("no token returned from metadata service") } - logger.Debugf(context.Background(), + logger.Debugf(ctx, "Refreshed access token from local metadata service, which expires on %s", token.Expiry.Format(time.RFC3339)) return token, nil diff --git a/config/auth_metadata_service_test.go b/config/auth_metadata_service_test.go index 0e0c70844..d6e54931f 100644 --- a/config/auth_metadata_service_test.go +++ b/config/auth_metadata_service_test.go @@ -12,6 +12,7 @@ import ( func TestAuthServerCheckHost(t *testing.T) { assertHeaders(t, &Config{ Host: "YYY", + AuthType: "metadata-service", MetadataServiceURL: "http://localhost:1234/metadata/token", HTTPTransport: fixtures.MappingTransport{ "GET /metadata/token": { @@ -31,6 +32,7 @@ func TestAuthServerCheckHost(t *testing.T) { func TestAuthServerRefresh(t *testing.T) { assertHeaders(t, &Config{ Host: "YYY", + AuthType: "metadata-service", MetadataServiceURL: "http://localhost:1234/metadata/token", HTTPTransport: fixtures.SliceTransport{ { @@ -56,6 +58,7 @@ func TestAuthServerRefresh(t *testing.T) { func TestAuthServerNotLocalhost(t *testing.T) { _, err := authenticateRequest(&Config{ Host: "YYY", + AuthType: "metadata-service", MetadataServiceURL: "http://otherhost/metadata/token", HTTPTransport: fixtures.Failures, }) @@ -65,6 +68,7 @@ func TestAuthServerNotLocalhost(t *testing.T) { func TestAuthServerMalformed(t *testing.T) { _, err := authenticateRequest(&Config{ Host: "YYY", + AuthType: "metadata-service", MetadataServiceURL: "#$%^&*", HTTPTransport: fixtures.Failures, }) @@ -74,6 +78,7 @@ func TestAuthServerMalformed(t *testing.T) { func TestAuthServerNoContent(t *testing.T) { _, err := authenticateRequest(&Config{ Host: "YYY", + AuthType: "metadata-service", MetadataServiceURL: "http://localhost:1234/metadata/token", HTTPTransport: fixtures.MappingTransport{ "GET /metadata/token": { @@ -87,6 +92,7 @@ func TestAuthServerNoContent(t *testing.T) { func TestAuthServerFailures(t *testing.T) { _, err := authenticateRequest(&Config{ Host: "YYY", + AuthType: "metadata-service", MetadataServiceURL: "http://localhost:1234/metadata/token", HTTPTransport: fixtures.Failures, }) From 7b198460a067cd6c04f99ac15988bb493a15e3ef Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 20 May 2025 19:20:33 +0000 Subject: [PATCH 07/10] Small simplification --- config/token_source_strategy_test.go | 34 ++++++++++++++++------------ 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/config/token_source_strategy_test.go b/config/token_source_strategy_test.go index ea9265029..713a757bc 100644 --- a/config/token_source_strategy_test.go +++ b/config/token_source_strategy_test.go @@ -6,10 +6,22 @@ import ( "net/http" "testing" + "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/config/experimental/auth" "golang.org/x/oauth2" ) +// authHeader returns the Authorization header set by the given credentials +// provider. It returns an empty string if the provider is nil. +func authHeader(cp credentials.CredentialsProvider) string { + if cp == nil { + return "" + } + req := &http.Request{Header: http.Header{}} + cp.SetHeaders(req) + return req.Header.Get("Authorization") +} + func TestTokenSourceStrategy_Configure(t *testing.T) { testCases := []struct { desc string @@ -42,25 +54,17 @@ func TestTokenSourceStrategy_Configure(t *testing.T) { ts: tc.ts, } - provider, err := start.Configure(context.Background(), &Config{}) + cp, err := start.Configure(context.Background(), &Config{}) + gotHeader := authHeader(cp) - if tc.wantError { - if err == nil { - t.Errorf("Expected error, but got none") - } - if provider != nil { - t.Errorf("Expected provider to be nil, but it's not") - } - return + if tc.wantError && err == nil { + t.Errorf("Expected error, but got none") } - - if err != nil { + if !tc.wantError && err != nil { t.Errorf("Expected no error, but got %q", err) } - req := &http.Request{Header: make(http.Header)} - provider.SetHeaders(req) - if req.Header.Get("Authorization") != tc.wantHeader { - t.Errorf("Expected header %q, but got %q", tc.wantHeader, req.Header.Get("Authorization")) + if gotHeader != tc.wantHeader { + t.Errorf("Expected header %q, but got %q", tc.wantHeader, gotHeader) } }) } From 5c8e8933336876c926e22759b4dcdfe32ad2e6a5 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 20 May 2025 19:22:25 +0000 Subject: [PATCH 08/10] Fix test --- config/auth_m2m_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/config/auth_m2m_test.go b/config/auth_m2m_test.go index 1d77a126a..4284cbe62 100644 --- a/config/auth_m2m_test.go +++ b/config/auth_m2m_test.go @@ -74,6 +74,7 @@ func TestM2mNotSupported(t *testing.T) { Host: "a", ClientID: "b", ClientSecret: "c", + AuthType: "oauth-m2m", HTTPTransport: fixtures.MappingTransport{ "GET /oidc/.well-known/oauth-authorization-server": { Status: 404, From 47aa5e5ff7f06cecf8f93cb6c73ea98e7dd89024 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 20 May 2025 19:25:50 +0000 Subject: [PATCH 09/10] Add comment --- config/token_source_strategy.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/config/token_source_strategy.go b/config/token_source_strategy.go index e5e5180df..5181a3260 100644 --- a/config/token_source_strategy.go +++ b/config/token_source_strategy.go @@ -24,6 +24,9 @@ func (tss *tokenSourceStrategy) Configure(ctx context.Context, cfg *Config) (cre cp := credentials.NewOAuthCredentialsProviderFromTokenSource(auth.NewCachedTokenSource(tss.ts)) // Sanity check that a token can be obtained. + // + // TODO: Move this outside of this function. If credentials providers have + // to be tested, this should be done in the main default loop, not here. if _, err := cp.Token(); err != nil { return nil, err } From 2441754eecda11438338f9df8e98642ea7362f29 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 21 May 2025 07:23:16 +0000 Subject: [PATCH 10/10] Address comments --- config/auth_default.go | 2 +- config/auth_default_test.go | 69 +++++++++++++++++++------------------ 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/config/auth_default.go b/config/auth_default.go index db026d618..cb3b948e3 100644 --- a/config/auth_default.go +++ b/config/auth_default.go @@ -76,7 +76,7 @@ func (c *DefaultCredentials) Configure(ctx context.Context, cfg *Config) (creden logger.Tracef(ctx, "Attempting to configure auth: %q", s.Name()) cp, err := s.Configure(ctx, cfg) if err != nil || cp == nil { - logger.Tracef(ctx, "Failed to configure auth: %q", s.Name()) + logger.Debugf(ctx, "Failed to configure auth: %q", s.Name()) continue } c.name = s.Name() diff --git a/config/auth_default_test.go b/config/auth_default_test.go index 21fb5968f..fbcb67116 100644 --- a/config/auth_default_test.go +++ b/config/auth_default_test.go @@ -6,43 +6,44 @@ import ( "testing" ) -func TestDefaultCredentials_Configure_unknownAuthType(t *testing.T) { - ctx := context.Background() - cfg := &Config{ - AuthType: "unknown-mode-for-test", - resolved: true, // avoid calling EnsureResolved +func TestDefaultCredentials_Configure(t *testing.T) { + testCases := []struct { + desc string + authType string + wantErr string + }{ + { + desc: "unknown auth type", + authType: "unknown-auth-type-1337", + wantErr: "auth type \"unknown-auth-type-1337\" not found", + }, + { + desc: "not valid auth", + authType: "", + wantErr: "cannot configure default credentials", + }, } - dc := DefaultCredentials{} - got, gotErr := dc.Configure(ctx, cfg) + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + ctx := context.Background() + cfg := &Config{ + AuthType: tc.authType, + resolved: true, // avoid calling EnsureResolved + } - if got != nil { - t.Errorf("DefaultCredentials.Configure: got %v, want nil", got) - } - if gotErr == nil { - t.Errorf("DefaultCredentials.Configure: got error %v, want non-nil", gotErr) - } - if !strings.Contains(gotErr.Error(), "auth type \"unknown-mode-for-test\" not found") { - t.Errorf("DefaultCredentials.Configure: got error %v, want error containing \"auth type \"unknown-mode-for-test\" not found\"", gotErr) - } -} + dc := DefaultCredentials{} + got, gotErr := dc.Configure(ctx, cfg) -func TestDefaultCredentials_Configure_noValidAuth(t *testing.T) { - ctx := context.Background() - cfg := &Config{ - resolved: true, // avoid calling EnsureResolved - } - - dc := DefaultCredentials{} - got, gotErr := dc.Configure(ctx, cfg) - - if got != nil { - t.Errorf("DefaultCredentials.Configure: got %v, want nil", got) - } - if gotErr == nil { - t.Errorf("DefaultCredentials.Configure: got error %v, want non-nil", gotErr) - } - if !strings.Contains(gotErr.Error(), "cannot configure default credentials") { - t.Errorf("DefaultCredentials.Configure: got error %v, want error containing \"cannot configure default credentials\"", gotErr) + if got != nil { + t.Errorf("DefaultCredentials.Configure: got %v, want nil", got) + } + if gotErr == nil { + t.Errorf("DefaultCredentials.Configure: got error %v, want non-nil", gotErr) + } + if !strings.Contains(gotErr.Error(), tc.wantErr) { + t.Errorf("DefaultCredentials.Configure: got error %v, want error containing %q", gotErr, tc.wantErr) + } + }) } }