diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index a80f60d2b..7e31bcafc 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -3,6 +3,11 @@ ## Release v0.65.0 ### New Features and Improvements +* Introduce support for Databricks Workload Identity Federation in GitHub workflows ([1177](https://github.com/databricks/databricks-sdk-go/pull/1177)). + See README.md for instructions. +* [Breaking] Users running their worklows in GitHub Actions, which use Cloud native authentication and also have a `DATABRICKS_CLIENT_ID` and `DATABRICKS_HOST` + environment variables set may see their authentication start failing due to the order in which the SDK tries different authentication methods. + In such case, the `DATABRICKS_AUTH_TYPE` environment variable must be set to match the previously used authentication method. ### Bug Fixes diff --git a/README.md b/README.md index a299507ee..ae8413273 100644 --- a/README.md +++ b/README.md @@ -14,19 +14,35 @@ The Databricks SDK for Go includes functionality to accelerate development with ## Contents -- [Getting started](#getting-started) -- [Authentication](#authentication) -- [Code examples](#code-examples) -- [Long running operations](#long-running-operations) -- [Paginated responses](#paginated-responses) -- [GetByName utility methods](#getbyname-utility-methods) -- [Node type and Databricks Runtime selectors](#node-type-and-databricks-runtime-selectors) -- [Integration with `io` interfaces for DBFS](#integration-with-io-interfaces-for-dbfs) -- [User Agent Request Attribution](#user-agent-request-attribution) -- [Error Handling](#error-handling) -- [Logging](#logging) +- [Databricks SDK for Go](#databricks-sdk-for-go) + - [Contents](#contents) + - [Getting started](#getting-started) + - [Authentication](#authentication) + - [In this section](#in-this-section) + - [Default authentication flow](#default-authentication-flow) + - [Databricks native authentication](#databricks-native-authentication) + - [Azure native authentication](#azure-native-authentication) + - [Google Cloud Platform native authentication](#google-cloud-platform-native-authentication) + - [Overriding `.databrickscfg`](#overriding-databrickscfg) + - [Additional authentication configuration options](#additional-authentication-configuration-options) + - [Custom credentials provider](#custom-credentials-provider) + - [Code examples](#code-examples) + - [Long-running operations](#long-running-operations) + - [In this section](#in-this-section-1) + - [Command execution on clusters](#command-execution-on-clusters) + - [Cluster library management](#cluster-library-management) + - [Advanced usage](#advanced-usage) + - [Paginated responses](#paginated-responses) + - [`GetByName` utility methods](#getbyname-utility-methods) + - [Node type and Databricks Runtime selectors](#node-type-and-databricks-runtime-selectors) + - [Integration with `io` interfaces for DBFS](#integration-with-io-interfaces-for-dbfs) + - [Reading into and writing from buffers](#reading-into-and-writing-from-buffers) + - [`pflag.Value` for enums](#pflagvalue-for-enums) + - [User Agent Request Attribution](#user-agent-request-attribution) + - [Error handling](#error-handling) + - [Logging](#logging) - [Testing](#testing) -- [Interface stability](#interface-stability) + - [Interface stability](#interface-stability) ## Getting started @@ -158,18 +174,17 @@ Depending on the Databricks authentication method, the SDK uses the following in ### Databricks native authentication -By default, the Databricks SDK for Go initially tries Databricks token authentication (`AuthType: "pat"` in `*databricks.Config`). If the SDK is unsuccessful, it then tries Databricks basic (username/password) authentication (`AuthType: "basic"` in `*databricks.Config`). +By default, the Databricks SDK for Go initially tries Databricks token authentication (`AuthType: "pat"` in `*databricks.Config`). If the SDK is unsuccessful, it then tries Workload Identity Federation (WIF) based authentication(`AuthType: "github-oidc"` in `*databricks.Config`). Currently, only GitHub provided JWT Tokens is supported. - For Databricks token authentication, you must provide `Host` and `Token`; or their environment variable or `.databrickscfg` file field equivalents. -- For Databricks basic authentication, you must provide `Host`, `Username`, and `Password` _(for AWS workspace-level operations)_; or `Host`, `AccountID`, `Username`, and `Password` _(for AWS, Azure, or GCP account-level operations)_; or their environment variable or `.databrickscfg` file field equivalents. +- For Databricks OIDC authentication, you must provide the `Host`, `ClientId` and `TokenAudience` _(optional)_ either directly, through the corresponding environment variables, or in your `.databrickscfg` configuration file. More information can be found in [Databricks Documentation](https://docs.databricks.com/aws/en/dev-tools/auth/oauth-federation#workload-identity-federation) | `*databricks.Config` argument | Description | Environment variable / `.databrickscfg` file field | | ----------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -------------------------------------------------- | | `Host` | _(String)_ The Databricks host URL for either the Databricks workspace endpoint or the Databricks accounts endpoint. | `DATABRICKS_HOST` / `host` | | `AccountID` | _(String)_ The Databricks account ID for the Databricks accounts endpoint. Only has effect when `Host` is either `https://accounts.cloud.databricks.com/` _(AWS)_, `https://accounts.azuredatabricks.net/` _(Azure)_, or `https://accounts.gcp.databricks.com/` _(GCP)_. | `DATABRICKS_ACCOUNT_ID` / `account_id` | | `Token` | _(String)_ The Databricks personal access token (PAT) _(AWS, Azure, and GCP)_ or Azure Active Directory (Azure AD) token _(Azure)_. | `DATABRICKS_TOKEN` / `token` | -| `Username` | _(String)_ The Databricks username part of basic authentication. Only possible when `Host` is `*.cloud.databricks.com` _(AWS)_. | `DATABRICKS_USERNAME` / `username` | -| `Password` | _(String)_ The Databricks password part of basic authentication. Only possible when `Host` is `*.cloud.databricks.com` _(AWS)_. | `DATABRICKS_PASSWORD` / `password` | +| `TokenAudience` | _(String)_ When using Workload Identity Federation, the audience to specify when fetching an ID token from the ID token supplier. | `DATABRICKS_TOKEN_AUDIENCE` / `token_audience` | For example, to use Databricks token authentication: diff --git a/config/auth_azure_github_oidc.go b/config/auth_azure_github_oidc.go index 2f82214f2..7be69563f 100644 --- a/config/auth_azure_github_oidc.go +++ b/config/auth_azure_github_oidc.go @@ -8,7 +8,6 @@ import ( "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/httpclient" - "github.com/databricks/databricks-sdk-go/logger" "golang.org/x/oauth2" ) @@ -24,15 +23,19 @@ func (c AzureGithubOIDCCredentials) Name() string { // Configure implements [CredentialsStrategy.Configure]. func (c AzureGithubOIDCCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { // Sanity check that the config is configured for Azure Databricks. - if !cfg.IsAzure() || cfg.AzureClientID == "" || cfg.Host == "" || cfg.AzureTenantID == "" { + if !cfg.IsAzure() || cfg.AzureClientID == "" || cfg.Host == "" || cfg.AzureTenantID == "" || cfg.ActionsIDTokenRequestURL == "" || cfg.ActionsIDTokenRequestToken == "" { return nil, nil } + supplier := githubIDTokenSource{actionsIDTokenRequestURL: cfg.ActionsIDTokenRequestURL, + actionsIDTokenRequestToken: cfg.ActionsIDTokenRequestToken, + refreshClient: cfg.refreshClient, + } - idToken, err := requestIDToken(ctx, cfg) + idToken, err := supplier.IDToken(ctx, "api://AzureADTokenExchange") if err != nil { return nil, err } - if idToken == "" { + if idToken.Value == "" { return nil, nil } @@ -40,38 +43,13 @@ func (c AzureGithubOIDCCredentials) Configure(ctx context.Context, cfg *Config) aadEndpoint: fmt.Sprintf("%s%s/oauth2/token", cfg.Environment().AzureActiveDirectoryEndpoint(), cfg.AzureTenantID), clientID: cfg.AzureClientID, applicationID: cfg.Environment().AzureApplicationID, - idToken: idToken, + idToken: idToken.Value, httpClient: cfg.refreshClient, } return credentials.NewOAuthCredentialsProvider(refreshableVisitor(ts), ts.Token), nil } -// requestIDToken requests an ID token from the Github Action. -func requestIDToken(ctx context.Context, cfg *Config) (string, error) { - if cfg.ActionsIDTokenRequestURL == "" { - logger.Debugf(ctx, "Missing cfg.ActionsIDTokenRequestURL, likely not calling from a Github action") - return "", nil - } - if cfg.ActionsIDTokenRequestToken == "" { - logger.Debugf(ctx, "Missing cfg.ActionsIDTokenRequestToken, likely not calling from a Github action") - return "", nil - } - - resp := struct { // anonymous struct to parse the response - Value string `json:"value"` - }{} - err := cfg.refreshClient.Do(ctx, "GET", fmt.Sprintf("%s&audience=api://AzureADTokenExchange", cfg.ActionsIDTokenRequestURL), - httpclient.WithRequestHeader("Authorization", fmt.Sprintf("Bearer %s", cfg.ActionsIDTokenRequestToken)), - httpclient.WithResponseUnmarshal(&resp), - ) - if err != nil { - return "", fmt.Errorf("failed to request ID token from %s: %w", cfg.ActionsIDTokenRequestURL, err) - } - - return resp.Value, nil -} - // azureOIDCTokenSource implements [oauth2.TokenSource] to obtain Azure auth // tokens from an ID token. type azureOIDCTokenSource struct { diff --git a/config/auth_databricks_oidc.go b/config/auth_databricks_oidc.go new file mode 100644 index 000000000..434eb49b1 --- /dev/null +++ b/config/auth_databricks_oidc.go @@ -0,0 +1,91 @@ +package config + +import ( + "context" + "errors" + "net/url" + + "github.com/databricks/databricks-sdk-go/config/experimental/auth" + "github.com/databricks/databricks-sdk-go/credentials/u2m" + "github.com/databricks/databricks-sdk-go/logger" + "golang.org/x/oauth2" + "golang.org/x/oauth2/clientcredentials" +) + +// Creates a new Databricks OIDC TokenSource. +func NewDatabricksOIDCTokenSource(cfg DatabricksOIDCTokenSourceConfig) auth.TokenSource { + return &databricksOIDCTokenSource{ + cfg: cfg, + } +} + +// Config for Databricks OIDC TokenSource. +type DatabricksOIDCTokenSourceConfig struct { + // ClientID is the client ID of the Databricks OIDC application. For + // Databricks Service Principal, this is the Application ID of the Service Principal. + ClientID string + // [Optional] AccountID is the account ID of the Databricks Account. + // This is only used for Account level tokens. + AccountID string + // Host is the host of the Databricks account or workspace. + Host string + // TokenEndpointProvider returns the token endpoint for the Databricks OIDC application. + TokenEndpointProvider func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error) + // Audience is the audience of the Databricks OIDC application. + // This is only used for Workspace level tokens. + Audience string + // IdTokenSource returns the IDToken to be used for the token exchange. + IdTokenSource IDTokenSource +} + +// databricksOIDCTokenSource is a auth.TokenSource which exchanges a token using +// Workload Identity Federation. +type databricksOIDCTokenSource struct { + cfg DatabricksOIDCTokenSourceConfig +} + +// Token implements [TokenSource.Token] +func (w *databricksOIDCTokenSource) Token(ctx context.Context) (*oauth2.Token, error) { + if w.cfg.ClientID == "" { + logger.Debugf(ctx, "Missing ClientID") + return nil, errors.New("missing ClientID") + } + if w.cfg.Host == "" { + logger.Debugf(ctx, "Missing Host") + return nil, errors.New("missing Host") + } + endpoints, err := w.cfg.TokenEndpointProvider(ctx) + if err != nil { + return nil, err + } + audience := w.determineAudience(endpoints) + idToken, err := w.cfg.IdTokenSource.IDToken(ctx, audience) + if err != nil { + return nil, err + } + + c := &clientcredentials.Config{ + ClientID: w.cfg.ClientID, + AuthStyle: oauth2.AuthStyleInParams, + TokenURL: endpoints.TokenEndpoint, + Scopes: []string{"all-apis"}, + EndpointParams: url.Values{ + "subject_token_type": {"urn:ietf:params:oauth:token-type:jwt"}, + "subject_token": {idToken.Value}, + "grant_type": {"urn:ietf:params:oauth:grant-type:token-exchange"}, + }, + } + return c.Token(ctx) +} + +func (w *databricksOIDCTokenSource) determineAudience(endpoints *u2m.OAuthAuthorizationServer) string { + if w.cfg.Audience != "" { + return w.cfg.Audience + } + // For Databricks Accounts, the account id is the default audience. + if w.cfg.AccountID != "" { + return w.cfg.AccountID + } + // For Databricks Workspaces, the auth endpoint is the default audience. + return endpoints.TokenEndpoint +} diff --git a/config/auth_databricks_oidc_test.go b/config/auth_databricks_oidc_test.go new file mode 100644 index 000000000..388766e14 --- /dev/null +++ b/config/auth_databricks_oidc_test.go @@ -0,0 +1,298 @@ +package config + +import ( + "context" + "errors" + "net/http" + "net/url" + "testing" + + "github.com/databricks/databricks-sdk-go/credentials/u2m" + "github.com/databricks/databricks-sdk-go/httpclient/fixtures" + "github.com/google/go-cmp/cmp" + "golang.org/x/oauth2" +) + +type mockIdTokenProvider struct { + // input + audience string + // output + idToken string + err error +} + +func (m *mockIdTokenProvider) IDToken(ctx context.Context, audience string) (*IDToken, error) { + m.audience = audience + return &IDToken{Value: m.idToken}, m.err +} + +func TestDatabricksOidcTokenSource(t *testing.T) { + testCases := []struct { + desc string + clientID string + accountID string + host string + tokenAudience string + httpTransport http.RoundTripper + oidcEndpointProvider func(context.Context) (*u2m.OAuthAuthorizationServer, error) + idToken string + expectedAudience string + tokenProviderError error + wantToken string + wantErrPrefix *string + }{ + { + desc: "missing host", + clientID: "client-id", + tokenAudience: "token-audience", + wantErrPrefix: errPrefix("missing Host"), + }, + { + desc: "missing client ID", + host: "http://host.com", + tokenAudience: "token-audience", + wantErrPrefix: errPrefix("missing ClientID"), + }, + { + desc: "token provider error", + + clientID: "client-id", + host: "http://host.com", + tokenAudience: "token-audience", + oidcEndpointProvider: func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ + TokenEndpoint: "https://host.com/oidc/v1/token", + }, nil + }, + expectedAudience: "token-audience", + tokenProviderError: errors.New("error getting id token"), + wantErrPrefix: errPrefix("error getting id token"), + }, + { + desc: "databricks workspace server error", + clientID: "client-id", + host: "http://host.com", + tokenAudience: "token-audience", + oidcEndpointProvider: func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ + TokenEndpoint: "https://host.com/oidc/v1/token", + }, nil + }, + httpTransport: fixtures.MappingTransport{ + "POST /oidc/v1/token": { + Status: http.StatusInternalServerError, + ExpectedHeaders: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + }, + }, + expectedAudience: "token-audience", + idToken: "id-token-42", + wantErrPrefix: errPrefix("oauth2: cannot fetch token: Internal Server Error"), + }, + { + desc: "invalid auth token", + clientID: "client-id", + host: "http://host.com", + tokenAudience: "token-audience", + oidcEndpointProvider: func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ + TokenEndpoint: "https://host.com/oidc/v1/token", + }, nil + }, + httpTransport: fixtures.MappingTransport{ + "POST /oidc/v1/token": { + Status: http.StatusOK, + ExpectedHeaders: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + Response: map[string]string{ + "foo": "bar", + }, + }, + }, + expectedAudience: "token-audience", + idToken: "id-token-42", + wantErrPrefix: errPrefix("oauth2: server response missing access_token"), + }, + { + desc: "success workspace", + clientID: "client-id", + host: "http://host.com", + tokenAudience: "token-audience", + oidcEndpointProvider: func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ + TokenEndpoint: "https://host.com/oidc/v1/token", + }, nil + }, + httpTransport: fixtures.MappingTransport{ + "POST /oidc/v1/token": { + + Status: http.StatusOK, + ExpectedHeaders: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + ExpectedRequest: url.Values{ + "client_id": {"client-id"}, + "scope": {"all-apis"}, + "subject_token_type": {"urn:ietf:params:oauth:token-type:jwt"}, + "subject_token": {"id-token-42"}, + "grant_type": {"urn:ietf:params:oauth:grant-type:token-exchange"}, + }, + Response: map[string]string{ + "token_type": "access-token", + "access_token": "test-auth-token", + "refresh_token": "refresh", + "expires_on": "0", + }, + }, + }, + expectedAudience: "token-audience", + idToken: "id-token-42", + wantToken: "test-auth-token", + }, + { + desc: "success account", + clientID: "client-id", + accountID: "ac123", + host: "https://accounts.databricks.com", + tokenAudience: "token-audience", + oidcEndpointProvider: func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ + TokenEndpoint: "https://host.com/oidc/v1/token", + }, nil + }, + httpTransport: fixtures.MappingTransport{ + "POST /oidc/v1/token": { + Status: http.StatusOK, + ExpectedHeaders: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + ExpectedRequest: url.Values{ + "client_id": {"client-id"}, + "scope": {"all-apis"}, + "subject_token_type": {"urn:ietf:params:oauth:token-type:jwt"}, + "subject_token": {"id-token-42"}, + "grant_type": {"urn:ietf:params:oauth:grant-type:token-exchange"}, + }, + Response: map[string]string{ + "token_type": "access-token", + "access_token": "test-auth-token", + "refresh_token": "refresh", + "expires_on": "0", + }, + }, + }, + expectedAudience: "token-audience", + idToken: "id-token-42", + wantToken: "test-auth-token", + }, + { + desc: "default token audience account", + clientID: "client-id", + accountID: "ac123", + host: "https://accounts.databricks.com", + oidcEndpointProvider: func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ + TokenEndpoint: "https://host.com/oidc/v1/token", + }, nil + }, + httpTransport: fixtures.MappingTransport{ + "POST /oidc/v1/token": { + Status: http.StatusOK, + ExpectedHeaders: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + Response: map[string]string{ + "token_type": "access-token", + "access_token": "test-auth-token", + "refresh_token": "refresh", + "expires_on": "0", + }, + }, + }, + expectedAudience: "ac123", + idToken: "id-token-42", + wantToken: "test-auth-token", + }, + { + desc: "default token audience workspace", + clientID: "client-id", + host: "https://host.com", + oidcEndpointProvider: func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ + TokenEndpoint: "https://host.com/oidc/v1/token", + }, nil + }, + httpTransport: fixtures.MappingTransport{ + "POST /oidc/v1/token": { + Status: http.StatusOK, + ExpectedHeaders: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + Response: map[string]string{ + "token_type": "access-token", + "access_token": "test-auth-token", + "refresh_token": "refresh", + "expires_on": "0", + }, + }, + }, + expectedAudience: "https://host.com/oidc/v1/token", + idToken: "id-token-42", + wantToken: "test-auth-token", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + p := &mockIdTokenProvider{ + idToken: tc.idToken, + err: tc.tokenProviderError, + } + + cfg := DatabricksOIDCTokenSourceConfig{ + ClientID: tc.clientID, + AccountID: tc.accountID, + Host: tc.host, + TokenEndpointProvider: tc.oidcEndpointProvider, + Audience: tc.tokenAudience, + IdTokenSource: p, + } + + ts := NewDatabricksOIDCTokenSource(cfg) + if tc.httpTransport != nil { + ts.(*databricksOIDCTokenSource).cfg.TokenEndpointProvider = func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ + TokenEndpoint: "https://host.com/oidc/v1/token", + }, nil + } + } + + ctx := context.Background() + if tc.httpTransport != nil { + ctx = context.WithValue(ctx, oauth2.HTTPClient, &http.Client{ + Transport: tc.httpTransport, + }) + } + + token, err := ts.Token(ctx) + if tc.wantErrPrefix == nil && err != nil { + t.Errorf("Token(ctx): got error %q, want none", err) + } + if tc.wantErrPrefix != nil && !hasPrefix(err, *tc.wantErrPrefix) { + t.Errorf("Token(ctx): got error %q, want error with prefix %q", err, *tc.wantErrPrefix) + } + if tc.expectedAudience != p.audience { + t.Errorf("mockTokenProvider: got audience %s, want %s", p.audience, tc.expectedAudience) + } + tokenValue := "" + if token != nil { + tokenValue = token.AccessToken + } + if diff := cmp.Diff(tc.wantToken, tokenValue); diff != "" { + t.Errorf("Authenticate(): mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/config/auth_default.go b/config/auth_default.go index 2170f3d2b..a49757d02 100644 --- a/config/auth_default.go +++ b/config/auth_default.go @@ -9,22 +9,60 @@ import ( "github.com/databricks/databricks-sdk-go/logger" ) -var authProviders = []CredentialsStrategy{ - PatCredentials{}, - BasicCredentials{}, - M2mCredentials{}, - DatabricksCliCredentials, - MetadataServiceCredentials{}, - - // Attempt to configure auth from most specific to most generic (the Azure CLI). - AzureGithubOIDCCredentials{}, - AzureMsiCredentials{}, - AzureClientSecretCredentials{}, - AzureCliCredentials{}, +// Constructs all Databricks OIDC Credentials Strategies +func buildOidcTokenCredentialStrategies(cfg *Config) []CredentialsStrategy { + type namedIdTokenSource struct { + name string + tokenSource IDTokenSource + } + idTokenSources := []namedIdTokenSource{ + { + name: "github-oidc", + tokenSource: &githubIDTokenSource{ + actionsIDTokenRequestURL: cfg.ActionsIDTokenRequestURL, + actionsIDTokenRequestToken: cfg.ActionsIDTokenRequestToken, + refreshClient: cfg.refreshClient, + }, + }, + // Add new providers at the end of the list + } + strategies := []CredentialsStrategy{} + for _, idTokenSource := range idTokenSources { + oidcConfig := DatabricksOIDCTokenSourceConfig{ + ClientID: cfg.ClientID, + Host: cfg.CanonicalHostName(), + TokenEndpointProvider: cfg.getOidcEndpoints, + Audience: cfg.TokenAudience, + IdTokenSource: idTokenSource.tokenSource, + } + if cfg.IsAccountClient() { + oidcConfig.AccountID = cfg.AccountID + } + tokenSource := NewDatabricksOIDCTokenSource(oidcConfig) + strategies = append(strategies, NewTokenSourceStrategy(idTokenSource.name, tokenSource)) + } + return strategies +} - // Attempt to configure auth from most specific to most generic (Google Application Default Credentials). - GoogleCredentials{}, - GoogleDefaultCredentials{}, +func buildDefaultStrategies(cfg *Config) []CredentialsStrategy { + strategies := []CredentialsStrategy{} + strategies = append(strategies, + PatCredentials{}, + BasicCredentials{}, + M2mCredentials{}, + DatabricksCliCredentials, + MetadataServiceCredentials{}) + strategies = append(strategies, buildOidcTokenCredentialStrategies(cfg)...) + strategies = append(strategies, + // Attempt to configure auth from most specific to most generic (the Azure CLI). + AzureGithubOIDCCredentials{}, + AzureMsiCredentials{}, + AzureClientSecretCredentials{}, + AzureCliCredentials{}, + // Attempt to configure auth from most specific to most generic (Google Application Default Credentials). + GoogleCredentials{}, + GoogleDefaultCredentials{}) + return strategies } type DefaultCredentials struct { @@ -45,7 +83,11 @@ var errorMessage = fmt.Sprintf("cannot configure default credentials, please che var ErrCannotConfigureAuth = errors.New(errorMessage) func (c *DefaultCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { - for _, p := range authProviders { + err := cfg.EnsureResolved() + if err != nil { + return nil, err + } + for _, p := range buildDefaultStrategies(cfg) { if cfg.AuthType != "" && p.Name() != cfg.AuthType { // ignore other auth types if one is explicitly enforced logger.Infof(ctx, "Ignoring %s auth, because %s is preferred", p.Name(), cfg.AuthType) diff --git a/config/config.go b/config/config.go index 6c14f0a27..6346d6cf8 100644 --- a/config/config.go +++ b/config/config.go @@ -134,6 +134,9 @@ type Config struct { // Environment override to return when resolving the current environment. DatabricksEnvironment *environment.DatabricksEnvironment + // When using Workload Identity Federation, the audience to specify when fetching an ID token from the ID token supplier. + TokenAudience string `name:"audience" env:"DATABRICKS_TOKEN_AUDIENCE" auth:"-"` + Loaders []Loader // marker for configuration resolving diff --git a/config/id_token_source_github_oidc.go b/config/id_token_source_github_oidc.go new file mode 100644 index 000000000..6f4048226 --- /dev/null +++ b/config/id_token_source_github_oidc.go @@ -0,0 +1,45 @@ +package config + +import ( + "context" + "errors" + "fmt" + + "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/databricks/databricks-sdk-go/logger" +) + +// githubIDTokenSource retrieves JWT Tokens from Github Actions. +type githubIDTokenSource struct { + actionsIDTokenRequestURL string + actionsIDTokenRequestToken string + refreshClient *httpclient.ApiClient +} + +// IDToken returns a JWT Token for the specified audience. It will return +// an error if not running in GitHub Actions. +func (g *githubIDTokenSource) IDToken(ctx context.Context, audience string) (*IDToken, error) { + if g.actionsIDTokenRequestURL == "" { + logger.Debugf(ctx, "Missing ActionsIDTokenRequestURL, likely not calling from a Github action") + return nil, errors.New("missing ActionsIDTokenRequestURL") + } + if g.actionsIDTokenRequestToken == "" { + logger.Debugf(ctx, "Missing ActionsIDTokenRequestToken, likely not calling from a Github action") + return nil, errors.New("missing ActionsIDTokenRequestToken") + } + + resp := &IDToken{} + requestUrl := g.actionsIDTokenRequestURL + if audience != "" { + requestUrl = fmt.Sprintf("%s&audience=%s", requestUrl, audience) + } + err := g.refreshClient.Do(ctx, "GET", requestUrl, + httpclient.WithRequestHeader("Authorization", fmt.Sprintf("Bearer %s", g.actionsIDTokenRequestToken)), + httpclient.WithResponseUnmarshal(resp), + ) + if err != nil { + return nil, fmt.Errorf("failed to request ID token from %s: %w", g.actionsIDTokenRequestURL, err) + } + + return resp, nil +} diff --git a/config/id_token_source_github_oidc_test.go b/config/id_token_source_github_oidc_test.go new file mode 100644 index 000000000..58a1bbc2b --- /dev/null +++ b/config/id_token_source_github_oidc_test.go @@ -0,0 +1,91 @@ +package config + +import ( + "context" + "net/http" + "testing" + + "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/databricks/databricks-sdk-go/httpclient/fixtures" + "github.com/google/go-cmp/cmp" +) + +func TestGithubIDTokenSource(t *testing.T) { + testCases := []struct { + desc string + tokenRequestUrl string + tokenRequestToken string + audience string + httpTransport http.RoundTripper + wantToken *IDToken + wantErrPrefix *string + }{ + { + desc: "missing request token url", + tokenRequestToken: "token-1337", + wantErrPrefix: errPrefix("missing ActionsIDTokenRequestURL"), + }, + { + desc: "missing request token token", + tokenRequestUrl: "http://endpoint.com/test?version=1", + wantErrPrefix: errPrefix("missing ActionsIDTokenRequestToken"), + }, + { + desc: "error getting token", + tokenRequestToken: "token-1337", + tokenRequestUrl: "http://endpoint.com/test?version=1", + httpTransport: fixtures.MappingTransport{ + "GET /test?version=1": { + Status: http.StatusInternalServerError, + ExpectedHeaders: map[string]string{ + "Authorization": "Bearer token-1337", + "Accept": "application/json", + }, + }, + }, + wantErrPrefix: errPrefix("failed to request ID token from"), + }, + { + desc: "success", + tokenRequestToken: "token-1337", + tokenRequestUrl: "http://endpoint.com/test?version=1", + httpTransport: fixtures.MappingTransport{ + "GET /test?version=1": { + Status: http.StatusOK, + ExpectedHeaders: map[string]string{ + "Authorization": "Bearer token-1337", + "Accept": "application/json", + }, + Response: `{"value": "id-token-42"}`, + }, + }, + wantToken: &IDToken{ + Value: "id-token-42", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + cli := httpclient.NewApiClient(httpclient.ClientConfig{ + Transport: tc.httpTransport, + }) + p := &githubIDTokenSource{ + actionsIDTokenRequestURL: tc.tokenRequestUrl, + actionsIDTokenRequestToken: tc.tokenRequestToken, + refreshClient: cli, + } + token, gotErr := p.IDToken(context.Background(), tc.audience) + + if tc.wantErrPrefix == nil && gotErr != nil { + t.Errorf("Authenticate(): got error %q, want none", gotErr) + } + if tc.wantErrPrefix != nil && !hasPrefix(gotErr, *tc.wantErrPrefix) { + t.Errorf("Authenticate(): got error %q, want error with prefix %q", gotErr, *tc.wantErrPrefix) + } + if diff := cmp.Diff(tc.wantToken, token); diff != "" { + t.Errorf("Authenticate(): mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/config/oauth_visitors.go b/config/oauth_visitors.go index fc7a3d153..69fadc03f 100644 --- a/config/oauth_visitors.go +++ b/config/oauth_visitors.go @@ -35,7 +35,12 @@ func serviceToServiceVisitor(primary, secondary oauth2.TokenSource, secondaryHea // The same as serviceToServiceVisitor, but without a secondary token source. func refreshableVisitor(inner oauth2.TokenSource) func(r *http.Request) error { - cts := auth.NewCachedTokenSource(authconv.AuthTokenSource(inner)) + return refreshableAuthVisitor(authconv.AuthTokenSource(inner)) +} + +// The same as serviceToServiceVisitor, but without a secondary token source. +func refreshableAuthVisitor(inner auth.TokenSource) func(r *http.Request) error { + cts := auth.NewCachedTokenSource(inner) return func(r *http.Request) error { inner, err := cts.Token(context.Background()) if err != nil { diff --git a/config/token_source_strategy.go b/config/token_source_strategy.go new file mode 100644 index 000000000..fd5d995ce --- /dev/null +++ b/config/token_source_strategy.go @@ -0,0 +1,62 @@ +package config + +import ( + "context" + "fmt" + + "github.com/databricks/databricks-sdk-go/config/credentials" + "github.com/databricks/databricks-sdk-go/config/experimental/auth" + "github.com/databricks/databricks-sdk-go/config/experimental/auth/authconv" + "github.com/databricks/databricks-sdk-go/logger" +) + +// IDToken is a token that can be exchanged for a an access token. +// Value is the token string. +type IDToken struct { + Value string +} + +// IDTokenSource is anything that returns an IDToken given an audience. +type IDTokenSource interface { + // Function to get the token + IDToken(ctx context.Context, audience string) (*IDToken, error) +} + +// Creates a CredentialsStrategy from a TokenSource. +func NewTokenSourceStrategy( + name string, + tokenSource auth.TokenSource, +) CredentialsStrategy { + return &tokenSourceStrategy{ + name: name, + tokenSource: tokenSource, + } +} + +// tokenSourceStrategy is wrapper on a auth.TokenSource which converts it into a CredentialsStrategy +type tokenSourceStrategy struct { + tokenSource auth.TokenSource + name string +} + +// Configure implements [CredentialsStrategy.Configure]. +func (t *tokenSourceStrategy) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { + + // If we cannot get a token, skip this CredentialsStrategy. + // We don't want to fail here because it's possible that the supplier is enabled + // without the user action. For instance, jobs running in GitHub will have + // OIDC environment variables added automatically + cached := auth.NewCachedTokenSource(t.tokenSource) + if _, err := cached.Token(ctx); err != nil { + logger.Debugf(ctx, fmt.Sprintf("Skipping %s due to error: %v", t.name, err)) + return nil, nil + } + + visitor := refreshableAuthVisitor(cached) + return credentials.NewOAuthCredentialsProvider(visitor, authconv.OAuth2TokenSource(cached).Token), nil +} + +// Name implements [CredentialsStrategy.Name]. +func (t *tokenSourceStrategy) Name() string { + return t.name +} diff --git a/config/token_source_strategy_test.go b/config/token_source_strategy_test.go new file mode 100644 index 000000000..48ecc38c7 --- /dev/null +++ b/config/token_source_strategy_test.go @@ -0,0 +1,69 @@ +package config + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/databricks/databricks-sdk-go/config/experimental/auth" + "github.com/google/go-cmp/cmp" + "golang.org/x/oauth2" +) + +func TestDatabricksTokenSourceStrategy(t *testing.T) { + testCases := []struct { + desc string + token *oauth2.Token + tokenSourceError error + wantHeaders http.Header + }{ + { + desc: "token source error skips", + tokenSourceError: errors.New("random error"), + }, + { + desc: "token source error skips", + token: &oauth2.Token{ + AccessToken: "token-123", + }, + wantHeaders: http.Header{"Authorization": {"Bearer token-123"}}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + strat := &tokenSourceStrategy{ + name: "github-oidc", + tokenSource: auth.TokenSourceFn(func(_ context.Context) (*oauth2.Token, error) { + return tc.token, tc.tokenSourceError + }), + } + provider, err := strat.Configure(context.Background(), &Config{}) + if tc.tokenSourceError == nil && provider == nil { + t.Error("Provider expected to not be nil, but it is") + } + if tc.tokenSourceError != nil && provider != nil { + t.Error("A failure in the TokenSource should cause the provider to be nil, but it's not") + } + if err != nil { + t.Errorf("Configure() got error %q, want none", err) + } + + if provider != nil { + req, _ := http.NewRequest("GET", "http://localhost", nil) + + gotErr := provider.SetHeaders(req) + + if gotErr != nil { + t.Errorf("SetHeaders(): got error %q, want none", gotErr) + } + if diff := cmp.Diff(tc.wantHeaders, req.Header); diff != "" { + t.Errorf("Authenticate(): mismatch (-want +got):\n%s", diff) + } + + } + + }) + } +} diff --git a/internal/auth_test.go b/internal/auth_test.go new file mode 100644 index 000000000..ef12d40a5 --- /dev/null +++ b/internal/auth_test.go @@ -0,0 +1,153 @@ +package internal + +import ( + "strconv" + "testing" + + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/service/iam" + "github.com/databricks/databricks-sdk-go/service/oauth2" + "github.com/stretchr/testify/require" +) + +func TestUcAccWifAuth(t *testing.T) { + // This test cannot be run locally. It can only be run from GitHub Workflows. + _ = GetEnvOrSkipTest(t, "ACTIONS_ID_TOKEN_REQUEST_URL") + ctx, a := ucacctTest(t) + + // Create SP with access to the workspace + sp, err := a.ServicePrincipals.Create(ctx, iam.ServicePrincipal{ + Active: true, + DisplayName: RandomName("go-sdk-sp-"), + Roles: []iam.ComplexValue{ + {Value: "account_admin"}, // Assigning account-level admin role + }, + }) + require.NoError(t, err) + t.Cleanup(func() { + err := a.ServicePrincipals.Delete(ctx, iam.DeleteAccountServicePrincipalRequest{Id: sp.Id}) + require.True(t, err == nil || apierr.IsMissing(err)) + }) + + spId, err := strconv.ParseInt(sp.Id, 10, 64) + require.NoError(t, err) + + // Setup Federation Policy + p, err := a.ServicePrincipalFederationPolicy.Create(ctx, oauth2.CreateServicePrincipalFederationPolicyRequest{ + Policy: &oauth2.FederationPolicy{ + OidcPolicy: &oauth2.OidcFederationPolicy{ + Issuer: "https://token.actions.githubusercontent.com", + Audiences: []string{ + "https://github.com/databricks-eng", + }, + Subject: "repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests", + }, + }, + ServicePrincipalId: spId, + }) + + require.NoError(t, err) + t.Cleanup(func() { + err := a.ServicePrincipalFederationPolicy.Delete(ctx, oauth2.DeleteServicePrincipalFederationPolicyRequest{ + ServicePrincipalId: spId, + PolicyId: p.Uid, + }) + require.True(t, err == nil || apierr.IsMissing(err)) + }) + + // Test Workspace Identity Federation at Account Level + + accCfg := &databricks.Config{ + Host: a.Config.Host, + AccountID: a.Config.AccountID, + ClientID: sp.ApplicationId, + AuthType: "github-oidc", + TokenAudience: "https://github.com/databricks-eng", + } + + wifAccClient, err := databricks.NewAccountClient(accCfg) + + require.NoError(t, err) + it := wifAccClient.Groups.List(ctx, iam.ListAccountGroupsRequest{}) + _, err = it.Next(ctx) + require.NoError(t, err) + +} + +func TestUcAccWifAuthWorkspace(t *testing.T) { + // This test cannot be run locally. It can only be run from GitHub Workflows. + _ = GetEnvOrSkipTest(t, "ACTIONS_ID_TOKEN_REQUEST_URL") + ctx, a := ucacctTest(t) + + workspaceIdEnvVar := GetEnvOrSkipTest(t, "TEST_WORKSPACE_ID") + workspaceId, err := strconv.ParseInt(workspaceIdEnvVar, 10, 64) + require.NoError(t, err) + + workspaceUrl := GetEnvOrSkipTest(t, "TEST_WORKSPACE_URL") + + // Create SP with access to the workspace + sp, err := a.ServicePrincipals.Create(ctx, iam.ServicePrincipal{ + Active: true, + DisplayName: RandomName("go-sdk-sp-"), + }) + require.NoError(t, err) + t.Cleanup(func() { + err := a.ServicePrincipals.Delete(ctx, iam.DeleteAccountServicePrincipalRequest{Id: sp.Id}) + require.True(t, err == nil || apierr.IsMissing(err)) + }) + + spId, err := strconv.ParseInt(sp.Id, 10, 64) + require.NoError(t, err) + + _, err = a.WorkspaceAssignment.Update(ctx, iam.UpdateWorkspaceAssignments{ + WorkspaceId: workspaceId, + PrincipalId: spId, + Permissions: []iam.WorkspacePermission{iam.WorkspacePermissionAdmin}, + }) + + require.NoError(t, err) + t.Cleanup(func() { + err := a.WorkspaceAssignment.Delete(ctx, iam.DeleteWorkspaceAssignmentRequest{ + PrincipalId: spId, + WorkspaceId: workspaceId, + }) + require.True(t, err == nil || apierr.IsMissing(err)) + }) + + // Setup Federation Policy + p, err := a.ServicePrincipalFederationPolicy.Create(ctx, oauth2.CreateServicePrincipalFederationPolicyRequest{ + Policy: &oauth2.FederationPolicy{ + OidcPolicy: &oauth2.OidcFederationPolicy{ + Issuer: "https://token.actions.githubusercontent.com", + Audiences: []string{ + "https://github.com/databricks-eng", + }, + Subject: "repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests", + }, + }, + ServicePrincipalId: spId, + }) + + require.NoError(t, err) + t.Cleanup(func() { + err := a.ServicePrincipalFederationPolicy.Delete(ctx, oauth2.DeleteServicePrincipalFederationPolicyRequest{ + ServicePrincipalId: spId, + PolicyId: p.Uid, + }) + require.True(t, err == nil || apierr.IsMissing(err)) + }) + + wsCfg := &databricks.Config{ + Host: workspaceUrl, + ClientID: sp.ApplicationId, + AuthType: "github-oidc", + TokenAudience: "https://github.com/databricks-eng", + } + + wifWsClient, err := databricks.NewWorkspaceClient(wsCfg) + + require.NoError(t, err) + _, err = wifWsClient.CurrentUser.Me(ctx) + require.NoError(t, err) +}