diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 9683a9fb6..36a2565bf 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -6,6 +6,9 @@ ### Bug Fixes +- Stop swallowing authentication errors when a non-default auth type is + explicitly set ([#1223](https://github.com/databricks/databricks-sdk-go/pull/1223)) + ### Documentation ### Internal Changes diff --git a/config/auth_azure_msi_test.go b/config/auth_azure_msi_test.go index ce2328153..f86a48266 100644 --- a/config/auth_azure_msi_test.go +++ b/config/auth_azure_msi_test.go @@ -47,6 +47,7 @@ func TestMsiHappyFlow(t *testing.T) { assertHeaders(t, &Config{ AzureUseMSI: true, AzureResourceID: "/a/b/c", + AuthType: "azure-msi", HTTPTransport: fixtures.MappingTransport{ "GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": { ExpectedHeaders: map[string]string{ @@ -86,6 +87,7 @@ func TestMsiFailsOnResolveWorkspace(t *testing.T) { _, err := authenticateRequest(&Config{ AzureUseMSI: true, AzureResourceID: "/a/b/c", + AuthType: "azure-msi", HTTPTransport: fixtures.MappingTransport{ "GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": { Response: someValidToken("bcd"), @@ -108,6 +110,7 @@ func TestMsiTokenNotFound(t *testing.T) { AzureUseMSI: true, AzureClientID: "abc", AzureResourceID: "/a/b/c", + AuthType: "azure-msi", HTTPTransport: fixtures.MappingTransport{ "GET /metadata/identity/oauth2/token?api-version=2018-02-01&client_id=abc&resource=https%3A%2F%2Fmanagement.azure.com%2F": { Status: 404, @@ -122,6 +125,7 @@ func TestMsiInvalidTokenExpiry(t *testing.T) { _, err := authenticateRequest(&Config{ AzureUseMSI: true, AzureResourceID: "/a/b/c", + AuthType: "azure-msi", HTTPTransport: fixtures.MappingTransport{ "GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": { Response: map[string]any{ diff --git a/config/auth_default.go b/config/auth_default.go index e2be578fe..cb3b948e3 100644 --- a/config/auth_default.go +++ b/config/auth_default.go @@ -2,7 +2,6 @@ package config import ( "context" - "errors" "fmt" "github.com/databricks/databricks-sdk-go/config/credentials" @@ -10,117 +9,116 @@ import ( "github.com/databricks/databricks-sdk-go/logger" ) -// Constructs all Databricks OIDC Credentials Strategies -func buildOidcTokenCredentialStrategies(cfg *Config) []CredentialsStrategy { - type namedIdTokenSource struct { - name string - tokenSource oidc.IDTokenSource - } - idTokenSources := []namedIdTokenSource{ - { - name: "env-oidc", - // If the OIDCTokenEnv is not set, use DATABRICKS_OIDC_TOKEN as - // default value. - tokenSource: func() oidc.IDTokenSource { - v := cfg.OIDCTokenEnv - if v == "" { - v = "DATABRICKS_OIDC_TOKEN" - } - return oidc.NewEnvIDTokenSource(v) - }(), - }, - { - name: "file-oidc", - tokenSource: oidc.NewFileTokenSource(cfg.OIDCTokenFilepath), - }, - { - name: "github-oidc", - tokenSource: oidc.NewGithubIDTokenSource( - cfg.refreshClient, - cfg.ActionsIDTokenRequestURL, - cfg.ActionsIDTokenRequestToken, - ), - }, - // Add new providers at the end of the list - } +const authDocURL = "https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication" - strategies := []CredentialsStrategy{} - for _, idTokenSource := range idTokenSources { - oidcConfig := oidc.DatabricksOIDCTokenSourceConfig{ - ClientID: cfg.ClientID, - Host: cfg.CanonicalHostName(), - TokenEndpointProvider: cfg.getOidcEndpoints, - Audience: cfg.TokenAudience, - IDTokenSource: idTokenSource.tokenSource, - } - if cfg.IsAccountClient() { - oidcConfig.AccountID = cfg.AccountID - } - tokenSource := oidc.NewDatabricksOIDCTokenSource(oidcConfig) - strategies = append(strategies, NewTokenSourceStrategy(idTokenSource.name, tokenSource)) +type DefaultCredentials struct { + name string +} + +func (c *DefaultCredentials) Name() string { + if c.name == "" { + return "default" } - return strategies + return c.name } -func buildDefaultStrategies(cfg *Config) []CredentialsStrategy { - strategies := []CredentialsStrategy{} - strategies = append(strategies, +func (c *DefaultCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { + err := cfg.EnsureResolved() + if err != nil { + return nil, err + } + + // Order in which strategies are tested. Iteration proceeds from most + // specific to most generic, and the first strategy to return a non-nil + // credentials provider is selected. + // + // Modifying this order could break authentication for users whose + // environments are compatible with multiple strategies and who rely on the + // current priority for tie-breaking. While arguably an anti-pattern, this + // order is maintained for backward compatibility. + strategies := []CredentialsStrategy{ PatCredentials{}, BasicCredentials{}, M2mCredentials{}, - DatabricksCliCredentials, - MetadataServiceCredentials{}) - strategies = append(strategies, buildOidcTokenCredentialStrategies(cfg)...) - strategies = append(strategies, - // Attempt to configure auth from most specific to most generic (the Azure CLI). + u2mCredentials{}, + MetadataServiceCredentials{}, + // OIDC Strategies. + githubOIDC(cfg), + envOIDC(cfg), + fileOIDC(cfg), + // Azure strategies. AzureGithubOIDCCredentials{}, AzureMsiCredentials{}, AzureClientSecretCredentials{}, AzureCliCredentials{}, - // Attempt to configure auth from most specific to most generic (Google Application Default Credentials). + // Google strategies. GoogleCredentials{}, - GoogleDefaultCredentials{}) - return strategies + GoogleDefaultCredentials{}, + } + + // If an auth type is specified, try to configure the credentials for that + // specific auth type. If an error is encountered, return it. + if cfg.AuthType != "" { + for _, s := range strategies { + if s.Name() == cfg.AuthType { + logger.Tracef(ctx, "Attempting to configure auth: %q", s.Name()) + c.name = s.Name() + return s.Configure(ctx, cfg) + } + } + return nil, fmt.Errorf("auth type %q not found, please check %s for a list of supported auth types", cfg.AuthType, authDocURL) + } + + // If no auth type is specified, try the strategies in order. If a strategy + // succeeds, returns the credentials provider. If a strategy fails, swallow + // the error and try the next strategy. + for _, s := range strategies { + logger.Tracef(ctx, "Attempting to configure auth: %q", s.Name()) + cp, err := s.Configure(ctx, cfg) + if err != nil || cp == nil { + logger.Debugf(ctx, "Failed to configure auth: %q", s.Name()) + continue + } + c.name = s.Name() + return cp, nil + } + + return nil, fmt.Errorf("cannot configure default credentials, please check %s to configure credentials for your preferred authentication method", authDocURL) } -type DefaultCredentials struct { - name string +func githubOIDC(cfg *Config) CredentialsStrategy { + return oidcStrategy(cfg, "github-oidc", oidc.NewGithubIDTokenSource( + cfg.refreshClient, + cfg.ActionsIDTokenRequestURL, + cfg.ActionsIDTokenRequestToken, + )) } -func (c *DefaultCredentials) Name() string { - if c.name == "" { - return "default" +func envOIDC(cfg *Config) CredentialsStrategy { + v := cfg.OIDCTokenEnv + if v == "" { + v = "DATABRICKS_OIDC_TOKEN" } - return c.name + return oidcStrategy(cfg, "env-oidc", oidc.NewEnvIDTokenSource(v)) } -var authFlowUrl = "https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication" -var errorMessage = fmt.Sprintf("cannot configure default credentials, please check %s to configure credentials for your preferred authentication method", authFlowUrl) - -// ErrCannotConfigureAuth (experimental) is returned when no auth is configured -var ErrCannotConfigureAuth = errors.New(errorMessage) +func fileOIDC(cfg *Config) CredentialsStrategy { + return oidcStrategy(cfg, "file-oidc", oidc.NewFileTokenSource(cfg.OIDCTokenFilepath)) +} -func (c *DefaultCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { - err := cfg.EnsureResolved() - if err != nil { - return nil, err +// oidcStrategy returns a new CredentialsStrategy to authenticate with +// Databricks using the given OIDC IDTokenSource. +func oidcStrategy(cfg *Config, name string, ts oidc.IDTokenSource) CredentialsStrategy { + oidcConfig := oidc.DatabricksOIDCTokenSourceConfig{ + ClientID: cfg.ClientID, + Host: cfg.CanonicalHostName(), + TokenEndpointProvider: cfg.getOidcEndpoints, + Audience: cfg.TokenAudience, + IDTokenSource: ts, } - for _, p := range buildDefaultStrategies(cfg) { - if cfg.AuthType != "" && p.Name() != cfg.AuthType { - // ignore other auth types if one is explicitly enforced - logger.Infof(ctx, "Ignoring %s auth, because %s is preferred", p.Name(), cfg.AuthType) - continue - } - logger.Tracef(ctx, "Attempting to configure auth: %s", p.Name()) - credentialsProvider, err := p.Configure(ctx, cfg) - if err != nil { - return nil, fmt.Errorf("%s: %w", p.Name(), err) - } - if credentialsProvider == nil { - continue - } - c.name = p.Name() - return credentialsProvider, nil + if cfg.IsAccountClient() { + oidcConfig.AccountID = cfg.AccountID } - return nil, ErrCannotConfigureAuth + tokenSource := oidc.NewDatabricksOIDCTokenSource(oidcConfig) + return NewTokenSourceStrategy(name, tokenSource) } diff --git a/config/auth_default_test.go b/config/auth_default_test.go index 78642c200..fbcb67116 100644 --- a/config/auth_default_test.go +++ b/config/auth_default_test.go @@ -1,19 +1,49 @@ -package config_test +package config import ( "context" - "errors" + "strings" "testing" - - "github.com/databricks/databricks-sdk-go" - "github.com/databricks/databricks-sdk-go/config" - "github.com/databricks/databricks-sdk-go/internal/env" - "github.com/stretchr/testify/assert" ) -func TestErrCannotConfigureAuth(t *testing.T) { - env.CleanupEnvironment(t) - w := databricks.Must(databricks.NewWorkspaceClient()) - _, err := w.CurrentUser.Me(context.Background()) - assert.True(t, errors.Is(err, config.ErrCannotConfigureAuth)) +func TestDefaultCredentials_Configure(t *testing.T) { + testCases := []struct { + desc string + authType string + wantErr string + }{ + { + desc: "unknown auth type", + authType: "unknown-auth-type-1337", + wantErr: "auth type \"unknown-auth-type-1337\" not found", + }, + { + desc: "not valid auth", + authType: "", + wantErr: "cannot configure default credentials", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + ctx := context.Background() + cfg := &Config{ + AuthType: tc.authType, + resolved: true, // avoid calling EnsureResolved + } + + dc := DefaultCredentials{} + got, gotErr := dc.Configure(ctx, cfg) + + if got != nil { + t.Errorf("DefaultCredentials.Configure: got %v, want nil", got) + } + if gotErr == nil { + t.Errorf("DefaultCredentials.Configure: got error %v, want non-nil", gotErr) + } + if !strings.Contains(gotErr.Error(), tc.wantErr) { + t.Errorf("DefaultCredentials.Configure: got error %v, want error containing %q", gotErr, tc.wantErr) + } + }) + } } diff --git a/config/auth_m2m_test.go b/config/auth_m2m_test.go index 1d77a126a..4284cbe62 100644 --- a/config/auth_m2m_test.go +++ b/config/auth_m2m_test.go @@ -74,6 +74,7 @@ func TestM2mNotSupported(t *testing.T) { Host: "a", ClientID: "b", ClientSecret: "c", + AuthType: "oauth-m2m", HTTPTransport: fixtures.MappingTransport{ "GET /oidc/.well-known/oauth-authorization-server": { Status: 404, diff --git a/config/auth_metadata_service.go b/config/auth_metadata_service.go index beae500c4..3718b4a93 100644 --- a/config/auth_metadata_service.go +++ b/config/auth_metadata_service.go @@ -66,15 +66,15 @@ func (c MetadataServiceCredentials) Configure(ctx context.Context, cfg *Config) metadataServiceURL: parsedMetadataServiceURL, config: cfg, } - response, err := ms.Get() + response, err := ms.Get(ctx) if err != nil { return nil, err } if response == nil { return nil, nil } - visitor := refreshableVisitor(&ms) - return credentials.NewCredentialsProvider(visitor), nil + + return credentials.NewOAuthCredentialsProviderFromTokenSource(ms), nil } type metadataService struct { @@ -83,8 +83,8 @@ type metadataService struct { } // performs a request to the metadata service and returns the token -func (s metadataService) Get() (*oauth2.Token, error) { - ctx, cancel := context.WithTimeout(context.Background(), metadataServiceTimeout) +func (s metadataService) Get(ctx context.Context) (*oauth2.Token, error) { + ctx, cancel := context.WithTimeout(ctx, metadataServiceTimeout) defer cancel() var inner msiToken err := s.config.refreshClient.Do(ctx, http.MethodGet, @@ -99,15 +99,15 @@ func (s metadataService) Get() (*oauth2.Token, error) { return inner.Token() } -func (t metadataService) Token() (*oauth2.Token, error) { - token, err := t.Get() +func (t metadataService) Token(ctx context.Context) (*oauth2.Token, error) { + token, err := t.Get(ctx) if err != nil { return nil, err } if token == nil { return nil, fmt.Errorf("no token returned from metadata service") } - logger.Debugf(context.Background(), + logger.Debugf(ctx, "Refreshed access token from local metadata service, which expires on %s", token.Expiry.Format(time.RFC3339)) return token, nil diff --git a/config/auth_metadata_service_test.go b/config/auth_metadata_service_test.go index 0e0c70844..d6e54931f 100644 --- a/config/auth_metadata_service_test.go +++ b/config/auth_metadata_service_test.go @@ -12,6 +12,7 @@ import ( func TestAuthServerCheckHost(t *testing.T) { assertHeaders(t, &Config{ Host: "YYY", + AuthType: "metadata-service", MetadataServiceURL: "http://localhost:1234/metadata/token", HTTPTransport: fixtures.MappingTransport{ "GET /metadata/token": { @@ -31,6 +32,7 @@ func TestAuthServerCheckHost(t *testing.T) { func TestAuthServerRefresh(t *testing.T) { assertHeaders(t, &Config{ Host: "YYY", + AuthType: "metadata-service", MetadataServiceURL: "http://localhost:1234/metadata/token", HTTPTransport: fixtures.SliceTransport{ { @@ -56,6 +58,7 @@ func TestAuthServerRefresh(t *testing.T) { func TestAuthServerNotLocalhost(t *testing.T) { _, err := authenticateRequest(&Config{ Host: "YYY", + AuthType: "metadata-service", MetadataServiceURL: "http://otherhost/metadata/token", HTTPTransport: fixtures.Failures, }) @@ -65,6 +68,7 @@ func TestAuthServerNotLocalhost(t *testing.T) { func TestAuthServerMalformed(t *testing.T) { _, err := authenticateRequest(&Config{ Host: "YYY", + AuthType: "metadata-service", MetadataServiceURL: "#$%^&*", HTTPTransport: fixtures.Failures, }) @@ -74,6 +78,7 @@ func TestAuthServerMalformed(t *testing.T) { func TestAuthServerNoContent(t *testing.T) { _, err := authenticateRequest(&Config{ Host: "YYY", + AuthType: "metadata-service", MetadataServiceURL: "http://localhost:1234/metadata/token", HTTPTransport: fixtures.MappingTransport{ "GET /metadata/token": { @@ -87,6 +92,7 @@ func TestAuthServerNoContent(t *testing.T) { func TestAuthServerFailures(t *testing.T) { _, err := authenticateRequest(&Config{ Host: "YYY", + AuthType: "metadata-service", MetadataServiceURL: "http://localhost:1234/metadata/token", HTTPTransport: fixtures.Failures, }) diff --git a/config/auth_permutations_test.go b/config/auth_permutations_test.go index 1f0c4ed02..ad43b0eb8 100644 --- a/config/auth_permutations_test.go +++ b/config/auth_permutations_test.go @@ -385,7 +385,7 @@ func TestConfig_AzureCliHost_Fail(t *testing.T) { "HOME": "testdata/azure", "FAIL": "yes", }, - AssertError: "default auth: azure-cli: cannot get access token: This is just a failing script.\n. Config: azure_workspace_resource_id=/sub/rg/ws", + AssertError: fmt.Sprintf("%s. Config: azure_workspace_resource_id=/sub/rg/ws", defaultAuthBaseErrorMessage), }.apply(t) } diff --git a/config/credentials/credentials.go b/config/credentials/credentials.go index 0fe494d5c..580040fe2 100644 --- a/config/credentials/credentials.go +++ b/config/credentials/credentials.go @@ -1,8 +1,11 @@ package credentials import ( + "context" + "fmt" "net/http" + "github.com/databricks/databricks-sdk-go/config/experimental/auth" "golang.org/x/oauth2" ) @@ -32,22 +35,50 @@ type OAuthCredentialsProvider interface { Token() (*oauth2.Token, error) } -type oauthCredentialsProvider struct { - setHeaders func(r *http.Request) error - token func() (*oauth2.Token, error) +// NewOAuthCredentialsProviderFromTokenSource returns a new OAuthCredentialsProvider +// that uses the given TokenSource to get the token. +// +// The returned OAuthCredentialsProvider does not alter the behavior of the token +// source. For example, it does not implement any caching or retrying logic. It +// is the responsibility of the TokenSource to implement these behaviors. +func NewOAuthCredentialsProviderFromTokenSource(ts auth.TokenSource) OAuthCredentialsProvider { + return &tsOAuthCredentialsProvider{ts} } -func (c *oauthCredentialsProvider) SetHeaders(r *http.Request) error { - return c.setHeaders(r) +type tsOAuthCredentialsProvider struct { + ts auth.TokenSource } -func (c *oauthCredentialsProvider) Token() (*oauth2.Token, error) { - return c.token() +func (cp tsOAuthCredentialsProvider) SetHeaders(r *http.Request) error { + token, err := cp.ts.Token(context.Background()) + if err != nil { + return fmt.Errorf("error getting token: %w", err) + } + token.SetAuthHeader(r) + return nil +} + +func (cp tsOAuthCredentialsProvider) Token() (*oauth2.Token, error) { + return cp.ts.Token(context.Background()) } +// DEPRECATED: Use NewOAuthCredentialsProviderFromTokenSource instead. func NewOAuthCredentialsProvider(visitor func(r *http.Request) error, tokenProvider func() (*oauth2.Token, error)) OAuthCredentialsProvider { return &oauthCredentialsProvider{ setHeaders: visitor, token: tokenProvider, } } + +type oauthCredentialsProvider struct { + setHeaders func(r *http.Request) error + token func() (*oauth2.Token, error) +} + +func (c *oauthCredentialsProvider) SetHeaders(r *http.Request) error { + return c.setHeaders(r) +} + +func (c *oauthCredentialsProvider) Token() (*oauth2.Token, error) { + return c.token() +} diff --git a/config/token_source_strategy.go b/config/token_source_strategy.go index 45393ca31..5181a3260 100644 --- a/config/token_source_strategy.go +++ b/config/token_source_strategy.go @@ -2,46 +2,36 @@ package config import ( "context" - "fmt" "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/config/experimental/auth" - "github.com/databricks/databricks-sdk-go/config/experimental/auth/authconv" - "github.com/databricks/databricks-sdk-go/logger" ) // Creates a CredentialsStrategy from a TokenSource. -func NewTokenSourceStrategy( - name string, - tokenSource auth.TokenSource, -) CredentialsStrategy { - return &tokenSourceStrategy{ - name: name, - tokenSource: tokenSource, - } +func NewTokenSourceStrategy(name string, ts auth.TokenSource) CredentialsStrategy { + return &tokenSourceStrategy{name: name, ts: ts} } -// tokenSourceStrategy is wrapper on a auth.TokenSource which converts it into a CredentialsStrategy +// tokenSourceStrategy is wrapper on a auth.TokenSource which converts it into +// a CredentialsStrategy. type tokenSourceStrategy struct { - tokenSource auth.TokenSource - name string + name string + ts auth.TokenSource } // Configure implements [CredentialsStrategy.Configure]. -func (t *tokenSourceStrategy) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { - - // If we cannot get a token, skip this CredentialsStrategy. - // We don't want to fail here because it's possible that the supplier is enabled - // without the user action. For instance, jobs running in GitHub will have - // OIDC environment variables added automatically - cached := auth.NewCachedTokenSource(t.tokenSource) - if _, err := cached.Token(ctx); err != nil { - logger.Debugf(ctx, fmt.Sprintf("Skipping %s due to error: %v", t.name, err)) - return nil, nil +func (tss *tokenSourceStrategy) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { + cp := credentials.NewOAuthCredentialsProviderFromTokenSource(auth.NewCachedTokenSource(tss.ts)) + + // Sanity check that a token can be obtained. + // + // TODO: Move this outside of this function. If credentials providers have + // to be tested, this should be done in the main default loop, not here. + if _, err := cp.Token(); err != nil { + return nil, err } - visitor := refreshableAuthVisitor(cached) - return credentials.NewOAuthCredentialsProvider(visitor, authconv.OAuth2TokenSource(cached).Token), nil + return cp, nil } // Name implements [CredentialsStrategy.Name]. diff --git a/config/token_source_strategy_test.go b/config/token_source_strategy_test.go index 48ecc38c7..713a757bc 100644 --- a/config/token_source_strategy_test.go +++ b/config/token_source_strategy_test.go @@ -6,64 +6,66 @@ import ( "net/http" "testing" + "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/config/experimental/auth" - "github.com/google/go-cmp/cmp" "golang.org/x/oauth2" ) -func TestDatabricksTokenSourceStrategy(t *testing.T) { +// authHeader returns the Authorization header set by the given credentials +// provider. It returns an empty string if the provider is nil. +func authHeader(cp credentials.CredentialsProvider) string { + if cp == nil { + return "" + } + req := &http.Request{Header: http.Header{}} + cp.SetHeaders(req) + return req.Header.Get("Authorization") +} + +func TestTokenSourceStrategy_Configure(t *testing.T) { testCases := []struct { - desc string - token *oauth2.Token - tokenSourceError error - wantHeaders http.Header + desc string + ts auth.TokenSource + wantHeader string + wantError bool }{ { - desc: "token source error skips", - tokenSourceError: errors.New("random error"), + desc: "token source return error", + ts: auth.TokenSourceFn(func(_ context.Context) (*oauth2.Token, error) { + return nil, errors.New("test error") + }), + wantError: true, }, { - desc: "token source error skips", - token: &oauth2.Token{ - AccessToken: "token-123", - }, - wantHeaders: http.Header{"Authorization": {"Bearer token-123"}}, + desc: "token source return token", + ts: auth.TokenSourceFn(func(_ context.Context) (*oauth2.Token, error) { + return &oauth2.Token{ + AccessToken: "token-123", + }, nil + }), + wantHeader: "Bearer token-123", }, } for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - strat := &tokenSourceStrategy{ - name: "github-oidc", - tokenSource: auth.TokenSourceFn(func(_ context.Context) (*oauth2.Token, error) { - return tc.token, tc.tokenSourceError - }), - } - provider, err := strat.Configure(context.Background(), &Config{}) - if tc.tokenSourceError == nil && provider == nil { - t.Error("Provider expected to not be nil, but it is") - } - if tc.tokenSourceError != nil && provider != nil { - t.Error("A failure in the TokenSource should cause the provider to be nil, but it's not") - } - if err != nil { - t.Errorf("Configure() got error %q, want none", err) + start := &tokenSourceStrategy{ + name: "test-strategy", + ts: tc.ts, } - if provider != nil { - req, _ := http.NewRequest("GET", "http://localhost", nil) - - gotErr := provider.SetHeaders(req) - - if gotErr != nil { - t.Errorf("SetHeaders(): got error %q, want none", gotErr) - } - if diff := cmp.Diff(tc.wantHeaders, req.Header); diff != "" { - t.Errorf("Authenticate(): mismatch (-want +got):\n%s", diff) - } + cp, err := start.Configure(context.Background(), &Config{}) + gotHeader := authHeader(cp) + if tc.wantError && err == nil { + t.Errorf("Expected error, but got none") + } + if !tc.wantError && err != nil { + t.Errorf("Expected no error, but got %q", err) + } + if gotHeader != tc.wantHeader { + t.Errorf("Expected header %q, but got %q", tc.wantHeader, gotHeader) } - }) } }