diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 36350246f..3d3adac6e 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -4,6 +4,7 @@ ### New Features and Improvements +* Support user-to-machine authentication in the SDK ([#1108](https://github.com/databricks/databricks-sdk-go/pull/1108)). - Instances of `ApiClient` now share the same connection pool by default ([PR #1190](https://github.com/databricks/databricks-sdk-go/pull/1190)). ### Bug Fixes diff --git a/config/auth_databricks_cli.go b/config/auth_databricks_cli.go deleted file mode 100644 index 1ec798307..000000000 --- a/config/auth_databricks_cli.go +++ /dev/null @@ -1,123 +0,0 @@ -package config - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - - "github.com/databricks/databricks-sdk-go/config/credentials" - "github.com/databricks/databricks-sdk-go/logger" - "golang.org/x/oauth2" -) - -type DatabricksCliCredentials struct { -} - -func (c DatabricksCliCredentials) Name() string { - return "databricks-cli" -} - -func (c DatabricksCliCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { - if cfg.Host == "" { - return nil, nil - } - - ts, err := newDatabricksCliTokenSource(ctx, cfg) - if err != nil { - if errors.Is(err, exec.ErrNotFound) { - logger.Debugf(ctx, "Most likely the Databricks CLI is not installed") - return nil, nil - } - if err == errLegacyDatabricksCli { - logger.Debugf(ctx, "Databricks CLI version <0.100.0 detected") - return nil, nil - } - return nil, err - } - - _, err = ts.Token() - if err != nil { - if strings.Contains(err.Error(), "no configuration file found at") { - // databricks auth token produced this error message between - // v0.207.1 and v0.209.1 - return nil, nil - } - if strings.Contains(err.Error(), "databricks OAuth is not") { - // OAuth is not configured or not supported - return nil, nil - } - return nil, err - } - logger.Debugf(ctx, "Using Databricks CLI authentication with Databricks OAuth tokens") - visitor := refreshableVisitor(ts) - return credentials.NewOAuthCredentialsProvider(visitor, ts.Token), nil -} - -var errLegacyDatabricksCli = errors.New("legacy Databricks CLI detected") - -type databricksCliTokenSource struct { - ctx context.Context - name string - args []string -} - -func newDatabricksCliTokenSource(ctx context.Context, cfg *Config) (*databricksCliTokenSource, error) { - args := []string{"auth", "token", "--host", cfg.Host} - - if cfg.IsAccountClient() { - args = append(args, "--account-id", cfg.AccountID) - } - - databricksCliPath := cfg.DatabricksCliPath - if databricksCliPath == "" { - databricksCliPath = "databricks" - } - - // Resolve absolute path to the Databricks CLI executable. - path, err := exec.LookPath(databricksCliPath) - if err != nil { - return nil, err - } - - // Resolve symlinks in order to figure out executable size. - path, err = filepath.EvalSymlinks(path) - if err != nil { - return nil, err - } - - // Determine executable size as signal to determine old/new Databricks CLI. - stat, err := os.Stat(path) - if err != nil { - return nil, err - } - - // The new Databricks CLI is a single binary with size > 1MB. - // We use the size as a signal to determine which Databricks CLI is installed. - if stat.Size() < (1024 * 1024) { - return nil, errLegacyDatabricksCli - } - - return &databricksCliTokenSource{ctx: ctx, name: path, args: args}, nil -} - -func (ts *databricksCliTokenSource) Token() (*oauth2.Token, error) { - out, err := runCommand(ts.ctx, ts.name, ts.args) - if ee, ok := err.(*exec.ExitError); ok { - return nil, fmt.Errorf("cannot get access token: %s", string(ee.Stderr)) - } - if err != nil { - return nil, fmt.Errorf("cannot get access token: %v", err) - } - var t oauth2.Token - err = json.Unmarshal(out, &t) - if err != nil { - return nil, fmt.Errorf("cannot unmarshal Databricks CLI result: %w", err) - } - logger.Infof(context.Background(), "Refreshed OAuth token from Databricks CLI, expires on %s", t.Expiry) - return &t, nil -} diff --git a/config/auth_databricks_cli_test.go b/config/auth_databricks_cli_test.go deleted file mode 100644 index 5566d93c0..000000000 --- a/config/auth_databricks_cli_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package config - -import ( - "context" - "os" - "path/filepath" - "testing" - - "github.com/databricks/databricks-sdk-go/internal/env" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var cliDummy = &Config{Host: "https://abc.cloud.databricks.com/"} - -func writeSmallDummyExecutable(t *testing.T, path string) { - f, err := os.Create(filepath.Join(path, "databricks")) - require.NoError(t, err) - defer f.Close() - err = os.Chmod(f.Name(), 0755) - require.NoError(t, err) - _, err = f.WriteString("#!/bin/sh\necho hello world\n") - require.NoError(t, err) -} - -func writeLargeDummyExecutable(t *testing.T, path string) { - f, err := os.Create(filepath.Join(path, "databricks")) - require.NoError(t, err) - defer f.Close() - err = os.Chmod(f.Name(), 0755) - require.NoError(t, err) - _, err = f.WriteString("#!/bin/sh\n") - require.NoError(t, err) - - f.WriteString(` -cat </oidc/accounts/". +func (a BasicAccountOAuthArgument) GetCacheKey() string { + return fmt.Sprintf("%s/oidc/accounts/%s", a.accountHost, a.accountID) +} diff --git a/credentials/u2m/cache/cache.go b/credentials/u2m/cache/cache.go new file mode 100644 index 000000000..6e4137211 --- /dev/null +++ b/credentials/u2m/cache/cache.go @@ -0,0 +1,32 @@ +/* +Package cache provides an interface for storing and looking up OAuth tokens. + +The cache should be primarily used for user-to-machine (U2M) OAuth flows. In U2M +OAuth flows, the application needs to store the token for later use, such as in +a separate process, and the cache provides a way to do so without requiring the +user to follow the OAuth flow again. + +In machine-to-machine (M2M) OAuth flows, the application is configured with a +secret and can fetch a new token on demand without user interaction, so the +token cache is not necessary. +*/ +package cache + +import ( + "errors" + + "golang.org/x/oauth2" +) + +// TokenCache is an interface for storing and looking up OAuth tokens. +type TokenCache interface { + // Store stores the token with the given key, replacing any existing token. + // If t is nil, it deletes the token. + Store(key string, t *oauth2.Token) error + + // Lookup looks up the token with the given key. If the token is not found, it + // returns ErrNotConfigured. + Lookup(key string) (*oauth2.Token, error) +} + +var ErrNotConfigured = errors.New("databricks OAuth is not configured for this host") diff --git a/credentials/u2m/cache/file.go b/credentials/u2m/cache/file.go new file mode 100644 index 000000000..5652f4833 --- /dev/null +++ b/credentials/u2m/cache/file.go @@ -0,0 +1,179 @@ +package cache + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + + "golang.org/x/oauth2" +) + +const ( + // tokenCacheFile is the path of the default token cache, relative to the + // user's home directory. + tokenCacheFilePath = ".databricks/token-cache.json" + + // ownerExecReadWrite is the permission for the .databricks directory. + ownerExecReadWrite = 0o700 + + // ownerReadWrite is the permission for the token-cache.json file. + ownerReadWrite = 0o600 + + // tokenCacheVersion is the version of the token cache file format. + // + // Version 1 format: + // + // { + // "version": 1, + // "tokens": { + // "": { + // "access_token": "", + // "token_type": "", + // "refresh_token": "", + // "expiry": "" + // } + // } + // } + tokenCacheVersion = 1 +) + +// tokenCacheFile is the format of the token cache file. +type tokenCacheFile struct { + Version int `json:"version"` + Tokens map[string]*oauth2.Token `json:"tokens"` +} + +type FileTokenCacheOption func(*fileTokenCache) + +func WithFileLocation(fileLocation string) FileTokenCacheOption { + return func(c *fileTokenCache) { + c.fileLocation = fileLocation + } +} + +// fileTokenCache caches tokens in "~/.databricks/token-cache.json". fileTokenCache +// implements the TokenCache interface. +type fileTokenCache struct { + fileLocation string + + // locker protects the token cache file from concurrent reads and writes. + locker sync.Mutex +} + +// NewFileTokenCache creates a new FileTokenCache. By default, the cache is +// stored in "~/.databricks/token-cache.json". The cache file is created if it +// does not already exist. The cache file is created with owner permissions +// 0600 and the directory is created with owner permissions 0700. If the cache +// file is corrupt or if its version does not match tokenCacheVersion, an error +// is returned. +func NewFileTokenCache(opts ...FileTokenCacheOption) (TokenCache, error) { + c := &fileTokenCache{} + for _, opt := range opts { + opt(c) + } + if err := c.init(); err != nil { + return nil, err + } + // Fail fast if the cache is not working. + if _, err := c.load(); err != nil { + return nil, fmt.Errorf("load: %w", err) + } + return c, nil +} + +// Store implements the TokenCache interface. +func (c *fileTokenCache) Store(key string, t *oauth2.Token) error { + c.locker.Lock() + defer c.locker.Unlock() + f, err := c.load() + if err != nil { + return fmt.Errorf("load: %w", err) + } + if f.Tokens == nil { + f.Tokens = map[string]*oauth2.Token{} + } + if t == nil { + delete(f.Tokens, key) + } else { + f.Tokens[key] = t + } + raw, err := json.MarshalIndent(f, "", " ") + if err != nil { + return fmt.Errorf("marshal: %w", err) + } + return os.WriteFile(c.fileLocation, raw, ownerReadWrite) +} + +// Lookup implements the TokenCache interface. +func (c *fileTokenCache) Lookup(key string) (*oauth2.Token, error) { + c.locker.Lock() + defer c.locker.Unlock() + f, err := c.load() + if err != nil { + return nil, fmt.Errorf("load: %w", err) + } + t, ok := f.Tokens[key] + if !ok { + return nil, ErrNotConfigured + } + return t, nil +} + +// init initializes the token cache file. It creates the file and directory if +// they do not already exist. +func (c *fileTokenCache) init() error { + // set the default file location + if c.fileLocation == "" { + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("failed loading home directory: %w", err) + } + c.fileLocation = filepath.Join(home, tokenCacheFilePath) + } + // Create the cache file if it does not exist. + if _, err := os.Stat(c.fileLocation); err != nil { + if !os.IsNotExist(err) { + return fmt.Errorf("stat file: %w", err) + } + // Create the parent directories if needed. + if err := os.MkdirAll(filepath.Dir(c.fileLocation), ownerExecReadWrite); err != nil { + return fmt.Errorf("mkdir: %w", err) + } + + // Create an empty cache file. + f := &tokenCacheFile{ + Version: tokenCacheVersion, + Tokens: map[string]*oauth2.Token{}, + } + raw, err := json.MarshalIndent(f, "", " ") + if err != nil { + return fmt.Errorf("marshal: %w", err) + } + if err := os.WriteFile(c.fileLocation, raw, ownerReadWrite); err != nil { + return fmt.Errorf("write: %w", err) + } + } + return nil +} + +// load loads the token cache file from disk. If the file is corrupt or if its +// version does not match tokenCacheVersion, it returns an error. +func (c *fileTokenCache) load() (*tokenCacheFile, error) { + raw, err := os.ReadFile(c.fileLocation) + if err != nil { + return nil, fmt.Errorf("read: %w", err) + } + f := &tokenCacheFile{} + if err := json.Unmarshal(raw, &f); err != nil { + return nil, fmt.Errorf("parse: %w", err) + } + if f.Version != tokenCacheVersion { + // in the later iterations we could do state upgraders, + // so that we transform token cache from v1 to v2 without + // losing the tokens and asking the user to re-authenticate. + return nil, fmt.Errorf("needs version %d, got version %d", tokenCacheVersion, f.Version) + } + return f, nil +} diff --git a/credentials/u2m/cache/file_test.go b/credentials/u2m/cache/file_test.go new file mode 100644 index 000000000..ef45820b5 --- /dev/null +++ b/credentials/u2m/cache/file_test.go @@ -0,0 +1,66 @@ +package cache + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func setup(t *testing.T) string { + tempHomeDir := t.TempDir() + return filepath.Join(tempHomeDir, "token-cache.json") +} + +func TestStoreAndLookup(t *testing.T) { + c, err := NewFileTokenCache(WithFileLocation(setup(t))) + require.NoError(t, err) + err = c.Store("x", &oauth2.Token{ + AccessToken: "abc", + }) + require.NoError(t, err) + + err = c.Store("y", &oauth2.Token{ + AccessToken: "bcd", + }) + require.NoError(t, err) + + tok, err := c.Lookup("x") + require.NoError(t, err) + assert.Equal(t, "abc", tok.AccessToken) + + _, err = c.Lookup("z") + assert.Equal(t, ErrNotConfigured, err) +} + +func TestNoCacheFileReturnsErrNotConfigured(t *testing.T) { + l, err := NewFileTokenCache(WithFileLocation(setup(t))) + require.NoError(t, err) + _, err = l.Lookup("x") + assert.Equal(t, ErrNotConfigured, err) +} + +func TestLoadCorruptFile(t *testing.T) { + f := setup(t) + err := os.MkdirAll(filepath.Dir(f), ownerExecReadWrite) + require.NoError(t, err) + err = os.WriteFile(f, []byte("abc"), ownerExecReadWrite) + require.NoError(t, err) + + _, err = NewFileTokenCache(WithFileLocation(f)) + assert.EqualError(t, err, "load: parse: invalid character 'a' looking for beginning of value") +} + +func TestLoadWrongVersion(t *testing.T) { + f := setup(t) + err := os.MkdirAll(filepath.Dir(f), ownerExecReadWrite) + require.NoError(t, err) + err = os.WriteFile(f, []byte(`{"version": 823, "things": []}`), ownerExecReadWrite) + require.NoError(t, err) + + _, err = NewFileTokenCache(WithFileLocation(f)) + assert.EqualError(t, err, "load: needs version 1, got version 823") +} diff --git a/credentials/u2m/callback.go b/credentials/u2m/callback.go new file mode 100644 index 000000000..aab576a04 --- /dev/null +++ b/credentials/u2m/callback.go @@ -0,0 +1,138 @@ +package u2m + +import ( + "context" + _ "embed" + "fmt" + "html/template" + "net/http" + "strings" + + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +//go:embed page.tmpl +var pageTmpl string + +type oauthResult struct { + Error string + ErrorDescription string + State string + Code string + Host string +} + +// callbackServer is a server that listens for the redirect from the Databricks +// identity provider. It renders a page.html template that shows the result of +// the authentication attempt. +type callbackServer struct { + // ctx is the context used when waiting for the redirect from the identity + // provider. This is needed because the Handler() method from the oauth2 + // library does not accept a context. + ctx context.Context + + // srv is the server that listens for the redirect from the identity provider. + srv http.Server + + // browser is a function that opens a browser to the given URL. + browser func(string) error + + // arg is the OAuth argument used to authenticate. + arg OAuthArgument + + // renderErrCh is a channel that receives an error if there is an error + // rendering the page.html template. + renderErrCh chan error + + // feedbackCh is a channel that receives the result of the authentication + // attempt. + feedbackCh chan oauthResult + + // tmpl is the template used to render the response page after the user is + // redirected back to the callback server. + tmpl *template.Template +} + +// newCallbackServer creates a new callback server that listens for the redirect +// from the Databricks identity provider. +func (a *PersistentAuth) newCallbackServer() (*callbackServer, error) { + tmpl, err := template.New("page").Funcs(template.FuncMap{ + "title": func(in string) string { + title := cases.Title(language.English) + return title.String(strings.ReplaceAll(in, "_", " ")) + }, + }).Parse(pageTmpl) + if err != nil { + return nil, err + } + cb := &callbackServer{ + feedbackCh: make(chan oauthResult), + renderErrCh: make(chan error), + tmpl: tmpl, + ctx: a.ctx, + browser: a.browser, + arg: a.oAuthArgument, + } + cb.srv.Handler = cb + go func() { + _ = cb.srv.Serve(a.ln) + }() + return cb, nil +} + +// Close closes the callback server. +func (cb *callbackServer) Close() error { + return cb.srv.Close() +} + +// ServeHTTP renders the page.html template. +func (cb *callbackServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + res := oauthResult{ + Error: r.FormValue("error"), + ErrorDescription: r.FormValue("error_description"), + Code: r.FormValue("code"), + State: r.FormValue("state"), + Host: cb.getHost(), + } + if res.Error != "" { + w.WriteHeader(http.StatusBadRequest) + } else { + w.WriteHeader(http.StatusOK) + } + err := cb.tmpl.Execute(w, res) + if err != nil { + cb.renderErrCh <- err + } + cb.feedbackCh <- res +} + +func (cb *callbackServer) getHost() string { + switch a := cb.arg.(type) { + case AccountOAuthArgument: + return a.GetAccountHost() + case WorkspaceOAuthArgument: + return a.GetWorkspaceHost() + default: + return "" + } +} + +// Handler opens up a browser waits for redirect to come back from the identity provider +func (cb *callbackServer) Handler(authCodeURL string) (string, string, error) { + err := cb.browser(authCodeURL) + if err != nil { + fmt.Printf("Please open %s in the browser to continue authentication", authCodeURL) + } + select { + case <-cb.ctx.Done(): + return "", "", cb.ctx.Err() + case renderErr := <-cb.renderErrCh: + return "", "", renderErr + case res := <-cb.feedbackCh: + if res.Error != "" { + return "", "", fmt.Errorf("%s: %s", res.Error, res.ErrorDescription) + } + return res.Code, res.State, nil + } +} diff --git a/credentials/u2m/doc.go b/credentials/u2m/doc.go new file mode 100644 index 000000000..85835347f --- /dev/null +++ b/credentials/u2m/doc.go @@ -0,0 +1,41 @@ +/* +Package u2m supports the user-to-machine (U2M) OAuth flow for authenticating with Databricks. + +Databricks uses the authorization code flow from OAuth 2.0 to authenticate users. This flow +consists of four steps: + 1. Retrieve an authorization code for a user by opening a browser and directing them to the + Databricks authorization URL. + 2. Exchange the authorization code for an access token. + 3. Use the access token to authenticate with Databricks. + 4. When the access token expires, use the refresh token to get a new access token. + +The token and authorization endpoints for Databricks vary depending on whether the host is +an account- or workspace-level host. Account-level endpoints are fixed based on the account +ID and host, while workspace-level endpoints are discovered using the OIDC discovery endpoint +at /oidc/.well-known/oauth-authorization-server. + +To trigger the authorization flow, construct a PersistentAuth object and call the +Challenge() method: + + auth, err := oauth.NewPersistentAuth(ctx) + if err != nil { + log.Fatalf("failed to create persistent auth: %v", err) + } + token, err := auth.Challenge(ctx, oauth.BasicAccountOAuthArgument{ + AccountHost: "https://accounts.cloud.databricks.com", + AccountID: "xyz", + }) + +Because the U2M flow requires user interaction, the resulting access token and refresh token +can be stored in a persistent cache to avoid prompting the user for credentials on every +authentication attempt. By default, the cache is stored in ~/.databricks/token-cache.json. +Retrieve the cached token by calling the Load() method: + + token, err := auth.Load(ctx, oauth.BasicAccountOAuthArgument{ + AccountHost: "https://accounts.cloud.databricks.com", + AccountID: "xyz", + }) + +See the cache package for more information on customizing the cache. +*/ +package u2m diff --git a/credentials/u2m/endpoint_supplier.go b/credentials/u2m/endpoint_supplier.go new file mode 100644 index 000000000..fb5c48c24 --- /dev/null +++ b/credentials/u2m/endpoint_supplier.go @@ -0,0 +1,60 @@ +package u2m + +import ( + "context" + "errors" + "fmt" + + "github.com/databricks/databricks-sdk-go/httpclient" +) + +// OAuthEndpointSupplier provides the http functionality needed for interacting with the +// Databricks OAuth APIs. +type OAuthEndpointSupplier interface { + // GetWorkspaceOAuthEndpoints returns the OAuth2 endpoints for the workspace. + GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) + + // GetAccountOAuthEndpoints returns the OAuth2 endpoints for the account. + GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*OAuthAuthorizationServer, error) +} + +// BasicOAuthEndpointSupplier is an implementation of the OAuthEndpointSupplier interface. +type BasicOAuthEndpointSupplier struct { + // Client is the ApiClient to use for making HTTP requests. + Client *httpclient.ApiClient +} + +// GetWorkspaceOAuthEndpoints returns the OAuth endpoints for the given workspace. +// It queries the OIDC discovery endpoint to get the OAuth endpoints using the +// provided ApiClient. +func (c *BasicOAuthEndpointSupplier) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) { + oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", workspaceHost) + var oauthEndpoints OAuthAuthorizationServer + if err := c.Client.Do(ctx, "GET", oidc, httpclient.WithResponseUnmarshal(&oauthEndpoints)); err != nil { + return nil, ErrOAuthNotSupported + } + return &oauthEndpoints, nil +} + +// GetAccountOAuthEndpoints returns the OAuth2 endpoints for the account. The +// account-level OAuth endpoints are fixed based on the account ID and host. +func (c *BasicOAuthEndpointSupplier) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*OAuthAuthorizationServer, error) { + return &OAuthAuthorizationServer{ + AuthorizationEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/authorize", accountHost, accountId), + TokenEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/token", accountHost, accountId), + }, nil +} + +var ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host") + +// OAuthAuthorizationServer contains the OAuth endpoints for a Databricks account +// or workspace. +type OAuthAuthorizationServer struct { + // AuthorizationEndpoint is the URL to redirect users to for authorization. + // It typically ends with /v1/authroize. + AuthorizationEndpoint string `json:"authorization_endpoint"` + + // TokenEndpoint is the URL to exchange an authorization code for an access token. + // It typically ends with /v1/token. + TokenEndpoint string `json:"token_endpoint"` +} diff --git a/credentials/u2m/endpoint_supplier_test.go b/credentials/u2m/endpoint_supplier_test.go new file mode 100644 index 000000000..72106e91b --- /dev/null +++ b/credentials/u2m/endpoint_supplier_test.go @@ -0,0 +1,37 @@ +package u2m + +import ( + "context" + "testing" + + "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/databricks/databricks-sdk-go/httpclient/fixtures" + "github.com/stretchr/testify/assert" +) + +func TestBasicOAuthClient_GetAccountOAuthEndpoints(t *testing.T) { + c := &BasicOAuthEndpointSupplier{} + s, err := c.GetAccountOAuthEndpoints(context.Background(), "https://abc", "xyz") + assert.NoError(t, err) + assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", s.AuthorizationEndpoint) + assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/token", s.TokenEndpoint) +} + +func TestGetWorkspaceOAuthEndpoints(t *testing.T) { + p := httpclient.NewApiClient(httpclient.ClientConfig{ + Transport: fixtures.MappingTransport{ + "GET /oidc/.well-known/oauth-authorization-server": { + Status: 200, + Response: map[string]string{ + "authorization_endpoint": "a", + "token_endpoint": "b", + }, + }, + }, + }) + c := &BasicOAuthEndpointSupplier{Client: p} + endpoints, err := c.GetWorkspaceOAuthEndpoints(context.Background(), "https://abc") + assert.NoError(t, err) + assert.Equal(t, "a", endpoints.AuthorizationEndpoint) + assert.Equal(t, "b", endpoints.TokenEndpoint) +} diff --git a/credentials/u2m/error.go b/credentials/u2m/error.go new file mode 100644 index 000000000..be953b0c1 --- /dev/null +++ b/credentials/u2m/error.go @@ -0,0 +1,8 @@ +package u2m + +// InvalidRefreshTokenError is returned from PersistentAuth's Load() method +// if the access token has expired and the refresh token in the token cache +// is invalid. +type InvalidRefreshTokenError struct { + error +} diff --git a/credentials/u2m/oauth_argument.go b/credentials/u2m/oauth_argument.go new file mode 100644 index 000000000..f2d2ebc5d --- /dev/null +++ b/credentials/u2m/oauth_argument.go @@ -0,0 +1,11 @@ +package u2m + +// OAuthArgument is an interface that provides the necessary information to +// authenticate with PersistentAuth. Implementations of this interface must +// implement either the WorkspaceOAuthArgument or AccountOAuthArgument +// interface. +type OAuthArgument interface { + // GetCacheKey returns a unique key for the OAuthArgument. This key is used + // to store and retrieve the token from the token cache. + GetCacheKey() string +} diff --git a/credentials/u2m/page.tmpl b/credentials/u2m/page.tmpl new file mode 100644 index 000000000..1540222db --- /dev/null +++ b/credentials/u2m/page.tmpl @@ -0,0 +1,104 @@ + + + + + {{if .Error }}{{ .Error | title }}{{ else }}Success{{end}} + + + + + + + +
+
+ + +
{{ .Error | title }}
+
{{ .ErrorDescription }}
+ +
Authenticated
+ {{- if .Host }} +
Go to {{.Host}}
+ {{- end}} + +
+ You can close this tab. Or go to documentation +
+
+
+ + diff --git a/credentials/u2m/persistent_auth.go b/credentials/u2m/persistent_auth.go new file mode 100644 index 000000000..7e28a25f5 --- /dev/null +++ b/credentials/u2m/persistent_auth.go @@ -0,0 +1,371 @@ +package u2m + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "time" + + cache "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" + "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/databricks/databricks-sdk-go/logger" + "github.com/databricks/databricks-sdk-go/retries" + "github.com/pkg/browser" + "golang.org/x/oauth2" + "golang.org/x/oauth2/authhandler" +) + +const ( + // appClientId is the default client ID used by the SDK for U2M OAuth. + appClientID = "databricks-cli" + + // appRedirectAddr is the default address for the OAuth2 callback server. + appRedirectAddr = "localhost:8020" + + // listenerTimeout is the maximum amount of time to acquire listener on + // appRedirectAddr. + listenerTimeout = 45 * time.Second +) + +// PersistentAuth is an OAuth manager that handles the U2M OAuth flow. Tokens +// are stored in and looked up from the provided cache. Tokens include the +// refresh token. On load, if the access token is expired, it is refreshed +// using the refresh token. +// +// The PersistentAuth is safe for concurrent use. The token cache is locked +// during token retrieval, refresh and storage. +type PersistentAuth struct { + // cache is the token cache to store and lookup tokens. + cache cache.TokenCache + // client is the HTTP client to use for OAuth2 requests. + client *http.Client + // endpointSupplier is the HTTP endpointSupplier to use for OAuth2 requests. + endpointSupplier OAuthEndpointSupplier + // oAuthArgument defines the workspace or account to authenticate to and the + // cache key for the token. + oAuthArgument OAuthArgument + // browser is the function to open a URL in the default browser. + browser func(url string) error + // ln is the listener for the OAuth2 callback server. + ln net.Listener + // ctx is the context to use for underlying operations. This is needed in + // order to implement the oauth2.TokenSource interface. + ctx context.Context +} + +type PersistentAuthOption func(*PersistentAuth) + +// WithTokenCache sets the token cache for the PersistentAuth. +func WithTokenCache(c cache.TokenCache) PersistentAuthOption { + return func(a *PersistentAuth) { + a.cache = c + } +} + +// WithHttpClient sets the HTTP client for the PersistentAuth. +func WithHttpClient(c *http.Client) PersistentAuthOption { + return func(a *PersistentAuth) { + a.client = c + } +} + +// WithOAuthEndpointSupplier sets the OAuth endpoint supplier for the +// PersistentAuth. +func WithOAuthEndpointSupplier(c OAuthEndpointSupplier) PersistentAuthOption { + return func(a *PersistentAuth) { + a.endpointSupplier = c + } +} + +// WithOAuthArgument sets the OAuthArgument for the PersistentAuth. +func WithOAuthArgument(arg OAuthArgument) PersistentAuthOption { + return func(a *PersistentAuth) { + a.oAuthArgument = arg + } +} + +// WithBrowser sets the browser function for the PersistentAuth. +func WithBrowser(b func(url string) error) PersistentAuthOption { + return func(a *PersistentAuth) { + a.browser = b + } +} + +// NewPersistentAuth creates a new PersistentAuth with the provided options. +func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*PersistentAuth, error) { + p := &PersistentAuth{} + for _, opt := range opts { + opt(p) + } + // By default, PersistentAuth uses the default ApiClient to make HTTP + // requests. Furthermore, if the endpointSupplier is not provided, it uses + // this same client to fetch the OAuth endpoints. If the HTTP client is + // provided but the endpointSupplier is not, we construct a default + // ApiClient for use with BasicOAuthClient. + apiClient := httpclient.NewApiClient(httpclient.ClientConfig{}) + if p.client == nil { + p.client = &http.Client{ + Transport: apiClient, + // 30 seconds matches the default timeout of the ApiClient + Timeout: 30 * time.Second, + } + } + if p.endpointSupplier == nil { + p.endpointSupplier = &BasicOAuthEndpointSupplier{ + Client: apiClient, + } + } + if p.cache == nil { + var err error + p.cache, err = cache.NewFileTokenCache() + if err != nil { + return nil, fmt.Errorf("cache: %w", err) + } + } + if p.oAuthArgument == nil { + return nil, errors.New("missing OAuthArgument") + } + if err := p.validateArg(); err != nil { + return nil, err + } + if p.browser == nil { + p.browser = browser.OpenURL + } + p.ctx = ctx + return p, nil +} + +// Token loads the OAuth2 token for the given OAuthArgument from the cache. If +// the token is expired, it is refreshed using the refresh token. +func (a *PersistentAuth) Token() (t *oauth2.Token, err error) { + err = a.startListener(a.ctx) + if err != nil { + return nil, fmt.Errorf("starting listener: %w", err) + } + defer a.Close() + + key := a.oAuthArgument.GetCacheKey() + t, err = a.cache.Lookup(key) + if err != nil { + return nil, fmt.Errorf("cache: %w", err) + } + // refresh if invalid + if !t.Valid() { + t, err = a.refresh(t) + if err != nil { + return nil, fmt.Errorf("token refresh: %w", err) + } + } + // do not print refresh token to end-user + t.RefreshToken = "" + return t, nil +} + +// refresh refreshes the token for the given OAuthArgument, storing the new +// token in the cache. +func (a *PersistentAuth) refresh(oldToken *oauth2.Token) (*oauth2.Token, error) { + // OAuth2 config is invoked only for expired tokens to speed up + // the happy path in the token retrieval + cfg, err := a.oauth2Config() + if err != nil { + return nil, err + } + // make OAuth2 library use our client + ctx := a.setOAuthContext(a.ctx) + // eagerly refresh token + t, err := cfg.TokenSource(ctx, oldToken).Token() + if err != nil { + // The default RoundTripper of our httpclient.ApiClient returns an error + // if the response status code is not 2xx. This isn't compliant with the + // RoundTripper interface, so this error isn't handled by the oauth2 + // library. We need to handle it here. + var internalHttpError *httpclient.HttpError + if errors.As(err, &internalHttpError) { + // error fields + // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 + var errResponse struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + if unmarshalErr := json.Unmarshal([]byte(internalHttpError.Message), &errResponse); unmarshalErr != nil { + return nil, fmt.Errorf("unmarshal: %w", unmarshalErr) + } + // Invalid refresh tokens get their own error type so they can be + // better presented to users. + if errResponse.ErrorDescription == "Refresh token is invalid" { + return nil, &InvalidRefreshTokenError{err} + } + return nil, fmt.Errorf("%s (error code: %s)", errResponse.ErrorDescription, errResponse.Error) + } + + // Handle responses from well-behaved *http.Client implementations. + var httpErr *oauth2.RetrieveError + if errors.As(err, &httpErr) { + // Invalid refresh tokens get their own error type so they can be + // better presented to users. + if httpErr.ErrorDescription == "Refresh token is invalid" { + return nil, &InvalidRefreshTokenError{err} + } + return nil, fmt.Errorf("%s (error code: %s)", httpErr.ErrorDescription, httpErr.ErrorCode) + } + return nil, err + } + err = a.cache.Store(a.oAuthArgument.GetCacheKey(), t) + if err != nil { + return nil, fmt.Errorf("cache update: %w", err) + } + return t, nil +} + +// Challenge initiates the OAuth2 login flow for the given OAuthArgument. The +// OAuth2 flow is started by opening the browser to the OAuth2 authorization +// URL. The user is redirected to the callback server on appRedirectAddr. The +// callback server listens for the redirect from the identity provider and +// exchanges the authorization code for an access token. +func (a *PersistentAuth) Challenge() error { + err := a.startListener(a.ctx) + if err != nil { + return fmt.Errorf("starting listener: %w", err) + } + // The listener will be closed by the callback server automatically, but if + // the callback server is not created, we need to close the listener manually. + defer a.Close() + + cfg, err := a.oauth2Config() + if err != nil { + return fmt.Errorf("fetching oauth config: %w", err) + } + cb, err := a.newCallbackServer() + if err != nil { + return fmt.Errorf("callback server: %w", err) + } + defer cb.Close() + + state, pkce, err := a.stateAndPKCE() + if err != nil { + return fmt.Errorf("state and pkce: %w", err) + } + // make OAuth2 library use our client + ctx := a.setOAuthContext(a.ctx) + ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce) + t, err := ts.Token() + if err != nil { + return fmt.Errorf("authorize: %w", err) + } + // cache token identified by host (and possibly the account id) + err = a.cache.Store(a.oAuthArgument.GetCacheKey(), t) + if err != nil { + return fmt.Errorf("store: %w", err) + } + return nil +} + +// startListener starts a listener on appRedirectAddr, retrying if the address +// is already in use. +func (a *PersistentAuth) startListener(ctx context.Context) error { + listener, err := retries.Poll(ctx, listenerTimeout, + func() (*net.Listener, *retries.Err) { + var lc net.ListenConfig + l, err := lc.Listen(ctx, "tcp", appRedirectAddr) + if err != nil { + logger.Debugf(ctx, "failed to listen on %s: %v, retrying", appRedirectAddr, err) + return nil, retries.Continue(err) + } + return &l, nil + }) + if err != nil { + return fmt.Errorf("listener: %w", err) + } + a.ln = *listener + return nil +} + +func (a *PersistentAuth) Close() error { + if a.ln == nil { + return nil + } + return a.ln.Close() +} + +// validateArg ensures that the OAuthArgument is either a WorkspaceOAuthArgument +// or an AccountOAuthArgument. +func (a *PersistentAuth) validateArg() error { + _, isWorkspaceArg := a.oAuthArgument.(WorkspaceOAuthArgument) + _, isAccountArg := a.oAuthArgument.(AccountOAuthArgument) + if !isWorkspaceArg && !isAccountArg { + return fmt.Errorf("unsupported OAuthArgument type: %T, must implement either WorkspaceOAuthArgument or AccountOAuthArgument interface", a.oAuthArgument) + } + return nil +} + +// oauth2Config returns the OAuth2 configuration for the given OAuthArgument. +func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) { + scopes := []string{ + "offline_access", // ensures OAuth token includes refresh token + "all-apis", // ensures OAuth token has access to all control-plane APIs + } + var endpoints *OAuthAuthorizationServer + var err error + switch argg := a.oAuthArgument.(type) { + case WorkspaceOAuthArgument: + endpoints, err = a.endpointSupplier.GetWorkspaceOAuthEndpoints(a.ctx, argg.GetWorkspaceHost()) + case AccountOAuthArgument: + endpoints, err = a.endpointSupplier.GetAccountOAuthEndpoints( + a.ctx, argg.GetAccountHost(), argg.GetAccountId()) + default: + return nil, fmt.Errorf("unsupported OAuthArgument type: %T, must implement either WorkspaceOAuthArgument or AccountOAuthArgument interface", a.oAuthArgument) + } + if err != nil { + return nil, fmt.Errorf("fetching OAuth endpoints: %w", err) + } + return &oauth2.Config{ + ClientID: appClientID, + Endpoint: oauth2.Endpoint{ + AuthURL: endpoints.AuthorizationEndpoint, + TokenURL: endpoints.TokenEndpoint, + AuthStyle: oauth2.AuthStyleInParams, + }, + RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr), + Scopes: scopes, + }, nil +} + +func (a *PersistentAuth) stateAndPKCE() (string, *authhandler.PKCEParams, error) { + verifier, err := a.randomString(64) + if err != nil { + return "", nil, fmt.Errorf("verifier: %w", err) + } + verifierSha256 := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(verifierSha256[:]) + state, err := a.randomString(16) + if err != nil { + return "", nil, fmt.Errorf("state: %w", err) + } + return state, &authhandler.PKCEParams{ + Challenge: challenge, + ChallengeMethod: "S256", + Verifier: verifier, + }, nil +} + +func (a *PersistentAuth) randomString(size int) (string, error) { + raw := make([]byte, size) + // ignore error as rand.Reader never returns an error + _, err := rand.Read(raw) + if err != nil { + return "", fmt.Errorf("rand.Read: %w", err) + } + return base64.RawURLEncoding.EncodeToString(raw), nil +} + +func (a *PersistentAuth) setOAuthContext(ctx context.Context) context.Context { + return context.WithValue(ctx, oauth2.HTTPClient, a.client) +} + +var _ oauth2.TokenSource = (*PersistentAuth)(nil) diff --git a/credentials/u2m/persistent_auth_test.go b/credentials/u2m/persistent_auth_test.go new file mode 100644 index 000000000..5aac4ecda --- /dev/null +++ b/credentials/u2m/persistent_auth_test.go @@ -0,0 +1,296 @@ +package u2m_test + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "testing" + "time" + + "github.com/databricks/databricks-sdk-go/credentials/u2m" + "github.com/databricks/databricks-sdk-go/httpclient/fixtures" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +type tokenCacheMock struct { + store func(key string, t *oauth2.Token) error + lookup func(key string) (*oauth2.Token, error) +} + +func (m *tokenCacheMock) Store(key string, t *oauth2.Token) error { + if m.store == nil { + panic("no store mock") + } + return m.store(key, t) +} + +func (m *tokenCacheMock) Lookup(key string) (*oauth2.Token, error) { + if m.lookup == nil { + panic("no lookup mock") + } + return m.lookup(key) +} + +func TestToken(t *testing.T) { + cache := &tokenCacheMock{ + lookup: func(key string) (*oauth2.Token, error) { + assert.Equal(t, "https://abc/oidc/accounts/xyz", key) + return &oauth2.Token{ + AccessToken: "bcd", + Expiry: time.Now().Add(1 * time.Minute), + }, nil + }, + } + arg, err := u2m.NewBasicAccountOAuthArgument("https://abc", "xyz") + assert.NoError(t, err) + p, err := u2m.NewPersistentAuth(context.Background(), u2m.WithTokenCache(cache), u2m.WithOAuthArgument(arg)) + require.NoError(t, err) + defer p.Close() + tok, err := p.Token() + assert.NoError(t, err) + assert.Equal(t, "bcd", tok.AccessToken) + assert.Equal(t, "", tok.RefreshToken) +} + +type MockOAuthEndpointSupplier struct{} + +func (m MockOAuthEndpointSupplier) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ + AuthorizationEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/authorize", accountHost, accountId), + TokenEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/token", accountHost, accountId), + }, nil +} + +func (m MockOAuthEndpointSupplier) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ + AuthorizationEndpoint: fmt.Sprintf("%s/oidc/v1/authorize", workspaceHost), + TokenEndpoint: fmt.Sprintf("%s/oidc/v1/token", workspaceHost), + }, nil +} + +func TestToken_RefreshesExpiredAccessToken(t *testing.T) { + ctx := context.Background() + expectedKey := "https://accounts.cloud.databricks.com/oidc/accounts/xyz" + cache := &tokenCacheMock{ + lookup: func(key string) (*oauth2.Token, error) { + assert.Equal(t, expectedKey, key) + return &oauth2.Token{ + AccessToken: "expired", + RefreshToken: "cde", + Expiry: time.Now().Add(-1 * time.Minute), + }, nil + }, + store: func(key string, tok *oauth2.Token) error { + assert.Equal(t, expectedKey, key) + assert.Equal(t, "def", tok.RefreshToken) + return nil + }, + } + arg, err := u2m.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + assert.NoError(t, err) + p, err := u2m.NewPersistentAuth( + ctx, + u2m.WithTokenCache(cache), + u2m.WithHttpClient(&http.Client{ + Transport: fixtures.SliceTransport{ + { + Method: "POST", + Resource: "/oidc/accounts/xyz/v1/token", + Response: `access_token=refreshed&refresh_token=def`, + ResponseHeaders: map[string][]string{ + "Content-Type": {"application/x-www-form-urlencoded"}, + }, + }, + }, + }), + u2m.WithOAuthEndpointSupplier(MockOAuthEndpointSupplier{}), + u2m.WithOAuthArgument(arg), + ) + require.NoError(t, err) + defer p.Close() + tok, err := p.Token() + assert.NoError(t, err) + assert.Equal(t, "refreshed", tok.AccessToken) + assert.Equal(t, "", tok.RefreshToken) +} + +func TestToken_ReturnsError(t *testing.T) { + ctx := context.Background() + cache := &tokenCacheMock{ + lookup: func(key string) (*oauth2.Token, error) { + assert.Equal(t, "https://accounts.cloud.databricks.com/oidc/accounts/xyz", key) + return &oauth2.Token{ + AccessToken: "expired", + RefreshToken: "cde", + Expiry: time.Now().Add(-1 * time.Minute), + }, nil + }, + } + arg, err := u2m.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + assert.NoError(t, err) + p, err := u2m.NewPersistentAuth( + ctx, + u2m.WithTokenCache(cache), + u2m.WithHttpClient(&http.Client{ + Transport: fixtures.SliceTransport{ + { + Method: "POST", + Resource: "/oidc/accounts/xyz/v1/token", + Response: `{"error": "invalid_grant", "error_description": "Invalid Client"}`, + Status: 401, + }, + }, + }), + u2m.WithOAuthEndpointSupplier(MockOAuthEndpointSupplier{}), + u2m.WithOAuthArgument(arg), + ) + require.NoError(t, err) + defer p.Close() + tok, err := p.Token() + assert.Nil(t, tok) + assert.ErrorContains(t, err, "Invalid Client (error code: invalid_grant)") +} + +func TestToken_ReturnsInvalidRefreshTokenError(t *testing.T) { + ctx := context.Background() + cache := &tokenCacheMock{ + lookup: func(key string) (*oauth2.Token, error) { + assert.Equal(t, "https://accounts.cloud.databricks.com/oidc/accounts/xyz", key) + return &oauth2.Token{ + AccessToken: "expired", + RefreshToken: "cde", + Expiry: time.Now().Add(-1 * time.Minute), + }, nil + }, + } + arg, err := u2m.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + assert.NoError(t, err) + p, err := u2m.NewPersistentAuth( + ctx, + u2m.WithTokenCache(cache), + u2m.WithHttpClient(&http.Client{ + Transport: fixtures.SliceTransport{ + { + Method: "POST", + Resource: "/oidc/accounts/xyz/v1/token", + Response: `{"error": "invalid_grant", "error_description": "Refresh token is invalid"}`, + Status: 401, + }, + }, + }), + u2m.WithOAuthEndpointSupplier(MockOAuthEndpointSupplier{}), + u2m.WithOAuthArgument(arg), + ) + require.NoError(t, err) + defer p.Close() + tok, err := p.Token() + assert.Nil(t, tok) + target := &u2m.InvalidRefreshTokenError{} + assert.True(t, errors.As(err, &target)) +} + +func TestChallenge(t *testing.T) { + ctx := context.Background() + + browserOpened := make(chan string) + browser := func(redirect string) error { + u, err := url.ParseRequestURI(redirect) + if err != nil { + return err + } + assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path) + // for now we're ignoring asserting the fields of the redirect + query := u.Query() + browserOpened <- query.Get("state") + return nil + } + cache := &tokenCacheMock{ + store: func(key string, tok *oauth2.Token) error { + assert.Equal(t, "https://accounts.cloud.databricks.com/oidc/accounts/xyz", key) + assert.Equal(t, "__THAT__", tok.AccessToken) + assert.Equal(t, "__SOMETHING__", tok.RefreshToken) + return nil + }, + } + arg, err := u2m.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + assert.NoError(t, err) + p, err := u2m.NewPersistentAuth( + ctx, + u2m.WithTokenCache(cache), + u2m.WithBrowser(browser), + u2m.WithHttpClient(&http.Client{ + Transport: fixtures.SliceTransport{ + { + Method: "POST", + Resource: "/oidc/accounts/xyz/v1/token", + Response: `access_token=__THAT__&refresh_token=__SOMETHING__`, + ResponseHeaders: map[string][]string{ + "Content-Type": {"application/x-www-form-urlencoded"}, + }, + }, + }, + }), + u2m.WithOAuthEndpointSupplier(MockOAuthEndpointSupplier{}), + u2m.WithOAuthArgument(arg), + ) + require.NoError(t, err) + defer p.Close() + + errc := make(chan error) + go func() { + err := p.Challenge() + errc <- err + close(errc) + }() + + state := <-browserOpened + resp, err := http.Get(fmt.Sprintf("http://localhost:8020?code=__THIS__&state=%s", state)) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, 200, resp.StatusCode) + + err = <-errc + assert.NoError(t, err) +} + +func TestChallenge_ReturnsErrorOnFailure(t *testing.T) { + ctx := context.Background() + browserOpened := make(chan string) + browser := func(redirect string) error { + u, err := url.ParseRequestURI(redirect) + if err != nil { + return err + } + assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path) + // for now we're ignoring asserting the fields of the redirect + query := u.Query() + browserOpened <- query.Get("state") + return nil + } + arg, err := u2m.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + assert.NoError(t, err) + p, err := u2m.NewPersistentAuth(ctx, u2m.WithBrowser(browser), u2m.WithOAuthArgument(arg)) + require.NoError(t, err) + defer p.Close() + + errc := make(chan error) + go func() { + err := p.Challenge() + errc <- err + close(errc) + }() + + <-browserOpened + resp, err := http.Get( + "http://localhost:8020?error=access_denied&error_description=Policy%20evaluation%20failed%20for%20this%20request") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, 400, resp.StatusCode) + + err = <-errc + assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request") +} diff --git a/credentials/u2m/workspace_oauth_argument.go b/credentials/u2m/workspace_oauth_argument.go new file mode 100644 index 000000000..99a8a21e5 --- /dev/null +++ b/credentials/u2m/workspace_oauth_argument.go @@ -0,0 +1,51 @@ +package u2m + +import ( + "fmt" + "strings" +) + +// WorkspaceOAuthArgument is an interface that provides the necessary information +// to authenticate using OAuth to a specific workspace. +type WorkspaceOAuthArgument interface { + OAuthArgument + + // GetWorkspaceHost returns the host of the workspace to authenticate to. + GetWorkspaceHost() string +} + +// BasicWorkspaceOAuthArgument is a basic implementation of the WorkspaceOAuthArgument +// interface that links each host with exactly one OAuth token. +type BasicWorkspaceOAuthArgument struct { + // host is the host of the workspace to authenticate to. This must start + // with "https://" and must not have a trailing slash. + host string +} + +// NewBasicWorkspaceOAuthArgument creates a new BasicWorkspaceOAuthArgument. +func NewBasicWorkspaceOAuthArgument(host string) (BasicWorkspaceOAuthArgument, error) { + if !strings.HasPrefix(host, "https://") { + return BasicWorkspaceOAuthArgument{}, fmt.Errorf("host must start with 'https://': %s", host) + } + if strings.HasSuffix(host, "/") { + return BasicWorkspaceOAuthArgument{}, fmt.Errorf("host must not have a trailing slash: %s", host) + } + return BasicWorkspaceOAuthArgument{host: host}, nil +} + +// GetWorkspaceHost returns the host of the workspace to authenticate to. +func (a BasicWorkspaceOAuthArgument) GetWorkspaceHost() string { + return a.host +} + +// GetCacheKey returns a unique key for caching the OAuth token for the workspace. +// The key is in the format "". +func (a BasicWorkspaceOAuthArgument) GetCacheKey() string { + a.host = strings.TrimSuffix(a.host, "/") + if !strings.HasPrefix(a.host, "http") { + a.host = fmt.Sprintf("https://%s", a.host) + } + return a.host +} + +var _ WorkspaceOAuthArgument = BasicWorkspaceOAuthArgument{} diff --git a/go.mod b/go.mod index 0f7fbce9f..9dabdc5ba 100644 --- a/go.mod +++ b/go.mod @@ -6,11 +6,13 @@ require ( github.com/google/go-cmp v0.6.0 github.com/google/go-querystring v1.1.0 github.com/google/uuid v1.6.0 + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c github.com/stretchr/testify v1.9.0 golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 golang.org/x/mod v0.17.0 golang.org/x/net v0.33.0 golang.org/x/oauth2 v0.20.0 + golang.org/x/text v0.21.0 golang.org/x/time v0.5.0 google.golang.org/api v0.182.0 gopkg.in/ini.v1 v1.67.0 @@ -37,7 +39,6 @@ require ( go.opentelemetry.io/otel/trace v1.24.0 // indirect golang.org/x/crypto v0.31.0 // indirect golang.org/x/sys v0.28.0 // indirect - golang.org/x/text v0.21.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240521202816-d264139d666e // indirect google.golang.org/grpc v1.64.1 // indirect google.golang.org/protobuf v1.34.1 // indirect diff --git a/go.sum b/go.sum index 42cd14133..44b258fcc 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,8 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= github.com/googleapis/gax-go/v2 v2.12.4 h1:9gWcmF85Wvq4ryPFvGFaOgPIs1AQX0d0bcbGw4Z96qg= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -113,6 +115,7 @@ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/httpclient/fixtures/fixture.go b/httpclient/fixtures/fixture.go index 32b00fcbd..9a85c3c84 100644 --- a/httpclient/fixtures/fixture.go +++ b/httpclient/fixtures/fixture.go @@ -24,6 +24,7 @@ type HTTPFixture struct { Response any Status int + ResponseHeaders map[string][]string ExpectedRequest any ExpectedHeaders map[string]string PassFile string @@ -106,6 +107,7 @@ func (f HTTPFixture) replyWith(req *http.Request, body string) (*http.Response, StatusCode: f.Status, Status: http.StatusText(f.Status), Body: io.NopCloser(strings.NewReader(body)), + Header: f.ResponseHeaders, }, nil } diff --git a/httpclient/oauth_token.go b/httpclient/oauth_token.go index cb6ad5cc9..96f9a539e 100644 --- a/httpclient/oauth_token.go +++ b/httpclient/oauth_token.go @@ -5,15 +5,14 @@ import ( "net/http" "time" - "github.com/databricks/databricks-sdk-go/config/credentials" "golang.org/x/oauth2" ) const JWTGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" -// GetOAuthTokenRequest is the request to get an OAuth token. It follows the OAuth 2.0 Rich Authorization Requests specification. +// getOAuthTokenRequest is the request to get an OAuth token. It follows the OAuth 2.0 Rich Authorization Requests specification. // https://datatracker.ietf.org/doc/html/rfc9396 -type GetOAuthTokenRequest struct { +type getOAuthTokenRequest struct { // Defines the method used to get the token. GrantType string `url:"grant_type"` // An array of authorization details that the token should be scoped to. This needs to be passed in string format. @@ -22,6 +21,24 @@ type GetOAuthTokenRequest struct { Assertion string `url:"assertion"` } +// oAuthToken represents an OAuth token as defined by the OAuth 2.0 Authorization Framework. +// https://datatracker.ietf.org/doc/html/rfc6749. +// +// The Go SDK maintains its own implementation of OAuth because Go's oauth2 +// library lacks two features that we depend on: +// 1. The ability to use an arbitrary assertion with the JWT grant type. +// 2. The ability to set authorization_details when getting an OAuth token. +type oAuthToken struct { + // The access token issued by the authorization server. This is the token that will be used to authenticate requests. + AccessToken string `json:"access_token"` + // Time in seconds until the token expires. + ExpiresIn int `json:"expires_in"` + // The scope of the token. This is a space-separated list of strings that represent the permissions granted by the token. + Scope string `json:"scope"` + // The type of token that was issued. + TokenType string `json:"token_type"` +} + // Returns a new OAuth token using the provided token. The token must be a JWT token. // The resulting token is scoped to the authorization details provided. // @@ -29,12 +46,12 @@ type GetOAuthTokenRequest struct { // without warning. func (c *ApiClient) GetOAuthToken(ctx context.Context, authDetails string, token *oauth2.Token) (*oauth2.Token, error) { path := "/oidc/v1/token" - data := GetOAuthTokenRequest{ + data := getOAuthTokenRequest{ GrantType: JWTGrantType, AuthorizationDetails: authDetails, Assertion: token.AccessToken, } - var response credentials.OAuthToken + var response oAuthToken opts := []DoOption{ WithUrlEncodedData(data), WithResponseUnmarshal(&response), diff --git a/httpclient/request_test.go b/httpclient/request_test.go index 695875099..59f6e351d 100644 --- a/httpclient/request_test.go +++ b/httpclient/request_test.go @@ -55,7 +55,7 @@ func TestMakeRequestBodyFromReader(t *testing.T) { } func TestUrlEncoding(t *testing.T) { - data := GetOAuthTokenRequest{ + data := getOAuthTokenRequest{ Assertion: "assertion", AuthorizationDetails: "[{\"a\":\"b\"}]", GrantType: "grant",