From 4a538e37f557ca33def258f94dae6bef55d4def1 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 3 Jan 2025 10:42:51 +0100 Subject: [PATCH 01/44] work --- config/auth_databricks_cli.go | 123 ----------- config/auth_databricks_cli_test.go | 108 ---------- config/auth_default.go | 2 +- config/auth_m2m.go | 30 +-- config/auth_m2m_test.go | 3 +- config/auth_u2m.go | 49 +++++ credentials/cache/cache.go | 10 + credentials/cache/file.go | 108 ++++++++++ credentials/cache/file_test.go | 105 ++++++++++ credentials/cache/in_memory.go | 26 +++ credentials/cache/in_memory_test.go | 44 ++++ credentials/oauth/callback.go | 106 ++++++++++ credentials/oauth/lock.go | 37 ++++ credentials/oauth/oauth_argument.go | 51 +++++ credentials/oauth/page.tmpl | 102 ++++++++++ credentials/oauth/persistent_auth.go | 237 ++++++++++++++++++++++ credentials/oauth/persistent_auth_test.go | 201 ++++++++++++++++++ credentials/oauth_token.go | 14 -- go.mod | 4 +- go.sum | 8 + httpclient/oauth_token.go | 16 +- httpclient/oidc.go | 35 ++++ httpclient/oidc_test.go | 35 ++++ 23 files changed, 1175 insertions(+), 279 deletions(-) delete mode 100644 config/auth_databricks_cli.go delete mode 100644 config/auth_databricks_cli_test.go create mode 100644 config/auth_u2m.go create mode 100644 credentials/cache/cache.go create mode 100644 credentials/cache/file.go create mode 100644 credentials/cache/file_test.go create mode 100644 credentials/cache/in_memory.go create mode 100644 credentials/cache/in_memory_test.go create mode 100644 credentials/oauth/callback.go create mode 100644 credentials/oauth/lock.go create mode 100644 credentials/oauth/oauth_argument.go create mode 100644 credentials/oauth/page.tmpl create mode 100644 credentials/oauth/persistent_auth.go create mode 100644 credentials/oauth/persistent_auth_test.go delete mode 100644 credentials/oauth_token.go create mode 100644 httpclient/oidc.go create mode 100644 httpclient/oidc_test.go diff --git a/config/auth_databricks_cli.go b/config/auth_databricks_cli.go deleted file mode 100644 index 7f054d2e8..000000000 --- a/config/auth_databricks_cli.go +++ /dev/null @@ -1,123 +0,0 @@ -package config - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - - "github.com/databricks/databricks-sdk-go/credentials" - "github.com/databricks/databricks-sdk-go/logger" - "golang.org/x/oauth2" -) - -type DatabricksCliCredentials struct { -} - -func (c DatabricksCliCredentials) Name() string { - return "databricks-cli" -} - -func (c DatabricksCliCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { - if cfg.Host == "" { - return nil, nil - } - - ts, err := newDatabricksCliTokenSource(ctx, cfg) - if err != nil { - if errors.Is(err, exec.ErrNotFound) { - logger.Debugf(ctx, "Most likely the Databricks CLI is not installed") - return nil, nil - } - if err == errLegacyDatabricksCli { - logger.Debugf(ctx, "Databricks CLI version <0.100.0 detected") - return nil, nil - } - return nil, err - } - - _, err = ts.Token() - if err != nil { - if strings.Contains(err.Error(), "no configuration file found at") { - // databricks auth token produced this error message between - // v0.207.1 and v0.209.1 - return nil, nil - } - if strings.Contains(err.Error(), "databricks OAuth is not") { - // OAuth is not configured or not supported - return nil, nil - } - return nil, err - } - logger.Debugf(ctx, "Using Databricks CLI authentication with Databricks OAuth tokens") - visitor := refreshableVisitor(ts) - return credentials.NewOAuthCredentialsProvider(visitor, ts.Token), nil -} - -var errLegacyDatabricksCli = errors.New("legacy Databricks CLI detected") - -type databricksCliTokenSource struct { - ctx context.Context - name string - args []string -} - -func newDatabricksCliTokenSource(ctx context.Context, cfg *Config) (*databricksCliTokenSource, error) { - args := []string{"auth", "token", "--host", cfg.Host} - - if cfg.IsAccountClient() { - args = append(args, "--account-id", cfg.AccountID) - } - - databricksCliPath := cfg.DatabricksCliPath - if databricksCliPath == "" { - databricksCliPath = "databricks" - } - - // Resolve absolute path to the Databricks CLI executable. - path, err := exec.LookPath(databricksCliPath) - if err != nil { - return nil, err - } - - // Resolve symlinks in order to figure out executable size. - path, err = filepath.EvalSymlinks(path) - if err != nil { - return nil, err - } - - // Determine executable size as signal to determine old/new Databricks CLI. - stat, err := os.Stat(path) - if err != nil { - return nil, err - } - - // The new Databricks CLI is a single binary with size > 1MB. - // We use the size as a signal to determine which Databricks CLI is installed. - if stat.Size() < (1024 * 1024) { - return nil, errLegacyDatabricksCli - } - - return &databricksCliTokenSource{ctx: ctx, name: path, args: args}, nil -} - -func (ts *databricksCliTokenSource) Token() (*oauth2.Token, error) { - out, err := runCommand(ts.ctx, ts.name, ts.args) - if ee, ok := err.(*exec.ExitError); ok { - return nil, fmt.Errorf("cannot get access token: %s", string(ee.Stderr)) - } - if err != nil { - return nil, fmt.Errorf("cannot get access token: %v", err) - } - var t oauth2.Token - err = json.Unmarshal(out, &t) - if err != nil { - return nil, fmt.Errorf("cannot unmarshal Databricks CLI result: %w", err) - } - logger.Infof(context.Background(), "Refreshed OAuth token from Databricks CLI, expires on %s", t.Expiry) - return &t, nil -} diff --git a/config/auth_databricks_cli_test.go b/config/auth_databricks_cli_test.go deleted file mode 100644 index 5566d93c0..000000000 --- a/config/auth_databricks_cli_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package config - -import ( - "context" - "os" - "path/filepath" - "testing" - - "github.com/databricks/databricks-sdk-go/internal/env" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var cliDummy = &Config{Host: "https://abc.cloud.databricks.com/"} - -func writeSmallDummyExecutable(t *testing.T, path string) { - f, err := os.Create(filepath.Join(path, "databricks")) - require.NoError(t, err) - defer f.Close() - err = os.Chmod(f.Name(), 0755) - require.NoError(t, err) - _, err = f.WriteString("#!/bin/sh\necho hello world\n") - require.NoError(t, err) -} - -func writeLargeDummyExecutable(t *testing.T, path string) { - f, err := os.Create(filepath.Join(path, "databricks")) - require.NoError(t, err) - defer f.Close() - err = os.Chmod(f.Name(), 0755) - require.NoError(t, err) - _, err = f.WriteString("#!/bin/sh\n") - require.NoError(t, err) - - f.WriteString(` -cat < + + + + {{if .Error }}{{ .Error | title }}{{ else }}Success{{end}} + + + + + + + +
+
+ + +
{{ .Error | title }}
+
{{ .ErrorDescription }}
+ +
Authenticated
+
Go to {{.Host}}
+ +
+ You can close this tab. Or go to documentation +
+
+
+ + diff --git a/credentials/oauth/persistent_auth.go b/credentials/oauth/persistent_auth.go new file mode 100644 index 000000000..4f67337b6 --- /dev/null +++ b/credentials/oauth/persistent_auth.go @@ -0,0 +1,237 @@ +package oauth + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" + "net" + "os" + "path/filepath" + "sync" + "time" + + "github.com/databricks/databricks-sdk-go/credentials/cache" + "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/databricks/databricks-sdk-go/retries" + "github.com/pkg/browser" + "golang.org/x/oauth2" + "golang.org/x/oauth2/authhandler" +) + +const ( + // these values are predefined by Databricks as a public client + // and is specific to this application only. Using these values + // for other applications is not allowed. + appClientID = "databricks-cli" + appRedirectAddr = "localhost:8020" + + // lockfile location + lockFilePath = ".databricks/token-cache.lock" + + // maximum amount of time to acquire listener on appRedirectAddr + listenerTimeout = 45 * time.Second +) + +// PersistentAuth is an OAuth manager that handles the U2M OAuth flow. Tokens +// are stored in and looked up from the provided cache. Tokens include the +// refresh token. On load, if the access token is expired, it is refreshed +// using the refresh token. +type PersistentAuth struct { + // Cache is the token cache to store and lookup tokens. + cache cache.TokenCache + // Locker is the lock to synchronize token cache access. + locker sync.Locker + // Client is the HTTP client to use for OAuth2 requests. + client *httpclient.ApiClient + // Browser is the function to open a URL in the default browser. + browser func(url string) error + // ln is the listener for the OAuth2 callback server. + ln net.Listener +} + +type PersistentAuthOption func(*PersistentAuth) + +// WithTokenCache sets the token cache for the PersistentAuth. +func WithTokenCache(c cache.TokenCache) PersistentAuthOption { + return func(a *PersistentAuth) { + a.cache = c + } +} + +// WithLocker sets the locker for the PersistentAuth. +func WithLocker(l sync.Locker) PersistentAuthOption { + return func(a *PersistentAuth) { + a.locker = l + } +} + +// WithApiClient sets the HTTP client for the PersistentAuth. +func WithApiClient(c *httpclient.ApiClient) PersistentAuthOption { + return func(a *PersistentAuth) { + a.client = c + } +} + +// WithBrowser sets the browser function for the PersistentAuth. +func WithBrowser(b func(url string) error) PersistentAuthOption { + return func(a *PersistentAuth) { + a.browser = b + } +} + +func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*PersistentAuth, error) { + p := &PersistentAuth{} + for _, opt := range opts { + opt(p) + } + if p.client == nil { + p.client = httpclient.NewApiClient(httpclient.ClientConfig{}) + } + if p.cache == nil { + p.cache = &cache.FileTokenCache{} + } + if p.locker == nil { + home, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("home: %w", err) + } + + p.locker, err = newLocker(filepath.Join(home, lockFilePath)) + if err != nil { + return nil, fmt.Errorf("locker: %w", err) + } + } + if p.browser == nil { + p.browser = browser.OpenURL + } + return p, nil +} + +func (a *PersistentAuth) Load(ctx context.Context, arg OAuthArgument) (*oauth2.Token, error) { + key := arg.GetCacheKey(ctx) + t, err := a.cache.Lookup(key) + if err != nil { + return nil, fmt.Errorf("cache: %w", err) + } + // refresh if invalid + if !t.Valid() { + // OAuth2 config is invoked only for expired tokens to speed up + // the happy path in the token retrieval + cfg, err := a.oauth2Config(ctx, arg.GetHost(ctx), arg.GetAccountId(ctx)) + if err != nil { + return nil, err + } + // make OAuth2 library use our client + ctx = a.client.InContextForOAuth2(ctx) + // eagerly refresh token + t, err = cfg.TokenSource(ctx, t).Token() + if err != nil { + return nil, fmt.Errorf("token refresh: %w", err) + } + err = a.cache.Store(key, t) + if err != nil { + return nil, fmt.Errorf("cache refresh: %w", err) + } + } + // do not print refresh token to end-user + t.RefreshToken = "" + return t, nil +} + +func (a *PersistentAuth) Challenge(ctx context.Context, arg OAuthArgument) error { + err := a.startListener(ctx) + if err != nil { + return fmt.Errorf("starting listener: %w", err) + } + cfg, err := a.oauth2Config(ctx, arg.GetHost(ctx), arg.GetAccountId(ctx)) + if err != nil { + return fmt.Errorf("fetching oauth config: %w", err) + } + cb, err := a.newCallback(ctx, arg) + if err != nil { + return fmt.Errorf("callback server: %w", err) + } + defer cb.Close() + state, pkce := a.stateAndPKCE() + // make OAuth2 library use our client + ctx = a.client.InContextForOAuth2(ctx) + ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce) + t, err := ts.Token() + if err != nil { + return fmt.Errorf("authorize: %w", err) + } + // cache token identified by host (and possibly the account id) + err = a.cache.Store(arg.GetCacheKey(ctx), t) + if err != nil { + return fmt.Errorf("store: %w", err) + } + return nil +} + +func (a *PersistentAuth) startListener(ctx context.Context) error { + listener, err := retries.Poll(ctx, listenerTimeout, + func() (*net.Listener, *retries.Err) { + var lc net.ListenConfig + l, err := lc.Listen(ctx, "tcp", appRedirectAddr) + if err != nil { + return nil, retries.Continue(err) + } + return &l, nil + }) + if err != nil { + return fmt.Errorf("listener: %w", err) + } + a.ln = *listener + return nil +} + +func (a *PersistentAuth) Close() error { + if a.ln == nil { + return nil + } + return a.ln.Close() +} + +func (a *PersistentAuth) oauth2Config(ctx context.Context, host string, accountId string) (*oauth2.Config, error) { + // in this iteration of CLI, we're using all scopes by default, + // because tools like CLI and Terraform do use all apis. This + // decision may be reconsidered later, once we have a proper + // taxonomy of all scopes ready and implemented. + scopes := []string{ + "offline_access", + "all-apis", + } + endpoints, err := a.client.GetOidcEndpoints(ctx, host, accountId) + if err != nil { + return nil, fmt.Errorf("oidc: %w", err) + } + return &oauth2.Config{ + ClientID: appClientID, + Endpoint: oauth2.Endpoint{ + AuthURL: endpoints.AuthorizationEndpoint, + TokenURL: endpoints.TokenEndpoint, + AuthStyle: oauth2.AuthStyleInParams, + }, + RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr), + Scopes: scopes, + }, nil +} + +func (a *PersistentAuth) stateAndPKCE() (string, *authhandler.PKCEParams) { + verifier := a.randomString(64) + verifierSha256 := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(verifierSha256[:]) + return a.randomString(16), &authhandler.PKCEParams{ + Challenge: challenge, + ChallengeMethod: "S256", + Verifier: verifier, + } +} + +func (a *PersistentAuth) randomString(size int) string { + raw := make([]byte, size) + _, _ = rand.Read(raw) + return base64.RawURLEncoding.EncodeToString(raw) +} diff --git a/credentials/oauth/persistent_auth_test.go b/credentials/oauth/persistent_auth_test.go new file mode 100644 index 000000000..cd8613cb0 --- /dev/null +++ b/credentials/oauth/persistent_auth_test.go @@ -0,0 +1,201 @@ +package oauth_test + +import ( + "context" + "crypto/tls" + _ "embed" + "fmt" + "net/http" + "net/url" + "testing" + "time" + + "github.com/databricks/databricks-sdk-go/client" + "github.com/databricks/databricks-sdk-go/credentials/oauth" + "github.com/databricks/databricks-sdk-go/qa" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +type tokenCacheMock struct { + store func(key string, t *oauth2.Token) error + lookup func(key string) (*oauth2.Token, error) +} + +func (m *tokenCacheMock) Store(key string, t *oauth2.Token) error { + if m.store == nil { + panic("no store mock") + } + return m.store(key, t) +} + +func (m *tokenCacheMock) Lookup(key string) (*oauth2.Token, error) { + if m.lookup == nil { + panic("no lookup mock") + } + return m.lookup(key) +} + +func TestLoad(t *testing.T) { + cache := &tokenCacheMock{ + lookup: func(key string) (*oauth2.Token, error) { + assert.Equal(t, "https://abc/oidc/accounts/xyz", key) + return &oauth2.Token{ + AccessToken: "bcd", + Expiry: time.Now().Add(1 * time.Minute), + }, nil + }, + } + p, err := oauth.NewPersistentAuth(context.Background(), oauth.WithTokenCache(cache)) + require.NoError(t, err) + defer p.Close() + tok, err := p.Load(context.Background(), oauth.BasicOAuthArgument{ + Host: "https://abc", + AccountID: "xyz", + }) + assert.NoError(t, err) + assert.Equal(t, "bcd", tok.AccessToken) + assert.Equal(t, "", tok.RefreshToken) +} + +func useInsecureOAuthHttpClientForTests(ctx context.Context) context.Context { + return context.WithValue(ctx, oauth2.HTTPClient, &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + }) +} + +func TestLoadRefresh(t *testing.T) { + qa.HTTPFixtures{ + { + Method: "POST", + Resource: "/oidc/accounts/xyz/v1/token", + Response: `access_token=refreshed&refresh_token=def`, + }, + }.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) { + ctx = useInsecureOAuthHttpClientForTests(ctx) + expectedKey := fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host) + cache := &tokenCacheMock{ + lookup: func(key string) (*oauth2.Token, error) { + assert.Equal(t, expectedKey, key) + return &oauth2.Token{ + AccessToken: "expired", + RefreshToken: "cde", + Expiry: time.Now().Add(-1 * time.Minute), + }, nil + }, + store: func(key string, tok *oauth2.Token) error { + assert.Equal(t, expectedKey, key) + assert.Equal(t, "def", tok.RefreshToken) + return nil + }, + } + p, err := oauth.NewPersistentAuth(context.Background(), oauth.WithTokenCache(cache)) + require.NoError(t, err) + defer p.Close() + tok, err := p.Load(ctx, oauth.BasicOAuthArgument{ + Host: c.Config.Host, + AccountID: "xyz", + }) + assert.NoError(t, err) + assert.Equal(t, "refreshed", tok.AccessToken) + assert.Equal(t, "", tok.RefreshToken) + }) +} + +func TestChallenge(t *testing.T) { + qa.HTTPFixtures{ + { + Method: "POST", + Resource: "/oidc/accounts/xyz/v1/token", + Response: `access_token=__THAT__&refresh_token=__SOMETHING__`, + }, + }.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) { + ctx = useInsecureOAuthHttpClientForTests(ctx) + expectedKey := fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host) + + browserOpened := make(chan string) + browser := func(redirect string) error { + u, err := url.ParseRequestURI(redirect) + if err != nil { + return err + } + assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path) + // for now we're ignoring asserting the fields of the redirect + query := u.Query() + browserOpened <- query.Get("state") + return nil + } + cache := &tokenCacheMock{ + store: func(key string, tok *oauth2.Token) error { + assert.Equal(t, expectedKey, key) + assert.Equal(t, "__SOMETHING__", tok.RefreshToken) + return nil + }, + } + p, err := oauth.NewPersistentAuth(context.Background(), oauth.WithTokenCache(cache), oauth.WithBrowser(browser)) + require.NoError(t, err) + defer p.Close() + + errc := make(chan error) + go func() { + errc <- p.Challenge(ctx, oauth.BasicOAuthArgument{ + Host: c.Config.Host, + AccountID: "xyz", + }) + }() + + state := <-browserOpened + resp, err := http.Get(fmt.Sprintf("http://localhost:8020?code=__THIS__&state=%s", state)) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, 200, resp.StatusCode) + + err = <-errc + assert.NoError(t, err) + }) +} + +func TestChallengeFailed(t *testing.T) { + qa.HTTPFixtures{}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) { + ctx = useInsecureOAuthHttpClientForTests(ctx) + + browserOpened := make(chan string) + browser := func(redirect string) error { + u, err := url.ParseRequestURI(redirect) + if err != nil { + return err + } + assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path) + // for now we're ignoring asserting the fields of the redirect + query := u.Query() + browserOpened <- query.Get("state") + return nil + } + p, err := oauth.NewPersistentAuth(context.Background(), oauth.WithBrowser(browser)) + require.NoError(t, err) + defer p.Close() + + errc := make(chan error) + go func() { + errc <- p.Challenge(ctx, oauth.BasicOAuthArgument{ + Host: c.Config.Host, + AccountID: "xyz", + }) + }() + + <-browserOpened + resp, err := http.Get( + "http://localhost:8020?error=access_denied&error_description=Policy%%20evaluation%%20failed%%20for%%20this%%20request") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, 400, resp.StatusCode) + + err = <-errc + assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request") + }) +} diff --git a/credentials/oauth_token.go b/credentials/oauth_token.go deleted file mode 100644 index a1f6c131e..000000000 --- a/credentials/oauth_token.go +++ /dev/null @@ -1,14 +0,0 @@ -package credentials - -// OAuthToken represents an OAuth token as defined by the OAuth 2.0 Authorization Framework. -// https://datatracker.ietf.org/doc/html/rfc6749 -type OAuthToken struct { - // The access token issued by the authorization server. This is the token that will be used to authenticate requests. - AccessToken string `json:"access_token" auth:",sensitive"` - // Time in seconds until the token expires. - ExpiresIn int `json:"expires_in"` - // The scope of the token. This is a space-separated list of strings that represent the permissions granted by the token. - Scope string `json:"scope"` - // The type of token that was issued. - TokenType string `json:"token_type"` -} diff --git a/go.mod b/go.mod index b79183940..9aece8c4d 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,17 @@ module github.com/databricks/databricks-sdk-go go 1.18 require ( + github.com/alexflint/go-filemutex v1.3.0 github.com/google/go-cmp v0.6.0 github.com/google/go-querystring v1.1.0 github.com/google/uuid v1.6.0 + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c github.com/stretchr/testify v1.9.0 golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 golang.org/x/mod v0.17.0 golang.org/x/net v0.26.0 golang.org/x/oauth2 v0.20.0 + golang.org/x/text v0.16.0 golang.org/x/time v0.5.0 google.golang.org/api v0.182.0 gopkg.in/ini.v1 v1.67.0 @@ -37,7 +40,6 @@ require ( go.opentelemetry.io/otel/trace v1.24.0 // indirect golang.org/x/crypto v0.24.0 // indirect golang.org/x/sys v0.21.0 // indirect - golang.org/x/text v0.16.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240521202816-d264139d666e // indirect google.golang.org/grpc v1.64.1 // indirect google.golang.org/protobuf v1.34.1 // indirect diff --git a/go.sum b/go.sum index 95e9089ac..77455aed3 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRk cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/alexflint/go-filemutex v1.3.0 h1:LgE+nTUWnQCyRKbpoceKZsPQbs84LivvgwUymZXdOcM= +github.com/alexflint/go-filemutex v1.3.0/go.mod h1:U0+VA/i30mGBlLCrFPGtTe9y6wGQfNAWPBTekHQ+c8A= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= @@ -58,6 +60,8 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= github.com/googleapis/gax-go/v2 v2.12.4 h1:9gWcmF85Wvq4ryPFvGFaOgPIs1AQX0d0bcbGw4Z96qg= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -66,6 +70,7 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= @@ -113,6 +118,8 @@ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -158,6 +165,7 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/httpclient/oauth_token.go b/httpclient/oauth_token.go index 142afd48d..abbabb123 100644 --- a/httpclient/oauth_token.go +++ b/httpclient/oauth_token.go @@ -5,7 +5,6 @@ import ( "net/http" "time" - "github.com/databricks/databricks-sdk-go/credentials" "golang.org/x/oauth2" ) @@ -22,6 +21,19 @@ type GetOAuthTokenRequest struct { Assertion string `url:"assertion"` } +// OAuthToken represents an OAuth token as defined by the OAuth 2.0 Authorization Framework. +// https://datatracker.ietf.org/doc/html/rfc6749 +type OAuthToken struct { + // The access token issued by the authorization server. This is the token that will be used to authenticate requests. + AccessToken string `json:"access_token" auth:",sensitive"` + // Time in seconds until the token expires. + ExpiresIn int `json:"expires_in"` + // The scope of the token. This is a space-separated list of strings that represent the permissions granted by the token. + Scope string `json:"scope"` + // The type of token that was issued. + TokenType string `json:"token_type"` +} + // Returns a new OAuth token using the provided token. The token must be a JWT token. // The resulting token is scoped to the authorization details provided. // @@ -34,7 +46,7 @@ func (c *ApiClient) GetOAuthToken(ctx context.Context, authDetails string, token AuthorizationDetails: authDetails, Assertion: token.AccessToken, } - var response credentials.OAuthToken + var response OAuthToken opts := []DoOption{ WithUrlEncodedData(data), WithResponseUnmarshal(&response), diff --git a/httpclient/oidc.go b/httpclient/oidc.go new file mode 100644 index 000000000..b6d9b0a83 --- /dev/null +++ b/httpclient/oidc.go @@ -0,0 +1,35 @@ +package httpclient + +import ( + "context" + "errors" + "fmt" +) + +var ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host") + +func (c *ApiClient) GetOidcEndpoints(ctx context.Context, host, accountId string) (*OAuthAuthorizationServer, error) { + prefix := host + if /* cfg.IsAccountClient() && */ accountId != "" { + // TODO: technically, we could use the same config profile for both workspace + // and account, but we have to add logic for determining accounts host from + // workspace host. + prefix := fmt.Sprintf("%s/oidc/accounts/%s", host, accountId) + return &OAuthAuthorizationServer{ + AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix), + TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix), + }, nil + } + oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", prefix) + var oauthEndpoints OAuthAuthorizationServer + err := c.Do(ctx, "GET", oidc, WithResponseUnmarshal(&oauthEndpoints)) + if err != nil { + return nil, ErrOAuthNotSupported + } + return &oauthEndpoints, nil +} + +type OAuthAuthorizationServer struct { + AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize + TokenEndpoint string `json:"token_endpoint"` // ../v1/token +} diff --git a/httpclient/oidc_test.go b/httpclient/oidc_test.go new file mode 100644 index 000000000..4d6e941f2 --- /dev/null +++ b/httpclient/oidc_test.go @@ -0,0 +1,35 @@ +package httpclient + +import ( + "context" + "testing" + + "github.com/databricks/databricks-sdk-go/httpclient/fixtures" + "github.com/stretchr/testify/assert" +) + +func TestOidcEndpointsForAccounts(t *testing.T) { + p := NewApiClient(ClientConfig{}) + s, err := p.GetOidcEndpoints(context.Background(), "https://abc", "xyz") + assert.NoError(t, err) + assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", s.AuthorizationEndpoint) + assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/token", s.TokenEndpoint) +} + +func TestOidcForWorkspace(t *testing.T) { + p := NewApiClient(ClientConfig{ + Transport: fixtures.MappingTransport{ + "GET /oidc/.well-known/oauth-authorization-server": { + Status: 200, + Response: map[string]string{ + "authorization_endpoint": "a", + "token_endpoint": "b", + }, + }, + }, + }) + endpoints, err := p.GetOidcEndpoints(context.Background(), "https://abc", "") + assert.NoError(t, err) + assert.Equal(t, "a", endpoints.AuthorizationEndpoint) + assert.Equal(t, "b", endpoints.TokenEndpoint) +} From 35503294ba2293b19d32ed5c3d2cd22a72f89fd2 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 3 Jan 2025 12:05:43 +0100 Subject: [PATCH 02/44] fix tests --- config/auth_m2m.go | 3 --- config/auth_m2m_test.go | 2 +- config/auth_u2m.go | 14 ++++++++++++-- credentials/oauth/persistent_auth_test.go | 2 +- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/config/auth_m2m.go b/config/auth_m2m.go index 030dfe5b8..0ae99d6fe 100644 --- a/config/auth_m2m.go +++ b/config/auth_m2m.go @@ -2,7 +2,6 @@ package config import ( "context" - "errors" "fmt" "golang.org/x/oauth2" @@ -12,8 +11,6 @@ import ( "github.com/databricks/databricks-sdk-go/logger" ) -var errOAuthNotSupported = errors.New("databricks OAuth is not supported for this host") - type M2mCredentials struct { } diff --git a/config/auth_m2m_test.go b/config/auth_m2m_test.go index fc2ecff5b..f05264a47 100644 --- a/config/auth_m2m_test.go +++ b/config/auth_m2m_test.go @@ -81,5 +81,5 @@ func TestM2mNotSupported(t *testing.T) { }, }, }) - require.ErrorIs(t, err, errOAuthNotSupported) + require.ErrorIs(t, err, httpclient.ErrOAuthNotSupported) } diff --git a/config/auth_u2m.go b/config/auth_u2m.go index 3a968ab1e..7170d46b5 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -7,10 +7,11 @@ import ( "github.com/databricks/databricks-sdk-go/credentials" "github.com/databricks/databricks-sdk-go/credentials/oauth" + "github.com/databricks/databricks-sdk-go/logger" ) type U2MCredentials struct { - Auth oauth.PersistentAuth + Auth *oauth.PersistentAuth } // Name implements CredentialsStrategy. @@ -20,12 +21,21 @@ func (u U2MCredentials) Name() string { // Configure implements CredentialsStrategy. func (u U2MCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { + a := u.Auth + if a == nil { + var err error + a, err = oauth.NewPersistentAuth(ctx) + if err != nil { + logger.Debugf(ctx, "failed to create persistent auth: %v, continuing", err) + return nil, nil + } + } f := func(r *http.Request) error { arg := oauth.BasicOAuthArgument{ Host: cfg.Host, AccountID: cfg.AccountID, } - token, err := u.Auth.Load(r.Context(), arg) + token, err := a.Load(r.Context(), arg) if err != nil { return fmt.Errorf("oidc: %w", err) } diff --git a/credentials/oauth/persistent_auth_test.go b/credentials/oauth/persistent_auth_test.go index cd8613cb0..5c9e22578 100644 --- a/credentials/oauth/persistent_auth_test.go +++ b/credentials/oauth/persistent_auth_test.go @@ -190,7 +190,7 @@ func TestChallengeFailed(t *testing.T) { <-browserOpened resp, err := http.Get( - "http://localhost:8020?error=access_denied&error_description=Policy%%20evaluation%%20failed%%20for%%20this%%20request") + "http://localhost:8020?error=access_denied&error_description=Policy%20evaluation%20failed%20for%20this%20request") require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, 400, resp.StatusCode) From 0688c3b8afac35c2c45ab224524c8270339933c1 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 3 Jan 2025 14:37:40 +0100 Subject: [PATCH 03/44] fix test --- credentials/oauth/lock.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/credentials/oauth/lock.go b/credentials/oauth/lock.go index d200fbfe8..f2d8efebc 100644 --- a/credentials/oauth/lock.go +++ b/credentials/oauth/lock.go @@ -2,6 +2,8 @@ package oauth import ( "fmt" + "os" + "path/filepath" "sync" "github.com/alexflint/go-filemutex" @@ -29,6 +31,10 @@ func (l *lockerAdaptor) Unlock() { } func newLocker(path string) (sync.Locker, error) { + dirName := filepath.Dir(path) + if _, err := os.Stat(dirName); err != nil && os.IsNotExist(err) { + os.MkdirAll(dirName, 0750) + } m, err := filemutex.New(path) if err != nil { return nil, err From 35b83655f1316aa8140196cd3610bd14ff5162e5 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 6 Jan 2025 13:20:39 +0100 Subject: [PATCH 04/44] respond to some changes --- config/auth_default.go | 2 +- config/auth_u2m.go | 51 +++++++-- credentials/cache/cache.go | 4 + credentials/cache/file.go | 48 +++++--- credentials/oauth/callback.go | 13 ++- credentials/oauth/error.go | 15 +++ credentials/oauth/oauth_argument.go | 94 +++++++++++---- credentials/oauth/oidc.go | 32 ++++++ .../oauth}/oidc_test.go | 10 +- credentials/oauth/page.tmpl | 2 + credentials/oauth/persistent_auth.go | 107 ++++++++++++++---- credentials/oauth/persistent_auth_test.go | 28 ++--- httpclient/oidc.go | 35 ------ 13 files changed, 310 insertions(+), 131 deletions(-) create mode 100644 credentials/oauth/error.go create mode 100644 credentials/oauth/oidc.go rename {httpclient => credentials/oauth}/oidc_test.go (73%) delete mode 100644 httpclient/oidc.go diff --git a/config/auth_default.go b/config/auth_default.go index 7ccc48bbc..6f5889fb7 100644 --- a/config/auth_default.go +++ b/config/auth_default.go @@ -13,7 +13,7 @@ var authProviders = []CredentialsStrategy{ PatCredentials{}, BasicCredentials{}, M2mCredentials{}, - U2MCredentials{}, + databricksCliCredentials, MetadataServiceCredentials{}, // Attempt to configure auth from most specific to most generic (the Azure CLI). diff --git a/config/auth_u2m.go b/config/auth_u2m.go index 7170d46b5..6aeea162b 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -2,21 +2,35 @@ package config import ( "context" + "errors" "fmt" "net/http" "github.com/databricks/databricks-sdk-go/credentials" + "github.com/databricks/databricks-sdk-go/credentials/cache" "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/databricks/databricks-sdk-go/logger" ) type U2MCredentials struct { + // Auth is the persistent auth object to use. If not specified, a new one will + // be created, using the default cache and locker. Auth *oauth.PersistentAuth + + // ErrorHandler controls the behavior of Configure() when loading the OAuth + // token fails. If not specified, any error will cause Configure() to return + // said error. + ErrorHandler func(context.Context, error) error + + name string } // Name implements CredentialsStrategy. func (u U2MCredentials) Name() string { - return "oauth-u2m" + if u.name != "" { + return "oauth-u2m" + } + return u.name } // Configure implements CredentialsStrategy. @@ -30,6 +44,12 @@ func (u U2MCredentials) Configure(ctx context.Context, cfg *Config) (credentials return nil, nil } } + + r, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil) + if err != nil { + return nil, fmt.Errorf("http request: %w", err) + } + f := func(r *http.Request) error { arg := oauth.BasicOAuthArgument{ Host: cfg.Host, @@ -43,17 +63,32 @@ func (u U2MCredentials) Configure(ctx context.Context, cfg *Config) (credentials return nil } - r, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil) - if err != nil { - return nil, fmt.Errorf("http request: %w", err) - } - // Try to load the credential from the token cache. If absent, fall back - // to the next credentials strategy. + // Try to load the credential from the token cache. If absent, fall back to + // the next credentials strategy. If a token is present but cannot be loaded + // (e.g. expired), return an error. Otherwise, fall back to the next + // credentials strategy. if err := f(r); err != nil { - return nil, nil + if u.ErrorHandler != nil { + return nil, u.ErrorHandler(ctx, err) + } + return nil, err } return credentials.NewCredentialsProvider(f), nil } var _ CredentialsStrategy = U2MCredentials{} + +var databricksCliCredentials = U2MCredentials{ + ErrorHandler: func(ctx context.Context, err error) error { + if errors.Is(err, cache.ErrNotConfigured) { + return nil + } + if _, ok := err.(*oauth.InvalidRefreshTokenError); ok { + return err + } + logger.Debugf(ctx, "failed to load token: %v, continuing", err) + return nil + }, + name: "databricks-cli", +} diff --git a/credentials/cache/cache.go b/credentials/cache/cache.go index 271849207..e20edd23d 100644 --- a/credentials/cache/cache.go +++ b/credentials/cache/cache.go @@ -1,6 +1,8 @@ package cache import ( + "errors" + "golang.org/x/oauth2" ) @@ -8,3 +10,5 @@ type TokenCache interface { Store(key string, t *oauth2.Token) error Lookup(key string) (*oauth2.Token, error) } + +var ErrNotConfigured = errors.New("databricks OAuth is not configured for this host") diff --git a/credentials/cache/file.go b/credentials/cache/file.go index 38dfea9f2..f8ba51f58 100644 --- a/credentials/cache/file.go +++ b/credentials/cache/file.go @@ -12,7 +12,8 @@ import ( ) const ( - // where the token cache is stored + // tokenCacheFile is the path of the default token cache, relative to the + // user's home directory. tokenCacheFile = ".databricks/token-cache.json" // only the owner of the file has full execute, read, and write access @@ -21,12 +22,31 @@ const ( // only the owner of the file has full read and write access ownerReadWrite = 0o600 - // format versioning leaves some room for format improvement + // tokenCacheVersion is the version of the token cache file format. + // + // Version 1 format: + // + // { + // "version": 1, + // "tokens": { + // "": { + // "access_token": "", + // "token_type": "", + // "refresh_token": "" + // } + // } + // } + // + // The format of "" depends on whether the token is account- or + // workspace-scoped: + // - Account-scoped: "https:///oidc/accounts/" + // - Workspace-scoped: "https://" tokenCacheVersion = 1 ) -var ErrNotConfigured = errors.New("databricks OAuth is not configured for this host") - +// FileTokenCache caches tokens in "~/.databricks/token-cache.json". FileTokenCache +// implements the TokenCache interface. // this implementation requires the calling code to do a machine-wide lock, // otherwise the file might get corrupt. type FileTokenCache struct { @@ -38,14 +58,14 @@ type FileTokenCache struct { func (c *FileTokenCache) Store(key string, t *oauth2.Token) error { err := c.load() - if errors.Is(err, fs.ErrNotExist) { + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("load: %w", err) + } dir := filepath.Dir(c.fileLocation) - err = os.MkdirAll(dir, ownerExecReadWrite) - if err != nil { + if err := os.MkdirAll(dir, ownerExecReadWrite); err != nil { return fmt.Errorf("mkdir: %w", err) } - } else if err != nil { - return fmt.Errorf("load: %w", err) } c.Version = tokenCacheVersion if c.Tokens == nil { @@ -73,16 +93,12 @@ func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) { return t, nil } -func (c *FileTokenCache) location() (string, error) { +func (c *FileTokenCache) load() error { home, err := os.UserHomeDir() if err != nil { - return "", fmt.Errorf("home: %w", err) + return fmt.Errorf("failed loading home directory: %w", err) } - return filepath.Join(home, tokenCacheFile), nil -} - -func (c *FileTokenCache) load() error { - loc, err := c.location() + loc := filepath.Join(home, tokenCacheFile) if err != nil { return err } diff --git a/credentials/oauth/callback.go b/credentials/oauth/callback.go index 20d45d430..cb7861be7 100644 --- a/credentials/oauth/callback.go +++ b/credentials/oauth/callback.go @@ -72,7 +72,7 @@ func (cb *callbackServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { ErrorDescription: r.FormValue("error_description"), Code: r.FormValue("code"), State: r.FormValue("state"), - Host: cb.arg.GetHost(cb.ctx), + Host: cb.getHost(r.Context()), } if res.Error != "" { w.WriteHeader(http.StatusBadRequest) @@ -86,6 +86,17 @@ func (cb *callbackServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { cb.feedbackCh <- res } +func (cb *callbackServer) getHost(ctx context.Context) string { + switch a := cb.arg.(type) { + case AccountOAuthArgument: + return a.GetAccountHost(ctx) + case WorkspaceOAuthArgument: + return a.GetWorkspaceHost(ctx) + default: + return "" + } +} + // Handler opens up a browser waits for redirect to come back from the identity provider func (cb *callbackServer) Handler(authCodeURL string) (string, string, error) { err := cb.a.browser(authCodeURL) diff --git a/credentials/oauth/error.go b/credentials/oauth/error.go new file mode 100644 index 000000000..f29858f8d --- /dev/null +++ b/credentials/oauth/error.go @@ -0,0 +1,15 @@ +package oauth + +type InvalidRefreshTokenError struct { + err error +} + +func (e *InvalidRefreshTokenError) Error() string { + return e.err.Error() +} + +func (e *InvalidRefreshTokenError) Unwrap() error { + return e.err +} + +var _ error = &InvalidRefreshTokenError{} diff --git a/credentials/oauth/oauth_argument.go b/credentials/oauth/oauth_argument.go index 2ff66f963..344b069d3 100644 --- a/credentials/oauth/oauth_argument.go +++ b/credentials/oauth/oauth_argument.go @@ -9,43 +9,89 @@ import ( // OAuthArgument is an interface that provides the necessary information to // authenticate with PersistentAuth. type OAuthArgument interface { - // GetHost returns the host of the account or workspace to authenticate to. - GetHost(ctx context.Context) string - - // GetAccountId returns the account ID of the account to authenticate to. - GetAccountId(ctx context.Context) string - - // GetCacheKey returns the key to use for caching the token. On Challenge, - // this key is used to store the token. On Load, this key is used to lookup - // the token. + // GetCacheKey returns a unique key for the OAuthArgument. This key is used + // to store and retrieve the token from the token cache. GetCacheKey(ctx context.Context) string } -type BasicOAuthArgument struct { - Host string - AccountID string +type WorkspaceOAuthArgument interface { + OAuthArgument + + // GetWorkspaceHost returns the host of the workspace to authenticate to. + GetWorkspaceHost(ctx context.Context) string } -var _ OAuthArgument = BasicOAuthArgument{} +type BasicWorkspaceOAuthArgument struct { + // host is the host of the workspace to authenticate to. This must start + // with "https://" and must not have a trailing slash. + host string +} -func (a BasicOAuthArgument) GetHost(ctx context.Context) string { - return a.Host +func NewBasicWorkspaceOAuthArgument(host string) (BasicWorkspaceOAuthArgument, error) { + if !strings.HasPrefix(host, "https://") { + return BasicWorkspaceOAuthArgument{}, fmt.Errorf("host must start with 'https://': %s", host) + } + if strings.HasSuffix(host, "/") { + return BasicWorkspaceOAuthArgument{}, fmt.Errorf("host must not have a trailing slash: %s", host) + } + return BasicWorkspaceOAuthArgument{host: host}, nil } -func (a BasicOAuthArgument) GetAccountId(ctx context.Context) string { - return a.AccountID +func (a BasicWorkspaceOAuthArgument) GetHost(ctx context.Context) string { + return a.host } // key is currently used for two purposes: OIDC URL prefix and token cache key. // once we decide to start storing scopes in the token cache, we should change // this approach. -func (a BasicOAuthArgument) GetCacheKey(ctx context.Context) string { - a.Host = strings.TrimSuffix(a.Host, "/") - if !strings.HasPrefix(a.Host, "http") { - a.Host = fmt.Sprintf("https://%s", a.Host) +func (a BasicWorkspaceOAuthArgument) GetCacheKey(ctx context.Context) string { + a.host = strings.TrimSuffix(a.host, "/") + if !strings.HasPrefix(a.host, "http") { + a.host = fmt.Sprintf("https://%s", a.host) + } + return a.host +} + +var _ OAuthArgument = BasicWorkspaceOAuthArgument{} + +type AccountOAuthArgument interface { + OAuthArgument + + // GetAccountHost returns the host of the account to authenticate to. + GetAccountHost(ctx context.Context) string + + // GetAccountId returns the account ID of the account to authenticate to. + GetAccountId(ctx context.Context) string +} + +type BasicAccountOAuthArgument struct { + accountHost string + accountID string +} + +var _ OAuthArgument = BasicAccountOAuthArgument{} + +func NewBasicAccountOAuthArgument(accountsHost, accountID string) (BasicAccountOAuthArgument, error) { + if !strings.HasPrefix(accountsHost, "https://") { + return BasicAccountOAuthArgument{}, fmt.Errorf("accountsHost must start with 'https://': %s", accountsHost) } - if a.AccountID != "" { - return fmt.Sprintf("%s/oidc/accounts/%s", a.Host, a.AccountID) + if strings.HasSuffix(accountsHost, "/") { + return BasicAccountOAuthArgument{}, fmt.Errorf("accountsHost must not have a trailing slash: %s", accountsHost) } - return a.Host + return BasicAccountOAuthArgument{accountHost: accountsHost, accountID: accountID}, nil +} + +func (a BasicAccountOAuthArgument) GetHost(ctx context.Context) string { + return a.accountHost +} + +func (a BasicAccountOAuthArgument) GetAccountId(ctx context.Context) string { + return a.accountID +} + +// key is currently used for two purposes: OIDC URL prefix and token cache key. +// once we decide to start storing scopes in the token cache, we should change +// this approach. +func (a BasicAccountOAuthArgument) GetCacheKey(ctx context.Context) string { + return fmt.Sprintf("%s/oidc/accounts/%s", a.accountHost, a.accountID) } diff --git a/credentials/oauth/oidc.go b/credentials/oauth/oidc.go new file mode 100644 index 000000000..3b4c77fe2 --- /dev/null +++ b/credentials/oauth/oidc.go @@ -0,0 +1,32 @@ +package oauth + +import ( + "context" + "errors" + "fmt" + + "github.com/databricks/databricks-sdk-go/httpclient" +) + +var ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host") + +func GetAccountOAuthEndpoints(ctx context.Context, accountsHost, accountId string) (*OAuthAuthorizationServer, error) { + return &OAuthAuthorizationServer{ + AuthorizationEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/authorize", accountsHost, accountId), + TokenEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/token", accountsHost, accountId), + }, nil +} + +func GetWorkspaceOAuthEndpoints(ctx context.Context, c *httpclient.ApiClient, host string) (*OAuthAuthorizationServer, error) { + oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", host) + var oauthEndpoints OAuthAuthorizationServer + if err := c.Do(ctx, "GET", oidc, httpclient.WithResponseUnmarshal(&oauthEndpoints)); err != nil { + return nil, ErrOAuthNotSupported + } + return &oauthEndpoints, nil +} + +type OAuthAuthorizationServer struct { + AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize + TokenEndpoint string `json:"token_endpoint"` // ../v1/token +} diff --git a/httpclient/oidc_test.go b/credentials/oauth/oidc_test.go similarity index 73% rename from httpclient/oidc_test.go rename to credentials/oauth/oidc_test.go index 4d6e941f2..7d84c9bda 100644 --- a/httpclient/oidc_test.go +++ b/credentials/oauth/oidc_test.go @@ -1,23 +1,23 @@ -package httpclient +package oauth import ( "context" "testing" + "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" "github.com/stretchr/testify/assert" ) func TestOidcEndpointsForAccounts(t *testing.T) { - p := NewApiClient(ClientConfig{}) - s, err := p.GetOidcEndpoints(context.Background(), "https://abc", "xyz") + s, err := GetAccountOAuthEndpoints(context.Background(), "https://abc", "xyz") assert.NoError(t, err) assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", s.AuthorizationEndpoint) assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/token", s.TokenEndpoint) } func TestOidcForWorkspace(t *testing.T) { - p := NewApiClient(ClientConfig{ + p := httpclient.NewApiClient(httpclient.ClientConfig{ Transport: fixtures.MappingTransport{ "GET /oidc/.well-known/oauth-authorization-server": { Status: 200, @@ -28,7 +28,7 @@ func TestOidcForWorkspace(t *testing.T) { }, }, }) - endpoints, err := p.GetOidcEndpoints(context.Background(), "https://abc", "") + endpoints, err := GetWorkspaceOAuthEndpoints(context.Background(), p, "https://abc") assert.NoError(t, err) assert.Equal(t, "a", endpoints.AuthorizationEndpoint) assert.Equal(t, "b", endpoints.TokenEndpoint) diff --git a/credentials/oauth/page.tmpl b/credentials/oauth/page.tmpl index 4642bb3d4..1540222db 100644 --- a/credentials/oauth/page.tmpl +++ b/credentials/oauth/page.tmpl @@ -91,7 +91,9 @@
{{ .ErrorDescription }}
Authenticated
+ {{- if .Host }}
Go to {{.Host}}
+ {{- end}}
You can close this tab. Or go to documentation diff --git a/credentials/oauth/persistent_auth.go b/credentials/oauth/persistent_auth.go index 4f67337b6..619853d2d 100644 --- a/credentials/oauth/persistent_auth.go +++ b/credentials/oauth/persistent_auth.go @@ -5,6 +5,8 @@ import ( "crypto/rand" "crypto/sha256" "encoding/base64" + "encoding/json" + "errors" "fmt" "net" "os" @@ -38,6 +40,9 @@ const ( // are stored in and looked up from the provided cache. Tokens include the // refresh token. On load, if the access token is expired, it is refreshed // using the refresh token. +// +// The PersistentAuth is safe for concurrent use. The token cache is locked +// during token retrieval, refresh and storage. type PersistentAuth struct { // Cache is the token cache to store and lookup tokens. cache cache.TokenCache @@ -81,6 +86,7 @@ func WithBrowser(b func(url string) error) PersistentAuthOption { } } +// NewPersistentAuth creates a new PersistentAuth with the provided options. func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*PersistentAuth, error) { p := &PersistentAuth{} for _, opt := range opts { @@ -109,43 +115,87 @@ func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*Pers return p, nil } -func (a *PersistentAuth) Load(ctx context.Context, arg OAuthArgument) (*oauth2.Token, error) { +type tokenErrorResponse struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` +} + +func (a *PersistentAuth) Load(ctx context.Context, arg OAuthArgument) (t *oauth2.Token, err error) { + a.locker.Lock() + defer a.locker.Unlock() + + // TODO: remove this listener after several releases. + err = a.startListener(ctx) + if err != nil { + return nil, fmt.Errorf("starting listener: %w", err) + } + defer a.Close() + key := arg.GetCacheKey(ctx) - t, err := a.cache.Lookup(key) + t, err = a.cache.Lookup(key) if err != nil { return nil, fmt.Errorf("cache: %w", err) } // refresh if invalid if !t.Valid() { - // OAuth2 config is invoked only for expired tokens to speed up - // the happy path in the token retrieval - cfg, err := a.oauth2Config(ctx, arg.GetHost(ctx), arg.GetAccountId(ctx)) - if err != nil { - return nil, err - } - // make OAuth2 library use our client - ctx = a.client.InContextForOAuth2(ctx) - // eagerly refresh token - t, err = cfg.TokenSource(ctx, t).Token() + t, err = a.refresh(ctx, arg, t) if err != nil { return nil, fmt.Errorf("token refresh: %w", err) } - err = a.cache.Store(key, t) - if err != nil { - return nil, fmt.Errorf("cache refresh: %w", err) - } } // do not print refresh token to end-user t.RefreshToken = "" return t, nil } +func (a *PersistentAuth) refresh(ctx context.Context, arg OAuthArgument, oldToken *oauth2.Token) (*oauth2.Token, error) { + // OAuth2 config is invoked only for expired tokens to speed up + // the happy path in the token retrieval + cfg, err := a.oauth2Config(ctx, arg) + if err != nil { + return nil, err + } + // make OAuth2 library use our client + ctx = a.client.InContextForOAuth2(ctx) + // eagerly refresh token + t, err := cfg.TokenSource(ctx, oldToken).Token() + if err != nil { + var httpErr *httpclient.HttpError + if errors.As(err, &httpErr) { + resp := &tokenErrorResponse{} + err = json.Unmarshal([]byte(httpErr.Message), resp) + if err != nil { + return nil, fmt.Errorf("unexpected parsing token response: %w", err) + } + // Invalid refresh tokens get their own error type so they can be + // better presented to users. + if resp.ErrorDescription == "Refresh token is invalid" { + return nil, &InvalidRefreshTokenError{err} + } else { + return nil, fmt.Errorf("unexpected error refreshing token: %s", resp.ErrorDescription) + } + } + return nil, fmt.Errorf("token refresh: %w", err) + } + err = a.cache.Store(arg.GetCacheKey(ctx), t) + if err != nil { + return nil, fmt.Errorf("cache refresh: %w", err) + } + return t, nil +} + func (a *PersistentAuth) Challenge(ctx context.Context, arg OAuthArgument) error { + a.locker.Lock() + defer a.locker.Unlock() err := a.startListener(ctx) if err != nil { return fmt.Errorf("starting listener: %w", err) } - cfg, err := a.oauth2Config(ctx, arg.GetHost(ctx), arg.GetAccountId(ctx)) + // The listener will be closed by the callback server automatically, but if + // the callback server is not created, we need to close the listener manually. + defer a.Close() + + cfg, err := a.oauth2Config(ctx, arg) if err != nil { return fmt.Errorf("fetching oauth config: %w", err) } @@ -154,6 +204,7 @@ func (a *PersistentAuth) Challenge(ctx context.Context, arg OAuthArgument) error return fmt.Errorf("callback server: %w", err) } defer cb.Close() + state, pkce := a.stateAndPKCE() // make OAuth2 library use our client ctx = a.client.InContextForOAuth2(ctx) @@ -194,16 +245,22 @@ func (a *PersistentAuth) Close() error { return a.ln.Close() } -func (a *PersistentAuth) oauth2Config(ctx context.Context, host string, accountId string) (*oauth2.Config, error) { - // in this iteration of CLI, we're using all scopes by default, - // because tools like CLI and Terraform do use all apis. This - // decision may be reconsidered later, once we have a proper - // taxonomy of all scopes ready and implemented. +func (a *PersistentAuth) oauth2Config(ctx context.Context, arg OAuthArgument) (*oauth2.Config, error) { scopes := []string{ - "offline_access", - "all-apis", + "offline_access", // ensures OAuth token includes refresh token + "all-apis", // ensures OAuth token has access to all control-plane APIs + } + var endpoints *OAuthAuthorizationServer + var err error + switch arg := arg.(type) { + case WorkspaceOAuthArgument: + endpoints, err = GetWorkspaceOAuthEndpoints(ctx, a.client, arg.GetWorkspaceHost(ctx)) + if err != nil { + return nil, fmt.Errorf("workspace oauth endpoints: %w", err) + } + case AccountOAuthArgument: + endpoints, err = GetAccountOAuthEndpoints(ctx, arg.GetAccountHost(ctx), arg.GetAccountId(ctx)) } - endpoints, err := a.client.GetOidcEndpoints(ctx, host, accountId) if err != nil { return nil, fmt.Errorf("oidc: %w", err) } diff --git a/credentials/oauth/persistent_auth_test.go b/credentials/oauth/persistent_auth_test.go index 5c9e22578..36de8805f 100644 --- a/credentials/oauth/persistent_auth_test.go +++ b/credentials/oauth/persistent_auth_test.go @@ -50,10 +50,9 @@ func TestLoad(t *testing.T) { p, err := oauth.NewPersistentAuth(context.Background(), oauth.WithTokenCache(cache)) require.NoError(t, err) defer p.Close() - tok, err := p.Load(context.Background(), oauth.BasicOAuthArgument{ - Host: "https://abc", - AccountID: "xyz", - }) + arg, err := oauth.NewBasicAccountOAuthArgument("https://abc", "xyz") + assert.NoError(t, err) + tok, err := p.Load(context.Background(), arg) assert.NoError(t, err) assert.Equal(t, "bcd", tok.AccessToken) assert.Equal(t, "", tok.RefreshToken) @@ -97,10 +96,9 @@ func TestLoadRefresh(t *testing.T) { p, err := oauth.NewPersistentAuth(context.Background(), oauth.WithTokenCache(cache)) require.NoError(t, err) defer p.Close() - tok, err := p.Load(ctx, oauth.BasicOAuthArgument{ - Host: c.Config.Host, - AccountID: "xyz", - }) + arg, err := oauth.NewBasicAccountOAuthArgument(c.Config.Host, "xyz") + assert.NoError(t, err) + tok, err := p.Load(ctx, arg) assert.NoError(t, err) assert.Equal(t, "refreshed", tok.AccessToken) assert.Equal(t, "", tok.RefreshToken) @@ -140,13 +138,12 @@ func TestChallenge(t *testing.T) { p, err := oauth.NewPersistentAuth(context.Background(), oauth.WithTokenCache(cache), oauth.WithBrowser(browser)) require.NoError(t, err) defer p.Close() + arg, err := oauth.NewBasicAccountOAuthArgument(c.Config.Host, "xyz") + assert.NoError(t, err) errc := make(chan error) go func() { - errc <- p.Challenge(ctx, oauth.BasicOAuthArgument{ - Host: c.Config.Host, - AccountID: "xyz", - }) + errc <- p.Challenge(ctx, arg) }() state := <-browserOpened @@ -179,13 +176,12 @@ func TestChallengeFailed(t *testing.T) { p, err := oauth.NewPersistentAuth(context.Background(), oauth.WithBrowser(browser)) require.NoError(t, err) defer p.Close() + arg, err := oauth.NewBasicAccountOAuthArgument(c.Config.Host, "xyz") + assert.NoError(t, err) errc := make(chan error) go func() { - errc <- p.Challenge(ctx, oauth.BasicOAuthArgument{ - Host: c.Config.Host, - AccountID: "xyz", - }) + errc <- p.Challenge(ctx, arg) }() <-browserOpened diff --git a/httpclient/oidc.go b/httpclient/oidc.go deleted file mode 100644 index b6d9b0a83..000000000 --- a/httpclient/oidc.go +++ /dev/null @@ -1,35 +0,0 @@ -package httpclient - -import ( - "context" - "errors" - "fmt" -) - -var ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host") - -func (c *ApiClient) GetOidcEndpoints(ctx context.Context, host, accountId string) (*OAuthAuthorizationServer, error) { - prefix := host - if /* cfg.IsAccountClient() && */ accountId != "" { - // TODO: technically, we could use the same config profile for both workspace - // and account, but we have to add logic for determining accounts host from - // workspace host. - prefix := fmt.Sprintf("%s/oidc/accounts/%s", host, accountId) - return &OAuthAuthorizationServer{ - AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix), - TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix), - }, nil - } - oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", prefix) - var oauthEndpoints OAuthAuthorizationServer - err := c.Do(ctx, "GET", oidc, WithResponseUnmarshal(&oauthEndpoints)) - if err != nil { - return nil, ErrOAuthNotSupported - } - return &oauthEndpoints, nil -} - -type OAuthAuthorizationServer struct { - AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize - TokenEndpoint string `json:"token_endpoint"` // ../v1/token -} From 47f2fb2219a11c6e15d9224f438a634bb8a6be75 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 6 Jan 2025 13:57:20 +0100 Subject: [PATCH 05/44] more work --- config/auth_m2m.go | 2 +- config/auth_m2m_test.go | 6 +- config/auth_u2m.go | 16 +++- config/config.go | 8 ++ credentials/cache/cache.go | 5 ++ credentials/cache/file.go | 120 ++++++++++++++++++---------- credentials/cache/file_test.go | 64 ++++----------- credentials/cache/in_memory.go | 26 ------ credentials/cache/in_memory_test.go | 44 ---------- 9 files changed, 122 insertions(+), 169 deletions(-) delete mode 100644 credentials/cache/in_memory.go delete mode 100644 credentials/cache/in_memory_test.go diff --git a/config/auth_m2m.go b/config/auth_m2m.go index 0ae99d6fe..cfa7a59c3 100644 --- a/config/auth_m2m.go +++ b/config/auth_m2m.go @@ -22,7 +22,7 @@ func (c M2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials if cfg.ClientID == "" || cfg.ClientSecret == "" { return nil, nil } - endpoints, err := cfg.refreshClient.GetOidcEndpoints(ctx, cfg.Host, cfg.AccountID) + endpoints, err := cfg.getOidcEndpoints(ctx) if err != nil { return nil, fmt.Errorf("oidc: %w", err) } diff --git a/config/auth_m2m_test.go b/config/auth_m2m_test.go index f05264a47..7181be436 100644 --- a/config/auth_m2m_test.go +++ b/config/auth_m2m_test.go @@ -4,7 +4,7 @@ import ( "net/url" "testing" - "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -17,7 +17,7 @@ func TestM2mHappyFlow(t *testing.T) { ClientSecret: "c", HTTPTransport: fixtures.MappingTransport{ "GET /oidc/.well-known/oauth-authorization-server": { - Response: httpclient.OAuthAuthorizationServer{ + Response: oauth.OAuthAuthorizationServer{ AuthorizationEndpoint: "https://localhost:1234/dummy/auth", TokenEndpoint: "https://localhost:1234/dummy/token", }, @@ -81,5 +81,5 @@ func TestM2mNotSupported(t *testing.T) { }, }, }) - require.ErrorIs(t, err, httpclient.ErrOAuthNotSupported) + require.ErrorIs(t, err, oauth.ErrOAuthNotSupported) } diff --git a/config/auth_u2m.go b/config/auth_u2m.go index 6aeea162b..d08f3642c 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -45,16 +45,17 @@ func (u U2MCredentials) Configure(ctx context.Context, cfg *Config) (credentials } } + arg, err := u.getOAuthArg(cfg) + if err != nil { + return nil, fmt.Errorf("oidc: %w", err) + } + r, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil) if err != nil { return nil, fmt.Errorf("http request: %w", err) } f := func(r *http.Request) error { - arg := oauth.BasicOAuthArgument{ - Host: cfg.Host, - AccountID: cfg.AccountID, - } token, err := a.Load(r.Context(), arg) if err != nil { return fmt.Errorf("oidc: %w", err) @@ -77,6 +78,13 @@ func (u U2MCredentials) Configure(ctx context.Context, cfg *Config) (credentials return credentials.NewCredentialsProvider(f), nil } +func (u U2MCredentials) getOAuthArg(cfg *Config) (oauth.OAuthArgument, error) { + if cfg.IsAccountClient() { + return oauth.NewBasicAccountOAuthArgument(cfg.Host, cfg.AccountID) + } + return oauth.NewBasicWorkspaceOAuthArgument(cfg.Host) +} + var _ CredentialsStrategy = U2MCredentials{} var databricksCliCredentials = U2MCredentials{ diff --git a/config/config.go b/config/config.go index fcf69d2cb..fca6e77fa 100644 --- a/config/config.go +++ b/config/config.go @@ -14,6 +14,7 @@ import ( "github.com/databricks/databricks-sdk-go/common" "github.com/databricks/databricks-sdk-go/common/environment" "github.com/databricks/databricks-sdk-go/credentials" + "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" "golang.org/x/oauth2" @@ -434,3 +435,10 @@ func (c *Config) refreshTokenErrorMapper(ctx context.Context, resp common.Respon err: err, } } + +func (c *Config) getOidcEndpoints(ctx context.Context) (*oauth.OAuthAuthorizationServer, error) { + if c.IsAccountClient() { + return oauth.GetAccountOAuthEndpoints(ctx, c.Host, c.AccountID) + } + return oauth.GetWorkspaceOAuthEndpoints(ctx, c.refreshClient, c.Host) +} diff --git a/credentials/cache/cache.go b/credentials/cache/cache.go index e20edd23d..0562b41fb 100644 --- a/credentials/cache/cache.go +++ b/credentials/cache/cache.go @@ -6,8 +6,13 @@ import ( "golang.org/x/oauth2" ) +// TokenCache is an interface for storing and looking up OAuth tokens. type TokenCache interface { + // Store stores the token with the given key, replacing any existing token. Store(key string, t *oauth2.Token) error + + // Lookup looks up the token with the given key. If the token is not found, it + // returns ErrNotConfigured. Lookup(key string) (*oauth2.Token, error) } diff --git a/credentials/cache/file.go b/credentials/cache/file.go index f8ba51f58..bdf09dfd4 100644 --- a/credentials/cache/file.go +++ b/credentials/cache/file.go @@ -2,9 +2,7 @@ package cache import ( "encoding/json" - "errors" "fmt" - "io/fs" "os" "path/filepath" @@ -14,12 +12,12 @@ import ( const ( // tokenCacheFile is the path of the default token cache, relative to the // user's home directory. - tokenCacheFile = ".databricks/token-cache.json" + tokenCacheFilePath = ".databricks/token-cache.json" - // only the owner of the file has full execute, read, and write access + // ownerExecReadWrite is the permission for the .databricks directory. ownerExecReadWrite = 0o700 - // only the owner of the file has full read and write access + // ownerReadWrite is the permission for the token-cache.json file. ownerReadWrite = 0o600 // tokenCacheVersion is the version of the token cache file format. @@ -45,33 +43,36 @@ const ( tokenCacheVersion = 1 ) -// FileTokenCache caches tokens in "~/.databricks/token-cache.json". FileTokenCache -// implements the TokenCache interface. -// this implementation requires the calling code to do a machine-wide lock, -// otherwise the file might get corrupt. -type FileTokenCache struct { +// The format of the token cache file. +type tokenCacheFile struct { Version int `json:"version"` Tokens map[string]*oauth2.Token `json:"tokens"` +} +// FileTokenCache caches tokens in "~/.databricks/token-cache.json". FileTokenCache +// implements the TokenCache interface. +type FileTokenCache struct { fileLocation string } +func NewFileTokenCache() (*FileTokenCache, error) { + c := &FileTokenCache{} + if err := c.init(); err != nil { + return nil, err + } + return c, nil +} + +// Store implements the TokenCache interface. func (c *FileTokenCache) Store(key string, t *oauth2.Token) error { - err := c.load() + f, err := c.load() if err != nil { - if !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("load: %w", err) - } - dir := filepath.Dir(c.fileLocation) - if err := os.MkdirAll(dir, ownerExecReadWrite); err != nil { - return fmt.Errorf("mkdir: %w", err) - } + return fmt.Errorf("load: %w", err) } - c.Version = tokenCacheVersion - if c.Tokens == nil { - c.Tokens = map[string]*oauth2.Token{} + if f.Tokens == nil { + f.Tokens = map[string]*oauth2.Token{} } - c.Tokens[key] = t + f.Tokens[key] = t raw, err := json.MarshalIndent(c, "", " ") if err != nil { return fmt.Errorf("marshal: %w", err) @@ -79,46 +80,79 @@ func (c *FileTokenCache) Store(key string, t *oauth2.Token) error { return os.WriteFile(c.fileLocation, raw, ownerReadWrite) } +// Lookup implements the TokenCache interface. func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) { - err := c.load() - if errors.Is(err, fs.ErrNotExist) { - return nil, ErrNotConfigured - } else if err != nil { + f, err := c.load() + if err != nil { return nil, fmt.Errorf("load: %w", err) } - t, ok := c.Tokens[key] + t, ok := f.Tokens[key] if !ok { return nil, ErrNotConfigured } return t, nil } -func (c *FileTokenCache) load() error { - home, err := os.UserHomeDir() - if err != nil { - return fmt.Errorf("failed loading home directory: %w", err) +// init initializes the token cache file. It creates the file and directory if +// they do not already exist. +func (c *FileTokenCache) init() error { + // set the default file location + if c.fileLocation == "" { + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("failed loading home directory: %w", err) + } + c.fileLocation = filepath.Join(home, tokenCacheFilePath) } - loc := filepath.Join(home, tokenCacheFile) - if err != nil { - return err + // create the directory if it doesn't already exist + if _, err := os.Stat(filepath.Dir(c.fileLocation)); err != nil { + if !os.IsNotExist(err) { + return fmt.Errorf("stat directory: %w", err) + } + // create the directory + if err := os.MkdirAll(filepath.Dir(c.fileLocation), ownerExecReadWrite); err != nil { + return fmt.Errorf("mkdir: %w", err) + } } - c.fileLocation = loc - raw, err := os.ReadFile(loc) + // create the file if it doesn't already exist + if _, err := os.Stat(c.fileLocation); err != nil { + if !os.IsNotExist(err) { + return fmt.Errorf("stat file: %w", err) + } + f := &tokenCacheFile{ + Version: tokenCacheVersion, + Tokens: map[string]*oauth2.Token{}, + } + raw, err := json.MarshalIndent(f, "", " ") + if err != nil { + return fmt.Errorf("marshal: %w", err) + } + if err := os.WriteFile(c.fileLocation, raw, ownerReadWrite); err != nil { + return fmt.Errorf("write: %w", err) + } + } + return nil +} + +// load loads the token cache file from disk. If the file is corrupt or if its +// version does not match tokenCacheVersion, it returns an error. +func (c *FileTokenCache) load() (*tokenCacheFile, error) { + raw, err := os.ReadFile(c.fileLocation) if err != nil { - return fmt.Errorf("read: %w", err) + return nil, fmt.Errorf("read: %w", err) } - err = json.Unmarshal(raw, c) + f := &tokenCacheFile{} + err = json.Unmarshal(raw, f) if err != nil { - return fmt.Errorf("parse: %w", err) + return nil, fmt.Errorf("parse: %w", err) } - if c.Version != tokenCacheVersion { + if f.Version != tokenCacheVersion { // in the later iterations we could do state upgraders, // so that we transform token cache from v1 to v2 without // losing the tokens and asking the user to re-authenticate. - return fmt.Errorf("needs version %d, got version %d", - tokenCacheVersion, c.Version) + return nil, fmt.Errorf("needs version %d, got version %d", tokenCacheVersion, f.Version) } - return nil + return f, nil } var _ TokenCache = (*FileTokenCache)(nil) diff --git a/credentials/cache/file_test.go b/credentials/cache/file_test.go index 3e4aae36f..9760d581f 100644 --- a/credentials/cache/file_test.go +++ b/credentials/cache/file_test.go @@ -3,7 +3,6 @@ package cache import ( "os" "path/filepath" - "runtime" "testing" "github.com/stretchr/testify/assert" @@ -11,23 +10,16 @@ import ( "golang.org/x/oauth2" ) -var homeEnvVar = "HOME" - -func init() { - if runtime.GOOS == "windows" { - homeEnvVar = "USERPROFILE" - } -} - func setup(t *testing.T) string { tempHomeDir := t.TempDir() - t.Setenv(homeEnvVar, tempHomeDir) - return tempHomeDir + return filepath.Join(tempHomeDir, "token-cache.json") } func TestStoreAndLookup(t *testing.T) { - setup(t) - c := &FileTokenCache{} + c := &FileTokenCache{ + fileLocation: setup(t), + } + assert.NoError(t, c.init()) err := c.Store("x", &oauth2.Token{ AccessToken: "abc", }) @@ -42,64 +34,40 @@ func TestStoreAndLookup(t *testing.T) { tok, err := l.Lookup("x") require.NoError(t, err) assert.Equal(t, "abc", tok.AccessToken) - assert.Equal(t, 2, len(l.Tokens)) _, err = l.Lookup("z") assert.Equal(t, ErrNotConfigured, err) } func TestNoCacheFileReturnsErrNotConfigured(t *testing.T) { - setup(t) - l := &FileTokenCache{} + l := &FileTokenCache{ + fileLocation: setup(t), + } + assert.NoError(t, l.init()) _, err := l.Lookup("x") assert.Equal(t, ErrNotConfigured, err) } func TestLoadCorruptFile(t *testing.T) { - home := setup(t) - f := filepath.Join(home, tokenCacheFile) + f := setup(t) err := os.MkdirAll(filepath.Dir(f), ownerExecReadWrite) require.NoError(t, err) err = os.WriteFile(f, []byte("abc"), ownerExecReadWrite) require.NoError(t, err) - l := &FileTokenCache{} - _, err = l.Lookup("x") - assert.EqualError(t, err, "load: parse: invalid character 'a' looking for beginning of value") + l := &FileTokenCache{ + fileLocation: f, + } + assert.EqualError(t, l.init(), "load: parse: invalid character 'a' looking for beginning of value") } func TestLoadWrongVersion(t *testing.T) { - home := setup(t) - f := filepath.Join(home, tokenCacheFile) + f := setup(t) err := os.MkdirAll(filepath.Dir(f), ownerExecReadWrite) require.NoError(t, err) err = os.WriteFile(f, []byte(`{"version": 823, "things": []}`), ownerExecReadWrite) require.NoError(t, err) l := &FileTokenCache{} - _, err = l.Lookup("x") - assert.EqualError(t, err, "load: needs version 1, got version 823") -} - -func TestDevNull(t *testing.T) { - t.Setenv(homeEnvVar, "/dev/null") - l := &FileTokenCache{} - _, err := l.Lookup("x") - // macOS/Linux: load: read: open /dev/null/.databricks/token-cache.json: - // windows: databricks OAuth is not configured for this host - assert.Error(t, err) -} - -func TestStoreOnDev(t *testing.T) { - if runtime.GOOS == "windows" { - t.SkipNow() - } - t.Setenv(homeEnvVar, "/dev") - c := &FileTokenCache{} - err := c.Store("x", &oauth2.Token{ - AccessToken: "abc", - }) - // Linux: permission denied - // macOS: read-only file system - assert.Error(t, err) + assert.EqualError(t, l.init(), "load: needs version 1, got version 823") } diff --git a/credentials/cache/in_memory.go b/credentials/cache/in_memory.go deleted file mode 100644 index 469d45575..000000000 --- a/credentials/cache/in_memory.go +++ /dev/null @@ -1,26 +0,0 @@ -package cache - -import ( - "golang.org/x/oauth2" -) - -type InMemoryTokenCache struct { - Tokens map[string]*oauth2.Token -} - -// Lookup implements TokenCache. -func (i *InMemoryTokenCache) Lookup(key string) (*oauth2.Token, error) { - token, ok := i.Tokens[key] - if !ok { - return nil, ErrNotConfigured - } - return token, nil -} - -// Store implements TokenCache. -func (i *InMemoryTokenCache) Store(key string, t *oauth2.Token) error { - i.Tokens[key] = t - return nil -} - -var _ TokenCache = (*InMemoryTokenCache)(nil) diff --git a/credentials/cache/in_memory_test.go b/credentials/cache/in_memory_test.go deleted file mode 100644 index d8394d3b2..000000000 --- a/credentials/cache/in_memory_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package cache - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "golang.org/x/oauth2" -) - -func TestInMemoryCacheHit(t *testing.T) { - token := &oauth2.Token{ - AccessToken: "abc", - } - c := &InMemoryTokenCache{ - Tokens: map[string]*oauth2.Token{ - "key": token, - }, - } - res, err := c.Lookup("key") - assert.Equal(t, res, token) - assert.NoError(t, err) -} - -func TestInMemoryCacheMiss(t *testing.T) { - c := &InMemoryTokenCache{ - Tokens: map[string]*oauth2.Token{}, - } - _, err := c.Lookup("key") - assert.ErrorIs(t, err, ErrNotConfigured) -} - -func TestInMemoryCacheStore(t *testing.T) { - token := &oauth2.Token{ - AccessToken: "abc", - } - c := &InMemoryTokenCache{ - Tokens: map[string]*oauth2.Token{}, - } - err := c.Store("key", token) - assert.NoError(t, err) - res, err := c.Lookup("key") - assert.Equal(t, res, token) - assert.NoError(t, err) -} From e8836168f681e242e379c1068e706922dc2c9c67 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 6 Jan 2025 14:34:46 +0100 Subject: [PATCH 06/44] work --- config/auth_u2m.go | 96 ++++++++++++++++++++++++++-- credentials/oauth/persistent_auth.go | 6 +- 2 files changed, 94 insertions(+), 8 deletions(-) diff --git a/config/auth_u2m.go b/config/auth_u2m.go index d08f3642c..f29b94152 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "net/http" + "os/exec" + "strings" "github.com/databricks/databricks-sdk-go/credentials" "github.com/databricks/databricks-sdk-go/credentials/cache" @@ -12,15 +14,32 @@ import ( "github.com/databricks/databricks-sdk-go/logger" ) +// U2MCredentials is a credentials strategy that uses the U2M OAuth flow to +// authenticate with Databricks. +// +// To authenticate with U2M OAuth, the user must already have an existing OAuth +// session. The specific OAuth session is indicated by the OAuth argument +// provided by GetOAuthArg. By default, the OAuth argument is determined by the +// account host and account ID or workspace host in the Config. +// +// Error handling for this strategy is controlled by the ErrorHandler field. If +// ErrorHandler is not specified, any error will cause Configure() to return said +// error. type U2MCredentials struct { // Auth is the persistent auth object to use. If not specified, a new one will // be created, using the default cache and locker. Auth *oauth.PersistentAuth + // GetOAuthArg is a function that returns the OAuth argument to use for + // loading the OAuth session token. If not specified, the OAuth argument is + // determined by the account host and account ID or workspace host in the + // Config. + GetOAuthArg func(context.Context, *Config) (oauth.OAuthArgument, error) + // ErrorHandler controls the behavior of Configure() when loading the OAuth // token fails. If not specified, any error will cause Configure() to return // said error. - ErrorHandler func(context.Context, error) error + ErrorHandler func(context.Context, *Config, oauth.OAuthArgument, error) error name string } @@ -45,7 +64,13 @@ func (u U2MCredentials) Configure(ctx context.Context, cfg *Config) (credentials } } - arg, err := u.getOAuthArg(cfg) + var arg oauth.OAuthArgument + var err error + if u.GetOAuthArg != nil { + arg, err = u.GetOAuthArg(ctx, cfg) + } else { + arg, err = defaultGetOAuthArg(ctx, cfg) + } if err != nil { return nil, fmt.Errorf("oidc: %w", err) } @@ -70,7 +95,7 @@ func (u U2MCredentials) Configure(ctx context.Context, cfg *Config) (credentials // credentials strategy. if err := f(r); err != nil { if u.ErrorHandler != nil { - return nil, u.ErrorHandler(ctx, err) + return nil, u.ErrorHandler(ctx, cfg, arg, err) } return nil, err } @@ -78,7 +103,7 @@ func (u U2MCredentials) Configure(ctx context.Context, cfg *Config) (credentials return credentials.NewCredentialsProvider(f), nil } -func (u U2MCredentials) getOAuthArg(cfg *Config) (oauth.OAuthArgument, error) { +func defaultGetOAuthArg(_ context.Context, cfg *Config) (oauth.OAuthArgument, error) { if cfg.IsAccountClient() { return oauth.NewBasicAccountOAuthArgument(cfg.Host, cfg.AccountID) } @@ -87,16 +112,73 @@ func (u U2MCredentials) getOAuthArg(cfg *Config) (oauth.OAuthArgument, error) { var _ CredentialsStrategy = U2MCredentials{} +// CliInvalidRefreshTokenError is a special error type that is returned when a +// new access token could not be retrieved because the refresh token is invalid. +// It is returned only by the `databricks-cli` credentials strategy. +type CliInvalidRefreshTokenError struct { + loginCommand string + err error +} + +func (e *CliInvalidRefreshTokenError) Error() string { + msg := "a new access token could not be retrieved because the refresh token is invalid." + if e.loginCommand != "" { + msg += fmt.Sprintf(" To reauthenticate, run `%s`", e.loginCommand) + } + return msg +} + +func (e *CliInvalidRefreshTokenError) Unwrap() error { + return e.err +} + +func buildLoginCommand(ctx context.Context, profile string, arg oauth.OAuthArgument) string { + cmd := []string{ + "databricks", + "auth", + "login", + } + if profile != "" { + cmd = append(cmd, "--profile", profile) + } else { + switch arg := arg.(type) { + case oauth.AccountOAuthArgument: + cmd = append(cmd, "--host", arg.GetAccountHost(ctx), "--account", arg.GetAccountId(ctx)) + case oauth.WorkspaceOAuthArgument: + cmd = append(cmd, "--host", arg.GetWorkspaceHost(ctx)) + } + } + return strings.Join(cmd, " ") +} + +// databricksCliCredentials is a credentials strategy that emulates the behavior +// of the earlier `databricks-cli` credentials strategy which invoked the +// `databricks auth token` command. var databricksCliCredentials = U2MCredentials{ - ErrorHandler: func(ctx context.Context, err error) error { + ErrorHandler: func(ctx context.Context, cfg *Config, arg oauth.OAuthArgument, err error) error { + // If the current OAuth argument doesn't have a corresponding session + // token, fall back to the next credentials strategy. if errors.Is(err, cache.ErrNotConfigured) { return nil } + // If there is an existing token but the refresh token is invalid, + // return a special error message for invalid refresh tokens. If the + // `databricks` CLI is on the PATH, include a command that the user can + // run to reauthenticate. if _, ok := err.(*oauth.InvalidRefreshTokenError); ok { - return err + var loginCommand string + if _, execErr := exec.LookPath("databricks"); execErr == nil { + loginCommand = buildLoginCommand(ctx, cfg.Profile, arg) + } + return &CliInvalidRefreshTokenError{ + loginCommand: loginCommand, + err: err, + } } + // Otherwise, log the error and continue to the next credentials strategy. logger.Debugf(ctx, "failed to load token: %v, continuing", err) return nil }, - name: "databricks-cli", + GetOAuthArg: defaultGetOAuthArg, + name: "databricks-cli", } diff --git a/credentials/oauth/persistent_auth.go b/credentials/oauth/persistent_auth.go index 619853d2d..f0c23a0a4 100644 --- a/credentials/oauth/persistent_auth.go +++ b/credentials/oauth/persistent_auth.go @@ -96,7 +96,11 @@ func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*Pers p.client = httpclient.NewApiClient(httpclient.ClientConfig{}) } if p.cache == nil { - p.cache = &cache.FileTokenCache{} + var err error + p.cache, err = cache.NewFileTokenCache() + if err != nil { + return nil, fmt.Errorf("cache: %w", err) + } } if p.locker == nil { home, err := os.UserHomeDir() From 181f88b7ed8b3e65e36c594ab67c4fe5eea70572 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 6 Jan 2025 15:00:35 +0100 Subject: [PATCH 07/44] basic tests --- config/auth_u2m_test.go | 99 ++++++++++++++++++++++++++++++++++++++++ config/in_memory_test.go | 27 +++++++++++ 2 files changed, 126 insertions(+) create mode 100644 config/auth_u2m_test.go create mode 100644 config/in_memory_test.go diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go new file mode 100644 index 000000000..2caf4805f --- /dev/null +++ b/config/auth_u2m_test.go @@ -0,0 +1,99 @@ +package config + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/databricks/databricks-sdk-go/credentials/oauth" + "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/databricks/databricks-sdk-go/httpclient/fixtures" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestU2MCredentials(t *testing.T) { + tests := []struct { + name string + cfg *Config + auth func() (*oauth.PersistentAuth, error) + expectErr bool + expectAuth string + }{ + { + name: "happy path", + cfg: &Config{ + Host: "https://myworkspace.cloud.databricks.com", + }, + auth: func() (*oauth.PersistentAuth, error) { + return oauth.NewPersistentAuth( + context.Background(), + oauth.WithTokenCache(&InMemoryTokenCache{ + Tokens: map[string]*oauth2.Token{ + "https://myworkspace.cloud.databricks.com": { + AccessToken: "dummy_access_token", + Expiry: time.Now().Add(1 * time.Hour), + }, + }, + })) + }, + expectErr: false, + expectAuth: "Bearer dummy_access_token", + }, + { + name: "expired token with invalid refresh token", + cfg: &Config{ + Host: "https://databricks.com", + }, + auth: func() (*oauth.PersistentAuth, error) { + return oauth.NewPersistentAuth( + context.Background(), + oauth.WithTokenCache(&InMemoryTokenCache{ + Tokens: map[string]*oauth2.Token{ + "https://myworkspace.cloud.databricks.com": { + AccessToken: "dummy_access_token", + RefreshToken: "dummy_refresh_token", + Expiry: time.Now().Add(-1 * time.Hour), + }, + }, + }), + oauth.WithApiClient(httpclient.NewApiClient(httpclient.ClientConfig{ + Transport: fixtures.SliceTransport{ + { + MatchAny: true, + Status: 401, + Response: `{"error":"invalid_refresh_token","error_description":"Refresh token is invalid"}`, + }, + }, + })), + ) + }, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + auth, err := tt.auth() + require.NoError(t, err) + strat := U2MCredentials{ + Auth: auth, + } + provider, err := strat.Configure(ctx, tt.cfg) + if tt.expectErr { + require.Error(t, err) + return + } + require.NoError(t, err) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://databricks.com", nil) + require.NoError(t, err) + + err = provider.SetHeaders(req) + require.NoError(t, err) + require.Equal(t, tt.expectAuth, req.Header.Get("Authorization")) + }) + } +} diff --git a/config/in_memory_test.go b/config/in_memory_test.go new file mode 100644 index 000000000..82ce6e2c7 --- /dev/null +++ b/config/in_memory_test.go @@ -0,0 +1,27 @@ +package config + +import ( + "github.com/databricks/databricks-sdk-go/credentials/cache" + "golang.org/x/oauth2" +) + +type InMemoryTokenCache struct { + Tokens map[string]*oauth2.Token +} + +// Lookup implements TokenCache. +func (i *InMemoryTokenCache) Lookup(key string) (*oauth2.Token, error) { + token, ok := i.Tokens[key] + if !ok { + return nil, cache.ErrNotConfigured + } + return token, nil +} + +// Store implements TokenCache. +func (i *InMemoryTokenCache) Store(key string, t *oauth2.Token) error { + i.Tokens[key] = t + return nil +} + +var _ cache.TokenCache = (*InMemoryTokenCache)(nil) From 74592c388f923958ab07f0e6a1b9f319fb6078d5 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 6 Jan 2025 16:58:16 +0100 Subject: [PATCH 08/44] some tests --- config/auth_default.go | 2 +- config/auth_u2m.go | 66 +++++++++------ config/auth_u2m_test.go | 122 ++++++++++++++++++++++++--- credentials/oauth/oauth_argument.go | 8 +- credentials/oauth/persistent_auth.go | 61 +++++++++++--- httpclient/http.go | 10 +++ 6 files changed, 214 insertions(+), 55 deletions(-) create mode 100644 httpclient/http.go diff --git a/config/auth_default.go b/config/auth_default.go index 6f5889fb7..0d2091583 100644 --- a/config/auth_default.go +++ b/config/auth_default.go @@ -13,7 +13,7 @@ var authProviders = []CredentialsStrategy{ PatCredentials{}, BasicCredentials{}, M2mCredentials{}, - databricksCliCredentials, + makeDatabricksCliCredentials(defaultPathLooker{}), MetadataServiceCredentials{}, // Attempt to configure auth from most specific to most generic (the Azure CLI). diff --git a/config/auth_u2m.go b/config/auth_u2m.go index f29b94152..f1ec7ff0a 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -143,7 +143,7 @@ func buildLoginCommand(ctx context.Context, profile string, arg oauth.OAuthArgum } else { switch arg := arg.(type) { case oauth.AccountOAuthArgument: - cmd = append(cmd, "--host", arg.GetAccountHost(ctx), "--account", arg.GetAccountId(ctx)) + cmd = append(cmd, "--host", arg.GetAccountHost(ctx), "--account-id", arg.GetAccountId(ctx)) case oauth.WorkspaceOAuthArgument: cmd = append(cmd, "--host", arg.GetWorkspaceHost(ctx)) } @@ -151,34 +151,48 @@ func buildLoginCommand(ctx context.Context, profile string, arg oauth.OAuthArgum return strings.Join(cmd, " ") } +// pathLooker is an interface that abstracts the LookPath function from the +// os/exec package. It is used to facilitate testing. +type pathLooker interface { + LookPath(file string) (string, error) +} + +type defaultPathLooker struct{} + +func (defaultPathLooker) LookPath(file string) (string, error) { + return exec.LookPath(file) +} + // databricksCliCredentials is a credentials strategy that emulates the behavior // of the earlier `databricks-cli` credentials strategy which invoked the // `databricks auth token` command. -var databricksCliCredentials = U2MCredentials{ - ErrorHandler: func(ctx context.Context, cfg *Config, arg oauth.OAuthArgument, err error) error { - // If the current OAuth argument doesn't have a corresponding session - // token, fall back to the next credentials strategy. - if errors.Is(err, cache.ErrNotConfigured) { - return nil - } - // If there is an existing token but the refresh token is invalid, - // return a special error message for invalid refresh tokens. If the - // `databricks` CLI is on the PATH, include a command that the user can - // run to reauthenticate. - if _, ok := err.(*oauth.InvalidRefreshTokenError); ok { - var loginCommand string - if _, execErr := exec.LookPath("databricks"); execErr == nil { - loginCommand = buildLoginCommand(ctx, cfg.Profile, arg) +func makeDatabricksCliCredentials(pathLooker pathLooker) U2MCredentials { + return U2MCredentials{ + ErrorHandler: func(ctx context.Context, cfg *Config, arg oauth.OAuthArgument, err error) error { + // If the current OAuth argument doesn't have a corresponding session + // token, fall back to the next credentials strategy. + if errors.Is(err, cache.ErrNotConfigured) { + return nil } - return &CliInvalidRefreshTokenError{ - loginCommand: loginCommand, - err: err, + // If there is an existing token but the refresh token is invalid, + // return a special error message for invalid refresh tokens. If the + // `databricks` CLI is on the PATH, include a command that the user can + // run to reauthenticate. + if _, ok := err.(*oauth.InvalidRefreshTokenError); ok { + var loginCommand string + if _, execErr := pathLooker.LookPath("databricks"); execErr == nil { + loginCommand = buildLoginCommand(ctx, cfg.Profile, arg) + } + return &CliInvalidRefreshTokenError{ + loginCommand: loginCommand, + err: err, + } } - } - // Otherwise, log the error and continue to the next credentials strategy. - logger.Debugf(ctx, "failed to load token: %v, continuing", err) - return nil - }, - GetOAuthArg: defaultGetOAuthArg, - name: "databricks-cli", + // Otherwise, log the error and continue to the next credentials strategy. + logger.Debugf(ctx, "failed to load token: %v, continuing", err) + return nil + }, + GetOAuthArg: defaultGetOAuthArg, + name: "databricks-cli", + } } diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index 2caf4805f..9a8cc4955 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -2,23 +2,44 @@ package config import ( "context" + "errors" "net/http" "testing" "time" + "github.com/databricks/databricks-sdk-go/credentials/cache" "github.com/databricks/databricks-sdk-go/credentials/oauth" - "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" "github.com/stretchr/testify/require" "golang.org/x/oauth2" ) +type MockOAuthClient struct { + Transport http.RoundTripper + GetAccountOAuthEndpointsFn func(ctx context.Context, accountHost string, accountId string) (*oauth.OAuthAuthorizationServer, error) + GetWorkspaceOAuthEndpointsFn func(ctx context.Context, workspaceHost string) (*oauth.OAuthAuthorizationServer, error) +} + +func (m MockOAuthClient) GetHttpClient(_ context.Context) *http.Client { + return &http.Client{ + Transport: m.Transport, + } +} + +func (m MockOAuthClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*oauth.OAuthAuthorizationServer, error) { + return m.GetAccountOAuthEndpointsFn(ctx, accountHost, accountId) +} + +func (m MockOAuthClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*oauth.OAuthAuthorizationServer, error) { + return m.GetWorkspaceOAuthEndpointsFn(ctx, workspaceHost) +} + func TestU2MCredentials(t *testing.T) { tests := []struct { name string cfg *Config auth func() (*oauth.PersistentAuth, error) - expectErr bool + expectErr string expectAuth string }{ { @@ -38,13 +59,12 @@ func TestU2MCredentials(t *testing.T) { }, })) }, - expectErr: false, expectAuth: "Bearer dummy_access_token", }, { name: "expired token with invalid refresh token", cfg: &Config{ - Host: "https://databricks.com", + Host: "https://myworkspace.cloud.databricks.com", }, auth: func() (*oauth.PersistentAuth, error) { return oauth.NewPersistentAuth( @@ -58,18 +78,25 @@ func TestU2MCredentials(t *testing.T) { }, }, }), - oauth.WithApiClient(httpclient.NewApiClient(httpclient.ClientConfig{ + oauth.WithOAuthClient(MockOAuthClient{ Transport: fixtures.SliceTransport{ { - MatchAny: true, + Method: "POST", + Resource: "/oidc/v1/token", Status: 401, Response: `{"error":"invalid_refresh_token","error_description":"Refresh token is invalid"}`, }, }, - })), + GetWorkspaceOAuthEndpointsFn: func(ctx context.Context, workspaceHost string) (*oauth.OAuthAuthorizationServer, error) { + return &oauth.OAuthAuthorizationServer{ + TokenEndpoint: "https://myworkspace.cloud.databricks.com/oidc/v1/token", + AuthorizationEndpoint: "https://myworkspace.cloud.databricks.com/oidc/v1/authorize", + }, nil + }, + }), ) }, - expectErr: true, + expectErr: "oidc: token refresh: token refresh: oauth2: \"invalid_refresh_token\" \"Refresh token is invalid\"", }, } @@ -82,8 +109,8 @@ func TestU2MCredentials(t *testing.T) { Auth: auth, } provider, err := strat.Configure(ctx, tt.cfg) - if tt.expectErr { - require.Error(t, err) + if tt.expectErr != "" { + require.ErrorContains(t, err, tt.expectErr) return } require.NoError(t, err) @@ -97,3 +124,78 @@ func TestU2MCredentials(t *testing.T) { }) } } + +type mockPathLooker struct { + found bool +} + +func (m mockPathLooker) LookPath(_ string) (string, error) { + if m.found { + return "databricks", nil + } + return "", errors.New("not found") +} + +func TestDatabricksCli_ErrorHandler(t *testing.T) { + invalidRefreshTokenError := &oauth.InvalidRefreshTokenError{} + workspaceArg := func() (oauth.OAuthArgument, error) { + return oauth.NewBasicWorkspaceOAuthArgument("https://myworkspace.cloud.databricks.com") + } + accountArg := func() (oauth.OAuthArgument, error) { + return oauth.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "abc") + } + testCases := []struct { + name string + pathLooker pathLooker + cfg *Config + arg func() (oauth.OAuthArgument, error) + err error + want error + }{ + { + name: "not configured is ignored", + arg: workspaceArg, + err: cache.ErrNotConfigured, + want: nil, + }, + { + name: "other error is ignored", + arg: workspaceArg, + err: errors.New("foobar"), + want: nil, + }, + { + name: "invalid refresh token is adapted: profile provided", + pathLooker: mockPathLooker{found: true}, + arg: workspaceArg, + cfg: &Config{Profile: "my-profile"}, + err: invalidRefreshTokenError, + want: &CliInvalidRefreshTokenError{loginCommand: "databricks auth login --profile my-profile", err: invalidRefreshTokenError}, + }, + { + name: "invalid refresh token is adapted: profile not provided for workspace", + pathLooker: mockPathLooker{found: true}, + cfg: &Config{}, + arg: workspaceArg, + err: invalidRefreshTokenError, + want: &CliInvalidRefreshTokenError{loginCommand: "databricks auth login --host https://myworkspace.cloud.databricks.com", err: invalidRefreshTokenError}, + }, + { + name: "invalid refresh token is adapted: profile not provided for account", + pathLooker: mockPathLooker{found: true}, + cfg: &Config{}, + arg: accountArg, + err: invalidRefreshTokenError, + want: &CliInvalidRefreshTokenError{loginCommand: "databricks auth login --host https://accounts.cloud.databricks.com --account-id abc", err: invalidRefreshTokenError}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + databricksCliCredentials := makeDatabricksCliCredentials(tc.pathLooker) + arg, err := tc.arg() + require.NoError(t, err) + got := databricksCliCredentials.ErrorHandler(context.Background(), tc.cfg, arg, tc.err) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/credentials/oauth/oauth_argument.go b/credentials/oauth/oauth_argument.go index 344b069d3..d98b616ab 100644 --- a/credentials/oauth/oauth_argument.go +++ b/credentials/oauth/oauth_argument.go @@ -37,7 +37,7 @@ func NewBasicWorkspaceOAuthArgument(host string) (BasicWorkspaceOAuthArgument, e return BasicWorkspaceOAuthArgument{host: host}, nil } -func (a BasicWorkspaceOAuthArgument) GetHost(ctx context.Context) string { +func (a BasicWorkspaceOAuthArgument) GetWorkspaceHost(ctx context.Context) string { return a.host } @@ -52,7 +52,7 @@ func (a BasicWorkspaceOAuthArgument) GetCacheKey(ctx context.Context) string { return a.host } -var _ OAuthArgument = BasicWorkspaceOAuthArgument{} +var _ WorkspaceOAuthArgument = BasicWorkspaceOAuthArgument{} type AccountOAuthArgument interface { OAuthArgument @@ -69,7 +69,7 @@ type BasicAccountOAuthArgument struct { accountID string } -var _ OAuthArgument = BasicAccountOAuthArgument{} +var _ AccountOAuthArgument = BasicAccountOAuthArgument{} func NewBasicAccountOAuthArgument(accountsHost, accountID string) (BasicAccountOAuthArgument, error) { if !strings.HasPrefix(accountsHost, "https://") { @@ -81,7 +81,7 @@ func NewBasicAccountOAuthArgument(accountsHost, accountID string) (BasicAccountO return BasicAccountOAuthArgument{accountHost: accountsHost, accountID: accountID}, nil } -func (a BasicAccountOAuthArgument) GetHost(ctx context.Context) string { +func (a BasicAccountOAuthArgument) GetAccountHost(ctx context.Context) string { return a.accountHost } diff --git a/credentials/oauth/persistent_auth.go b/credentials/oauth/persistent_auth.go index f0c23a0a4..fecae9f35 100644 --- a/credentials/oauth/persistent_auth.go +++ b/credentials/oauth/persistent_auth.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "net" + "net/http" "os" "path/filepath" "sync" @@ -36,6 +37,35 @@ const ( listenerTimeout = 45 * time.Second ) +// OAuthClient provides the http functionality needed for interacting with the +// Databricks OAuth APIs. +type OAuthClient interface { + // GetHttpClient returns an HTTP client for OAuth2 requests. + GetHttpClient(context.Context) *http.Client + + // GetWorkspaceOAuthEndpoints returns the OAuth2 endpoints for the workspace. + GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) + + // GetAccountOAuthEndpoints returns the OAuth2 endpoints for the account. + GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*OAuthAuthorizationServer, error) +} + +type BasicOAuthClient struct { + client *httpclient.ApiClient +} + +func (c *BasicOAuthClient) GetHttpClient(_ context.Context) *http.Client { + return c.client.ToHttpClient() +} + +func (c *BasicOAuthClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) { + return GetWorkspaceOAuthEndpoints(ctx, c.client, workspaceHost) +} + +func (c *BasicOAuthClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*OAuthAuthorizationServer, error) { + return GetAccountOAuthEndpoints(ctx, accountHost, accountId) +} + // PersistentAuth is an OAuth manager that handles the U2M OAuth flow. Tokens // are stored in and looked up from the provided cache. Tokens include the // refresh token. On load, if the access token is expired, it is refreshed @@ -49,7 +79,7 @@ type PersistentAuth struct { // Locker is the lock to synchronize token cache access. locker sync.Locker // Client is the HTTP client to use for OAuth2 requests. - client *httpclient.ApiClient + client OAuthClient // Browser is the function to open a URL in the default browser. browser func(url string) error // ln is the listener for the OAuth2 callback server. @@ -73,7 +103,7 @@ func WithLocker(l sync.Locker) PersistentAuthOption { } // WithApiClient sets the HTTP client for the PersistentAuth. -func WithApiClient(c *httpclient.ApiClient) PersistentAuthOption { +func WithOAuthClient(c OAuthClient) PersistentAuthOption { return func(a *PersistentAuth) { a.client = c } @@ -93,7 +123,9 @@ func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*Pers opt(p) } if p.client == nil { - p.client = httpclient.NewApiClient(httpclient.ClientConfig{}) + p.client = &BasicOAuthClient{ + client: httpclient.NewApiClient(httpclient.ClientConfig{}), + } } if p.cache == nil { var err error @@ -160,7 +192,7 @@ func (a *PersistentAuth) refresh(ctx context.Context, arg OAuthArgument, oldToke return nil, err } // make OAuth2 library use our client - ctx = a.client.InContextForOAuth2(ctx) + ctx = a.setOAuthContext(ctx) // eagerly refresh token t, err := cfg.TokenSource(ctx, oldToken).Token() if err != nil { @@ -211,7 +243,7 @@ func (a *PersistentAuth) Challenge(ctx context.Context, arg OAuthArgument) error state, pkce := a.stateAndPKCE() // make OAuth2 library use our client - ctx = a.client.InContextForOAuth2(ctx) + ctx = a.setOAuthContext(ctx) ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce) t, err := ts.Token() if err != nil { @@ -256,17 +288,14 @@ func (a *PersistentAuth) oauth2Config(ctx context.Context, arg OAuthArgument) (* } var endpoints *OAuthAuthorizationServer var err error - switch arg := arg.(type) { - case WorkspaceOAuthArgument: - endpoints, err = GetWorkspaceOAuthEndpoints(ctx, a.client, arg.GetWorkspaceHost(ctx)) - if err != nil { - return nil, fmt.Errorf("workspace oauth endpoints: %w", err) - } - case AccountOAuthArgument: - endpoints, err = GetAccountOAuthEndpoints(ctx, arg.GetAccountHost(ctx), arg.GetAccountId(ctx)) + if workspaceArg, ok := arg.(WorkspaceOAuthArgument); ok { + endpoints, err = a.client.GetWorkspaceOAuthEndpoints(ctx, workspaceArg.GetWorkspaceHost(ctx)) + } else if accountArg, ok := arg.(AccountOAuthArgument); ok { + endpoints, err = a.client.GetAccountOAuthEndpoints( + ctx, accountArg.GetAccountHost(ctx), accountArg.GetAccountId(ctx)) } if err != nil { - return nil, fmt.Errorf("oidc: %w", err) + return nil, fmt.Errorf("fetching OAuth endpoints: %w", err) } return &oauth2.Config{ ClientID: appClientID, @@ -296,3 +325,7 @@ func (a *PersistentAuth) randomString(size int) string { _, _ = rand.Read(raw) return base64.RawURLEncoding.EncodeToString(raw) } + +func (a *PersistentAuth) setOAuthContext(ctx context.Context) context.Context { + return context.WithValue(ctx, oauth2.HTTPClient, a.client.GetHttpClient(ctx)) +} diff --git a/httpclient/http.go b/httpclient/http.go new file mode 100644 index 000000000..b0430b69d --- /dev/null +++ b/httpclient/http.go @@ -0,0 +1,10 @@ +package httpclient + +import "net/http" + +func (a *ApiClient) ToHttpClient() *http.Client { + return &http.Client{ + Transport: a, + Timeout: a.config.HTTPTimeout, + } +} From 9b5913c7af7a124234759483834810f3eade9fdb Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 6 Jan 2025 17:01:46 +0100 Subject: [PATCH 09/44] one more test case --- config/auth_u2m_test.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index 9a8cc4955..bfcc69cc1 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -188,6 +188,14 @@ func TestDatabricksCli_ErrorHandler(t *testing.T) { err: invalidRefreshTokenError, want: &CliInvalidRefreshTokenError{loginCommand: "databricks auth login --host https://accounts.cloud.databricks.com --account-id abc", err: invalidRefreshTokenError}, }, + { + name: "invalid refresh token is adapted: CLI not present", + pathLooker: mockPathLooker{found: false}, + cfg: &Config{}, + arg: accountArg, + err: invalidRefreshTokenError, + want: &CliInvalidRefreshTokenError{loginCommand: "", err: invalidRefreshTokenError}, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { From 73c73afabbdfa8db405bc47ae12e0859361e03f5 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 6 Jan 2025 17:13:31 +0100 Subject: [PATCH 10/44] more tests --- config/config.go | 2 ++ config/config_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/config/config.go b/config/config.go index fca6e77fa..8b93804fd 100644 --- a/config/config.go +++ b/config/config.go @@ -437,6 +437,8 @@ func (c *Config) refreshTokenErrorMapper(ctx context.Context, resp common.Respon } func (c *Config) getOidcEndpoints(ctx context.Context) (*oauth.OAuthAuthorizationServer, error) { + // + c.EnsureResolved() if c.IsAccountClient() { return oauth.GetAccountOAuthEndpoints(ctx, c.Host, c.AccountID) } diff --git a/config/config_test.go b/config/config_test.go index 63a59b951..9cebb1a15 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -5,6 +5,8 @@ import ( "net/http" "testing" + "github.com/databricks/databricks-sdk-go/credentials/oauth" + "github.com/databricks/databricks-sdk-go/httpclient/fixtures" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -66,3 +68,36 @@ func TestAuthenticate_InvalidHostSet(t *testing.T) { err = c.Authenticate(req) assert.ErrorIs(t, err, ErrNoHostConfigured) } + +func TestConfig_getOidcEndpoints_account(t *testing.T) { + c := &Config{ + Host: "https://accounts.cloud.databricks.com", + AccountID: "abc", + } + got, err := c.getOidcEndpoints(context.Background()) + assert.NoError(t, err) + assert.Equal(t, &oauth.OAuthAuthorizationServer{ + AuthorizationEndpoint: "https://accounts.cloud.databricks.com/oidc/accounts/abc/v1/authorize", + TokenEndpoint: "https://accounts.cloud.databricks.com/oidc/accounts/abc/v1/token", + }, got) +} + +func TestConfig_getOidcEndpoints_workspace(t *testing.T) { + c := &Config{ + Host: "https://myworkspace.cloud.databricks.com", + HTTPTransport: fixtures.SliceTransport{ + { + Method: "GET", + Resource: "/oidc/.well-known/oauth-authorization-server", + Status: 200, + Response: `{"authorization_endpoint": "https://myworkspace.cloud.databricks.com/oidc/v1/authorize", "token_endpoint": "https://myworkspace.cloud.databricks.com/oidc/v1/token"}`, + }, + }, + } + got, err := c.getOidcEndpoints(context.Background()) + assert.NoError(t, err) + assert.Equal(t, &oauth.OAuthAuthorizationServer{ + AuthorizationEndpoint: "https://myworkspace.cloud.databricks.com/oidc/v1/authorize", + TokenEndpoint: "https://myworkspace.cloud.databricks.com/oidc/v1/token", + }, got) +} From 63bf52192143449ef9cae982f44a9697adff14b0 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 6 Jan 2025 17:13:54 +0100 Subject: [PATCH 11/44] comment --- config/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config.go b/config/config.go index 8b93804fd..d78e178e5 100644 --- a/config/config.go +++ b/config/config.go @@ -436,8 +436,8 @@ func (c *Config) refreshTokenErrorMapper(ctx context.Context, resp common.Respon } } +// getOidcEndpoints returns the OAuth endpoints for the current configuration. func (c *Config) getOidcEndpoints(ctx context.Context) (*oauth.OAuthAuthorizationServer, error) { - // c.EnsureResolved() if c.IsAccountClient() { return oauth.GetAccountOAuthEndpoints(ctx, c.Host, c.AccountID) From 2622989e860a2fc031b3004f8d2aed6808cad6f5 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 6 Jan 2025 17:16:12 +0100 Subject: [PATCH 12/44] mutex --- credentials/cache/file.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/credentials/cache/file.go b/credentials/cache/file.go index bdf09dfd4..e5f7af0cf 100644 --- a/credentials/cache/file.go +++ b/credentials/cache/file.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "sync" "golang.org/x/oauth2" ) @@ -53,10 +54,15 @@ type tokenCacheFile struct { // implements the TokenCache interface. type FileTokenCache struct { fileLocation string + + // mu protects the token cache file from concurrent reads and writes. + mu *sync.Mutex } func NewFileTokenCache() (*FileTokenCache, error) { - c := &FileTokenCache{} + c := &FileTokenCache{ + mu: &sync.Mutex{}, + } if err := c.init(); err != nil { return nil, err } @@ -65,6 +71,8 @@ func NewFileTokenCache() (*FileTokenCache, error) { // Store implements the TokenCache interface. func (c *FileTokenCache) Store(key string, t *oauth2.Token) error { + c.mu.Lock() + defer c.mu.Unlock() f, err := c.load() if err != nil { return fmt.Errorf("load: %w", err) @@ -82,6 +90,8 @@ func (c *FileTokenCache) Store(key string, t *oauth2.Token) error { // Lookup implements the TokenCache interface. func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) { + c.mu.Lock() + defer c.mu.Unlock() f, err := c.load() if err != nil { return nil, fmt.Errorf("load: %w", err) From 359a5d015531e071bbc306ce88d6287c44864e8e Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 6 Jan 2025 17:18:53 +0100 Subject: [PATCH 13/44] documentation --- credentials/oauth/error.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/credentials/oauth/error.go b/credentials/oauth/error.go index f29858f8d..a68c9fdaa 100644 --- a/credentials/oauth/error.go +++ b/credentials/oauth/error.go @@ -1,5 +1,8 @@ package oauth +// InvalidRefreshTokenError is returned from PersistentAuth's Load() method +// if the access token has expired and the refresh token in the token cache +// is invalid. type InvalidRefreshTokenError struct { err error } From 44af89fdbfc836f42b78c4383f7ab20edd3ffa56 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 6 Jan 2025 17:25:36 +0100 Subject: [PATCH 14/44] docs --- credentials/oauth/oauth_argument.go | 25 ++++++++++++++++++------- credentials/oauth/persistent_auth.go | 24 ++++++++++++++++++++---- 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/credentials/oauth/oauth_argument.go b/credentials/oauth/oauth_argument.go index d98b616ab..fa40c4232 100644 --- a/credentials/oauth/oauth_argument.go +++ b/credentials/oauth/oauth_argument.go @@ -7,13 +7,17 @@ import ( ) // OAuthArgument is an interface that provides the necessary information to -// authenticate with PersistentAuth. +// authenticate with PersistentAuth. Implementations of this interface must +// implement either the WorkspaceOAuthArgument or AccountOAuthArgument +// interface. type OAuthArgument interface { // GetCacheKey returns a unique key for the OAuthArgument. This key is used // to store and retrieve the token from the token cache. GetCacheKey(ctx context.Context) string } +// WorkspaceOAuthArgument is an interface that provides the necessary information +// to authenticate using OAuth to a specific workspace. type WorkspaceOAuthArgument interface { OAuthArgument @@ -21,12 +25,15 @@ type WorkspaceOAuthArgument interface { GetWorkspaceHost(ctx context.Context) string } +// BasicWorkspaceOAuthArgument is a basic implementation of the WorkspaceOAuthArgument +// interface that links each host with exactly one OAuth token. type BasicWorkspaceOAuthArgument struct { // host is the host of the workspace to authenticate to. This must start // with "https://" and must not have a trailing slash. host string } +// NewBasicWorkspaceOAuthArgument creates a new BasicWorkspaceOAuthArgument. func NewBasicWorkspaceOAuthArgument(host string) (BasicWorkspaceOAuthArgument, error) { if !strings.HasPrefix(host, "https://") { return BasicWorkspaceOAuthArgument{}, fmt.Errorf("host must start with 'https://': %s", host) @@ -37,13 +44,12 @@ func NewBasicWorkspaceOAuthArgument(host string) (BasicWorkspaceOAuthArgument, e return BasicWorkspaceOAuthArgument{host: host}, nil } +// GetWorkspaceHost returns the host of the workspace to authenticate to. func (a BasicWorkspaceOAuthArgument) GetWorkspaceHost(ctx context.Context) string { return a.host } -// key is currently used for two purposes: OIDC URL prefix and token cache key. -// once we decide to start storing scopes in the token cache, we should change -// this approach. +// GetCacheKey returns a unique key for caching the OAuth token for the workspace. func (a BasicWorkspaceOAuthArgument) GetCacheKey(ctx context.Context) string { a.host = strings.TrimSuffix(a.host, "/") if !strings.HasPrefix(a.host, "http") { @@ -54,6 +60,8 @@ func (a BasicWorkspaceOAuthArgument) GetCacheKey(ctx context.Context) string { var _ WorkspaceOAuthArgument = BasicWorkspaceOAuthArgument{} +// AccountOAuthArgument is an interface that provides the necessary information +// to authenticate using OAuth to a specific account. type AccountOAuthArgument interface { OAuthArgument @@ -64,6 +72,8 @@ type AccountOAuthArgument interface { GetAccountId(ctx context.Context) string } +// BasicAccountOAuthArgument is a basic implementation of the AccountOAuthArgument +// interface that links each account with exactly one OAuth token. type BasicAccountOAuthArgument struct { accountHost string accountID string @@ -71,6 +81,7 @@ type BasicAccountOAuthArgument struct { var _ AccountOAuthArgument = BasicAccountOAuthArgument{} +// NewBasicAccountOAuthArgument creates a new BasicAccountOAuthArgument. func NewBasicAccountOAuthArgument(accountsHost, accountID string) (BasicAccountOAuthArgument, error) { if !strings.HasPrefix(accountsHost, "https://") { return BasicAccountOAuthArgument{}, fmt.Errorf("accountsHost must start with 'https://': %s", accountsHost) @@ -81,17 +92,17 @@ func NewBasicAccountOAuthArgument(accountsHost, accountID string) (BasicAccountO return BasicAccountOAuthArgument{accountHost: accountsHost, accountID: accountID}, nil } +// GetAccountHost returns the host of the account to authenticate to. func (a BasicAccountOAuthArgument) GetAccountHost(ctx context.Context) string { return a.accountHost } +// GetAccountId returns the account ID of the account to authenticate to. func (a BasicAccountOAuthArgument) GetAccountId(ctx context.Context) string { return a.accountID } -// key is currently used for two purposes: OIDC URL prefix and token cache key. -// once we decide to start storing scopes in the token cache, we should change -// this approach. +// GetCacheKey returns a unique key for caching the OAuth token for the account. func (a BasicAccountOAuthArgument) GetCacheKey(ctx context.Context) string { return fmt.Sprintf("%s/oidc/accounts/%s", a.accountHost, a.accountID) } diff --git a/credentials/oauth/persistent_auth.go b/credentials/oauth/persistent_auth.go index fecae9f35..c0ab99a9d 100644 --- a/credentials/oauth/persistent_auth.go +++ b/credentials/oauth/persistent_auth.go @@ -160,6 +160,8 @@ func (a *PersistentAuth) Load(ctx context.Context, arg OAuthArgument) (t *oauth2 a.locker.Lock() defer a.locker.Unlock() + a.validateArg(arg) + // TODO: remove this listener after several releases. err = a.startListener(ctx) if err != nil { @@ -223,6 +225,8 @@ func (a *PersistentAuth) refresh(ctx context.Context, arg OAuthArgument, oldToke func (a *PersistentAuth) Challenge(ctx context.Context, arg OAuthArgument) error { a.locker.Lock() defer a.locker.Unlock() + + a.validateArg(arg) err := a.startListener(ctx) if err != nil { return fmt.Errorf("starting listener: %w", err) @@ -281,6 +285,15 @@ func (a *PersistentAuth) Close() error { return a.ln.Close() } +func (a *PersistentAuth) validateArg(arg OAuthArgument) error { + _, isWorkspaceArg := arg.(WorkspaceOAuthArgument) + _, isAccountArg := arg.(AccountOAuthArgument) + if !isWorkspaceArg && !isAccountArg { + return fmt.Errorf("unsupported OAuthArgument type: %T, must implement either WorkspaceOAuthArgument or AccountOAuthArgument interface", arg) + } + return nil +} + func (a *PersistentAuth) oauth2Config(ctx context.Context, arg OAuthArgument) (*oauth2.Config, error) { scopes := []string{ "offline_access", // ensures OAuth token includes refresh token @@ -288,11 +301,14 @@ func (a *PersistentAuth) oauth2Config(ctx context.Context, arg OAuthArgument) (* } var endpoints *OAuthAuthorizationServer var err error - if workspaceArg, ok := arg.(WorkspaceOAuthArgument); ok { - endpoints, err = a.client.GetWorkspaceOAuthEndpoints(ctx, workspaceArg.GetWorkspaceHost(ctx)) - } else if accountArg, ok := arg.(AccountOAuthArgument); ok { + switch argg := arg.(type) { + case WorkspaceOAuthArgument: + endpoints, err = a.client.GetWorkspaceOAuthEndpoints(ctx, argg.GetWorkspaceHost(ctx)) + case AccountOAuthArgument: endpoints, err = a.client.GetAccountOAuthEndpoints( - ctx, accountArg.GetAccountHost(ctx), accountArg.GetAccountId(ctx)) + ctx, argg.GetAccountHost(ctx), argg.GetAccountId(ctx)) + default: + return nil, fmt.Errorf("unsupported OAuthArgument type: %T, must implement either WorkspaceOAuthArgument or AccountOAuthArgument interface", arg) } if err != nil { return nil, fmt.Errorf("fetching OAuth endpoints: %w", err) From bd7303e3df7962079b411dcb8be5d3981a10f870 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 6 Jan 2025 17:28:06 +0100 Subject: [PATCH 15/44] fix test names --- credentials/oauth/oidc_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/credentials/oauth/oidc_test.go b/credentials/oauth/oidc_test.go index 7d84c9bda..ba199ea32 100644 --- a/credentials/oauth/oidc_test.go +++ b/credentials/oauth/oidc_test.go @@ -9,14 +9,14 @@ import ( "github.com/stretchr/testify/assert" ) -func TestOidcEndpointsForAccounts(t *testing.T) { +func TestGetAccountOAuthEndpoints(t *testing.T) { s, err := GetAccountOAuthEndpoints(context.Background(), "https://abc", "xyz") assert.NoError(t, err) assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", s.AuthorizationEndpoint) assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/token", s.TokenEndpoint) } -func TestOidcForWorkspace(t *testing.T) { +func TestGetWorkspaceOAuthEndpoints(t *testing.T) { p := httpclient.NewApiClient(httpclient.ClientConfig{ Transport: fixtures.MappingTransport{ "GET /oidc/.well-known/oauth-authorization-server": { From 323fb61606d137383165c470a31ae019f51332cb Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Tue, 7 Jan 2025 11:17:10 +0100 Subject: [PATCH 16/44] work --- config/auth_default.go | 2 +- config/auth_u2m.go | 67 ++++++++++++++--------------------------- config/auth_u2m_test.go | 64 ++++++++++++--------------------------- 3 files changed, 44 insertions(+), 89 deletions(-) diff --git a/config/auth_default.go b/config/auth_default.go index 0d2091583..6f5889fb7 100644 --- a/config/auth_default.go +++ b/config/auth_default.go @@ -13,7 +13,7 @@ var authProviders = []CredentialsStrategy{ PatCredentials{}, BasicCredentials{}, M2mCredentials{}, - makeDatabricksCliCredentials(defaultPathLooker{}), + databricksCliCredentials, MetadataServiceCredentials{}, // Attempt to configure auth from most specific to most generic (the Azure CLI). diff --git a/config/auth_u2m.go b/config/auth_u2m.go index f1ec7ff0a..d52d80e37 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net/http" - "os/exec" "strings" "github.com/databricks/databricks-sdk-go/credentials" @@ -122,9 +121,7 @@ type CliInvalidRefreshTokenError struct { func (e *CliInvalidRefreshTokenError) Error() string { msg := "a new access token could not be retrieved because the refresh token is invalid." - if e.loginCommand != "" { - msg += fmt.Sprintf(" To reauthenticate, run `%s`", e.loginCommand) - } + msg += fmt.Sprintf(" To reauthenticate, run `%s`", e.loginCommand) return msg } @@ -151,48 +148,30 @@ func buildLoginCommand(ctx context.Context, profile string, arg oauth.OAuthArgum return strings.Join(cmd, " ") } -// pathLooker is an interface that abstracts the LookPath function from the -// os/exec package. It is used to facilitate testing. -type pathLooker interface { - LookPath(file string) (string, error) -} - -type defaultPathLooker struct{} - -func (defaultPathLooker) LookPath(file string) (string, error) { - return exec.LookPath(file) -} - // databricksCliCredentials is a credentials strategy that emulates the behavior // of the earlier `databricks-cli` credentials strategy which invoked the // `databricks auth token` command. -func makeDatabricksCliCredentials(pathLooker pathLooker) U2MCredentials { - return U2MCredentials{ - ErrorHandler: func(ctx context.Context, cfg *Config, arg oauth.OAuthArgument, err error) error { - // If the current OAuth argument doesn't have a corresponding session - // token, fall back to the next credentials strategy. - if errors.Is(err, cache.ErrNotConfigured) { - return nil - } - // If there is an existing token but the refresh token is invalid, - // return a special error message for invalid refresh tokens. If the - // `databricks` CLI is on the PATH, include a command that the user can - // run to reauthenticate. - if _, ok := err.(*oauth.InvalidRefreshTokenError); ok { - var loginCommand string - if _, execErr := pathLooker.LookPath("databricks"); execErr == nil { - loginCommand = buildLoginCommand(ctx, cfg.Profile, arg) - } - return &CliInvalidRefreshTokenError{ - loginCommand: loginCommand, - err: err, - } - } - // Otherwise, log the error and continue to the next credentials strategy. - logger.Debugf(ctx, "failed to load token: %v, continuing", err) +var databricksCliCredentials = U2MCredentials{ + ErrorHandler: func(ctx context.Context, cfg *Config, arg oauth.OAuthArgument, err error) error { + // If the current OAuth argument doesn't have a corresponding session + // token, fall back to the next credentials strategy. + if errors.Is(err, cache.ErrNotConfigured) { return nil - }, - GetOAuthArg: defaultGetOAuthArg, - name: "databricks-cli", - } + } + // If there is an existing token but the refresh token is invalid, + // return a special error message for invalid refresh tokens. To help + // users easily reauthenticate, include a command that the user can + // run, prepopulating the profile, host and/or account ID. + if _, ok := err.(*oauth.InvalidRefreshTokenError); ok { + return &CliInvalidRefreshTokenError{ + loginCommand: buildLoginCommand(ctx, cfg.Profile, arg), + err: err, + } + } + // Otherwise, log the error and continue to the next credentials strategy. + logger.Debugf(ctx, "failed to load token: %v, continuing", err) + return nil + }, + GetOAuthArg: defaultGetOAuthArg, + name: "databricks-cli", } diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index bfcc69cc1..926775c38 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -125,17 +125,6 @@ func TestU2MCredentials(t *testing.T) { } } -type mockPathLooker struct { - found bool -} - -func (m mockPathLooker) LookPath(_ string) (string, error) { - if m.found { - return "databricks", nil - } - return "", errors.New("not found") -} - func TestDatabricksCli_ErrorHandler(t *testing.T) { invalidRefreshTokenError := &oauth.InvalidRefreshTokenError{} workspaceArg := func() (oauth.OAuthArgument, error) { @@ -145,12 +134,11 @@ func TestDatabricksCli_ErrorHandler(t *testing.T) { return oauth.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "abc") } testCases := []struct { - name string - pathLooker pathLooker - cfg *Config - arg func() (oauth.OAuthArgument, error) - err error - want error + name string + cfg *Config + arg func() (oauth.OAuthArgument, error) + err error + want error }{ { name: "not configured is ignored", @@ -165,41 +153,29 @@ func TestDatabricksCli_ErrorHandler(t *testing.T) { want: nil, }, { - name: "invalid refresh token is adapted: profile provided", - pathLooker: mockPathLooker{found: true}, - arg: workspaceArg, - cfg: &Config{Profile: "my-profile"}, - err: invalidRefreshTokenError, - want: &CliInvalidRefreshTokenError{loginCommand: "databricks auth login --profile my-profile", err: invalidRefreshTokenError}, - }, - { - name: "invalid refresh token is adapted: profile not provided for workspace", - pathLooker: mockPathLooker{found: true}, - cfg: &Config{}, - arg: workspaceArg, - err: invalidRefreshTokenError, - want: &CliInvalidRefreshTokenError{loginCommand: "databricks auth login --host https://myworkspace.cloud.databricks.com", err: invalidRefreshTokenError}, + name: "invalid refresh token is adapted: profile provided", + arg: workspaceArg, + cfg: &Config{Profile: "my-profile"}, + err: invalidRefreshTokenError, + want: &CliInvalidRefreshTokenError{loginCommand: "databricks auth login --profile my-profile", err: invalidRefreshTokenError}, }, { - name: "invalid refresh token is adapted: profile not provided for account", - pathLooker: mockPathLooker{found: true}, - cfg: &Config{}, - arg: accountArg, - err: invalidRefreshTokenError, - want: &CliInvalidRefreshTokenError{loginCommand: "databricks auth login --host https://accounts.cloud.databricks.com --account-id abc", err: invalidRefreshTokenError}, + name: "invalid refresh token is adapted: profile not provided for workspace", + cfg: &Config{}, + arg: workspaceArg, + err: invalidRefreshTokenError, + want: &CliInvalidRefreshTokenError{loginCommand: "databricks auth login --host https://myworkspace.cloud.databricks.com", err: invalidRefreshTokenError}, }, { - name: "invalid refresh token is adapted: CLI not present", - pathLooker: mockPathLooker{found: false}, - cfg: &Config{}, - arg: accountArg, - err: invalidRefreshTokenError, - want: &CliInvalidRefreshTokenError{loginCommand: "", err: invalidRefreshTokenError}, + name: "invalid refresh token is adapted: profile not provided for account", + cfg: &Config{}, + arg: accountArg, + err: invalidRefreshTokenError, + want: &CliInvalidRefreshTokenError{loginCommand: "databricks auth login --host https://accounts.cloud.databricks.com --account-id abc", err: invalidRefreshTokenError}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - databricksCliCredentials := makeDatabricksCliCredentials(tc.pathLooker) arg, err := tc.arg() require.NoError(t, err) got := databricksCliCredentials.ErrorHandler(context.Background(), tc.cfg, arg, tc.err) From 3905c5d915e79d6b8b41c89ca01e7c5d74f0e9c1 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Tue, 7 Jan 2025 12:24:55 +0100 Subject: [PATCH 17/44] fix tests --- config/auth_u2m.go | 5 +- credentials/cache/file.go | 21 +- credentials/cache/file_test.go | 31 +-- credentials/oauth/persistent_auth_test.go | 274 ++++++++++++---------- httpclient/fixtures/fixture.go | 2 + 5 files changed, 185 insertions(+), 148 deletions(-) diff --git a/config/auth_u2m.go b/config/auth_u2m.go index d52d80e37..f265e584b 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -53,6 +53,9 @@ func (u U2MCredentials) Name() string { // Configure implements CredentialsStrategy. func (u U2MCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { + if cfg.Host == "" { + return nil, nil + } a := u.Auth if a == nil { var err error @@ -121,7 +124,7 @@ type CliInvalidRefreshTokenError struct { func (e *CliInvalidRefreshTokenError) Error() string { msg := "a new access token could not be retrieved because the refresh token is invalid." - msg += fmt.Sprintf(" To reauthenticate, run `%s`", e.loginCommand) + msg += fmt.Sprintf(" If using the CLI, run `%s` to reauthenticate", e.loginCommand) return msg } diff --git a/credentials/cache/file.go b/credentials/cache/file.go index e5f7af0cf..1b924ac58 100644 --- a/credentials/cache/file.go +++ b/credentials/cache/file.go @@ -50,6 +50,14 @@ type tokenCacheFile struct { Tokens map[string]*oauth2.Token `json:"tokens"` } +type FileTokenCacheOpt func(*FileTokenCache) + +func WithFileLocation(fileLocation string) FileTokenCacheOpt { + return func(c *FileTokenCache) { + c.fileLocation = fileLocation + } +} + // FileTokenCache caches tokens in "~/.databricks/token-cache.json". FileTokenCache // implements the TokenCache interface. type FileTokenCache struct { @@ -59,13 +67,20 @@ type FileTokenCache struct { mu *sync.Mutex } -func NewFileTokenCache() (*FileTokenCache, error) { +func NewFileTokenCache(opts ...FileTokenCacheOpt) (*FileTokenCache, error) { c := &FileTokenCache{ mu: &sync.Mutex{}, } + for _, opt := range opts { + opt(c) + } if err := c.init(); err != nil { return nil, err } + // verify the cache is working + if _, err := c.load(); err != nil { + return nil, fmt.Errorf("load: %w", err) + } return c, nil } @@ -81,7 +96,7 @@ func (c *FileTokenCache) Store(key string, t *oauth2.Token) error { f.Tokens = map[string]*oauth2.Token{} } f.Tokens[key] = t - raw, err := json.MarshalIndent(c, "", " ") + raw, err := json.MarshalIndent(f, "", " ") if err != nil { return fmt.Errorf("marshal: %w", err) } @@ -152,7 +167,7 @@ func (c *FileTokenCache) load() (*tokenCacheFile, error) { return nil, fmt.Errorf("read: %w", err) } f := &tokenCacheFile{} - err = json.Unmarshal(raw, f) + err = json.Unmarshal(raw, &f) if err != nil { return nil, fmt.Errorf("parse: %w", err) } diff --git a/credentials/cache/file_test.go b/credentials/cache/file_test.go index 9760d581f..ef45820b5 100644 --- a/credentials/cache/file_test.go +++ b/credentials/cache/file_test.go @@ -16,11 +16,9 @@ func setup(t *testing.T) string { } func TestStoreAndLookup(t *testing.T) { - c := &FileTokenCache{ - fileLocation: setup(t), - } - assert.NoError(t, c.init()) - err := c.Store("x", &oauth2.Token{ + c, err := NewFileTokenCache(WithFileLocation(setup(t))) + require.NoError(t, err) + err = c.Store("x", &oauth2.Token{ AccessToken: "abc", }) require.NoError(t, err) @@ -30,21 +28,18 @@ func TestStoreAndLookup(t *testing.T) { }) require.NoError(t, err) - l := &FileTokenCache{} - tok, err := l.Lookup("x") + tok, err := c.Lookup("x") require.NoError(t, err) assert.Equal(t, "abc", tok.AccessToken) - _, err = l.Lookup("z") + _, err = c.Lookup("z") assert.Equal(t, ErrNotConfigured, err) } func TestNoCacheFileReturnsErrNotConfigured(t *testing.T) { - l := &FileTokenCache{ - fileLocation: setup(t), - } - assert.NoError(t, l.init()) - _, err := l.Lookup("x") + l, err := NewFileTokenCache(WithFileLocation(setup(t))) + require.NoError(t, err) + _, err = l.Lookup("x") assert.Equal(t, ErrNotConfigured, err) } @@ -55,10 +50,8 @@ func TestLoadCorruptFile(t *testing.T) { err = os.WriteFile(f, []byte("abc"), ownerExecReadWrite) require.NoError(t, err) - l := &FileTokenCache{ - fileLocation: f, - } - assert.EqualError(t, l.init(), "load: parse: invalid character 'a' looking for beginning of value") + _, err = NewFileTokenCache(WithFileLocation(f)) + assert.EqualError(t, err, "load: parse: invalid character 'a' looking for beginning of value") } func TestLoadWrongVersion(t *testing.T) { @@ -68,6 +61,6 @@ func TestLoadWrongVersion(t *testing.T) { err = os.WriteFile(f, []byte(`{"version": 823, "things": []}`), ownerExecReadWrite) require.NoError(t, err) - l := &FileTokenCache{} - assert.EqualError(t, l.init(), "load: needs version 1, got version 823") + _, err = NewFileTokenCache(WithFileLocation(f)) + assert.EqualError(t, err, "load: needs version 1, got version 823") } diff --git a/credentials/oauth/persistent_auth_test.go b/credentials/oauth/persistent_auth_test.go index 36de8805f..5099528f0 100644 --- a/credentials/oauth/persistent_auth_test.go +++ b/credentials/oauth/persistent_auth_test.go @@ -2,7 +2,6 @@ package oauth_test import ( "context" - "crypto/tls" _ "embed" "fmt" "net/http" @@ -10,9 +9,8 @@ import ( "testing" "time" - "github.com/databricks/databricks-sdk-go/client" "github.com/databricks/databricks-sdk-go/credentials/oauth" - "github.com/databricks/databricks-sdk-go/qa" + "github.com/databricks/databricks-sdk-go/httpclient/fixtures" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -58,140 +56,166 @@ func TestLoad(t *testing.T) { assert.Equal(t, "", tok.RefreshToken) } -func useInsecureOAuthHttpClientForTests(ctx context.Context) context.Context { - return context.WithValue(ctx, oauth2.HTTPClient, &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - }, - }) +type MockOAuthClient struct { + Transport http.RoundTripper +} + +func (m MockOAuthClient) GetHttpClient(_ context.Context) *http.Client { + return &http.Client{ + Transport: m.Transport, + } +} + +func (m MockOAuthClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*oauth.OAuthAuthorizationServer, error) { + return &oauth.OAuthAuthorizationServer{ + AuthorizationEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/authorize", accountHost, accountId), + TokenEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/token", accountHost, accountId), + }, nil +} + +func (m MockOAuthClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*oauth.OAuthAuthorizationServer, error) { + return &oauth.OAuthAuthorizationServer{ + AuthorizationEndpoint: fmt.Sprintf("%s/oidc/v1/authorize", workspaceHost), + TokenEndpoint: fmt.Sprintf("%s/oidc/v1/token", workspaceHost), + }, nil } func TestLoadRefresh(t *testing.T) { - qa.HTTPFixtures{ - { - Method: "POST", - Resource: "/oidc/accounts/xyz/v1/token", - Response: `access_token=refreshed&refresh_token=def`, + ctx := context.Background() + expectedKey := "https://accounts.cloud.databricks.com/oidc/accounts/xyz" + cache := &tokenCacheMock{ + lookup: func(key string) (*oauth2.Token, error) { + assert.Equal(t, expectedKey, key) + return &oauth2.Token{ + AccessToken: "expired", + RefreshToken: "cde", + Expiry: time.Now().Add(-1 * time.Minute), + }, nil }, - }.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) { - ctx = useInsecureOAuthHttpClientForTests(ctx) - expectedKey := fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host) - cache := &tokenCacheMock{ - lookup: func(key string) (*oauth2.Token, error) { - assert.Equal(t, expectedKey, key) - return &oauth2.Token{ - AccessToken: "expired", - RefreshToken: "cde", - Expiry: time.Now().Add(-1 * time.Minute), - }, nil - }, - store: func(key string, tok *oauth2.Token) error { - assert.Equal(t, expectedKey, key) - assert.Equal(t, "def", tok.RefreshToken) - return nil + store: func(key string, tok *oauth2.Token) error { + assert.Equal(t, expectedKey, key) + assert.Equal(t, "def", tok.RefreshToken) + return nil + }, + } + p, err := oauth.NewPersistentAuth( + context.Background(), + oauth.WithTokenCache(cache), + oauth.WithOAuthClient(&MockOAuthClient{ + Transport: fixtures.SliceTransport{ + { + Method: "POST", + Resource: "/oidc/accounts/xyz/v1/token", + Response: `access_token=refreshed&refresh_token=def`, + ResponseHeaders: map[string][]string{ + "Content-Type": {"application/x-www-form-urlencoded"}, + }, + }, }, - } - p, err := oauth.NewPersistentAuth(context.Background(), oauth.WithTokenCache(cache)) - require.NoError(t, err) - defer p.Close() - arg, err := oauth.NewBasicAccountOAuthArgument(c.Config.Host, "xyz") - assert.NoError(t, err) - tok, err := p.Load(ctx, arg) - assert.NoError(t, err) - assert.Equal(t, "refreshed", tok.AccessToken) - assert.Equal(t, "", tok.RefreshToken) - }) + }), + ) + require.NoError(t, err) + defer p.Close() + arg, err := oauth.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + assert.NoError(t, err) + tok, err := p.Load(ctx, arg) + assert.NoError(t, err) + assert.Equal(t, "refreshed", tok.AccessToken) + assert.Equal(t, "", tok.RefreshToken) } func TestChallenge(t *testing.T) { - qa.HTTPFixtures{ - { - Method: "POST", - Resource: "/oidc/accounts/xyz/v1/token", - Response: `access_token=__THAT__&refresh_token=__SOMETHING__`, - }, - }.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) { - ctx = useInsecureOAuthHttpClientForTests(ctx) - expectedKey := fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host) - - browserOpened := make(chan string) - browser := func(redirect string) error { - u, err := url.ParseRequestURI(redirect) - if err != nil { - return err - } - assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path) - // for now we're ignoring asserting the fields of the redirect - query := u.Query() - browserOpened <- query.Get("state") - return nil + ctx := context.Background() + expectedKey := "https://accounts.cloud.databricks.com/oidc/accounts/xyz" + + browserOpened := make(chan string) + browser := func(redirect string) error { + u, err := url.ParseRequestURI(redirect) + if err != nil { + return err } - cache := &tokenCacheMock{ - store: func(key string, tok *oauth2.Token) error { - assert.Equal(t, expectedKey, key) - assert.Equal(t, "__SOMETHING__", tok.RefreshToken) - return nil + assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path) + // for now we're ignoring asserting the fields of the redirect + query := u.Query() + browserOpened <- query.Get("state") + return nil + } + cache := &tokenCacheMock{ + store: func(key string, tok *oauth2.Token) error { + assert.Equal(t, expectedKey, key) + assert.Equal(t, "__SOMETHING__", tok.RefreshToken) + return nil + }, + } + p, err := oauth.NewPersistentAuth( + context.Background(), + oauth.WithTokenCache(cache), + oauth.WithBrowser(browser), + oauth.WithOAuthClient(&MockOAuthClient{ + Transport: fixtures.SliceTransport{ + { + Method: "POST", + Resource: "/oidc/accounts/xyz/v1/token", + Response: `access_token=__THAT__&refresh_token=__SOMETHING__`, + ResponseHeaders: map[string][]string{ + "Content-Type": {"application/x-www-form-urlencoded"}, + }, + }, }, - } - p, err := oauth.NewPersistentAuth(context.Background(), oauth.WithTokenCache(cache), oauth.WithBrowser(browser)) - require.NoError(t, err) - defer p.Close() - arg, err := oauth.NewBasicAccountOAuthArgument(c.Config.Host, "xyz") - assert.NoError(t, err) - - errc := make(chan error) - go func() { - errc <- p.Challenge(ctx, arg) - }() - - state := <-browserOpened - resp, err := http.Get(fmt.Sprintf("http://localhost:8020?code=__THIS__&state=%s", state)) - require.NoError(t, err) - defer resp.Body.Close() - assert.Equal(t, 200, resp.StatusCode) - - err = <-errc - assert.NoError(t, err) - }) + }), + ) + require.NoError(t, err) + defer p.Close() + arg, err := oauth.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + assert.NoError(t, err) + + errc := make(chan error) + go func() { + errc <- p.Challenge(ctx, arg) + }() + + state := <-browserOpened + resp, err := http.Get(fmt.Sprintf("http://localhost:8020?code=__THIS__&state=%s", state)) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, 200, resp.StatusCode) + + err = <-errc + assert.NoError(t, err) } func TestChallengeFailed(t *testing.T) { - qa.HTTPFixtures{}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) { - ctx = useInsecureOAuthHttpClientForTests(ctx) - - browserOpened := make(chan string) - browser := func(redirect string) error { - u, err := url.ParseRequestURI(redirect) - if err != nil { - return err - } - assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path) - // for now we're ignoring asserting the fields of the redirect - query := u.Query() - browserOpened <- query.Get("state") - return nil + ctx := context.Background() + browserOpened := make(chan string) + browser := func(redirect string) error { + u, err := url.ParseRequestURI(redirect) + if err != nil { + return err } - p, err := oauth.NewPersistentAuth(context.Background(), oauth.WithBrowser(browser)) - require.NoError(t, err) - defer p.Close() - arg, err := oauth.NewBasicAccountOAuthArgument(c.Config.Host, "xyz") - assert.NoError(t, err) - - errc := make(chan error) - go func() { - errc <- p.Challenge(ctx, arg) - }() - - <-browserOpened - resp, err := http.Get( - "http://localhost:8020?error=access_denied&error_description=Policy%20evaluation%20failed%20for%20this%20request") - require.NoError(t, err) - defer resp.Body.Close() - assert.Equal(t, 400, resp.StatusCode) - - err = <-errc - assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request") - }) + assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path) + // for now we're ignoring asserting the fields of the redirect + query := u.Query() + browserOpened <- query.Get("state") + return nil + } + p, err := oauth.NewPersistentAuth(context.Background(), oauth.WithBrowser(browser)) + require.NoError(t, err) + defer p.Close() + arg, err := oauth.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + assert.NoError(t, err) + + errc := make(chan error) + go func() { + errc <- p.Challenge(ctx, arg) + }() + + <-browserOpened + resp, err := http.Get( + "http://localhost:8020?error=access_denied&error_description=Policy%20evaluation%20failed%20for%20this%20request") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, 400, resp.StatusCode) + + err = <-errc + assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request") } diff --git a/httpclient/fixtures/fixture.go b/httpclient/fixtures/fixture.go index 32b00fcbd..9a85c3c84 100644 --- a/httpclient/fixtures/fixture.go +++ b/httpclient/fixtures/fixture.go @@ -24,6 +24,7 @@ type HTTPFixture struct { Response any Status int + ResponseHeaders map[string][]string ExpectedRequest any ExpectedHeaders map[string]string PassFile string @@ -106,6 +107,7 @@ func (f HTTPFixture) replyWith(req *http.Request, body string) (*http.Response, StatusCode: f.Status, Status: http.StatusText(f.Status), Body: io.NopCloser(strings.NewReader(body)), + Header: f.ResponseHeaders, }, nil } From f53fb8492c4279043e49dd66a14e5a14708199f3 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Tue, 7 Jan 2025 13:54:41 +0100 Subject: [PATCH 18/44] comments and test fixes --- config/auth_u2m.go | 7 +- credentials/cache/file.go | 13 +-- credentials/oauth/account_oauth_argument.go | 55 +++++++++++ credentials/oauth/callback.go | 49 +++++++--- credentials/oauth/lock.go | 5 +- credentials/oauth/oauth_argument.go | 93 ------------------- credentials/oauth/oidc.go | 15 ++- credentials/oauth/persistent_auth.go | 44 ++++++--- credentials/oauth/persistent_auth_test.go | 17 +++- credentials/oauth/workspace_oauth_argument.go | 52 +++++++++++ 10 files changed, 217 insertions(+), 133 deletions(-) create mode 100644 credentials/oauth/account_oauth_argument.go create mode 100644 credentials/oauth/workspace_oauth_argument.go diff --git a/config/auth_u2m.go b/config/auth_u2m.go index f265e584b..ff2c838fb 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -123,15 +123,16 @@ type CliInvalidRefreshTokenError struct { } func (e *CliInvalidRefreshTokenError) Error() string { - msg := "a new access token could not be retrieved because the refresh token is invalid." - msg += fmt.Sprintf(" If using the CLI, run `%s` to reauthenticate", e.loginCommand) - return msg + return fmt.Sprintf("a new access token could not be retrieved because the refresh token is invalid. If using the CLI, run `%s` to reauthenticate", e.loginCommand) } func (e *CliInvalidRefreshTokenError) Unwrap() error { return e.err } +// buildLoginCommand returns the `databricks auth login` command that the user +// can run to reauthenticate. The command is prepopulated with the profile, host +// and/or account ID. func buildLoginCommand(ctx context.Context, profile string, arg oauth.OAuthArgument) string { cmd := []string{ "databricks", diff --git a/credentials/cache/file.go b/credentials/cache/file.go index 1b924ac58..ee0200e62 100644 --- a/credentials/cache/file.go +++ b/credentials/cache/file.go @@ -31,16 +31,11 @@ const ( // "": { // "access_token": "", // "token_type": "", - // "refresh_token": "", // "expiry": "" // } // } // } - // - // The format of "" depends on whether the token is account- or - // workspace-scoped: - // - Account-scoped: "https:///oidc/accounts/" - // - Workspace-scoped: "https://" tokenCacheVersion = 1 ) @@ -67,6 +62,12 @@ type FileTokenCache struct { mu *sync.Mutex } +// NewFileTokenCache creates a new FileTokenCache. By default, the cache is +// stored in "~/.databricks/token-cache.json". The cache file is created if it +// does not already exist. The cache file is created with owner permissions +// 0600 and the directory is created with owner permissions 0700. If the cache +// file is corrupt or if its version does not match tokenCacheVersion, an error +// is returned. func NewFileTokenCache(opts ...FileTokenCacheOpt) (*FileTokenCache, error) { c := &FileTokenCache{ mu: &sync.Mutex{}, diff --git a/credentials/oauth/account_oauth_argument.go b/credentials/oauth/account_oauth_argument.go new file mode 100644 index 000000000..a0661182e --- /dev/null +++ b/credentials/oauth/account_oauth_argument.go @@ -0,0 +1,55 @@ +package oauth + +import ( + "context" + "fmt" + "strings" +) + +// AccountOAuthArgument is an interface that provides the necessary information +// to authenticate using OAuth to a specific account. +type AccountOAuthArgument interface { + OAuthArgument + + // GetAccountHost returns the host of the account to authenticate to. + GetAccountHost(ctx context.Context) string + + // GetAccountId returns the account ID of the account to authenticate to. + GetAccountId(ctx context.Context) string +} + +// BasicAccountOAuthArgument is a basic implementation of the AccountOAuthArgument +// interface that links each account with exactly one OAuth token. +type BasicAccountOAuthArgument struct { + accountHost string + accountID string +} + +var _ AccountOAuthArgument = BasicAccountOAuthArgument{} + +// NewBasicAccountOAuthArgument creates a new BasicAccountOAuthArgument. +func NewBasicAccountOAuthArgument(accountsHost, accountID string) (BasicAccountOAuthArgument, error) { + if !strings.HasPrefix(accountsHost, "https://") { + return BasicAccountOAuthArgument{}, fmt.Errorf("accountsHost must start with 'https://': %s", accountsHost) + } + if strings.HasSuffix(accountsHost, "/") { + return BasicAccountOAuthArgument{}, fmt.Errorf("accountsHost must not have a trailing slash: %s", accountsHost) + } + return BasicAccountOAuthArgument{accountHost: accountsHost, accountID: accountID}, nil +} + +// GetAccountHost returns the host of the account to authenticate to. +func (a BasicAccountOAuthArgument) GetAccountHost(ctx context.Context) string { + return a.accountHost +} + +// GetAccountId returns the account ID of the account to authenticate to. +func (a BasicAccountOAuthArgument) GetAccountId(ctx context.Context) string { + return a.accountID +} + +// GetCacheKey returns a unique key for caching the OAuth token for the account. +// The key is in the format "/oidc/accounts/". +func (a BasicAccountOAuthArgument) GetCacheKey(ctx context.Context) string { + return fmt.Sprintf("%s/oidc/accounts/%s", a.accountHost, a.accountID) +} diff --git a/credentials/oauth/callback.go b/credentials/oauth/callback.go index cb7861be7..258a10a82 100644 --- a/credentials/oauth/callback.go +++ b/credentials/oauth/callback.go @@ -5,7 +5,6 @@ import ( _ "embed" "fmt" "html/template" - "net" "net/http" "strings" @@ -24,18 +23,40 @@ type oauthResult struct { Host string } +// callbackServer is a server that listens for the redirect from the Databricks +// identity provider. It renders a page.html template that shows the result of +// the authentication attempt. type callbackServer struct { - ln net.Listener - srv http.Server - ctx context.Context - a *PersistentAuth - arg OAuthArgument + // ctx is the context used when waiting for the redirect from the identity + // provider. This is needed because the Handler() method from the oauth2 + // library does not accept a context. + ctx context.Context + + // srv is the server that listens for the redirect from the identity provider. + srv http.Server + + // browser is a function that opens a browser to the given URL. + browser func(string) error + + // arg is the OAuth argument used to authenticate. + arg OAuthArgument + + // renderErrCh is a channel that receives an error if there is an error + // rendering the page.html template. renderErrCh chan error - feedbackCh chan oauthResult - tmpl *template.Template + + // feedbackCh is a channel that receives the result of the authentication + // attempt. + feedbackCh chan oauthResult + + // tmpl is the template used to render the response page after the user is + // redirected back to the callback server. + tmpl *template.Template } -func (a *PersistentAuth) newCallback(ctx context.Context, arg OAuthArgument) (*callbackServer, error) { +// newCallbackServer creates a new callback server that listens for the redirect +// from the Databricks identity provider. +func (a *PersistentAuth) newCallbackServer(ctx context.Context, arg OAuthArgument) (*callbackServer, error) { tmpl, err := template.New("page").Funcs(template.FuncMap{ "title": func(in string) string { title := cases.Title(language.English) @@ -50,22 +71,22 @@ func (a *PersistentAuth) newCallback(ctx context.Context, arg OAuthArgument) (*c renderErrCh: make(chan error), tmpl: tmpl, ctx: ctx, - ln: a.ln, - a: a, + browser: a.browser, arg: arg, } cb.srv.Handler = cb go func() { - _ = cb.srv.Serve(cb.ln) + _ = cb.srv.Serve(a.ln) }() return cb, nil } +// Close closes the callback server. func (cb *callbackServer) Close() error { return cb.srv.Close() } -// ServeHTTP renders page.html template +// ServeHTTP renders the page.html template. func (cb *callbackServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { res := oauthResult{ Error: r.FormValue("error"), @@ -99,7 +120,7 @@ func (cb *callbackServer) getHost(ctx context.Context) string { // Handler opens up a browser waits for redirect to come back from the identity provider func (cb *callbackServer) Handler(authCodeURL string) (string, string, error) { - err := cb.a.browser(authCodeURL) + err := cb.browser(authCodeURL) if err != nil { fmt.Printf("Please open %s in the browser to continue authentication", authCodeURL) } diff --git a/credentials/oauth/lock.go b/credentials/oauth/lock.go index f2d8efebc..93bc65cb2 100644 --- a/credentials/oauth/lock.go +++ b/credentials/oauth/lock.go @@ -14,6 +14,8 @@ type lockerAdaptor struct { fileMutex *filemutex.FileMutex } +var _ sync.Locker = (*lockerAdaptor)(nil) + // Lock implements sync.Locker. func (l *lockerAdaptor) Lock() { err := l.fileMutex.Lock() @@ -30,7 +32,8 @@ func (l *lockerAdaptor) Unlock() { } } -func newLocker(path string) (sync.Locker, error) { +// newLocker creates a new sync.Locker that uses a file-based mutex. +func newLocker(path string) (*lockerAdaptor, error) { dirName := filepath.Dir(path) if _, err := os.Stat(dirName); err != nil && os.IsNotExist(err) { os.MkdirAll(dirName, 0750) diff --git a/credentials/oauth/oauth_argument.go b/credentials/oauth/oauth_argument.go index fa40c4232..97777876f 100644 --- a/credentials/oauth/oauth_argument.go +++ b/credentials/oauth/oauth_argument.go @@ -2,8 +2,6 @@ package oauth import ( "context" - "fmt" - "strings" ) // OAuthArgument is an interface that provides the necessary information to @@ -15,94 +13,3 @@ type OAuthArgument interface { // to store and retrieve the token from the token cache. GetCacheKey(ctx context.Context) string } - -// WorkspaceOAuthArgument is an interface that provides the necessary information -// to authenticate using OAuth to a specific workspace. -type WorkspaceOAuthArgument interface { - OAuthArgument - - // GetWorkspaceHost returns the host of the workspace to authenticate to. - GetWorkspaceHost(ctx context.Context) string -} - -// BasicWorkspaceOAuthArgument is a basic implementation of the WorkspaceOAuthArgument -// interface that links each host with exactly one OAuth token. -type BasicWorkspaceOAuthArgument struct { - // host is the host of the workspace to authenticate to. This must start - // with "https://" and must not have a trailing slash. - host string -} - -// NewBasicWorkspaceOAuthArgument creates a new BasicWorkspaceOAuthArgument. -func NewBasicWorkspaceOAuthArgument(host string) (BasicWorkspaceOAuthArgument, error) { - if !strings.HasPrefix(host, "https://") { - return BasicWorkspaceOAuthArgument{}, fmt.Errorf("host must start with 'https://': %s", host) - } - if strings.HasSuffix(host, "/") { - return BasicWorkspaceOAuthArgument{}, fmt.Errorf("host must not have a trailing slash: %s", host) - } - return BasicWorkspaceOAuthArgument{host: host}, nil -} - -// GetWorkspaceHost returns the host of the workspace to authenticate to. -func (a BasicWorkspaceOAuthArgument) GetWorkspaceHost(ctx context.Context) string { - return a.host -} - -// GetCacheKey returns a unique key for caching the OAuth token for the workspace. -func (a BasicWorkspaceOAuthArgument) GetCacheKey(ctx context.Context) string { - a.host = strings.TrimSuffix(a.host, "/") - if !strings.HasPrefix(a.host, "http") { - a.host = fmt.Sprintf("https://%s", a.host) - } - return a.host -} - -var _ WorkspaceOAuthArgument = BasicWorkspaceOAuthArgument{} - -// AccountOAuthArgument is an interface that provides the necessary information -// to authenticate using OAuth to a specific account. -type AccountOAuthArgument interface { - OAuthArgument - - // GetAccountHost returns the host of the account to authenticate to. - GetAccountHost(ctx context.Context) string - - // GetAccountId returns the account ID of the account to authenticate to. - GetAccountId(ctx context.Context) string -} - -// BasicAccountOAuthArgument is a basic implementation of the AccountOAuthArgument -// interface that links each account with exactly one OAuth token. -type BasicAccountOAuthArgument struct { - accountHost string - accountID string -} - -var _ AccountOAuthArgument = BasicAccountOAuthArgument{} - -// NewBasicAccountOAuthArgument creates a new BasicAccountOAuthArgument. -func NewBasicAccountOAuthArgument(accountsHost, accountID string) (BasicAccountOAuthArgument, error) { - if !strings.HasPrefix(accountsHost, "https://") { - return BasicAccountOAuthArgument{}, fmt.Errorf("accountsHost must start with 'https://': %s", accountsHost) - } - if strings.HasSuffix(accountsHost, "/") { - return BasicAccountOAuthArgument{}, fmt.Errorf("accountsHost must not have a trailing slash: %s", accountsHost) - } - return BasicAccountOAuthArgument{accountHost: accountsHost, accountID: accountID}, nil -} - -// GetAccountHost returns the host of the account to authenticate to. -func (a BasicAccountOAuthArgument) GetAccountHost(ctx context.Context) string { - return a.accountHost -} - -// GetAccountId returns the account ID of the account to authenticate to. -func (a BasicAccountOAuthArgument) GetAccountId(ctx context.Context) string { - return a.accountID -} - -// GetCacheKey returns a unique key for caching the OAuth token for the account. -func (a BasicAccountOAuthArgument) GetCacheKey(ctx context.Context) string { - return fmt.Sprintf("%s/oidc/accounts/%s", a.accountHost, a.accountID) -} diff --git a/credentials/oauth/oidc.go b/credentials/oauth/oidc.go index 3b4c77fe2..67012d2ab 100644 --- a/credentials/oauth/oidc.go +++ b/credentials/oauth/oidc.go @@ -10,6 +10,7 @@ import ( var ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host") +// GetAccountOAuthEndpoints returns the OAuth endpoints for the given account. func GetAccountOAuthEndpoints(ctx context.Context, accountsHost, accountId string) (*OAuthAuthorizationServer, error) { return &OAuthAuthorizationServer{ AuthorizationEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/authorize", accountsHost, accountId), @@ -17,6 +18,9 @@ func GetAccountOAuthEndpoints(ctx context.Context, accountsHost, accountId strin }, nil } +// GetWorkspaceOAuthEndpoints returns the OAuth endpoints for the given workspace, +// It queries the OIDC discovery endpoint to get the OAuth endpoints using the +// provided ApiClient. func GetWorkspaceOAuthEndpoints(ctx context.Context, c *httpclient.ApiClient, host string) (*OAuthAuthorizationServer, error) { oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", host) var oauthEndpoints OAuthAuthorizationServer @@ -26,7 +30,14 @@ func GetWorkspaceOAuthEndpoints(ctx context.Context, c *httpclient.ApiClient, ho return &oauthEndpoints, nil } +// OAuthAuthorizationServer contains the OAuth endpoints for a Databricks account +// or workspace. type OAuthAuthorizationServer struct { - AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize - TokenEndpoint string `json:"token_endpoint"` // ../v1/token + // AuthorizationEndpoint is the URL to redirect users to for authorization. + // It typically ends with /v1/authroize. + AuthorizationEndpoint string `json:"authorization_endpoint"` + + // TokenEndpoint is the URL to exchange an authorization code for an access token. + // It typically ends with /v1/token. + TokenEndpoint string `json:"token_endpoint"` } diff --git a/credentials/oauth/persistent_auth.go b/credentials/oauth/persistent_auth.go index c0ab99a9d..59be2d967 100644 --- a/credentials/oauth/persistent_auth.go +++ b/credentials/oauth/persistent_auth.go @@ -17,6 +17,7 @@ import ( "github.com/databricks/databricks-sdk-go/credentials/cache" "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/databricks/databricks-sdk-go/logger" "github.com/databricks/databricks-sdk-go/retries" "github.com/pkg/browser" "golang.org/x/oauth2" @@ -151,17 +152,22 @@ func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*Pers return p, nil } +// tokenErrorResponse is the response from the OAuth2 token endpoint when an +// error occurs. type tokenErrorResponse struct { Error string `json:"error"` ErrorDescription string `json:"error_description"` } +// Load loads the OAuth2 token for the given OAuthArgument from the cache. If +// the token is expired, it is refreshed using the refresh token. func (a *PersistentAuth) Load(ctx context.Context, arg OAuthArgument) (t *oauth2.Token, err error) { a.locker.Lock() defer a.locker.Unlock() - a.validateArg(arg) - + if err := a.validateArg(arg); err != nil { + return nil, err + } // TODO: remove this listener after several releases. err = a.startListener(ctx) if err != nil { @@ -186,6 +192,8 @@ func (a *PersistentAuth) Load(ctx context.Context, arg OAuthArgument) (t *oauth2 return t, nil } +// refresh refreshes the token for the given OAuthArgument, storing the new +// token in the cache. func (a *PersistentAuth) refresh(ctx context.Context, arg OAuthArgument, oldToken *oauth2.Token) (*oauth2.Token, error) { // OAuth2 config is invoked only for expired tokens to speed up // the happy path in the token retrieval @@ -222,14 +230,22 @@ func (a *PersistentAuth) refresh(ctx context.Context, arg OAuthArgument, oldToke return t, nil } -func (a *PersistentAuth) Challenge(ctx context.Context, arg OAuthArgument) error { +// Challenge initiates the OAuth2 login flow for the given OAuthArgument. The +// OAuth2 flow is started by opening the browser to the OAuth2 authorization +// URL. The user is redirected to the callback server on appRedirectAddr. The +// callback server listens for the redirect from the identity provider and +// exchanges the authorization code for an access token. It returns the OAuth2 +// token on success. +func (a *PersistentAuth) Challenge(ctx context.Context, arg OAuthArgument) (*oauth2.Token, error) { a.locker.Lock() defer a.locker.Unlock() - a.validateArg(arg) + if err := a.validateArg(arg); err != nil { + return nil, err + } err := a.startListener(ctx) if err != nil { - return fmt.Errorf("starting listener: %w", err) + return nil, fmt.Errorf("starting listener: %w", err) } // The listener will be closed by the callback server automatically, but if // the callback server is not created, we need to close the listener manually. @@ -237,11 +253,11 @@ func (a *PersistentAuth) Challenge(ctx context.Context, arg OAuthArgument) error cfg, err := a.oauth2Config(ctx, arg) if err != nil { - return fmt.Errorf("fetching oauth config: %w", err) + return nil, fmt.Errorf("fetching oauth config: %w", err) } - cb, err := a.newCallback(ctx, arg) + cb, err := a.newCallbackServer(ctx, arg) if err != nil { - return fmt.Errorf("callback server: %w", err) + return nil, fmt.Errorf("callback server: %w", err) } defer cb.Close() @@ -251,22 +267,25 @@ func (a *PersistentAuth) Challenge(ctx context.Context, arg OAuthArgument) error ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce) t, err := ts.Token() if err != nil { - return fmt.Errorf("authorize: %w", err) + return nil, fmt.Errorf("authorize: %w", err) } // cache token identified by host (and possibly the account id) err = a.cache.Store(arg.GetCacheKey(ctx), t) if err != nil { - return fmt.Errorf("store: %w", err) + return nil, fmt.Errorf("store: %w", err) } - return nil + return t, nil } +// startListener starts a listener on appRedirectAddr, retrying if the address +// is already in use. func (a *PersistentAuth) startListener(ctx context.Context) error { listener, err := retries.Poll(ctx, listenerTimeout, func() (*net.Listener, *retries.Err) { var lc net.ListenConfig l, err := lc.Listen(ctx, "tcp", appRedirectAddr) if err != nil { + logger.Debugf(ctx, "failed to listen on %s: %v, retrying", appRedirectAddr, err) return nil, retries.Continue(err) } return &l, nil @@ -285,6 +304,8 @@ func (a *PersistentAuth) Close() error { return a.ln.Close() } +// validateArg ensures that the OAuthArgument is either a WorkspaceOAuthArgument +// or an AccountOAuthArgument. func (a *PersistentAuth) validateArg(arg OAuthArgument) error { _, isWorkspaceArg := arg.(WorkspaceOAuthArgument) _, isAccountArg := arg.(AccountOAuthArgument) @@ -294,6 +315,7 @@ func (a *PersistentAuth) validateArg(arg OAuthArgument) error { return nil } +// oauth2Config returns the OAuth2 configuration for the given OAuthArgument. func (a *PersistentAuth) oauth2Config(ctx context.Context, arg OAuthArgument) (*oauth2.Config, error) { scopes := []string{ "offline_access", // ensures OAuth token includes refresh token diff --git a/credentials/oauth/persistent_auth_test.go b/credentials/oauth/persistent_auth_test.go index 5099528f0..5410aab30 100644 --- a/credentials/oauth/persistent_auth_test.go +++ b/credentials/oauth/persistent_auth_test.go @@ -2,7 +2,6 @@ package oauth_test import ( "context" - _ "embed" "fmt" "net/http" "net/url" @@ -169,9 +168,14 @@ func TestChallenge(t *testing.T) { arg, err := oauth.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") assert.NoError(t, err) + tokenc := make(chan *oauth2.Token) errc := make(chan error) go func() { - errc <- p.Challenge(ctx, arg) + token, err := p.Challenge(ctx, arg) + errc <- err + close(errc) + tokenc <- token + close(tokenc) }() state := <-browserOpened @@ -182,6 +186,7 @@ func TestChallenge(t *testing.T) { err = <-errc assert.NoError(t, err) + assert.Equal(t, "__THAT__", (<-tokenc).AccessToken) } func TestChallengeFailed(t *testing.T) { @@ -204,9 +209,14 @@ func TestChallengeFailed(t *testing.T) { arg, err := oauth.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") assert.NoError(t, err) + tokenc := make(chan *oauth2.Token) errc := make(chan error) go func() { - errc <- p.Challenge(ctx, arg) + token, err := p.Challenge(ctx, arg) + errc <- err + close(errc) + tokenc <- token + close(tokenc) }() <-browserOpened @@ -218,4 +228,5 @@ func TestChallengeFailed(t *testing.T) { err = <-errc assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request") + assert.Nil(t, <-tokenc) } diff --git a/credentials/oauth/workspace_oauth_argument.go b/credentials/oauth/workspace_oauth_argument.go new file mode 100644 index 000000000..f2552910e --- /dev/null +++ b/credentials/oauth/workspace_oauth_argument.go @@ -0,0 +1,52 @@ +package oauth + +import ( + "context" + "fmt" + "strings" +) + +// WorkspaceOAuthArgument is an interface that provides the necessary information +// to authenticate using OAuth to a specific workspace. +type WorkspaceOAuthArgument interface { + OAuthArgument + + // GetWorkspaceHost returns the host of the workspace to authenticate to. + GetWorkspaceHost(ctx context.Context) string +} + +// BasicWorkspaceOAuthArgument is a basic implementation of the WorkspaceOAuthArgument +// interface that links each host with exactly one OAuth token. +type BasicWorkspaceOAuthArgument struct { + // host is the host of the workspace to authenticate to. This must start + // with "https://" and must not have a trailing slash. + host string +} + +// NewBasicWorkspaceOAuthArgument creates a new BasicWorkspaceOAuthArgument. +func NewBasicWorkspaceOAuthArgument(host string) (BasicWorkspaceOAuthArgument, error) { + if !strings.HasPrefix(host, "https://") { + return BasicWorkspaceOAuthArgument{}, fmt.Errorf("host must start with 'https://': %s", host) + } + if strings.HasSuffix(host, "/") { + return BasicWorkspaceOAuthArgument{}, fmt.Errorf("host must not have a trailing slash: %s", host) + } + return BasicWorkspaceOAuthArgument{host: host}, nil +} + +// GetWorkspaceHost returns the host of the workspace to authenticate to. +func (a BasicWorkspaceOAuthArgument) GetWorkspaceHost(ctx context.Context) string { + return a.host +} + +// GetCacheKey returns a unique key for caching the OAuth token for the workspace. +// The key is in the format "". +func (a BasicWorkspaceOAuthArgument) GetCacheKey(ctx context.Context) string { + a.host = strings.TrimSuffix(a.host, "/") + if !strings.HasPrefix(a.host, "http") { + a.host = fmt.Sprintf("https://%s", a.host) + } + return a.host +} + +var _ WorkspaceOAuthArgument = BasicWorkspaceOAuthArgument{} From ef8c3f356afc8350c395c95dae8b737c4829b2ae Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Tue, 7 Jan 2025 15:42:33 +0100 Subject: [PATCH 19/44] fix --- config/auth_u2m.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/auth_u2m.go b/config/auth_u2m.go index ff2c838fb..5050a01a6 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -46,9 +46,9 @@ type U2MCredentials struct { // Name implements CredentialsStrategy. func (u U2MCredentials) Name() string { if u.name != "" { - return "oauth-u2m" + return u.name } - return u.name + return "oauth-u2m" } // Configure implements CredentialsStrategy. From 0ff447d7489e33ce18b3817e78d6b41a20c87b9a Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Tue, 7 Jan 2025 15:56:42 +0100 Subject: [PATCH 20/44] fix --- config/auth_u2m.go | 3 ++- config/auth_u2m_test.go | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/config/auth_u2m.go b/config/auth_u2m.go index 5050a01a6..d6887df68 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -166,7 +166,8 @@ var databricksCliCredentials = U2MCredentials{ // return a special error message for invalid refresh tokens. To help // users easily reauthenticate, include a command that the user can // run, prepopulating the profile, host and/or account ID. - if _, ok := err.(*oauth.InvalidRefreshTokenError); ok { + target := &oauth.InvalidRefreshTokenError{} + if errors.As(err, &target) { return &CliInvalidRefreshTokenError{ loginCommand: buildLoginCommand(ctx, cfg.Profile, arg), err: err, diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index 926775c38..b0b84e9a5 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -3,6 +3,7 @@ package config import ( "context" "errors" + "fmt" "net/http" "testing" "time" @@ -126,7 +127,7 @@ func TestU2MCredentials(t *testing.T) { } func TestDatabricksCli_ErrorHandler(t *testing.T) { - invalidRefreshTokenError := &oauth.InvalidRefreshTokenError{} + invalidRefreshTokenError := fmt.Errorf("refresh: %w", &oauth.InvalidRefreshTokenError{}) workspaceArg := func() (oauth.OAuthArgument, error) { return oauth.NewBasicWorkspaceOAuthArgument("https://myworkspace.cloud.databricks.com") } From e9f3732ef745f82e4d196bb60e085cbaa0b73c3b Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Wed, 8 Jan 2025 12:01:22 +0100 Subject: [PATCH 21/44] better error message --- credentials/oauth/persistent_auth.go | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/credentials/oauth/persistent_auth.go b/credentials/oauth/persistent_auth.go index 59be2d967..1c5b66837 100644 --- a/credentials/oauth/persistent_auth.go +++ b/credentials/oauth/persistent_auth.go @@ -5,7 +5,6 @@ import ( "crypto/rand" "crypto/sha256" "encoding/base64" - "encoding/json" "errors" "fmt" "net" @@ -206,26 +205,20 @@ func (a *PersistentAuth) refresh(ctx context.Context, arg OAuthArgument, oldToke // eagerly refresh token t, err := cfg.TokenSource(ctx, oldToken).Token() if err != nil { - var httpErr *httpclient.HttpError + var httpErr *oauth2.RetrieveError if errors.As(err, &httpErr) { - resp := &tokenErrorResponse{} - err = json.Unmarshal([]byte(httpErr.Message), resp) - if err != nil { - return nil, fmt.Errorf("unexpected parsing token response: %w", err) - } // Invalid refresh tokens get their own error type so they can be // better presented to users. - if resp.ErrorDescription == "Refresh token is invalid" { + if httpErr.ErrorDescription == "Refresh token is invalid" { return nil, &InvalidRefreshTokenError{err} - } else { - return nil, fmt.Errorf("unexpected error refreshing token: %s", resp.ErrorDescription) } + return nil, fmt.Errorf("%s (error code: %s)", httpErr.ErrorDescription, httpErr.ErrorCode) } - return nil, fmt.Errorf("token refresh: %w", err) + return nil, err } err = a.cache.Store(arg.GetCacheKey(ctx), t) if err != nil { - return nil, fmt.Errorf("cache refresh: %w", err) + return nil, fmt.Errorf("cache update: %w", err) } return t, nil } From 19a34ce0b10eba0e995a38b54c99ea3234ae4676 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Wed, 8 Jan 2025 12:04:02 +0100 Subject: [PATCH 22/44] fix test --- config/auth_u2m_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index b0b84e9a5..cad1d079d 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -97,7 +97,7 @@ func TestU2MCredentials(t *testing.T) { }), ) }, - expectErr: "oidc: token refresh: token refresh: oauth2: \"invalid_refresh_token\" \"Refresh token is invalid\"", + expectErr: "oidc: token refresh: oauth2: \"invalid_refresh_token\" \"Refresh token is invalid\"", }, } From c7b515504c9c5f6b450711b1449c7894511af0e3 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Wed, 8 Jan 2025 12:11:54 +0100 Subject: [PATCH 23/44] fix --- credentials/oauth/persistent_auth.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/credentials/oauth/persistent_auth.go b/credentials/oauth/persistent_auth.go index 1c5b66837..545f61fcb 100644 --- a/credentials/oauth/persistent_auth.go +++ b/credentials/oauth/persistent_auth.go @@ -151,13 +151,6 @@ func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*Pers return p, nil } -// tokenErrorResponse is the response from the OAuth2 token endpoint when an -// error occurs. -type tokenErrorResponse struct { - Error string `json:"error"` - ErrorDescription string `json:"error_description"` -} - // Load loads the OAuth2 token for the given OAuthArgument from the cache. If // the token is expired, it is refreshed using the refresh token. func (a *PersistentAuth) Load(ctx context.Context, arg OAuthArgument) (t *oauth2.Token, err error) { From 37b44b1f9ae07600033849b2550b482c68f5a541 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Wed, 8 Jan 2025 13:53:12 +0100 Subject: [PATCH 24/44] tweak --- config/auth_u2m.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/config/auth_u2m.go b/config/auth_u2m.go index 0e989445b..671d7af1a 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -123,7 +123,8 @@ type CliInvalidRefreshTokenError struct { } func (e *CliInvalidRefreshTokenError) Error() string { - return fmt.Sprintf("a new access token could not be retrieved because the refresh token is invalid. If using the CLI, run `%s` to reauthenticate", e.loginCommand) + return fmt.Sprintf(`a new access token could not be retrieved because the refresh token is invalid. If using the CLI, run the following command to reauthenticate: + $ %s`, e.loginCommand) } func (e *CliInvalidRefreshTokenError) Unwrap() error { From 36ea3dce518f23dc7979e6be0d8597c2772ac2ec Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 16 Jan 2025 09:33:31 +0100 Subject: [PATCH 25/44] work --- config/auth_default.go | 2 +- config/auth_m2m_test.go | 2 +- config/auth_u2m.go | 45 +++---- config/auth_u2m_test.go | 10 +- config/config.go | 2 +- config/config_test.go | 2 +- config/credentials/credentials.go | 13 -- config/in_memory_test.go | 2 +- httpclient/oauth_token.go | 21 ++-- httpclient/request_test.go | 2 +- .../credentials}/cache/cache.go | 0 .../credentials}/cache/file.go | 38 ++++-- .../credentials}/cache/file_test.go | 0 .../credentials/cache}/lock.go | 2 +- .../oauth/account_oauth_argument.go | 0 .../credentials}/oauth/callback.go | 0 internal/credentials/oauth/client.go | 37 ++++++ .../credentials}/oauth/error.go | 0 .../credentials}/oauth/oauth_argument.go | 0 .../credentials}/oauth/oidc.go | 0 .../credentials}/oauth/oidc_test.go | 0 .../credentials}/oauth/page.tmpl | 0 .../credentials}/oauth/persistent_auth.go | 119 +++++++----------- .../oauth/persistent_auth_test.go | 2 +- .../oauth/workspace_oauth_argument.go | 0 25 files changed, 163 insertions(+), 136 deletions(-) rename {credentials => internal/credentials}/cache/cache.go (100%) rename {credentials => internal/credentials}/cache/file.go (86%) rename {credentials => internal/credentials}/cache/file_test.go (100%) rename {credentials/oauth => internal/credentials/cache}/lock.go (98%) rename {credentials => internal/credentials}/oauth/account_oauth_argument.go (100%) rename {credentials => internal/credentials}/oauth/callback.go (100%) create mode 100644 internal/credentials/oauth/client.go rename {credentials => internal/credentials}/oauth/error.go (100%) rename {credentials => internal/credentials}/oauth/oauth_argument.go (100%) rename {credentials => internal/credentials}/oauth/oidc.go (100%) rename {credentials => internal/credentials}/oauth/oidc_test.go (100%) rename {credentials => internal/credentials}/oauth/page.tmpl (100%) rename {credentials => internal/credentials}/oauth/persistent_auth.go (80%) rename {credentials => internal/credentials}/oauth/persistent_auth_test.go (98%) rename {credentials => internal/credentials}/oauth/workspace_oauth_argument.go (100%) diff --git a/config/auth_default.go b/config/auth_default.go index f20111d03..2170f3d2b 100644 --- a/config/auth_default.go +++ b/config/auth_default.go @@ -13,7 +13,7 @@ var authProviders = []CredentialsStrategy{ PatCredentials{}, BasicCredentials{}, M2mCredentials{}, - databricksCliCredentials, + DatabricksCliCredentials, MetadataServiceCredentials{}, // Attempt to configure auth from most specific to most generic (the Azure CLI). diff --git a/config/auth_m2m_test.go b/config/auth_m2m_test.go index 7181be436..a7ae9b0a6 100644 --- a/config/auth_m2m_test.go +++ b/config/auth_m2m_test.go @@ -4,8 +4,8 @@ import ( "net/url" "testing" - "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" + "github.com/databricks/databricks-sdk-go/internal/credentials/oauth" "github.com/stretchr/testify/require" "golang.org/x/oauth2" ) diff --git a/config/auth_u2m.go b/config/auth_u2m.go index 671d7af1a..86875e02e 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -8,12 +8,12 @@ import ( "strings" "github.com/databricks/databricks-sdk-go/config/credentials" - "github.com/databricks/databricks-sdk-go/credentials/cache" - "github.com/databricks/databricks-sdk-go/credentials/oauth" + "github.com/databricks/databricks-sdk-go/internal/credentials/cache" + "github.com/databricks/databricks-sdk-go/internal/credentials/oauth" "github.com/databricks/databricks-sdk-go/logger" ) -// U2MCredentials is a credentials strategy that uses the U2M OAuth flow to +// u2mCredentials is a credentials strategy that uses the U2M OAuth flow to // authenticate with Databricks. // // To authenticate with U2M OAuth, the user must already have an existing OAuth @@ -24,27 +24,27 @@ import ( // Error handling for this strategy is controlled by the ErrorHandler field. If // ErrorHandler is not specified, any error will cause Configure() to return said // error. -type U2MCredentials struct { - // Auth is the persistent auth object to use. If not specified, a new one will +type u2mCredentials struct { + // auth is the persistent auth object to use. If not specified, a new one will // be created, using the default cache and locker. - Auth *oauth.PersistentAuth + auth *oauth.PersistentAuth - // GetOAuthArg is a function that returns the OAuth argument to use for + // getOAuthArg is a function that returns the OAuth argument to use for // loading the OAuth session token. If not specified, the OAuth argument is // determined by the account host and account ID or workspace host in the // Config. - GetOAuthArg func(context.Context, *Config) (oauth.OAuthArgument, error) + getOAuthArg func(context.Context, *Config) (oauth.OAuthArgument, error) - // ErrorHandler controls the behavior of Configure() when loading the OAuth + // errorHandler controls the behavior of Configure() when loading the OAuth // token fails. If not specified, any error will cause Configure() to return // said error. - ErrorHandler func(context.Context, *Config, oauth.OAuthArgument, error) error + errorHandler func(context.Context, *Config, oauth.OAuthArgument, error) error name string } // Name implements CredentialsStrategy. -func (u U2MCredentials) Name() string { +func (u u2mCredentials) Name() string { if u.name != "" { return u.name } @@ -52,11 +52,11 @@ func (u U2MCredentials) Name() string { } // Configure implements CredentialsStrategy. -func (u U2MCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { +func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { if cfg.Host == "" { return nil, nil } - a := u.Auth + a := u.auth if a == nil { var err error a, err = oauth.NewPersistentAuth(ctx) @@ -68,8 +68,8 @@ func (u U2MCredentials) Configure(ctx context.Context, cfg *Config) (credentials var arg oauth.OAuthArgument var err error - if u.GetOAuthArg != nil { - arg, err = u.GetOAuthArg(ctx, cfg) + if u.getOAuthArg != nil { + arg, err = u.getOAuthArg(ctx, cfg) } else { arg, err = defaultGetOAuthArg(ctx, cfg) } @@ -96,8 +96,8 @@ func (u U2MCredentials) Configure(ctx context.Context, cfg *Config) (credentials // (e.g. expired), return an error. Otherwise, fall back to the next // credentials strategy. if err := f(r); err != nil { - if u.ErrorHandler != nil { - return nil, u.ErrorHandler(ctx, cfg, arg, err) + if u.errorHandler != nil { + return nil, u.errorHandler(ctx, cfg, arg, err) } return nil, err } @@ -112,7 +112,7 @@ func defaultGetOAuthArg(_ context.Context, cfg *Config) (oauth.OAuthArgument, er return oauth.NewBasicWorkspaceOAuthArgument(cfg.Host) } -var _ CredentialsStrategy = U2MCredentials{} +var _ CredentialsStrategy = u2mCredentials{} // CliInvalidRefreshTokenError is a special error type that is returned when a // new access token could not be retrieved because the refresh token is invalid. @@ -124,6 +124,7 @@ type CliInvalidRefreshTokenError struct { func (e *CliInvalidRefreshTokenError) Error() string { return fmt.Sprintf(`a new access token could not be retrieved because the refresh token is invalid. If using the CLI, run the following command to reauthenticate: + $ %s`, e.loginCommand) } @@ -153,11 +154,11 @@ func buildLoginCommand(ctx context.Context, profile string, arg oauth.OAuthArgum return strings.Join(cmd, " ") } -// databricksCliCredentials is a credentials strategy that emulates the behavior +// DatabricksCliCredentials is a credentials strategy that emulates the behavior // of the earlier `databricks-cli` credentials strategy which invoked the // `databricks auth token` command. -var databricksCliCredentials = U2MCredentials{ - ErrorHandler: func(ctx context.Context, cfg *Config, arg oauth.OAuthArgument, err error) error { +var DatabricksCliCredentials = u2mCredentials{ + errorHandler: func(ctx context.Context, cfg *Config, arg oauth.OAuthArgument, err error) error { // If the current OAuth argument doesn't have a corresponding session // token, fall back to the next credentials strategy. if errors.Is(err, cache.ErrNotConfigured) { @@ -178,6 +179,6 @@ var databricksCliCredentials = U2MCredentials{ logger.Debugf(ctx, "failed to load token: %v, continuing", err) return nil }, - GetOAuthArg: defaultGetOAuthArg, + getOAuthArg: defaultGetOAuthArg, name: "databricks-cli", } diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index cad1d079d..7901bca82 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -8,9 +8,9 @@ import ( "testing" "time" - "github.com/databricks/databricks-sdk-go/credentials/cache" - "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" + "github.com/databricks/databricks-sdk-go/internal/credentials/cache" + "github.com/databricks/databricks-sdk-go/internal/credentials/oauth" "github.com/stretchr/testify/require" "golang.org/x/oauth2" ) @@ -106,8 +106,8 @@ func TestU2MCredentials(t *testing.T) { ctx := context.Background() auth, err := tt.auth() require.NoError(t, err) - strat := U2MCredentials{ - Auth: auth, + strat := u2mCredentials{ + auth: auth, } provider, err := strat.Configure(ctx, tt.cfg) if tt.expectErr != "" { @@ -179,7 +179,7 @@ func TestDatabricksCli_ErrorHandler(t *testing.T) { t.Run(tc.name, func(t *testing.T) { arg, err := tc.arg() require.NoError(t, err) - got := databricksCliCredentials.ErrorHandler(context.Background(), tc.cfg, arg, tc.err) + got := DatabricksCliCredentials.errorHandler(context.Background(), tc.cfg, arg, tc.err) require.Equal(t, tc.want, got) }) } diff --git a/config/config.go b/config/config.go index 8469a957e..f2aa02668 100644 --- a/config/config.go +++ b/config/config.go @@ -14,8 +14,8 @@ import ( "github.com/databricks/databricks-sdk-go/common" "github.com/databricks/databricks-sdk-go/common/environment" "github.com/databricks/databricks-sdk-go/config/credentials" - "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/databricks/databricks-sdk-go/internal/credentials/oauth" "github.com/databricks/databricks-sdk-go/logger" "golang.org/x/oauth2" ) diff --git a/config/config_test.go b/config/config_test.go index 9cebb1a15..c282f6143 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -5,8 +5,8 @@ import ( "net/http" "testing" - "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" + "github.com/databricks/databricks-sdk-go/internal/credentials/oauth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/config/credentials/credentials.go b/config/credentials/credentials.go index fe1041f15..ca796e777 100644 --- a/config/credentials/credentials.go +++ b/config/credentials/credentials.go @@ -53,16 +53,3 @@ func NewOAuthCredentialsProvider(visitor func(r *http.Request) error, tokenProvi token: tokenProvider, } } - -// OAuthToken represents an OAuth token as defined by the OAuth 2.0 Authorization Framework. -// https://datatracker.ietf.org/doc/html/rfc6749 -type OAuthToken struct { - // The access token issued by the authorization server. This is the token that will be used to authenticate requests. - AccessToken string `json:"access_token" auth:",sensitive"` - // Time in seconds until the token expires. - ExpiresIn int `json:"expires_in"` - // The scope of the token. This is a space-separated list of strings that represent the permissions granted by the token. - Scope string `json:"scope"` - // The type of token that was issued. - TokenType string `json:"token_type"` -} diff --git a/config/in_memory_test.go b/config/in_memory_test.go index 82ce6e2c7..cf67eb259 100644 --- a/config/in_memory_test.go +++ b/config/in_memory_test.go @@ -1,7 +1,7 @@ package config import ( - "github.com/databricks/databricks-sdk-go/credentials/cache" + "github.com/databricks/databricks-sdk-go/internal/credentials/cache" "golang.org/x/oauth2" ) diff --git a/httpclient/oauth_token.go b/httpclient/oauth_token.go index abbabb123..96f9a539e 100644 --- a/httpclient/oauth_token.go +++ b/httpclient/oauth_token.go @@ -10,9 +10,9 @@ import ( const JWTGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" -// GetOAuthTokenRequest is the request to get an OAuth token. It follows the OAuth 2.0 Rich Authorization Requests specification. +// getOAuthTokenRequest is the request to get an OAuth token. It follows the OAuth 2.0 Rich Authorization Requests specification. // https://datatracker.ietf.org/doc/html/rfc9396 -type GetOAuthTokenRequest struct { +type getOAuthTokenRequest struct { // Defines the method used to get the token. GrantType string `url:"grant_type"` // An array of authorization details that the token should be scoped to. This needs to be passed in string format. @@ -21,11 +21,16 @@ type GetOAuthTokenRequest struct { Assertion string `url:"assertion"` } -// OAuthToken represents an OAuth token as defined by the OAuth 2.0 Authorization Framework. -// https://datatracker.ietf.org/doc/html/rfc6749 -type OAuthToken struct { +// oAuthToken represents an OAuth token as defined by the OAuth 2.0 Authorization Framework. +// https://datatracker.ietf.org/doc/html/rfc6749. +// +// The Go SDK maintains its own implementation of OAuth because Go's oauth2 +// library lacks two features that we depend on: +// 1. The ability to use an arbitrary assertion with the JWT grant type. +// 2. The ability to set authorization_details when getting an OAuth token. +type oAuthToken struct { // The access token issued by the authorization server. This is the token that will be used to authenticate requests. - AccessToken string `json:"access_token" auth:",sensitive"` + AccessToken string `json:"access_token"` // Time in seconds until the token expires. ExpiresIn int `json:"expires_in"` // The scope of the token. This is a space-separated list of strings that represent the permissions granted by the token. @@ -41,12 +46,12 @@ type OAuthToken struct { // without warning. func (c *ApiClient) GetOAuthToken(ctx context.Context, authDetails string, token *oauth2.Token) (*oauth2.Token, error) { path := "/oidc/v1/token" - data := GetOAuthTokenRequest{ + data := getOAuthTokenRequest{ GrantType: JWTGrantType, AuthorizationDetails: authDetails, Assertion: token.AccessToken, } - var response OAuthToken + var response oAuthToken opts := []DoOption{ WithUrlEncodedData(data), WithResponseUnmarshal(&response), diff --git a/httpclient/request_test.go b/httpclient/request_test.go index 6b388af53..2346f617f 100644 --- a/httpclient/request_test.go +++ b/httpclient/request_test.go @@ -55,7 +55,7 @@ func TestMakeRequestBodyFromReader(t *testing.T) { } func TestUrlEncoding(t *testing.T) { - data := GetOAuthTokenRequest{ + data := getOAuthTokenRequest{ Assertion: "assertion", AuthorizationDetails: "[{\"a\":\"b\"}]", GrantType: "grant", diff --git a/credentials/cache/cache.go b/internal/credentials/cache/cache.go similarity index 100% rename from credentials/cache/cache.go rename to internal/credentials/cache/cache.go diff --git a/credentials/cache/file.go b/internal/credentials/cache/file.go similarity index 86% rename from credentials/cache/file.go rename to internal/credentials/cache/file.go index ee0200e62..ca8230b53 100644 --- a/credentials/cache/file.go +++ b/internal/credentials/cache/file.go @@ -37,6 +37,10 @@ const ( // } // } tokenCacheVersion = 1 + + // lockFilePath is the path of the lock file used to prevent concurrent + // reads and writes to the token cache file. + lockFilePath = ".databricks/token-cache.lock" ) // The format of the token cache file. @@ -53,13 +57,19 @@ func WithFileLocation(fileLocation string) FileTokenCacheOpt { } } +func WithLocker(locker sync.Locker) FileTokenCacheOpt { + return func(c *FileTokenCache) { + c.locker = locker + } +} + // FileTokenCache caches tokens in "~/.databricks/token-cache.json". FileTokenCache // implements the TokenCache interface. type FileTokenCache struct { fileLocation string - // mu protects the token cache file from concurrent reads and writes. - mu *sync.Mutex + // locker protects the token cache file from concurrent reads and writes. + locker sync.Locker } // NewFileTokenCache creates a new FileTokenCache. By default, the cache is @@ -69,9 +79,7 @@ type FileTokenCache struct { // file is corrupt or if its version does not match tokenCacheVersion, an error // is returned. func NewFileTokenCache(opts ...FileTokenCacheOpt) (*FileTokenCache, error) { - c := &FileTokenCache{ - mu: &sync.Mutex{}, - } + c := &FileTokenCache{} for _, opt := range opts { opt(c) } @@ -87,8 +95,8 @@ func NewFileTokenCache(opts ...FileTokenCacheOpt) (*FileTokenCache, error) { // Store implements the TokenCache interface. func (c *FileTokenCache) Store(key string, t *oauth2.Token) error { - c.mu.Lock() - defer c.mu.Unlock() + c.locker.Lock() + defer c.locker.Unlock() f, err := c.load() if err != nil { return fmt.Errorf("load: %w", err) @@ -106,8 +114,8 @@ func (c *FileTokenCache) Store(key string, t *oauth2.Token) error { // Lookup implements the TokenCache interface. func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) { - c.mu.Lock() - defer c.mu.Unlock() + c.locker.Lock() + defer c.locker.Unlock() f, err := c.load() if err != nil { return nil, fmt.Errorf("load: %w", err) @@ -157,6 +165,18 @@ func (c *FileTokenCache) init() error { return fmt.Errorf("write: %w", err) } } + // Initialize the locker if it is not already set. + if c.locker == nil { + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("home: %w", err) + } + + c.locker, err = newLocker(filepath.Join(home, lockFilePath)) + if err != nil { + return fmt.Errorf("locker: %w", err) + } + } return nil } diff --git a/credentials/cache/file_test.go b/internal/credentials/cache/file_test.go similarity index 100% rename from credentials/cache/file_test.go rename to internal/credentials/cache/file_test.go diff --git a/credentials/oauth/lock.go b/internal/credentials/cache/lock.go similarity index 98% rename from credentials/oauth/lock.go rename to internal/credentials/cache/lock.go index 93bc65cb2..dcdf35f9b 100644 --- a/credentials/oauth/lock.go +++ b/internal/credentials/cache/lock.go @@ -1,4 +1,4 @@ -package oauth +package cache import ( "fmt" diff --git a/credentials/oauth/account_oauth_argument.go b/internal/credentials/oauth/account_oauth_argument.go similarity index 100% rename from credentials/oauth/account_oauth_argument.go rename to internal/credentials/oauth/account_oauth_argument.go diff --git a/credentials/oauth/callback.go b/internal/credentials/oauth/callback.go similarity index 100% rename from credentials/oauth/callback.go rename to internal/credentials/oauth/callback.go diff --git a/internal/credentials/oauth/client.go b/internal/credentials/oauth/client.go new file mode 100644 index 000000000..0e233d14d --- /dev/null +++ b/internal/credentials/oauth/client.go @@ -0,0 +1,37 @@ +package oauth + +import ( + "context" + "net/http" + + "github.com/databricks/databricks-sdk-go/httpclient" +) + +// OAuthClient provides the http functionality needed for interacting with the +// Databricks OAuth APIs. +type OAuthClient interface { + // GetHttpClient returns an HTTP client for OAuth2 requests. + GetHttpClient(context.Context) *http.Client + + // GetWorkspaceOAuthEndpoints returns the OAuth2 endpoints for the workspace. + GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) + + // GetAccountOAuthEndpoints returns the OAuth2 endpoints for the account. + GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*OAuthAuthorizationServer, error) +} + +type BasicOAuthClient struct { + client *httpclient.ApiClient +} + +func (c *BasicOAuthClient) GetHttpClient(_ context.Context) *http.Client { + return c.client.ToHttpClient() +} + +func (c *BasicOAuthClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) { + return GetWorkspaceOAuthEndpoints(ctx, c.client, workspaceHost) +} + +func (c *BasicOAuthClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*OAuthAuthorizationServer, error) { + return GetAccountOAuthEndpoints(ctx, accountHost, accountId) +} diff --git a/credentials/oauth/error.go b/internal/credentials/oauth/error.go similarity index 100% rename from credentials/oauth/error.go rename to internal/credentials/oauth/error.go diff --git a/credentials/oauth/oauth_argument.go b/internal/credentials/oauth/oauth_argument.go similarity index 100% rename from credentials/oauth/oauth_argument.go rename to internal/credentials/oauth/oauth_argument.go diff --git a/credentials/oauth/oidc.go b/internal/credentials/oauth/oidc.go similarity index 100% rename from credentials/oauth/oidc.go rename to internal/credentials/oauth/oidc.go diff --git a/credentials/oauth/oidc_test.go b/internal/credentials/oauth/oidc_test.go similarity index 100% rename from credentials/oauth/oidc_test.go rename to internal/credentials/oauth/oidc_test.go diff --git a/credentials/oauth/page.tmpl b/internal/credentials/oauth/page.tmpl similarity index 100% rename from credentials/oauth/page.tmpl rename to internal/credentials/oauth/page.tmpl diff --git a/credentials/oauth/persistent_auth.go b/internal/credentials/oauth/persistent_auth.go similarity index 80% rename from credentials/oauth/persistent_auth.go rename to internal/credentials/oauth/persistent_auth.go index 545f61fcb..62abc31a5 100644 --- a/credentials/oauth/persistent_auth.go +++ b/internal/credentials/oauth/persistent_auth.go @@ -5,17 +5,14 @@ import ( "crypto/rand" "crypto/sha256" "encoding/base64" + "encoding/json" "errors" "fmt" "net" - "net/http" - "os" - "path/filepath" - "sync" "time" - "github.com/databricks/databricks-sdk-go/credentials/cache" "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/databricks/databricks-sdk-go/internal/credentials/cache" "github.com/databricks/databricks-sdk-go/logger" "github.com/databricks/databricks-sdk-go/retries" "github.com/pkg/browser" @@ -30,42 +27,10 @@ const ( appClientID = "databricks-cli" appRedirectAddr = "localhost:8020" - // lockfile location - lockFilePath = ".databricks/token-cache.lock" - // maximum amount of time to acquire listener on appRedirectAddr listenerTimeout = 45 * time.Second ) -// OAuthClient provides the http functionality needed for interacting with the -// Databricks OAuth APIs. -type OAuthClient interface { - // GetHttpClient returns an HTTP client for OAuth2 requests. - GetHttpClient(context.Context) *http.Client - - // GetWorkspaceOAuthEndpoints returns the OAuth2 endpoints for the workspace. - GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) - - // GetAccountOAuthEndpoints returns the OAuth2 endpoints for the account. - GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*OAuthAuthorizationServer, error) -} - -type BasicOAuthClient struct { - client *httpclient.ApiClient -} - -func (c *BasicOAuthClient) GetHttpClient(_ context.Context) *http.Client { - return c.client.ToHttpClient() -} - -func (c *BasicOAuthClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) { - return GetWorkspaceOAuthEndpoints(ctx, c.client, workspaceHost) -} - -func (c *BasicOAuthClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*OAuthAuthorizationServer, error) { - return GetAccountOAuthEndpoints(ctx, accountHost, accountId) -} - // PersistentAuth is an OAuth manager that handles the U2M OAuth flow. Tokens // are stored in and looked up from the provided cache. Tokens include the // refresh token. On load, if the access token is expired, it is refreshed @@ -76,8 +41,6 @@ func (c *BasicOAuthClient) GetAccountOAuthEndpoints(ctx context.Context, account type PersistentAuth struct { // Cache is the token cache to store and lookup tokens. cache cache.TokenCache - // Locker is the lock to synchronize token cache access. - locker sync.Locker // Client is the HTTP client to use for OAuth2 requests. client OAuthClient // Browser is the function to open a URL in the default browser. @@ -95,13 +58,6 @@ func WithTokenCache(c cache.TokenCache) PersistentAuthOption { } } -// WithLocker sets the locker for the PersistentAuth. -func WithLocker(l sync.Locker) PersistentAuthOption { - return func(a *PersistentAuth) { - a.locker = l - } -} - // WithApiClient sets the HTTP client for the PersistentAuth. func WithOAuthClient(c OAuthClient) PersistentAuthOption { return func(a *PersistentAuth) { @@ -134,17 +90,6 @@ func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*Pers return nil, fmt.Errorf("cache: %w", err) } } - if p.locker == nil { - home, err := os.UserHomeDir() - if err != nil { - return nil, fmt.Errorf("home: %w", err) - } - - p.locker, err = newLocker(filepath.Join(home, lockFilePath)) - if err != nil { - return nil, fmt.Errorf("locker: %w", err) - } - } if p.browser == nil { p.browser = browser.OpenURL } @@ -154,9 +99,6 @@ func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*Pers // Load loads the OAuth2 token for the given OAuthArgument from the cache. If // the token is expired, it is refreshed using the refresh token. func (a *PersistentAuth) Load(ctx context.Context, arg OAuthArgument) (t *oauth2.Token, err error) { - a.locker.Lock() - defer a.locker.Unlock() - if err := a.validateArg(arg); err != nil { return nil, err } @@ -198,6 +140,30 @@ func (a *PersistentAuth) refresh(ctx context.Context, arg OAuthArgument, oldToke // eagerly refresh token t, err := cfg.TokenSource(ctx, oldToken).Token() if err != nil { + // The default RoundTripper of our httpclient.ApiClient returns an error + // if the response status code is not 2xx. This isn't compliant with the + // RoundTripper interface, so this error isn't handled by the oauth2 + // library. We need to handle it here. + var internalHttpError *httpclient.HttpError + if errors.As(err, &internalHttpError) { + // error fields + // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 + var errResponse struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + if unmarshalErr := json.Unmarshal([]byte(internalHttpError.Message), &errResponse); unmarshalErr != nil { + return nil, fmt.Errorf("unmarshal: %w", unmarshalErr) + } + // Invalid refresh tokens get their own error type so they can be + // better presented to users. + if errResponse.ErrorDescription == "Refresh token is invalid" { + return nil, &InvalidRefreshTokenError{err} + } + return nil, fmt.Errorf("%s (error code: %s)", errResponse.ErrorDescription, errResponse.Error) + } + + // Handle responses from well-behaved *http.Client implementations. var httpErr *oauth2.RetrieveError if errors.As(err, &httpErr) { // Invalid refresh tokens get their own error type so they can be @@ -223,9 +189,6 @@ func (a *PersistentAuth) refresh(ctx context.Context, arg OAuthArgument, oldToke // exchanges the authorization code for an access token. It returns the OAuth2 // token on success. func (a *PersistentAuth) Challenge(ctx context.Context, arg OAuthArgument) (*oauth2.Token, error) { - a.locker.Lock() - defer a.locker.Unlock() - if err := a.validateArg(arg); err != nil { return nil, err } @@ -247,7 +210,10 @@ func (a *PersistentAuth) Challenge(ctx context.Context, arg OAuthArgument) (*oau } defer cb.Close() - state, pkce := a.stateAndPKCE() + state, pkce, err := a.stateAndPKCE() + if err != nil { + return nil, fmt.Errorf("state and pkce: %w", err) + } // make OAuth2 library use our client ctx = a.setOAuthContext(ctx) ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce) @@ -333,21 +299,32 @@ func (a *PersistentAuth) oauth2Config(ctx context.Context, arg OAuthArgument) (* }, nil } -func (a *PersistentAuth) stateAndPKCE() (string, *authhandler.PKCEParams) { - verifier := a.randomString(64) +func (a *PersistentAuth) stateAndPKCE() (string, *authhandler.PKCEParams, error) { + verifier, err := a.randomString(64) + if err != nil { + return "", nil, fmt.Errorf("verifier: %w", err) + } verifierSha256 := sha256.Sum256([]byte(verifier)) challenge := base64.RawURLEncoding.EncodeToString(verifierSha256[:]) - return a.randomString(16), &authhandler.PKCEParams{ + state, err := a.randomString(16) + if err != nil { + return "", nil, fmt.Errorf("state: %w", err) + } + return state, &authhandler.PKCEParams{ Challenge: challenge, ChallengeMethod: "S256", Verifier: verifier, - } + }, nil } -func (a *PersistentAuth) randomString(size int) string { +func (a *PersistentAuth) randomString(size int) (string, error) { raw := make([]byte, size) - _, _ = rand.Read(raw) - return base64.RawURLEncoding.EncodeToString(raw) + // ignore error as rand.Reader never returns an error + _, err := rand.Read(raw) + if err != nil { + return "", fmt.Errorf("rand.Read: %w", err) + } + return base64.RawURLEncoding.EncodeToString(raw), nil } func (a *PersistentAuth) setOAuthContext(ctx context.Context) context.Context { diff --git a/credentials/oauth/persistent_auth_test.go b/internal/credentials/oauth/persistent_auth_test.go similarity index 98% rename from credentials/oauth/persistent_auth_test.go rename to internal/credentials/oauth/persistent_auth_test.go index 5410aab30..80997c7fe 100644 --- a/credentials/oauth/persistent_auth_test.go +++ b/internal/credentials/oauth/persistent_auth_test.go @@ -8,8 +8,8 @@ import ( "testing" "time" - "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" + "github.com/databricks/databricks-sdk-go/internal/credentials/oauth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" diff --git a/credentials/oauth/workspace_oauth_argument.go b/internal/credentials/oauth/workspace_oauth_argument.go similarity index 100% rename from credentials/oauth/workspace_oauth_argument.go rename to internal/credentials/oauth/workspace_oauth_argument.go From fc87393cc3963fab0df1b5c9d60e1be07947c78f Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 17 Jan 2025 16:42:48 +0100 Subject: [PATCH 26/44] fix test --- openapi/roll/tool.go | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/openapi/roll/tool.go b/openapi/roll/tool.go index 9acf159bb..3ed40f355 100644 --- a/openapi/roll/tool.go +++ b/openapi/roll/tool.go @@ -20,34 +20,32 @@ func NewSuite(dirname string) (*Suite, error) { fset: fset, ServiceToPackage: map[string]string{}, } - err := filepath.WalkDir(dirname, func(path string, info os.DirEntry, err error) error { - if err != nil { - return err - } - if info.IsDir() { - return nil + entries, err := os.ReadDir(dirname) + if err != nil { + return nil, err + } + for _, entry := range entries { + if entry.IsDir() { + continue } + path := filepath.Join(dirname, entry.Name()) if strings.HasSuffix(path, "acceptance_test.go") { // not transpilable - return nil + continue } if strings.HasSuffix(path, "files_test.go") { // not transpilable - return nil + continue } if strings.HasSuffix(path, "workspaceconf_test.go") { // not transpilable - return nil + continue } file, err := parser.ParseFile(fset, path, nil, parser.ParseComments) if err != nil { - return err + return nil, err } s.expectExamples(file) - return nil - }) - if err != nil { - return nil, err } err = s.parsePackages(dirname+"/../workspace_client.go", "WorkspaceClient") if err != nil { From 945151f23c9df85dc0d120ec23e5c941e72a1c22 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 17 Jan 2025 16:59:56 +0100 Subject: [PATCH 27/44] undo move --- config/auth_m2m_test.go | 2 +- config/auth_u2m.go | 4 +-- config/auth_u2m_test.go | 4 +-- config/config.go | 2 +- config/config_test.go | 2 +- config/in_memory_test.go | 2 +- .../cache/cache.go | 0 .../credentials => credentials}/cache/file.go | 0 .../cache/file_test.go | 0 .../credentials => credentials}/cache/lock.go | 0 .../oauth/account_oauth_argument.go | 0 .../oauth/callback.go | 0 .../oauth/client.go | 0 .../oauth/error.go | 0 .../oauth/oauth_argument.go | 0 .../credentials => credentials}/oauth/oidc.go | 0 .../oauth/oidc_test.go | 0 .../oauth/page.tmpl | 0 .../oauth/persistent_auth.go | 2 +- .../oauth/persistent_auth_test.go | 2 +- .../oauth/workspace_oauth_argument.go | 0 openapi/roll/tool.go | 26 ++++++++++--------- 22 files changed, 24 insertions(+), 22 deletions(-) rename {internal/credentials => credentials}/cache/cache.go (100%) rename {internal/credentials => credentials}/cache/file.go (100%) rename {internal/credentials => credentials}/cache/file_test.go (100%) rename {internal/credentials => credentials}/cache/lock.go (100%) rename {internal/credentials => credentials}/oauth/account_oauth_argument.go (100%) rename {internal/credentials => credentials}/oauth/callback.go (100%) rename {internal/credentials => credentials}/oauth/client.go (100%) rename {internal/credentials => credentials}/oauth/error.go (100%) rename {internal/credentials => credentials}/oauth/oauth_argument.go (100%) rename {internal/credentials => credentials}/oauth/oidc.go (100%) rename {internal/credentials => credentials}/oauth/oidc_test.go (100%) rename {internal/credentials => credentials}/oauth/page.tmpl (100%) rename {internal/credentials => credentials}/oauth/persistent_auth.go (99%) rename {internal/credentials => credentials}/oauth/persistent_auth_test.go (98%) rename {internal/credentials => credentials}/oauth/workspace_oauth_argument.go (100%) diff --git a/config/auth_m2m_test.go b/config/auth_m2m_test.go index a7ae9b0a6..7181be436 100644 --- a/config/auth_m2m_test.go +++ b/config/auth_m2m_test.go @@ -4,8 +4,8 @@ import ( "net/url" "testing" + "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" - "github.com/databricks/databricks-sdk-go/internal/credentials/oauth" "github.com/stretchr/testify/require" "golang.org/x/oauth2" ) diff --git a/config/auth_u2m.go b/config/auth_u2m.go index 86875e02e..22c2568f7 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -8,8 +8,8 @@ import ( "strings" "github.com/databricks/databricks-sdk-go/config/credentials" - "github.com/databricks/databricks-sdk-go/internal/credentials/cache" - "github.com/databricks/databricks-sdk-go/internal/credentials/oauth" + "github.com/databricks/databricks-sdk-go/credentials/cache" + "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/databricks/databricks-sdk-go/logger" ) diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index 7901bca82..3d9fe964a 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -8,9 +8,9 @@ import ( "testing" "time" + "github.com/databricks/databricks-sdk-go/credentials/cache" + "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" - "github.com/databricks/databricks-sdk-go/internal/credentials/cache" - "github.com/databricks/databricks-sdk-go/internal/credentials/oauth" "github.com/stretchr/testify/require" "golang.org/x/oauth2" ) diff --git a/config/config.go b/config/config.go index f2aa02668..8469a957e 100644 --- a/config/config.go +++ b/config/config.go @@ -14,8 +14,8 @@ import ( "github.com/databricks/databricks-sdk-go/common" "github.com/databricks/databricks-sdk-go/common/environment" "github.com/databricks/databricks-sdk-go/config/credentials" + "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/databricks/databricks-sdk-go/httpclient" - "github.com/databricks/databricks-sdk-go/internal/credentials/oauth" "github.com/databricks/databricks-sdk-go/logger" "golang.org/x/oauth2" ) diff --git a/config/config_test.go b/config/config_test.go index c282f6143..9cebb1a15 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -5,8 +5,8 @@ import ( "net/http" "testing" + "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" - "github.com/databricks/databricks-sdk-go/internal/credentials/oauth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/config/in_memory_test.go b/config/in_memory_test.go index cf67eb259..82ce6e2c7 100644 --- a/config/in_memory_test.go +++ b/config/in_memory_test.go @@ -1,7 +1,7 @@ package config import ( - "github.com/databricks/databricks-sdk-go/internal/credentials/cache" + "github.com/databricks/databricks-sdk-go/credentials/cache" "golang.org/x/oauth2" ) diff --git a/internal/credentials/cache/cache.go b/credentials/cache/cache.go similarity index 100% rename from internal/credentials/cache/cache.go rename to credentials/cache/cache.go diff --git a/internal/credentials/cache/file.go b/credentials/cache/file.go similarity index 100% rename from internal/credentials/cache/file.go rename to credentials/cache/file.go diff --git a/internal/credentials/cache/file_test.go b/credentials/cache/file_test.go similarity index 100% rename from internal/credentials/cache/file_test.go rename to credentials/cache/file_test.go diff --git a/internal/credentials/cache/lock.go b/credentials/cache/lock.go similarity index 100% rename from internal/credentials/cache/lock.go rename to credentials/cache/lock.go diff --git a/internal/credentials/oauth/account_oauth_argument.go b/credentials/oauth/account_oauth_argument.go similarity index 100% rename from internal/credentials/oauth/account_oauth_argument.go rename to credentials/oauth/account_oauth_argument.go diff --git a/internal/credentials/oauth/callback.go b/credentials/oauth/callback.go similarity index 100% rename from internal/credentials/oauth/callback.go rename to credentials/oauth/callback.go diff --git a/internal/credentials/oauth/client.go b/credentials/oauth/client.go similarity index 100% rename from internal/credentials/oauth/client.go rename to credentials/oauth/client.go diff --git a/internal/credentials/oauth/error.go b/credentials/oauth/error.go similarity index 100% rename from internal/credentials/oauth/error.go rename to credentials/oauth/error.go diff --git a/internal/credentials/oauth/oauth_argument.go b/credentials/oauth/oauth_argument.go similarity index 100% rename from internal/credentials/oauth/oauth_argument.go rename to credentials/oauth/oauth_argument.go diff --git a/internal/credentials/oauth/oidc.go b/credentials/oauth/oidc.go similarity index 100% rename from internal/credentials/oauth/oidc.go rename to credentials/oauth/oidc.go diff --git a/internal/credentials/oauth/oidc_test.go b/credentials/oauth/oidc_test.go similarity index 100% rename from internal/credentials/oauth/oidc_test.go rename to credentials/oauth/oidc_test.go diff --git a/internal/credentials/oauth/page.tmpl b/credentials/oauth/page.tmpl similarity index 100% rename from internal/credentials/oauth/page.tmpl rename to credentials/oauth/page.tmpl diff --git a/internal/credentials/oauth/persistent_auth.go b/credentials/oauth/persistent_auth.go similarity index 99% rename from internal/credentials/oauth/persistent_auth.go rename to credentials/oauth/persistent_auth.go index 62abc31a5..81fd41ae9 100644 --- a/internal/credentials/oauth/persistent_auth.go +++ b/credentials/oauth/persistent_auth.go @@ -11,8 +11,8 @@ import ( "net" "time" + "github.com/databricks/databricks-sdk-go/credentials/cache" "github.com/databricks/databricks-sdk-go/httpclient" - "github.com/databricks/databricks-sdk-go/internal/credentials/cache" "github.com/databricks/databricks-sdk-go/logger" "github.com/databricks/databricks-sdk-go/retries" "github.com/pkg/browser" diff --git a/internal/credentials/oauth/persistent_auth_test.go b/credentials/oauth/persistent_auth_test.go similarity index 98% rename from internal/credentials/oauth/persistent_auth_test.go rename to credentials/oauth/persistent_auth_test.go index 80997c7fe..5410aab30 100644 --- a/internal/credentials/oauth/persistent_auth_test.go +++ b/credentials/oauth/persistent_auth_test.go @@ -8,8 +8,8 @@ import ( "testing" "time" + "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" - "github.com/databricks/databricks-sdk-go/internal/credentials/oauth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" diff --git a/internal/credentials/oauth/workspace_oauth_argument.go b/credentials/oauth/workspace_oauth_argument.go similarity index 100% rename from internal/credentials/oauth/workspace_oauth_argument.go rename to credentials/oauth/workspace_oauth_argument.go diff --git a/openapi/roll/tool.go b/openapi/roll/tool.go index 3ed40f355..9acf159bb 100644 --- a/openapi/roll/tool.go +++ b/openapi/roll/tool.go @@ -20,32 +20,34 @@ func NewSuite(dirname string) (*Suite, error) { fset: fset, ServiceToPackage: map[string]string{}, } - entries, err := os.ReadDir(dirname) - if err != nil { - return nil, err - } - for _, entry := range entries { - if entry.IsDir() { - continue + err := filepath.WalkDir(dirname, func(path string, info os.DirEntry, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil } - path := filepath.Join(dirname, entry.Name()) if strings.HasSuffix(path, "acceptance_test.go") { // not transpilable - continue + return nil } if strings.HasSuffix(path, "files_test.go") { // not transpilable - continue + return nil } if strings.HasSuffix(path, "workspaceconf_test.go") { // not transpilable - continue + return nil } file, err := parser.ParseFile(fset, path, nil, parser.ParseComments) if err != nil { - return nil, err + return err } s.expectExamples(file) + return nil + }) + if err != nil { + return nil, err } err = s.parsePackages(dirname+"/../workspace_client.go", "WorkspaceClient") if err != nil { From 356a7c9aefe20a753a62b4f4d7fa6654529d4feb Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 20 Jan 2025 13:39:20 +0100 Subject: [PATCH 28/44] remove extra context --- config/auth_u2m.go | 8 ++++---- credentials/oauth/account_oauth_argument.go | 11 +++++------ credentials/oauth/callback.go | 8 ++++---- credentials/oauth/oauth_argument.go | 6 +----- credentials/oauth/persistent_auth.go | 10 +++++----- credentials/oauth/workspace_oauth_argument.go | 7 +++---- 6 files changed, 22 insertions(+), 28 deletions(-) diff --git a/config/auth_u2m.go b/config/auth_u2m.go index 22c2568f7..f47bab21d 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -135,7 +135,7 @@ func (e *CliInvalidRefreshTokenError) Unwrap() error { // buildLoginCommand returns the `databricks auth login` command that the user // can run to reauthenticate. The command is prepopulated with the profile, host // and/or account ID. -func buildLoginCommand(ctx context.Context, profile string, arg oauth.OAuthArgument) string { +func buildLoginCommand(profile string, arg oauth.OAuthArgument) string { cmd := []string{ "databricks", "auth", @@ -146,9 +146,9 @@ func buildLoginCommand(ctx context.Context, profile string, arg oauth.OAuthArgum } else { switch arg := arg.(type) { case oauth.AccountOAuthArgument: - cmd = append(cmd, "--host", arg.GetAccountHost(ctx), "--account-id", arg.GetAccountId(ctx)) + cmd = append(cmd, "--host", arg.GetAccountHost(), "--account-id", arg.GetAccountId()) case oauth.WorkspaceOAuthArgument: - cmd = append(cmd, "--host", arg.GetWorkspaceHost(ctx)) + cmd = append(cmd, "--host", arg.GetWorkspaceHost()) } } return strings.Join(cmd, " ") @@ -171,7 +171,7 @@ var DatabricksCliCredentials = u2mCredentials{ target := &oauth.InvalidRefreshTokenError{} if errors.As(err, &target) { return &CliInvalidRefreshTokenError{ - loginCommand: buildLoginCommand(ctx, cfg.Profile, arg), + loginCommand: buildLoginCommand(cfg.Profile, arg), err: err, } } diff --git a/credentials/oauth/account_oauth_argument.go b/credentials/oauth/account_oauth_argument.go index a0661182e..e8ea46ffe 100644 --- a/credentials/oauth/account_oauth_argument.go +++ b/credentials/oauth/account_oauth_argument.go @@ -1,7 +1,6 @@ package oauth import ( - "context" "fmt" "strings" ) @@ -12,10 +11,10 @@ type AccountOAuthArgument interface { OAuthArgument // GetAccountHost returns the host of the account to authenticate to. - GetAccountHost(ctx context.Context) string + GetAccountHost() string // GetAccountId returns the account ID of the account to authenticate to. - GetAccountId(ctx context.Context) string + GetAccountId() string } // BasicAccountOAuthArgument is a basic implementation of the AccountOAuthArgument @@ -39,17 +38,17 @@ func NewBasicAccountOAuthArgument(accountsHost, accountID string) (BasicAccountO } // GetAccountHost returns the host of the account to authenticate to. -func (a BasicAccountOAuthArgument) GetAccountHost(ctx context.Context) string { +func (a BasicAccountOAuthArgument) GetAccountHost() string { return a.accountHost } // GetAccountId returns the account ID of the account to authenticate to. -func (a BasicAccountOAuthArgument) GetAccountId(ctx context.Context) string { +func (a BasicAccountOAuthArgument) GetAccountId() string { return a.accountID } // GetCacheKey returns a unique key for caching the OAuth token for the account. // The key is in the format "/oidc/accounts/". -func (a BasicAccountOAuthArgument) GetCacheKey(ctx context.Context) string { +func (a BasicAccountOAuthArgument) GetCacheKey() string { return fmt.Sprintf("%s/oidc/accounts/%s", a.accountHost, a.accountID) } diff --git a/credentials/oauth/callback.go b/credentials/oauth/callback.go index 258a10a82..c8d1ecc04 100644 --- a/credentials/oauth/callback.go +++ b/credentials/oauth/callback.go @@ -93,7 +93,7 @@ func (cb *callbackServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { ErrorDescription: r.FormValue("error_description"), Code: r.FormValue("code"), State: r.FormValue("state"), - Host: cb.getHost(r.Context()), + Host: cb.getHost(), } if res.Error != "" { w.WriteHeader(http.StatusBadRequest) @@ -107,12 +107,12 @@ func (cb *callbackServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { cb.feedbackCh <- res } -func (cb *callbackServer) getHost(ctx context.Context) string { +func (cb *callbackServer) getHost() string { switch a := cb.arg.(type) { case AccountOAuthArgument: - return a.GetAccountHost(ctx) + return a.GetAccountHost() case WorkspaceOAuthArgument: - return a.GetWorkspaceHost(ctx) + return a.GetWorkspaceHost() default: return "" } diff --git a/credentials/oauth/oauth_argument.go b/credentials/oauth/oauth_argument.go index 97777876f..ce815c67d 100644 --- a/credentials/oauth/oauth_argument.go +++ b/credentials/oauth/oauth_argument.go @@ -1,9 +1,5 @@ package oauth -import ( - "context" -) - // OAuthArgument is an interface that provides the necessary information to // authenticate with PersistentAuth. Implementations of this interface must // implement either the WorkspaceOAuthArgument or AccountOAuthArgument @@ -11,5 +7,5 @@ import ( type OAuthArgument interface { // GetCacheKey returns a unique key for the OAuthArgument. This key is used // to store and retrieve the token from the token cache. - GetCacheKey(ctx context.Context) string + GetCacheKey() string } diff --git a/credentials/oauth/persistent_auth.go b/credentials/oauth/persistent_auth.go index 81fd41ae9..62ffd482f 100644 --- a/credentials/oauth/persistent_auth.go +++ b/credentials/oauth/persistent_auth.go @@ -109,7 +109,7 @@ func (a *PersistentAuth) Load(ctx context.Context, arg OAuthArgument) (t *oauth2 } defer a.Close() - key := arg.GetCacheKey(ctx) + key := arg.GetCacheKey() t, err = a.cache.Lookup(key) if err != nil { return nil, fmt.Errorf("cache: %w", err) @@ -175,7 +175,7 @@ func (a *PersistentAuth) refresh(ctx context.Context, arg OAuthArgument, oldToke } return nil, err } - err = a.cache.Store(arg.GetCacheKey(ctx), t) + err = a.cache.Store(arg.GetCacheKey(), t) if err != nil { return nil, fmt.Errorf("cache update: %w", err) } @@ -222,7 +222,7 @@ func (a *PersistentAuth) Challenge(ctx context.Context, arg OAuthArgument) (*oau return nil, fmt.Errorf("authorize: %w", err) } // cache token identified by host (and possibly the account id) - err = a.cache.Store(arg.GetCacheKey(ctx), t) + err = a.cache.Store(arg.GetCacheKey(), t) if err != nil { return nil, fmt.Errorf("store: %w", err) } @@ -277,10 +277,10 @@ func (a *PersistentAuth) oauth2Config(ctx context.Context, arg OAuthArgument) (* var err error switch argg := arg.(type) { case WorkspaceOAuthArgument: - endpoints, err = a.client.GetWorkspaceOAuthEndpoints(ctx, argg.GetWorkspaceHost(ctx)) + endpoints, err = a.client.GetWorkspaceOAuthEndpoints(ctx, argg.GetWorkspaceHost()) case AccountOAuthArgument: endpoints, err = a.client.GetAccountOAuthEndpoints( - ctx, argg.GetAccountHost(ctx), argg.GetAccountId(ctx)) + ctx, argg.GetAccountHost(), argg.GetAccountId()) default: return nil, fmt.Errorf("unsupported OAuthArgument type: %T, must implement either WorkspaceOAuthArgument or AccountOAuthArgument interface", arg) } diff --git a/credentials/oauth/workspace_oauth_argument.go b/credentials/oauth/workspace_oauth_argument.go index f2552910e..ad632a413 100644 --- a/credentials/oauth/workspace_oauth_argument.go +++ b/credentials/oauth/workspace_oauth_argument.go @@ -1,7 +1,6 @@ package oauth import ( - "context" "fmt" "strings" ) @@ -12,7 +11,7 @@ type WorkspaceOAuthArgument interface { OAuthArgument // GetWorkspaceHost returns the host of the workspace to authenticate to. - GetWorkspaceHost(ctx context.Context) string + GetWorkspaceHost() string } // BasicWorkspaceOAuthArgument is a basic implementation of the WorkspaceOAuthArgument @@ -35,13 +34,13 @@ func NewBasicWorkspaceOAuthArgument(host string) (BasicWorkspaceOAuthArgument, e } // GetWorkspaceHost returns the host of the workspace to authenticate to. -func (a BasicWorkspaceOAuthArgument) GetWorkspaceHost(ctx context.Context) string { +func (a BasicWorkspaceOAuthArgument) GetWorkspaceHost() string { return a.host } // GetCacheKey returns a unique key for caching the OAuth token for the workspace. // The key is in the format "". -func (a BasicWorkspaceOAuthArgument) GetCacheKey(ctx context.Context) string { +func (a BasicWorkspaceOAuthArgument) GetCacheKey() string { a.host = strings.TrimSuffix(a.host, "/") if !strings.HasPrefix(a.host, "http") { a.host = fmt.Sprintf("https://%s", a.host) From 7441bd1a66e53a485201cb0155e80584b42e3543 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Tue, 28 Jan 2025 17:22:44 +0100 Subject: [PATCH 29/44] work --- config/auth_u2m.go | 37 +++++++++++---------- credentials/cache/file.go | 69 +++++++++++++++------------------------ 2 files changed, 46 insertions(+), 60 deletions(-) diff --git a/config/auth_u2m.go b/config/auth_u2m.go index f47bab21d..d4ac17fd1 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -56,10 +56,9 @@ func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials if cfg.Host == "" { return nil, nil } - a := u.auth - if a == nil { + if u.auth == nil { var err error - a, err = oauth.NewPersistentAuth(ctx) + u.auth, err = oauth.NewPersistentAuth(ctx) if err != nil { logger.Debugf(ctx, "failed to create persistent auth: %v, continuing", err) return nil, nil @@ -77,32 +76,34 @@ func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials return nil, fmt.Errorf("oidc: %w", err) } + // Construct the visitor, and try to load the credential from the token + // cache. If absent, fall back to the next credentials strategy. If a token + // is present but cannot be loaded (e.g. expired), return an error. + // Otherwise, fall back to the next credentials strategy. + visitor := u.makeVisitor(arg) r, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil) if err != nil { return nil, fmt.Errorf("http request: %w", err) } + if err := visitor(r); err != nil { + if u.errorHandler != nil { + return nil, u.errorHandler(ctx, cfg, arg, err) + } + return nil, err + } + + return credentials.NewCredentialsProvider(visitor), nil +} - f := func(r *http.Request) error { - token, err := a.Load(r.Context(), arg) +func (u u2mCredentials) makeVisitor(arg oauth.OAuthArgument) func(*http.Request) error { + return func(r *http.Request) error { + token, err := u.auth.Load(r.Context(), arg) if err != nil { return fmt.Errorf("oidc: %w", err) } r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) return nil } - - // Try to load the credential from the token cache. If absent, fall back to - // the next credentials strategy. If a token is present but cannot be loaded - // (e.g. expired), return an error. Otherwise, fall back to the next - // credentials strategy. - if err := f(r); err != nil { - if u.errorHandler != nil { - return nil, u.errorHandler(ctx, cfg, arg, err) - } - return nil, err - } - - return credentials.NewCredentialsProvider(f), nil } func defaultGetOAuthArg(_ context.Context, cfg *Config) (oauth.OAuthArgument, error) { diff --git a/credentials/cache/file.go b/credentials/cache/file.go index ca8230b53..cf585aaac 100644 --- a/credentials/cache/file.go +++ b/credentials/cache/file.go @@ -49,23 +49,17 @@ type tokenCacheFile struct { Tokens map[string]*oauth2.Token `json:"tokens"` } -type FileTokenCacheOpt func(*FileTokenCache) +type FileTokenCacheOption func(*fileTokenCache) -func WithFileLocation(fileLocation string) FileTokenCacheOpt { - return func(c *FileTokenCache) { +func WithFileLocation(fileLocation string) FileTokenCacheOption { + return func(c *fileTokenCache) { c.fileLocation = fileLocation } } -func WithLocker(locker sync.Locker) FileTokenCacheOpt { - return func(c *FileTokenCache) { - c.locker = locker - } -} - -// FileTokenCache caches tokens in "~/.databricks/token-cache.json". FileTokenCache +// fileTokenCache caches tokens in "~/.databricks/token-cache.json". fileTokenCache // implements the TokenCache interface. -type FileTokenCache struct { +type fileTokenCache struct { fileLocation string // locker protects the token cache file from concurrent reads and writes. @@ -78,15 +72,15 @@ type FileTokenCache struct { // 0600 and the directory is created with owner permissions 0700. If the cache // file is corrupt or if its version does not match tokenCacheVersion, an error // is returned. -func NewFileTokenCache(opts ...FileTokenCacheOpt) (*FileTokenCache, error) { - c := &FileTokenCache{} +func NewFileTokenCache(opts ...FileTokenCacheOption) (TokenCache, error) { + c := &fileTokenCache{} for _, opt := range opts { opt(c) } if err := c.init(); err != nil { return nil, err } - // verify the cache is working + // Fail fast if the cache is not working. if _, err := c.load(); err != nil { return nil, fmt.Errorf("load: %w", err) } @@ -94,7 +88,7 @@ func NewFileTokenCache(opts ...FileTokenCacheOpt) (*FileTokenCache, error) { } // Store implements the TokenCache interface. -func (c *FileTokenCache) Store(key string, t *oauth2.Token) error { +func (c *fileTokenCache) Store(key string, t *oauth2.Token) error { c.locker.Lock() defer c.locker.Unlock() f, err := c.load() @@ -113,7 +107,7 @@ func (c *FileTokenCache) Store(key string, t *oauth2.Token) error { } // Lookup implements the TokenCache interface. -func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) { +func (c *fileTokenCache) Lookup(key string) (*oauth2.Token, error) { c.locker.Lock() defer c.locker.Unlock() f, err := c.load() @@ -129,7 +123,7 @@ func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) { // init initializes the token cache file. It creates the file and directory if // they do not already exist. -func (c *FileTokenCache) init() error { +func (c *fileTokenCache) init() error { // set the default file location if c.fileLocation == "" { home, err := os.UserHomeDir() @@ -138,21 +132,17 @@ func (c *FileTokenCache) init() error { } c.fileLocation = filepath.Join(home, tokenCacheFilePath) } - // create the directory if it doesn't already exist - if _, err := os.Stat(filepath.Dir(c.fileLocation)); err != nil { + // Create the cache file if it does not exist. + if _, err := os.Stat(c.fileLocation); err != nil { if !os.IsNotExist(err) { - return fmt.Errorf("stat directory: %w", err) + return fmt.Errorf("stat file: %w", err) } - // create the directory + // Create the parent directories if needed. if err := os.MkdirAll(filepath.Dir(c.fileLocation), ownerExecReadWrite); err != nil { return fmt.Errorf("mkdir: %w", err) } - } - // create the file if it doesn't already exist - if _, err := os.Stat(c.fileLocation); err != nil { - if !os.IsNotExist(err) { - return fmt.Errorf("stat file: %w", err) - } + + // Create an empty cache file. f := &tokenCacheFile{ Version: tokenCacheVersion, Tokens: map[string]*oauth2.Token{}, @@ -165,31 +155,28 @@ func (c *FileTokenCache) init() error { return fmt.Errorf("write: %w", err) } } - // Initialize the locker if it is not already set. - if c.locker == nil { - home, err := os.UserHomeDir() - if err != nil { - return fmt.Errorf("home: %w", err) - } + // Initialize the locker. + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("home: %w", err) + } - c.locker, err = newLocker(filepath.Join(home, lockFilePath)) - if err != nil { - return fmt.Errorf("locker: %w", err) - } + c.locker, err = newLocker(filepath.Join(home, lockFilePath)) + if err != nil { + return fmt.Errorf("locker: %w", err) } return nil } // load loads the token cache file from disk. If the file is corrupt or if its // version does not match tokenCacheVersion, it returns an error. -func (c *FileTokenCache) load() (*tokenCacheFile, error) { +func (c *fileTokenCache) load() (*tokenCacheFile, error) { raw, err := os.ReadFile(c.fileLocation) if err != nil { return nil, fmt.Errorf("read: %w", err) } f := &tokenCacheFile{} - err = json.Unmarshal(raw, &f) - if err != nil { + if err := json.Unmarshal(raw, &f); err != nil { return nil, fmt.Errorf("parse: %w", err) } if f.Version != tokenCacheVersion { @@ -200,5 +187,3 @@ func (c *FileTokenCache) load() (*tokenCacheFile, error) { } return f, nil } - -var _ TokenCache = (*FileTokenCache)(nil) From 3550bbaf4ddd4422d4ba7662ef7a5e9e6a07b44e Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 30 Jan 2025 10:30:39 +0100 Subject: [PATCH 30/44] address comment --- credentials/cache/cache.go | 19 ------ credentials/oauth/client.go | 37 ---------- credentials/oauth/oidc.go | 43 ------------ .../{oauth => u2m}/account_oauth_argument.go | 2 +- credentials/u2m/cache/cache.go | 31 +++++++++ credentials/{ => u2m}/cache/file.go | 0 credentials/{ => u2m}/cache/file_test.go | 0 credentials/{ => u2m}/cache/lock.go | 0 credentials/{oauth => u2m}/callback.go | 2 +- credentials/u2m/client.go | 67 +++++++++++++++++++ .../oidc_test.go => u2m/client_test.go} | 10 +-- credentials/u2m/doc.go | 41 ++++++++++++ credentials/{oauth => u2m}/error.go | 2 +- credentials/{oauth => u2m}/oauth_argument.go | 2 +- credentials/{oauth => u2m}/page.tmpl | 0 credentials/{oauth => u2m}/persistent_auth.go | 16 ++--- .../{oauth => u2m}/persistent_auth_test.go | 38 +++++------ .../workspace_oauth_argument.go | 2 +- 18 files changed, 177 insertions(+), 135 deletions(-) delete mode 100644 credentials/cache/cache.go delete mode 100644 credentials/oauth/client.go delete mode 100644 credentials/oauth/oidc.go rename credentials/{oauth => u2m}/account_oauth_argument.go (99%) create mode 100644 credentials/u2m/cache/cache.go rename credentials/{ => u2m}/cache/file.go (100%) rename credentials/{ => u2m}/cache/file_test.go (100%) rename credentials/{ => u2m}/cache/lock.go (100%) rename credentials/{oauth => u2m}/callback.go (99%) create mode 100644 credentials/u2m/client.go rename credentials/{oauth/oidc_test.go => u2m/client_test.go} (73%) create mode 100644 credentials/u2m/doc.go rename credentials/{oauth => u2m}/error.go (96%) rename credentials/{oauth => u2m}/oauth_argument.go (96%) rename credentials/{oauth => u2m}/page.tmpl (100%) rename credentials/{oauth => u2m}/persistent_auth.go (96%) rename credentials/{oauth => u2m}/persistent_auth_test.go (83%) rename credentials/{oauth => u2m}/workspace_oauth_argument.go (99%) diff --git a/credentials/cache/cache.go b/credentials/cache/cache.go deleted file mode 100644 index 0562b41fb..000000000 --- a/credentials/cache/cache.go +++ /dev/null @@ -1,19 +0,0 @@ -package cache - -import ( - "errors" - - "golang.org/x/oauth2" -) - -// TokenCache is an interface for storing and looking up OAuth tokens. -type TokenCache interface { - // Store stores the token with the given key, replacing any existing token. - Store(key string, t *oauth2.Token) error - - // Lookup looks up the token with the given key. If the token is not found, it - // returns ErrNotConfigured. - Lookup(key string) (*oauth2.Token, error) -} - -var ErrNotConfigured = errors.New("databricks OAuth is not configured for this host") diff --git a/credentials/oauth/client.go b/credentials/oauth/client.go deleted file mode 100644 index 0e233d14d..000000000 --- a/credentials/oauth/client.go +++ /dev/null @@ -1,37 +0,0 @@ -package oauth - -import ( - "context" - "net/http" - - "github.com/databricks/databricks-sdk-go/httpclient" -) - -// OAuthClient provides the http functionality needed for interacting with the -// Databricks OAuth APIs. -type OAuthClient interface { - // GetHttpClient returns an HTTP client for OAuth2 requests. - GetHttpClient(context.Context) *http.Client - - // GetWorkspaceOAuthEndpoints returns the OAuth2 endpoints for the workspace. - GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) - - // GetAccountOAuthEndpoints returns the OAuth2 endpoints for the account. - GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*OAuthAuthorizationServer, error) -} - -type BasicOAuthClient struct { - client *httpclient.ApiClient -} - -func (c *BasicOAuthClient) GetHttpClient(_ context.Context) *http.Client { - return c.client.ToHttpClient() -} - -func (c *BasicOAuthClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) { - return GetWorkspaceOAuthEndpoints(ctx, c.client, workspaceHost) -} - -func (c *BasicOAuthClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*OAuthAuthorizationServer, error) { - return GetAccountOAuthEndpoints(ctx, accountHost, accountId) -} diff --git a/credentials/oauth/oidc.go b/credentials/oauth/oidc.go deleted file mode 100644 index 67012d2ab..000000000 --- a/credentials/oauth/oidc.go +++ /dev/null @@ -1,43 +0,0 @@ -package oauth - -import ( - "context" - "errors" - "fmt" - - "github.com/databricks/databricks-sdk-go/httpclient" -) - -var ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host") - -// GetAccountOAuthEndpoints returns the OAuth endpoints for the given account. -func GetAccountOAuthEndpoints(ctx context.Context, accountsHost, accountId string) (*OAuthAuthorizationServer, error) { - return &OAuthAuthorizationServer{ - AuthorizationEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/authorize", accountsHost, accountId), - TokenEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/token", accountsHost, accountId), - }, nil -} - -// GetWorkspaceOAuthEndpoints returns the OAuth endpoints for the given workspace, -// It queries the OIDC discovery endpoint to get the OAuth endpoints using the -// provided ApiClient. -func GetWorkspaceOAuthEndpoints(ctx context.Context, c *httpclient.ApiClient, host string) (*OAuthAuthorizationServer, error) { - oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", host) - var oauthEndpoints OAuthAuthorizationServer - if err := c.Do(ctx, "GET", oidc, httpclient.WithResponseUnmarshal(&oauthEndpoints)); err != nil { - return nil, ErrOAuthNotSupported - } - return &oauthEndpoints, nil -} - -// OAuthAuthorizationServer contains the OAuth endpoints for a Databricks account -// or workspace. -type OAuthAuthorizationServer struct { - // AuthorizationEndpoint is the URL to redirect users to for authorization. - // It typically ends with /v1/authroize. - AuthorizationEndpoint string `json:"authorization_endpoint"` - - // TokenEndpoint is the URL to exchange an authorization code for an access token. - // It typically ends with /v1/token. - TokenEndpoint string `json:"token_endpoint"` -} diff --git a/credentials/oauth/account_oauth_argument.go b/credentials/u2m/account_oauth_argument.go similarity index 99% rename from credentials/oauth/account_oauth_argument.go rename to credentials/u2m/account_oauth_argument.go index e8ea46ffe..d28206336 100644 --- a/credentials/oauth/account_oauth_argument.go +++ b/credentials/u2m/account_oauth_argument.go @@ -1,4 +1,4 @@ -package oauth +package u2m import ( "fmt" diff --git a/credentials/u2m/cache/cache.go b/credentials/u2m/cache/cache.go new file mode 100644 index 000000000..c059fa546 --- /dev/null +++ b/credentials/u2m/cache/cache.go @@ -0,0 +1,31 @@ +/* +Package cache provides an interface for storing and looking up OAuth tokens. + +The cache should be primarily used for user-to-machine (U2M) OAuth flows. In U2M +OAuth flows, the application needs to store the token for later use, such as in +a separate process, and the cache provides a way to do so without requiring the +user to follow the OAuth flow again. + +In machine-to-machine (M2M) OAuth flows, the application is configured with a +secret and can fetch a new token on demand without user interaction, so the +token cache is not necessary. +*/ +package cache + +import ( + "errors" + + "golang.org/x/oauth2" +) + +// TokenCache is an interface for storing and looking up OAuth tokens. +type TokenCache interface { + // Store stores the token with the given key, replacing any existing token. + Store(key string, t *oauth2.Token) error + + // Lookup looks up the token with the given key. If the token is not found, it + // returns ErrNotConfigured. + Lookup(key string) (*oauth2.Token, error) +} + +var ErrNotConfigured = errors.New("databricks OAuth is not configured for this host") diff --git a/credentials/cache/file.go b/credentials/u2m/cache/file.go similarity index 100% rename from credentials/cache/file.go rename to credentials/u2m/cache/file.go diff --git a/credentials/cache/file_test.go b/credentials/u2m/cache/file_test.go similarity index 100% rename from credentials/cache/file_test.go rename to credentials/u2m/cache/file_test.go diff --git a/credentials/cache/lock.go b/credentials/u2m/cache/lock.go similarity index 100% rename from credentials/cache/lock.go rename to credentials/u2m/cache/lock.go diff --git a/credentials/oauth/callback.go b/credentials/u2m/callback.go similarity index 99% rename from credentials/oauth/callback.go rename to credentials/u2m/callback.go index c8d1ecc04..c6022300c 100644 --- a/credentials/oauth/callback.go +++ b/credentials/u2m/callback.go @@ -1,4 +1,4 @@ -package oauth +package u2m import ( "context" diff --git a/credentials/u2m/client.go b/credentials/u2m/client.go new file mode 100644 index 000000000..099360a75 --- /dev/null +++ b/credentials/u2m/client.go @@ -0,0 +1,67 @@ +package u2m + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/databricks/databricks-sdk-go/httpclient" +) + +// OAuthClient provides the http functionality needed for interacting with the +// Databricks OAuth APIs. +type OAuthClient interface { + // GetHttpClient returns an HTTP client for OAuth2 requests. + GetHttpClient(context.Context) *http.Client + + // GetWorkspaceOAuthEndpoints returns the OAuth2 endpoints for the workspace. + GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) + + // GetAccountOAuthEndpoints returns the OAuth2 endpoints for the account. + GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*OAuthAuthorizationServer, error) +} + +// BasicOAuthClient is an implementation of the OAuthClient interface. +type BasicOAuthClient struct { + client *httpclient.ApiClient +} + +func (c *BasicOAuthClient) GetHttpClient(_ context.Context) *http.Client { + return c.client.ToHttpClient() +} + +// GetWorkspaceOAuthEndpoints returns the OAuth endpoints for the given workspace. +// It queries the OIDC discovery endpoint to get the OAuth endpoints using the +// provided ApiClient. +func (c *BasicOAuthClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) { + oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", workspaceHost) + var oauthEndpoints OAuthAuthorizationServer + if err := c.client.Do(ctx, "GET", oidc, httpclient.WithResponseUnmarshal(&oauthEndpoints)); err != nil { + return nil, ErrOAuthNotSupported + } + return &oauthEndpoints, nil +} + +// GetAccountOAuthEndpoints returns the OAuth2 endpoints for the account. The +// account-level OAuth endpoints are fixed based on the account ID and host. +func (c *BasicOAuthClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*OAuthAuthorizationServer, error) { + return &OAuthAuthorizationServer{ + AuthorizationEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/authorize", accountHost, accountId), + TokenEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/token", accountHost, accountId), + }, nil +} + +var ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host") + +// OAuthAuthorizationServer contains the OAuth endpoints for a Databricks account +// or workspace. +type OAuthAuthorizationServer struct { + // AuthorizationEndpoint is the URL to redirect users to for authorization. + // It typically ends with /v1/authroize. + AuthorizationEndpoint string `json:"authorization_endpoint"` + + // TokenEndpoint is the URL to exchange an authorization code for an access token. + // It typically ends with /v1/token. + TokenEndpoint string `json:"token_endpoint"` +} diff --git a/credentials/oauth/oidc_test.go b/credentials/u2m/client_test.go similarity index 73% rename from credentials/oauth/oidc_test.go rename to credentials/u2m/client_test.go index ba199ea32..b9316b160 100644 --- a/credentials/oauth/oidc_test.go +++ b/credentials/u2m/client_test.go @@ -1,4 +1,4 @@ -package oauth +package u2m import ( "context" @@ -9,8 +9,9 @@ import ( "github.com/stretchr/testify/assert" ) -func TestGetAccountOAuthEndpoints(t *testing.T) { - s, err := GetAccountOAuthEndpoints(context.Background(), "https://abc", "xyz") +func TestBasicOAuthClient_GetAccountOAuthEndpoints(t *testing.T) { + c := &BasicOAuthClient{} + s, err := c.GetAccountOAuthEndpoints(context.Background(), "https://abc", "xyz") assert.NoError(t, err) assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", s.AuthorizationEndpoint) assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/token", s.TokenEndpoint) @@ -28,7 +29,8 @@ func TestGetWorkspaceOAuthEndpoints(t *testing.T) { }, }, }) - endpoints, err := GetWorkspaceOAuthEndpoints(context.Background(), p, "https://abc") + c := &BasicOAuthClient{client: p} + endpoints, err := c.GetWorkspaceOAuthEndpoints(context.Background(), "https://abc") assert.NoError(t, err) assert.Equal(t, "a", endpoints.AuthorizationEndpoint) assert.Equal(t, "b", endpoints.TokenEndpoint) diff --git a/credentials/u2m/doc.go b/credentials/u2m/doc.go new file mode 100644 index 000000000..85835347f --- /dev/null +++ b/credentials/u2m/doc.go @@ -0,0 +1,41 @@ +/* +Package u2m supports the user-to-machine (U2M) OAuth flow for authenticating with Databricks. + +Databricks uses the authorization code flow from OAuth 2.0 to authenticate users. This flow +consists of four steps: + 1. Retrieve an authorization code for a user by opening a browser and directing them to the + Databricks authorization URL. + 2. Exchange the authorization code for an access token. + 3. Use the access token to authenticate with Databricks. + 4. When the access token expires, use the refresh token to get a new access token. + +The token and authorization endpoints for Databricks vary depending on whether the host is +an account- or workspace-level host. Account-level endpoints are fixed based on the account +ID and host, while workspace-level endpoints are discovered using the OIDC discovery endpoint +at /oidc/.well-known/oauth-authorization-server. + +To trigger the authorization flow, construct a PersistentAuth object and call the +Challenge() method: + + auth, err := oauth.NewPersistentAuth(ctx) + if err != nil { + log.Fatalf("failed to create persistent auth: %v", err) + } + token, err := auth.Challenge(ctx, oauth.BasicAccountOAuthArgument{ + AccountHost: "https://accounts.cloud.databricks.com", + AccountID: "xyz", + }) + +Because the U2M flow requires user interaction, the resulting access token and refresh token +can be stored in a persistent cache to avoid prompting the user for credentials on every +authentication attempt. By default, the cache is stored in ~/.databricks/token-cache.json. +Retrieve the cached token by calling the Load() method: + + token, err := auth.Load(ctx, oauth.BasicAccountOAuthArgument{ + AccountHost: "https://accounts.cloud.databricks.com", + AccountID: "xyz", + }) + +See the cache package for more information on customizing the cache. +*/ +package u2m diff --git a/credentials/oauth/error.go b/credentials/u2m/error.go similarity index 96% rename from credentials/oauth/error.go rename to credentials/u2m/error.go index a68c9fdaa..6c243fc78 100644 --- a/credentials/oauth/error.go +++ b/credentials/u2m/error.go @@ -1,4 +1,4 @@ -package oauth +package u2m // InvalidRefreshTokenError is returned from PersistentAuth's Load() method // if the access token has expired and the refresh token in the token cache diff --git a/credentials/oauth/oauth_argument.go b/credentials/u2m/oauth_argument.go similarity index 96% rename from credentials/oauth/oauth_argument.go rename to credentials/u2m/oauth_argument.go index ce815c67d..f2d2ebc5d 100644 --- a/credentials/oauth/oauth_argument.go +++ b/credentials/u2m/oauth_argument.go @@ -1,4 +1,4 @@ -package oauth +package u2m // OAuthArgument is an interface that provides the necessary information to // authenticate with PersistentAuth. Implementations of this interface must diff --git a/credentials/oauth/page.tmpl b/credentials/u2m/page.tmpl similarity index 100% rename from credentials/oauth/page.tmpl rename to credentials/u2m/page.tmpl diff --git a/credentials/oauth/persistent_auth.go b/credentials/u2m/persistent_auth.go similarity index 96% rename from credentials/oauth/persistent_auth.go rename to credentials/u2m/persistent_auth.go index 62ffd482f..9b0aa73a1 100644 --- a/credentials/oauth/persistent_auth.go +++ b/credentials/u2m/persistent_auth.go @@ -1,4 +1,4 @@ -package oauth +package u2m import ( "context" @@ -11,7 +11,7 @@ import ( "net" "time" - "github.com/databricks/databricks-sdk-go/credentials/cache" + cache "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" "github.com/databricks/databricks-sdk-go/retries" @@ -21,13 +21,14 @@ import ( ) const ( - // these values are predefined by Databricks as a public client - // and is specific to this application only. Using these values - // for other applications is not allowed. - appClientID = "databricks-cli" + // appClientId is the default client ID used by the SDK for U2M OAuth. + appClientID = "databricks-cli" + + // appRedirectAddr is the default address for the OAuth2 callback server. appRedirectAddr = "localhost:8020" - // maximum amount of time to acquire listener on appRedirectAddr + // listenerTimeout is the maximum amount of time to acquire listener on + // appRedirectAddr. listenerTimeout = 45 * time.Second ) @@ -102,7 +103,6 @@ func (a *PersistentAuth) Load(ctx context.Context, arg OAuthArgument) (t *oauth2 if err := a.validateArg(arg); err != nil { return nil, err } - // TODO: remove this listener after several releases. err = a.startListener(ctx) if err != nil { return nil, fmt.Errorf("starting listener: %w", err) diff --git a/credentials/oauth/persistent_auth_test.go b/credentials/u2m/persistent_auth_test.go similarity index 83% rename from credentials/oauth/persistent_auth_test.go rename to credentials/u2m/persistent_auth_test.go index 5410aab30..82a57c6b8 100644 --- a/credentials/oauth/persistent_auth_test.go +++ b/credentials/u2m/persistent_auth_test.go @@ -1,4 +1,4 @@ -package oauth_test +package u2m_test import ( "context" @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/databricks/databricks-sdk-go/credentials/oauth" + "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -44,10 +44,10 @@ func TestLoad(t *testing.T) { }, nil }, } - p, err := oauth.NewPersistentAuth(context.Background(), oauth.WithTokenCache(cache)) + p, err := u2m.NewPersistentAuth(context.Background(), u2m.WithTokenCache(cache)) require.NoError(t, err) defer p.Close() - arg, err := oauth.NewBasicAccountOAuthArgument("https://abc", "xyz") + arg, err := u2m.NewBasicAccountOAuthArgument("https://abc", "xyz") assert.NoError(t, err) tok, err := p.Load(context.Background(), arg) assert.NoError(t, err) @@ -65,15 +65,15 @@ func (m MockOAuthClient) GetHttpClient(_ context.Context) *http.Client { } } -func (m MockOAuthClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*oauth.OAuthAuthorizationServer, error) { - return &oauth.OAuthAuthorizationServer{ +func (m MockOAuthClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ AuthorizationEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/authorize", accountHost, accountId), TokenEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/token", accountHost, accountId), }, nil } -func (m MockOAuthClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*oauth.OAuthAuthorizationServer, error) { - return &oauth.OAuthAuthorizationServer{ +func (m MockOAuthClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ AuthorizationEndpoint: fmt.Sprintf("%s/oidc/v1/authorize", workspaceHost), TokenEndpoint: fmt.Sprintf("%s/oidc/v1/token", workspaceHost), }, nil @@ -97,10 +97,10 @@ func TestLoadRefresh(t *testing.T) { return nil }, } - p, err := oauth.NewPersistentAuth( + p, err := u2m.NewPersistentAuth( context.Background(), - oauth.WithTokenCache(cache), - oauth.WithOAuthClient(&MockOAuthClient{ + u2m.WithTokenCache(cache), + u2m.WithOAuthClient(&MockOAuthClient{ Transport: fixtures.SliceTransport{ { Method: "POST", @@ -115,7 +115,7 @@ func TestLoadRefresh(t *testing.T) { ) require.NoError(t, err) defer p.Close() - arg, err := oauth.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + arg, err := u2m.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") assert.NoError(t, err) tok, err := p.Load(ctx, arg) assert.NoError(t, err) @@ -146,11 +146,11 @@ func TestChallenge(t *testing.T) { return nil }, } - p, err := oauth.NewPersistentAuth( + p, err := u2m.NewPersistentAuth( context.Background(), - oauth.WithTokenCache(cache), - oauth.WithBrowser(browser), - oauth.WithOAuthClient(&MockOAuthClient{ + u2m.WithTokenCache(cache), + u2m.WithBrowser(browser), + u2m.WithOAuthClient(&MockOAuthClient{ Transport: fixtures.SliceTransport{ { Method: "POST", @@ -165,7 +165,7 @@ func TestChallenge(t *testing.T) { ) require.NoError(t, err) defer p.Close() - arg, err := oauth.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + arg, err := u2m.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") assert.NoError(t, err) tokenc := make(chan *oauth2.Token) @@ -203,10 +203,10 @@ func TestChallengeFailed(t *testing.T) { browserOpened <- query.Get("state") return nil } - p, err := oauth.NewPersistentAuth(context.Background(), oauth.WithBrowser(browser)) + p, err := u2m.NewPersistentAuth(context.Background(), u2m.WithBrowser(browser)) require.NoError(t, err) defer p.Close() - arg, err := oauth.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + arg, err := u2m.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") assert.NoError(t, err) tokenc := make(chan *oauth2.Token) diff --git a/credentials/oauth/workspace_oauth_argument.go b/credentials/u2m/workspace_oauth_argument.go similarity index 99% rename from credentials/oauth/workspace_oauth_argument.go rename to credentials/u2m/workspace_oauth_argument.go index ad632a413..99a8a21e5 100644 --- a/credentials/oauth/workspace_oauth_argument.go +++ b/credentials/u2m/workspace_oauth_argument.go @@ -1,4 +1,4 @@ -package oauth +package u2m import ( "fmt" From fcb031fe40256251e5a896b96e7fbaad7ac323dd Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 30 Jan 2025 10:40:19 +0100 Subject: [PATCH 31/44] fix --- config/auth_m2m_test.go | 6 ++-- config/auth_u2m.go | 32 ++++++++++----------- config/auth_u2m_test.go | 44 ++++++++++++++-------------- config/config.go | 13 +++++---- config/config_test.go | 6 ++-- config/in_memory_test.go | 2 +- credentials/u2m/cache/file.go | 14 +-------- credentials/u2m/cache/lock.go | 46 ------------------------------ credentials/u2m/client.go | 7 +++-- credentials/u2m/client_test.go | 2 +- credentials/u2m/persistent_auth.go | 2 +- go.mod | 1 - go.sum | 5 ---- 13 files changed, 60 insertions(+), 120 deletions(-) delete mode 100644 credentials/u2m/cache/lock.go diff --git a/config/auth_m2m_test.go b/config/auth_m2m_test.go index 7181be436..1d77a126a 100644 --- a/config/auth_m2m_test.go +++ b/config/auth_m2m_test.go @@ -4,7 +4,7 @@ import ( "net/url" "testing" - "github.com/databricks/databricks-sdk-go/credentials/oauth" + "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -17,7 +17,7 @@ func TestM2mHappyFlow(t *testing.T) { ClientSecret: "c", HTTPTransport: fixtures.MappingTransport{ "GET /oidc/.well-known/oauth-authorization-server": { - Response: oauth.OAuthAuthorizationServer{ + Response: u2m.OAuthAuthorizationServer{ AuthorizationEndpoint: "https://localhost:1234/dummy/auth", TokenEndpoint: "https://localhost:1234/dummy/token", }, @@ -81,5 +81,5 @@ func TestM2mNotSupported(t *testing.T) { }, }, }) - require.ErrorIs(t, err, oauth.ErrOAuthNotSupported) + require.ErrorIs(t, err, u2m.ErrOAuthNotSupported) } diff --git a/config/auth_u2m.go b/config/auth_u2m.go index d4ac17fd1..91dc36ca6 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -8,8 +8,8 @@ import ( "strings" "github.com/databricks/databricks-sdk-go/config/credentials" - "github.com/databricks/databricks-sdk-go/credentials/cache" - "github.com/databricks/databricks-sdk-go/credentials/oauth" + "github.com/databricks/databricks-sdk-go/credentials/u2m" + "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" "github.com/databricks/databricks-sdk-go/logger" ) @@ -27,18 +27,18 @@ import ( type u2mCredentials struct { // auth is the persistent auth object to use. If not specified, a new one will // be created, using the default cache and locker. - auth *oauth.PersistentAuth + auth *u2m.PersistentAuth // getOAuthArg is a function that returns the OAuth argument to use for // loading the OAuth session token. If not specified, the OAuth argument is // determined by the account host and account ID or workspace host in the // Config. - getOAuthArg func(context.Context, *Config) (oauth.OAuthArgument, error) + getOAuthArg func(context.Context, *Config) (u2m.OAuthArgument, error) // errorHandler controls the behavior of Configure() when loading the OAuth // token fails. If not specified, any error will cause Configure() to return // said error. - errorHandler func(context.Context, *Config, oauth.OAuthArgument, error) error + errorHandler func(context.Context, *Config, u2m.OAuthArgument, error) error name string } @@ -58,14 +58,14 @@ func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials } if u.auth == nil { var err error - u.auth, err = oauth.NewPersistentAuth(ctx) + u.auth, err = u2m.NewPersistentAuth(ctx) if err != nil { logger.Debugf(ctx, "failed to create persistent auth: %v, continuing", err) return nil, nil } } - var arg oauth.OAuthArgument + var arg u2m.OAuthArgument var err error if u.getOAuthArg != nil { arg, err = u.getOAuthArg(ctx, cfg) @@ -95,7 +95,7 @@ func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials return credentials.NewCredentialsProvider(visitor), nil } -func (u u2mCredentials) makeVisitor(arg oauth.OAuthArgument) func(*http.Request) error { +func (u u2mCredentials) makeVisitor(arg u2m.OAuthArgument) func(*http.Request) error { return func(r *http.Request) error { token, err := u.auth.Load(r.Context(), arg) if err != nil { @@ -106,11 +106,11 @@ func (u u2mCredentials) makeVisitor(arg oauth.OAuthArgument) func(*http.Request) } } -func defaultGetOAuthArg(_ context.Context, cfg *Config) (oauth.OAuthArgument, error) { +func defaultGetOAuthArg(_ context.Context, cfg *Config) (u2m.OAuthArgument, error) { if cfg.IsAccountClient() { - return oauth.NewBasicAccountOAuthArgument(cfg.Host, cfg.AccountID) + return u2m.NewBasicAccountOAuthArgument(cfg.Host, cfg.AccountID) } - return oauth.NewBasicWorkspaceOAuthArgument(cfg.Host) + return u2m.NewBasicWorkspaceOAuthArgument(cfg.Host) } var _ CredentialsStrategy = u2mCredentials{} @@ -136,7 +136,7 @@ func (e *CliInvalidRefreshTokenError) Unwrap() error { // buildLoginCommand returns the `databricks auth login` command that the user // can run to reauthenticate. The command is prepopulated with the profile, host // and/or account ID. -func buildLoginCommand(profile string, arg oauth.OAuthArgument) string { +func buildLoginCommand(profile string, arg u2m.OAuthArgument) string { cmd := []string{ "databricks", "auth", @@ -146,9 +146,9 @@ func buildLoginCommand(profile string, arg oauth.OAuthArgument) string { cmd = append(cmd, "--profile", profile) } else { switch arg := arg.(type) { - case oauth.AccountOAuthArgument: + case u2m.AccountOAuthArgument: cmd = append(cmd, "--host", arg.GetAccountHost(), "--account-id", arg.GetAccountId()) - case oauth.WorkspaceOAuthArgument: + case u2m.WorkspaceOAuthArgument: cmd = append(cmd, "--host", arg.GetWorkspaceHost()) } } @@ -159,7 +159,7 @@ func buildLoginCommand(profile string, arg oauth.OAuthArgument) string { // of the earlier `databricks-cli` credentials strategy which invoked the // `databricks auth token` command. var DatabricksCliCredentials = u2mCredentials{ - errorHandler: func(ctx context.Context, cfg *Config, arg oauth.OAuthArgument, err error) error { + errorHandler: func(ctx context.Context, cfg *Config, arg u2m.OAuthArgument, err error) error { // If the current OAuth argument doesn't have a corresponding session // token, fall back to the next credentials strategy. if errors.Is(err, cache.ErrNotConfigured) { @@ -169,7 +169,7 @@ var DatabricksCliCredentials = u2mCredentials{ // return a special error message for invalid refresh tokens. To help // users easily reauthenticate, include a command that the user can // run, prepopulating the profile, host and/or account ID. - target := &oauth.InvalidRefreshTokenError{} + target := &u2m.InvalidRefreshTokenError{} if errors.As(err, &target) { return &CliInvalidRefreshTokenError{ loginCommand: buildLoginCommand(cfg.Profile, arg), diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index 3d9fe964a..1ead12e0a 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -8,8 +8,8 @@ import ( "testing" "time" - "github.com/databricks/databricks-sdk-go/credentials/cache" - "github.com/databricks/databricks-sdk-go/credentials/oauth" + "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" + "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -17,8 +17,8 @@ import ( type MockOAuthClient struct { Transport http.RoundTripper - GetAccountOAuthEndpointsFn func(ctx context.Context, accountHost string, accountId string) (*oauth.OAuthAuthorizationServer, error) - GetWorkspaceOAuthEndpointsFn func(ctx context.Context, workspaceHost string) (*oauth.OAuthAuthorizationServer, error) + GetAccountOAuthEndpointsFn func(ctx context.Context, accountHost string, accountId string) (*u2m.OAuthAuthorizationServer, error) + GetWorkspaceOAuthEndpointsFn func(ctx context.Context, workspaceHost string) (*u2m.OAuthAuthorizationServer, error) } func (m MockOAuthClient) GetHttpClient(_ context.Context) *http.Client { @@ -27,11 +27,11 @@ func (m MockOAuthClient) GetHttpClient(_ context.Context) *http.Client { } } -func (m MockOAuthClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*oauth.OAuthAuthorizationServer, error) { +func (m MockOAuthClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*u2m.OAuthAuthorizationServer, error) { return m.GetAccountOAuthEndpointsFn(ctx, accountHost, accountId) } -func (m MockOAuthClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*oauth.OAuthAuthorizationServer, error) { +func (m MockOAuthClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*u2m.OAuthAuthorizationServer, error) { return m.GetWorkspaceOAuthEndpointsFn(ctx, workspaceHost) } @@ -39,7 +39,7 @@ func TestU2MCredentials(t *testing.T) { tests := []struct { name string cfg *Config - auth func() (*oauth.PersistentAuth, error) + auth func() (*u2m.PersistentAuth, error) expectErr string expectAuth string }{ @@ -48,10 +48,10 @@ func TestU2MCredentials(t *testing.T) { cfg: &Config{ Host: "https://myworkspace.cloud.databricks.com", }, - auth: func() (*oauth.PersistentAuth, error) { - return oauth.NewPersistentAuth( + auth: func() (*u2m.PersistentAuth, error) { + return u2m.NewPersistentAuth( context.Background(), - oauth.WithTokenCache(&InMemoryTokenCache{ + u2m.WithTokenCache(&InMemoryTokenCache{ Tokens: map[string]*oauth2.Token{ "https://myworkspace.cloud.databricks.com": { AccessToken: "dummy_access_token", @@ -67,10 +67,10 @@ func TestU2MCredentials(t *testing.T) { cfg: &Config{ Host: "https://myworkspace.cloud.databricks.com", }, - auth: func() (*oauth.PersistentAuth, error) { - return oauth.NewPersistentAuth( + auth: func() (*u2m.PersistentAuth, error) { + return u2m.NewPersistentAuth( context.Background(), - oauth.WithTokenCache(&InMemoryTokenCache{ + u2m.WithTokenCache(&InMemoryTokenCache{ Tokens: map[string]*oauth2.Token{ "https://myworkspace.cloud.databricks.com": { AccessToken: "dummy_access_token", @@ -79,7 +79,7 @@ func TestU2MCredentials(t *testing.T) { }, }, }), - oauth.WithOAuthClient(MockOAuthClient{ + u2m.WithOAuthClient(MockOAuthClient{ Transport: fixtures.SliceTransport{ { Method: "POST", @@ -88,8 +88,8 @@ func TestU2MCredentials(t *testing.T) { Response: `{"error":"invalid_refresh_token","error_description":"Refresh token is invalid"}`, }, }, - GetWorkspaceOAuthEndpointsFn: func(ctx context.Context, workspaceHost string) (*oauth.OAuthAuthorizationServer, error) { - return &oauth.OAuthAuthorizationServer{ + GetWorkspaceOAuthEndpointsFn: func(ctx context.Context, workspaceHost string) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ TokenEndpoint: "https://myworkspace.cloud.databricks.com/oidc/v1/token", AuthorizationEndpoint: "https://myworkspace.cloud.databricks.com/oidc/v1/authorize", }, nil @@ -127,17 +127,17 @@ func TestU2MCredentials(t *testing.T) { } func TestDatabricksCli_ErrorHandler(t *testing.T) { - invalidRefreshTokenError := fmt.Errorf("refresh: %w", &oauth.InvalidRefreshTokenError{}) - workspaceArg := func() (oauth.OAuthArgument, error) { - return oauth.NewBasicWorkspaceOAuthArgument("https://myworkspace.cloud.databricks.com") + invalidRefreshTokenError := fmt.Errorf("refresh: %w", &u2m.InvalidRefreshTokenError{}) + workspaceArg := func() (u2m.OAuthArgument, error) { + return u2m.NewBasicWorkspaceOAuthArgument("https://myworkspace.cloud.databricks.com") } - accountArg := func() (oauth.OAuthArgument, error) { - return oauth.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "abc") + accountArg := func() (u2m.OAuthArgument, error) { + return u2m.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "abc") } testCases := []struct { name string cfg *Config - arg func() (oauth.OAuthArgument, error) + arg func() (u2m.OAuthArgument, error) err error want error }{ diff --git a/config/config.go b/config/config.go index 8469a957e..f2faa8bc1 100644 --- a/config/config.go +++ b/config/config.go @@ -14,7 +14,7 @@ import ( "github.com/databricks/databricks-sdk-go/common" "github.com/databricks/databricks-sdk-go/common/environment" "github.com/databricks/databricks-sdk-go/config/credentials" - "github.com/databricks/databricks-sdk-go/credentials/oauth" + "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" "golang.org/x/oauth2" @@ -204,7 +204,7 @@ func (c *Config) NewWithWorkspaceHost(host string) (*Config, error) { // vice-versa. // // In the future, when unified login is widely available, we may be able to - // reuse the authentication visitor specifically for in-house OAuth. + // reuse the authentication visitor specifically for in-house u2m. return res, nil } @@ -437,10 +437,13 @@ func (c *Config) refreshTokenErrorMapper(ctx context.Context, resp common.Respon } // getOidcEndpoints returns the OAuth endpoints for the current configuration. -func (c *Config) getOidcEndpoints(ctx context.Context) (*oauth.OAuthAuthorizationServer, error) { +func (c *Config) getOidcEndpoints(ctx context.Context) (*u2m.OAuthAuthorizationServer, error) { c.EnsureResolved() + oauthClient := &u2m.BasicOAuthClient{ + Client: c.refreshClient, + } if c.IsAccountClient() { - return oauth.GetAccountOAuthEndpoints(ctx, c.Host, c.AccountID) + return oauthClient.GetAccountOAuthEndpoints(ctx, c.Host, c.AccountID) } - return oauth.GetWorkspaceOAuthEndpoints(ctx, c.refreshClient, c.Host) + return oauthClient.GetWorkspaceOAuthEndpoints(ctx, c.Host) } diff --git a/config/config_test.go b/config/config_test.go index 9cebb1a15..5a5ac7041 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -5,7 +5,7 @@ import ( "net/http" "testing" - "github.com/databricks/databricks-sdk-go/credentials/oauth" + "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -76,7 +76,7 @@ func TestConfig_getOidcEndpoints_account(t *testing.T) { } got, err := c.getOidcEndpoints(context.Background()) assert.NoError(t, err) - assert.Equal(t, &oauth.OAuthAuthorizationServer{ + assert.Equal(t, &u2m.OAuthAuthorizationServer{ AuthorizationEndpoint: "https://accounts.cloud.databricks.com/oidc/accounts/abc/v1/authorize", TokenEndpoint: "https://accounts.cloud.databricks.com/oidc/accounts/abc/v1/token", }, got) @@ -96,7 +96,7 @@ func TestConfig_getOidcEndpoints_workspace(t *testing.T) { } got, err := c.getOidcEndpoints(context.Background()) assert.NoError(t, err) - assert.Equal(t, &oauth.OAuthAuthorizationServer{ + assert.Equal(t, &u2m.OAuthAuthorizationServer{ AuthorizationEndpoint: "https://myworkspace.cloud.databricks.com/oidc/v1/authorize", TokenEndpoint: "https://myworkspace.cloud.databricks.com/oidc/v1/token", }, got) diff --git a/config/in_memory_test.go b/config/in_memory_test.go index 82ce6e2c7..6853dfd73 100644 --- a/config/in_memory_test.go +++ b/config/in_memory_test.go @@ -1,7 +1,7 @@ package config import ( - "github.com/databricks/databricks-sdk-go/credentials/cache" + "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" "golang.org/x/oauth2" ) diff --git a/credentials/u2m/cache/file.go b/credentials/u2m/cache/file.go index cf585aaac..9dc61e3a4 100644 --- a/credentials/u2m/cache/file.go +++ b/credentials/u2m/cache/file.go @@ -37,10 +37,6 @@ const ( // } // } tokenCacheVersion = 1 - - // lockFilePath is the path of the lock file used to prevent concurrent - // reads and writes to the token cache file. - lockFilePath = ".databricks/token-cache.lock" ) // The format of the token cache file. @@ -156,15 +152,7 @@ func (c *fileTokenCache) init() error { } } // Initialize the locker. - home, err := os.UserHomeDir() - if err != nil { - return fmt.Errorf("home: %w", err) - } - - c.locker, err = newLocker(filepath.Join(home, lockFilePath)) - if err != nil { - return fmt.Errorf("locker: %w", err) - } + c.locker = &sync.Mutex{} return nil } diff --git a/credentials/u2m/cache/lock.go b/credentials/u2m/cache/lock.go deleted file mode 100644 index dcdf35f9b..000000000 --- a/credentials/u2m/cache/lock.go +++ /dev/null @@ -1,46 +0,0 @@ -package cache - -import ( - "fmt" - "os" - "path/filepath" - "sync" - - "github.com/alexflint/go-filemutex" -) - -// Adapts a filemutex.FileMutex to sync.Locker. -type lockerAdaptor struct { - fileMutex *filemutex.FileMutex -} - -var _ sync.Locker = (*lockerAdaptor)(nil) - -// Lock implements sync.Locker. -func (l *lockerAdaptor) Lock() { - err := l.fileMutex.Lock() - if err != nil { - panic(fmt.Errorf("lock: %w", err)) - } -} - -// Unlock implements sync.Locker. -func (l *lockerAdaptor) Unlock() { - err := l.fileMutex.Unlock() - if err != nil { - panic(fmt.Errorf("unlock: %w", err)) - } -} - -// newLocker creates a new sync.Locker that uses a file-based mutex. -func newLocker(path string) (*lockerAdaptor, error) { - dirName := filepath.Dir(path) - if _, err := os.Stat(dirName); err != nil && os.IsNotExist(err) { - os.MkdirAll(dirName, 0750) - } - m, err := filemutex.New(path) - if err != nil { - return nil, err - } - return &lockerAdaptor{fileMutex: m}, nil -} diff --git a/credentials/u2m/client.go b/credentials/u2m/client.go index 099360a75..812a13132 100644 --- a/credentials/u2m/client.go +++ b/credentials/u2m/client.go @@ -24,11 +24,12 @@ type OAuthClient interface { // BasicOAuthClient is an implementation of the OAuthClient interface. type BasicOAuthClient struct { - client *httpclient.ApiClient + // Client is the ApiClient to use for making HTTP requests. + Client *httpclient.ApiClient } func (c *BasicOAuthClient) GetHttpClient(_ context.Context) *http.Client { - return c.client.ToHttpClient() + return c.Client.ToHttpClient() } // GetWorkspaceOAuthEndpoints returns the OAuth endpoints for the given workspace. @@ -37,7 +38,7 @@ func (c *BasicOAuthClient) GetHttpClient(_ context.Context) *http.Client { func (c *BasicOAuthClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) { oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", workspaceHost) var oauthEndpoints OAuthAuthorizationServer - if err := c.client.Do(ctx, "GET", oidc, httpclient.WithResponseUnmarshal(&oauthEndpoints)); err != nil { + if err := c.Client.Do(ctx, "GET", oidc, httpclient.WithResponseUnmarshal(&oauthEndpoints)); err != nil { return nil, ErrOAuthNotSupported } return &oauthEndpoints, nil diff --git a/credentials/u2m/client_test.go b/credentials/u2m/client_test.go index b9316b160..5c656d68e 100644 --- a/credentials/u2m/client_test.go +++ b/credentials/u2m/client_test.go @@ -29,7 +29,7 @@ func TestGetWorkspaceOAuthEndpoints(t *testing.T) { }, }, }) - c := &BasicOAuthClient{client: p} + c := &BasicOAuthClient{Client: p} endpoints, err := c.GetWorkspaceOAuthEndpoints(context.Background(), "https://abc") assert.NoError(t, err) assert.Equal(t, "a", endpoints.AuthorizationEndpoint) diff --git a/credentials/u2m/persistent_auth.go b/credentials/u2m/persistent_auth.go index 9b0aa73a1..10d240b06 100644 --- a/credentials/u2m/persistent_auth.go +++ b/credentials/u2m/persistent_auth.go @@ -81,7 +81,7 @@ func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*Pers } if p.client == nil { p.client = &BasicOAuthClient{ - client: httpclient.NewApiClient(httpclient.ClientConfig{}), + Client: httpclient.NewApiClient(httpclient.ClientConfig{}), } } if p.cache == nil { diff --git a/go.mod b/go.mod index 15499f9be..9dabdc5ba 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/databricks/databricks-sdk-go go 1.18 require ( - github.com/alexflint/go-filemutex v1.3.0 github.com/google/go-cmp v0.6.0 github.com/google/go-querystring v1.1.0 github.com/google/uuid v1.6.0 diff --git a/go.sum b/go.sum index d16bc8f26..44b258fcc 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,6 @@ cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRk cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/alexflint/go-filemutex v1.3.0 h1:LgE+nTUWnQCyRKbpoceKZsPQbs84LivvgwUymZXdOcM= -github.com/alexflint/go-filemutex v1.3.0/go.mod h1:U0+VA/i30mGBlLCrFPGtTe9y6wGQfNAWPBTekHQ+c8A= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= @@ -70,7 +68,6 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= @@ -119,7 +116,6 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -165,7 +161,6 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 83e4141ea595f28ecab27f974a81fff8b330ad6f Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 30 Jan 2025 10:57:21 +0100 Subject: [PATCH 32/44] more tweaks --- config/auth_u2m.go | 26 ++++---- credentials/u2m/callback.go | 6 +- credentials/u2m/persistent_auth.go | 79 ++++++++++++++----------- credentials/u2m/persistent_auth_test.go | 34 ++++++----- 4 files changed, 81 insertions(+), 64 deletions(-) diff --git a/config/auth_u2m.go b/config/auth_u2m.go index 91dc36ca6..db07ef719 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -11,6 +11,7 @@ import ( "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" "github.com/databricks/databricks-sdk-go/logger" + "golang.org/x/oauth2" ) // u2mCredentials is a credentials strategy that uses the U2M OAuth flow to @@ -56,14 +57,6 @@ func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials if cfg.Host == "" { return nil, nil } - if u.auth == nil { - var err error - u.auth, err = u2m.NewPersistentAuth(ctx) - if err != nil { - logger.Debugf(ctx, "failed to create persistent auth: %v, continuing", err) - return nil, nil - } - } var arg u2m.OAuthArgument var err error @@ -76,11 +69,20 @@ func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials return nil, fmt.Errorf("oidc: %w", err) } + if u.auth == nil { + var err error + u.auth, err = u2m.NewPersistentAuth(ctx, u2m.WithOAuthArgument(arg)) + if err != nil { + logger.Debugf(ctx, "failed to create persistent auth: %v, continuing", err) + return nil, nil + } + } + // Construct the visitor, and try to load the credential from the token // cache. If absent, fall back to the next credentials strategy. If a token // is present but cannot be loaded (e.g. expired), return an error. // Otherwise, fall back to the next credentials strategy. - visitor := u.makeVisitor(arg) + visitor := u.makeVisitor() r, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil) if err != nil { return nil, fmt.Errorf("http request: %w", err) @@ -92,12 +94,12 @@ func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials return nil, err } - return credentials.NewCredentialsProvider(visitor), nil + return credentials.NewOAuthCredentialsProvider(visitor, func() (*oauth2.Token, error) { return u.auth.Token() }), nil } -func (u u2mCredentials) makeVisitor(arg u2m.OAuthArgument) func(*http.Request) error { +func (u u2mCredentials) makeVisitor() func(*http.Request) error { return func(r *http.Request) error { - token, err := u.auth.Load(r.Context(), arg) + token, err := u.auth.Token() if err != nil { return fmt.Errorf("oidc: %w", err) } diff --git a/credentials/u2m/callback.go b/credentials/u2m/callback.go index c6022300c..aab576a04 100644 --- a/credentials/u2m/callback.go +++ b/credentials/u2m/callback.go @@ -56,7 +56,7 @@ type callbackServer struct { // newCallbackServer creates a new callback server that listens for the redirect // from the Databricks identity provider. -func (a *PersistentAuth) newCallbackServer(ctx context.Context, arg OAuthArgument) (*callbackServer, error) { +func (a *PersistentAuth) newCallbackServer() (*callbackServer, error) { tmpl, err := template.New("page").Funcs(template.FuncMap{ "title": func(in string) string { title := cases.Title(language.English) @@ -70,9 +70,9 @@ func (a *PersistentAuth) newCallbackServer(ctx context.Context, arg OAuthArgumen feedbackCh: make(chan oauthResult), renderErrCh: make(chan error), tmpl: tmpl, - ctx: ctx, + ctx: a.ctx, browser: a.browser, - arg: arg, + arg: a.oAuthArgument, } cb.srv.Handler = cb go func() { diff --git a/credentials/u2m/persistent_auth.go b/credentials/u2m/persistent_auth.go index 10d240b06..89834b4cf 100644 --- a/credentials/u2m/persistent_auth.go +++ b/credentials/u2m/persistent_auth.go @@ -40,14 +40,20 @@ const ( // The PersistentAuth is safe for concurrent use. The token cache is locked // during token retrieval, refresh and storage. type PersistentAuth struct { - // Cache is the token cache to store and lookup tokens. + // cache is the token cache to store and lookup tokens. cache cache.TokenCache - // Client is the HTTP client to use for OAuth2 requests. + // client is the HTTP client to use for OAuth2 requests. client OAuthClient - // Browser is the function to open a URL in the default browser. + // oAuthArgument defines the workspace or account to authenticate to and the + // cache key for the token. + oAuthArgument OAuthArgument + // browser is the function to open a URL in the default browser. browser func(url string) error // ln is the listener for the OAuth2 callback server. ln net.Listener + // ctx is the context to use for underlying operations. This is needed in + // order to implement the oauth2.TokenSource interface. + ctx context.Context } type PersistentAuthOption func(*PersistentAuth) @@ -66,6 +72,12 @@ func WithOAuthClient(c OAuthClient) PersistentAuthOption { } } +func WithOAuthArgument(arg OAuthArgument) PersistentAuthOption { + return func(a *PersistentAuth) { + a.oAuthArgument = arg + } +} + // WithBrowser sets the browser function for the PersistentAuth. func WithBrowser(b func(url string) error) PersistentAuthOption { return func(a *PersistentAuth) { @@ -91,32 +103,36 @@ func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*Pers return nil, fmt.Errorf("cache: %w", err) } } + if p.oAuthArgument == nil { + return nil, errors.New("missing OAuthArgument") + } + if err := p.validateArg(); err != nil { + return nil, err + } if p.browser == nil { p.browser = browser.OpenURL } + p.ctx = ctx return p, nil } -// Load loads the OAuth2 token for the given OAuthArgument from the cache. If +// Token loads the OAuth2 token for the given OAuthArgument from the cache. If // the token is expired, it is refreshed using the refresh token. -func (a *PersistentAuth) Load(ctx context.Context, arg OAuthArgument) (t *oauth2.Token, err error) { - if err := a.validateArg(arg); err != nil { - return nil, err - } - err = a.startListener(ctx) +func (a *PersistentAuth) Token() (t *oauth2.Token, err error) { + err = a.startListener(a.ctx) if err != nil { return nil, fmt.Errorf("starting listener: %w", err) } defer a.Close() - key := arg.GetCacheKey() + key := a.oAuthArgument.GetCacheKey() t, err = a.cache.Lookup(key) if err != nil { return nil, fmt.Errorf("cache: %w", err) } // refresh if invalid if !t.Valid() { - t, err = a.refresh(ctx, arg, t) + t, err = a.refresh(t) if err != nil { return nil, fmt.Errorf("token refresh: %w", err) } @@ -128,15 +144,15 @@ func (a *PersistentAuth) Load(ctx context.Context, arg OAuthArgument) (t *oauth2 // refresh refreshes the token for the given OAuthArgument, storing the new // token in the cache. -func (a *PersistentAuth) refresh(ctx context.Context, arg OAuthArgument, oldToken *oauth2.Token) (*oauth2.Token, error) { +func (a *PersistentAuth) refresh(oldToken *oauth2.Token) (*oauth2.Token, error) { // OAuth2 config is invoked only for expired tokens to speed up // the happy path in the token retrieval - cfg, err := a.oauth2Config(ctx, arg) + cfg, err := a.oauth2Config() if err != nil { return nil, err } // make OAuth2 library use our client - ctx = a.setOAuthContext(ctx) + ctx := a.setOAuthContext(a.ctx) // eagerly refresh token t, err := cfg.TokenSource(ctx, oldToken).Token() if err != nil { @@ -175,7 +191,7 @@ func (a *PersistentAuth) refresh(ctx context.Context, arg OAuthArgument, oldToke } return nil, err } - err = a.cache.Store(arg.GetCacheKey(), t) + err = a.cache.Store(a.oAuthArgument.GetCacheKey(), t) if err != nil { return nil, fmt.Errorf("cache update: %w", err) } @@ -188,11 +204,8 @@ func (a *PersistentAuth) refresh(ctx context.Context, arg OAuthArgument, oldToke // callback server listens for the redirect from the identity provider and // exchanges the authorization code for an access token. It returns the OAuth2 // token on success. -func (a *PersistentAuth) Challenge(ctx context.Context, arg OAuthArgument) (*oauth2.Token, error) { - if err := a.validateArg(arg); err != nil { - return nil, err - } - err := a.startListener(ctx) +func (a *PersistentAuth) Challenge() (*oauth2.Token, error) { + err := a.startListener(a.ctx) if err != nil { return nil, fmt.Errorf("starting listener: %w", err) } @@ -200,11 +213,11 @@ func (a *PersistentAuth) Challenge(ctx context.Context, arg OAuthArgument) (*oau // the callback server is not created, we need to close the listener manually. defer a.Close() - cfg, err := a.oauth2Config(ctx, arg) + cfg, err := a.oauth2Config() if err != nil { return nil, fmt.Errorf("fetching oauth config: %w", err) } - cb, err := a.newCallbackServer(ctx, arg) + cb, err := a.newCallbackServer() if err != nil { return nil, fmt.Errorf("callback server: %w", err) } @@ -215,14 +228,14 @@ func (a *PersistentAuth) Challenge(ctx context.Context, arg OAuthArgument) (*oau return nil, fmt.Errorf("state and pkce: %w", err) } // make OAuth2 library use our client - ctx = a.setOAuthContext(ctx) + ctx := a.setOAuthContext(a.ctx) ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce) t, err := ts.Token() if err != nil { return nil, fmt.Errorf("authorize: %w", err) } // cache token identified by host (and possibly the account id) - err = a.cache.Store(arg.GetCacheKey(), t) + err = a.cache.Store(a.oAuthArgument.GetCacheKey(), t) if err != nil { return nil, fmt.Errorf("store: %w", err) } @@ -258,31 +271,31 @@ func (a *PersistentAuth) Close() error { // validateArg ensures that the OAuthArgument is either a WorkspaceOAuthArgument // or an AccountOAuthArgument. -func (a *PersistentAuth) validateArg(arg OAuthArgument) error { - _, isWorkspaceArg := arg.(WorkspaceOAuthArgument) - _, isAccountArg := arg.(AccountOAuthArgument) +func (a *PersistentAuth) validateArg() error { + _, isWorkspaceArg := a.oAuthArgument.(WorkspaceOAuthArgument) + _, isAccountArg := a.oAuthArgument.(AccountOAuthArgument) if !isWorkspaceArg && !isAccountArg { - return fmt.Errorf("unsupported OAuthArgument type: %T, must implement either WorkspaceOAuthArgument or AccountOAuthArgument interface", arg) + return fmt.Errorf("unsupported OAuthArgument type: %T, must implement either WorkspaceOAuthArgument or AccountOAuthArgument interface", a.oAuthArgument) } return nil } // oauth2Config returns the OAuth2 configuration for the given OAuthArgument. -func (a *PersistentAuth) oauth2Config(ctx context.Context, arg OAuthArgument) (*oauth2.Config, error) { +func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) { scopes := []string{ "offline_access", // ensures OAuth token includes refresh token "all-apis", // ensures OAuth token has access to all control-plane APIs } var endpoints *OAuthAuthorizationServer var err error - switch argg := arg.(type) { + switch argg := a.oAuthArgument.(type) { case WorkspaceOAuthArgument: - endpoints, err = a.client.GetWorkspaceOAuthEndpoints(ctx, argg.GetWorkspaceHost()) + endpoints, err = a.client.GetWorkspaceOAuthEndpoints(a.ctx, argg.GetWorkspaceHost()) case AccountOAuthArgument: endpoints, err = a.client.GetAccountOAuthEndpoints( - ctx, argg.GetAccountHost(), argg.GetAccountId()) + a.ctx, argg.GetAccountHost(), argg.GetAccountId()) default: - return nil, fmt.Errorf("unsupported OAuthArgument type: %T, must implement either WorkspaceOAuthArgument or AccountOAuthArgument interface", arg) + return nil, fmt.Errorf("unsupported OAuthArgument type: %T, must implement either WorkspaceOAuthArgument or AccountOAuthArgument interface", a.oAuthArgument) } if err != nil { return nil, fmt.Errorf("fetching OAuth endpoints: %w", err) diff --git a/credentials/u2m/persistent_auth_test.go b/credentials/u2m/persistent_auth_test.go index 82a57c6b8..7ded254e9 100644 --- a/credentials/u2m/persistent_auth_test.go +++ b/credentials/u2m/persistent_auth_test.go @@ -44,12 +44,12 @@ func TestLoad(t *testing.T) { }, nil }, } - p, err := u2m.NewPersistentAuth(context.Background(), u2m.WithTokenCache(cache)) - require.NoError(t, err) - defer p.Close() arg, err := u2m.NewBasicAccountOAuthArgument("https://abc", "xyz") assert.NoError(t, err) - tok, err := p.Load(context.Background(), arg) + p, err := u2m.NewPersistentAuth(context.Background(), u2m.WithTokenCache(cache), u2m.WithOAuthArgument(arg)) + require.NoError(t, err) + defer p.Close() + tok, err := p.Token() assert.NoError(t, err) assert.Equal(t, "bcd", tok.AccessToken) assert.Equal(t, "", tok.RefreshToken) @@ -97,8 +97,10 @@ func TestLoadRefresh(t *testing.T) { return nil }, } + arg, err := u2m.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + assert.NoError(t, err) p, err := u2m.NewPersistentAuth( - context.Background(), + ctx, u2m.WithTokenCache(cache), u2m.WithOAuthClient(&MockOAuthClient{ Transport: fixtures.SliceTransport{ @@ -112,12 +114,11 @@ func TestLoadRefresh(t *testing.T) { }, }, }), + u2m.WithOAuthArgument(arg), ) require.NoError(t, err) defer p.Close() - arg, err := u2m.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") - assert.NoError(t, err) - tok, err := p.Load(ctx, arg) + tok, err := p.Token() assert.NoError(t, err) assert.Equal(t, "refreshed", tok.AccessToken) assert.Equal(t, "", tok.RefreshToken) @@ -146,8 +147,10 @@ func TestChallenge(t *testing.T) { return nil }, } + arg, err := u2m.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + assert.NoError(t, err) p, err := u2m.NewPersistentAuth( - context.Background(), + ctx, u2m.WithTokenCache(cache), u2m.WithBrowser(browser), u2m.WithOAuthClient(&MockOAuthClient{ @@ -162,16 +165,15 @@ func TestChallenge(t *testing.T) { }, }, }), + u2m.WithOAuthArgument(arg), ) require.NoError(t, err) defer p.Close() - arg, err := u2m.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") - assert.NoError(t, err) tokenc := make(chan *oauth2.Token) errc := make(chan error) go func() { - token, err := p.Challenge(ctx, arg) + token, err := p.Challenge() errc <- err close(errc) tokenc <- token @@ -203,16 +205,16 @@ func TestChallengeFailed(t *testing.T) { browserOpened <- query.Get("state") return nil } - p, err := u2m.NewPersistentAuth(context.Background(), u2m.WithBrowser(browser)) - require.NoError(t, err) - defer p.Close() arg, err := u2m.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") assert.NoError(t, err) + p, err := u2m.NewPersistentAuth(ctx, u2m.WithBrowser(browser), u2m.WithOAuthArgument(arg)) + require.NoError(t, err) + defer p.Close() tokenc := make(chan *oauth2.Token) errc := make(chan error) go func() { - token, err := p.Challenge(ctx, arg) + token, err := p.Challenge() errc <- err close(errc) tokenc <- token From 51a5b08c4f41030282572a1b533bc0f5fc06aae1 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 30 Jan 2025 11:01:49 +0100 Subject: [PATCH 33/44] work --- config/auth_u2m_test.go | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index 1ead12e0a..fc9934da1 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -8,8 +8,8 @@ import ( "testing" "time" - "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" "github.com/databricks/databricks-sdk-go/credentials/u2m" + "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -35,11 +35,18 @@ func (m MockOAuthClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspa return m.GetWorkspaceOAuthEndpointsFn(ctx, workspaceHost) } +func must[T any](c T, err error) T { + if err != nil { + panic(err) + } + return c +} + func TestU2MCredentials(t *testing.T) { tests := []struct { name string cfg *Config - auth func() (*u2m.PersistentAuth, error) + auth *u2m.PersistentAuth expectErr string expectAuth string }{ @@ -48,8 +55,8 @@ func TestU2MCredentials(t *testing.T) { cfg: &Config{ Host: "https://myworkspace.cloud.databricks.com", }, - auth: func() (*u2m.PersistentAuth, error) { - return u2m.NewPersistentAuth( + auth: must( + u2m.NewPersistentAuth( context.Background(), u2m.WithTokenCache(&InMemoryTokenCache{ Tokens: map[string]*oauth2.Token{ @@ -58,8 +65,10 @@ func TestU2MCredentials(t *testing.T) { Expiry: time.Now().Add(1 * time.Hour), }, }, - })) - }, + }), + u2m.WithOAuthArgument(must(u2m.NewBasicWorkspaceOAuthArgument("https://myworkspace.cloud.databricks.com"))), + ), + ), expectAuth: "Bearer dummy_access_token", }, { @@ -67,8 +76,8 @@ func TestU2MCredentials(t *testing.T) { cfg: &Config{ Host: "https://myworkspace.cloud.databricks.com", }, - auth: func() (*u2m.PersistentAuth, error) { - return u2m.NewPersistentAuth( + auth: must( + u2m.NewPersistentAuth( context.Background(), u2m.WithTokenCache(&InMemoryTokenCache{ Tokens: map[string]*oauth2.Token{ @@ -95,8 +104,9 @@ func TestU2MCredentials(t *testing.T) { }, nil }, }), - ) - }, + u2m.WithOAuthArgument(must(u2m.NewBasicWorkspaceOAuthArgument("https://myworkspace.cloud.databricks.com"))), + ), + ), expectErr: "oidc: token refresh: oauth2: \"invalid_refresh_token\" \"Refresh token is invalid\"", }, } @@ -104,10 +114,8 @@ func TestU2MCredentials(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() - auth, err := tt.auth() - require.NoError(t, err) strat := u2mCredentials{ - auth: auth, + auth: tt.auth, } provider, err := strat.Configure(ctx, tt.cfg) if tt.expectErr != "" { From 62acd68dddfa19e8f7b2597b07480c44f85407f3 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 30 Jan 2025 12:07:19 +0100 Subject: [PATCH 34/44] simpler --- config/auth_u2m.go | 115 +++++++++++++++------------------------------ config/config.go | 7 +++ 2 files changed, 44 insertions(+), 78 deletions(-) diff --git a/config/auth_u2m.go b/config/auth_u2m.go index db07ef719..c994be66c 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -15,41 +15,21 @@ import ( ) // u2mCredentials is a credentials strategy that uses the U2M OAuth flow to -// authenticate with Databricks. -// -// To authenticate with U2M OAuth, the user must already have an existing OAuth -// session. The specific OAuth session is indicated by the OAuth argument -// provided by GetOAuthArg. By default, the OAuth argument is determined by the -// account host and account ID or workspace host in the Config. -// -// Error handling for this strategy is controlled by the ErrorHandler field. If -// ErrorHandler is not specified, any error will cause Configure() to return said -// error. +// authenticate with Databricks. It loads a token from the token cache for the +// given workspace or account, refreshing it if needed. If the user has not +// authenticated with OAuth U2M, it falls back to the next credentials strategy. +// If they have but their access and refresh tokens are both invalid, it returns +// a special error message that instructs the user how to reauthenticate. type u2mCredentials struct { - // auth is the persistent auth object to use. If not specified, a new one will - // be created, using the default cache and locker. + // auth is the persistent auth object that manages the token cache. auth *u2m.PersistentAuth - - // getOAuthArg is a function that returns the OAuth argument to use for - // loading the OAuth session token. If not specified, the OAuth argument is - // determined by the account host and account ID or workspace host in the - // Config. - getOAuthArg func(context.Context, *Config) (u2m.OAuthArgument, error) - - // errorHandler controls the behavior of Configure() when loading the OAuth - // token fails. If not specified, any error will cause Configure() to return - // said error. - errorHandler func(context.Context, *Config, u2m.OAuthArgument, error) error - - name string } // Name implements CredentialsStrategy. func (u u2mCredentials) Name() string { - if u.name != "" { - return u.name - } - return "oauth-u2m" + // When we support allowing users to configure a custom U2M strategy, we + // should use a different name here. + return "databricks-cli" } // Configure implements CredentialsStrategy. @@ -58,48 +38,39 @@ func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials return nil, nil } - var arg u2m.OAuthArgument - var err error - if u.getOAuthArg != nil { - arg, err = u.getOAuthArg(ctx, cfg) - } else { - arg, err = defaultGetOAuthArg(ctx, cfg) - } + arg, err := cfg.getOAuthArgument() if err != nil { return nil, fmt.Errorf("oidc: %w", err) } if u.auth == nil { - var err error - u.auth, err = u2m.NewPersistentAuth(ctx, u2m.WithOAuthArgument(arg)) + auth, err := u2m.NewPersistentAuth(ctx, u2m.WithOAuthArgument(arg)) if err != nil { logger.Debugf(ctx, "failed to create persistent auth: %v, continuing", err) return nil, nil } + u.auth = auth } // Construct the visitor, and try to load the credential from the token // cache. If absent, fall back to the next credentials strategy. If a token // is present but cannot be loaded (e.g. expired), return an error. // Otherwise, fall back to the next credentials strategy. - visitor := u.makeVisitor() + visitor := u.makeVisitor(u.auth) r, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil) if err != nil { return nil, fmt.Errorf("http request: %w", err) } if err := visitor(r); err != nil { - if u.errorHandler != nil { - return nil, u.errorHandler(ctx, cfg, arg, err) - } - return nil, err + return nil, u.errorHandler(ctx, cfg, arg, err) } - return credentials.NewOAuthCredentialsProvider(visitor, func() (*oauth2.Token, error) { return u.auth.Token() }), nil + return credentials.NewOAuthCredentialsProvider(visitor, u.auth.Token), nil } -func (u u2mCredentials) makeVisitor() func(*http.Request) error { +func (u u2mCredentials) makeVisitor(auth oauth2.TokenSource) func(*http.Request) error { return func(r *http.Request) error { - token, err := u.auth.Token() + token, err := auth.Token() if err != nil { return fmt.Errorf("oidc: %w", err) } @@ -108,11 +79,26 @@ func (u u2mCredentials) makeVisitor() func(*http.Request) error { } } -func defaultGetOAuthArg(_ context.Context, cfg *Config) (u2m.OAuthArgument, error) { - if cfg.IsAccountClient() { - return u2m.NewBasicAccountOAuthArgument(cfg.Host, cfg.AccountID) +func (u u2mCredentials) errorHandler(ctx context.Context, cfg *Config, arg u2m.OAuthArgument, err error) error { + // If the current OAuth argument doesn't have a corresponding session + // token, fall back to the next credentials strategy. + if errors.Is(err, cache.ErrNotConfigured) { + return nil + } + // If there is an existing token but the refresh token is invalid, + // return a special error message for invalid refresh tokens. To help + // users easily reauthenticate, include a command that the user can + // run, prepopulating the profile, host and/or account ID. + target := &u2m.InvalidRefreshTokenError{} + if errors.As(err, &target) { + return &CliInvalidRefreshTokenError{ + loginCommand: buildLoginCommand(cfg.Profile, arg), + err: err, + } } - return u2m.NewBasicWorkspaceOAuthArgument(cfg.Host) + // Otherwise, log the error and continue to the next credentials strategy. + logger.Debugf(ctx, "failed to load token: %v, continuing", err) + return nil } var _ CredentialsStrategy = u2mCredentials{} @@ -157,31 +143,4 @@ func buildLoginCommand(profile string, arg u2m.OAuthArgument) string { return strings.Join(cmd, " ") } -// DatabricksCliCredentials is a credentials strategy that emulates the behavior -// of the earlier `databricks-cli` credentials strategy which invoked the -// `databricks auth token` command. -var DatabricksCliCredentials = u2mCredentials{ - errorHandler: func(ctx context.Context, cfg *Config, arg u2m.OAuthArgument, err error) error { - // If the current OAuth argument doesn't have a corresponding session - // token, fall back to the next credentials strategy. - if errors.Is(err, cache.ErrNotConfigured) { - return nil - } - // If there is an existing token but the refresh token is invalid, - // return a special error message for invalid refresh tokens. To help - // users easily reauthenticate, include a command that the user can - // run, prepopulating the profile, host and/or account ID. - target := &u2m.InvalidRefreshTokenError{} - if errors.As(err, &target) { - return &CliInvalidRefreshTokenError{ - loginCommand: buildLoginCommand(cfg.Profile, arg), - err: err, - } - } - // Otherwise, log the error and continue to the next credentials strategy. - logger.Debugf(ctx, "failed to load token: %v, continuing", err) - return nil - }, - getOAuthArg: defaultGetOAuthArg, - name: "databricks-cli", -} +var DatabricksCliCredentials = u2mCredentials{} diff --git a/config/config.go b/config/config.go index f2faa8bc1..c4f9260a9 100644 --- a/config/config.go +++ b/config/config.go @@ -447,3 +447,10 @@ func (c *Config) getOidcEndpoints(ctx context.Context) (*u2m.OAuthAuthorizationS } return oauthClient.GetWorkspaceOAuthEndpoints(ctx, c.Host) } + +func (c *Config) getOAuthArgument() (u2m.OAuthArgument, error) { + if c.IsAccountClient() { + return u2m.NewBasicAccountOAuthArgument(c.Host, c.AccountID) + } + return u2m.NewBasicWorkspaceOAuthArgument(c.Host) +} From f506510f59dea7aad47f9a37072dcc35f9d5e239 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 30 Jan 2025 12:11:24 +0100 Subject: [PATCH 35/44] tweaks --- config/config.go | 2 +- credentials/u2m/cache/file.go | 2 +- credentials/u2m/persistent_auth.go | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/config/config.go b/config/config.go index c4f9260a9..1fd41f864 100644 --- a/config/config.go +++ b/config/config.go @@ -204,7 +204,7 @@ func (c *Config) NewWithWorkspaceHost(host string) (*Config, error) { // vice-versa. // // In the future, when unified login is widely available, we may be able to - // reuse the authentication visitor specifically for in-house u2m. + // reuse the authentication visitor specifically for in-house OAuth. return res, nil } diff --git a/credentials/u2m/cache/file.go b/credentials/u2m/cache/file.go index 9dc61e3a4..678bb6ec4 100644 --- a/credentials/u2m/cache/file.go +++ b/credentials/u2m/cache/file.go @@ -39,7 +39,7 @@ const ( tokenCacheVersion = 1 ) -// The format of the token cache file. +// tokenCacheFile is the format of the token cache file. type tokenCacheFile struct { Version int `json:"version"` Tokens map[string]*oauth2.Token `json:"tokens"` diff --git a/credentials/u2m/persistent_auth.go b/credentials/u2m/persistent_auth.go index 89834b4cf..ad7651343 100644 --- a/credentials/u2m/persistent_auth.go +++ b/credentials/u2m/persistent_auth.go @@ -343,3 +343,5 @@ func (a *PersistentAuth) randomString(size int) (string, error) { func (a *PersistentAuth) setOAuthContext(ctx context.Context) context.Context { return context.WithValue(ctx, oauth2.HTTPClient, a.client.GetHttpClient(ctx)) } + +var _ oauth2.TokenSource = (*PersistentAuth)(nil) From d70dac45c7c8e38175f4950669b70886918543a5 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 30 Jan 2025 12:14:32 +0100 Subject: [PATCH 36/44] fix test --- config/auth_u2m_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index fc9934da1..b2be8bae7 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -107,7 +107,9 @@ func TestU2MCredentials(t *testing.T) { u2m.WithOAuthArgument(must(u2m.NewBasicWorkspaceOAuthArgument("https://myworkspace.cloud.databricks.com"))), ), ), - expectErr: "oidc: token refresh: oauth2: \"invalid_refresh_token\" \"Refresh token is invalid\"", + expectErr: `a new access token could not be retrieved because the refresh token is invalid. If using the CLI, run the following command to reauthenticate: + + $ databricks auth login --host https://myworkspace.cloud.databricks.com`, }, } From 2f0ebbb9153a2c180fe9f5c034e61639f398ce72 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 14 Feb 2025 16:58:55 +0100 Subject: [PATCH 37/44] work --- credentials/u2m/cache/cache.go | 1 + credentials/u2m/cache/file.go | 10 ++++++---- credentials/u2m/error.go | 12 +----------- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/credentials/u2m/cache/cache.go b/credentials/u2m/cache/cache.go index c059fa546..6e4137211 100644 --- a/credentials/u2m/cache/cache.go +++ b/credentials/u2m/cache/cache.go @@ -21,6 +21,7 @@ import ( // TokenCache is an interface for storing and looking up OAuth tokens. type TokenCache interface { // Store stores the token with the given key, replacing any existing token. + // If t is nil, it deletes the token. Store(key string, t *oauth2.Token) error // Lookup looks up the token with the given key. If the token is not found, it diff --git a/credentials/u2m/cache/file.go b/credentials/u2m/cache/file.go index 678bb6ec4..5652f4833 100644 --- a/credentials/u2m/cache/file.go +++ b/credentials/u2m/cache/file.go @@ -59,7 +59,7 @@ type fileTokenCache struct { fileLocation string // locker protects the token cache file from concurrent reads and writes. - locker sync.Locker + locker sync.Mutex } // NewFileTokenCache creates a new FileTokenCache. By default, the cache is @@ -94,7 +94,11 @@ func (c *fileTokenCache) Store(key string, t *oauth2.Token) error { if f.Tokens == nil { f.Tokens = map[string]*oauth2.Token{} } - f.Tokens[key] = t + if t == nil { + delete(f.Tokens, key) + } else { + f.Tokens[key] = t + } raw, err := json.MarshalIndent(f, "", " ") if err != nil { return fmt.Errorf("marshal: %w", err) @@ -151,8 +155,6 @@ func (c *fileTokenCache) init() error { return fmt.Errorf("write: %w", err) } } - // Initialize the locker. - c.locker = &sync.Mutex{} return nil } diff --git a/credentials/u2m/error.go b/credentials/u2m/error.go index 6c243fc78..be953b0c1 100644 --- a/credentials/u2m/error.go +++ b/credentials/u2m/error.go @@ -4,15 +4,5 @@ package u2m // if the access token has expired and the refresh token in the token cache // is invalid. type InvalidRefreshTokenError struct { - err error + error } - -func (e *InvalidRefreshTokenError) Error() string { - return e.err.Error() -} - -func (e *InvalidRefreshTokenError) Unwrap() error { - return e.err -} - -var _ error = &InvalidRefreshTokenError{} From 6ff1f3d5690fbee5cf7a5afa93c1464ea1f367b7 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 17 Feb 2025 16:53:11 +0100 Subject: [PATCH 38/44] fmt --- config/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config.go b/config/config.go index 68342de27..b23c4a8ef 100644 --- a/config/config.go +++ b/config/config.go @@ -14,9 +14,9 @@ import ( "github.com/databricks/databricks-sdk-go/common" "github.com/databricks/databricks-sdk-go/common/environment" "github.com/databricks/databricks-sdk-go/config/credentials" - "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/config/experimental/auth" "github.com/databricks/databricks-sdk-go/config/experimental/auth/authconv" + "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" "golang.org/x/oauth2" From f08c2b933ac42d284f2711d906886394e8a2e923 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 17 Feb 2025 17:18:21 +0100 Subject: [PATCH 39/44] work --- config/auth_u2m.go | 12 ++-- config/auth_u2m_test.go | 4 +- config/config.go | 2 +- .../u2m/{client.go => endpoint_supplier.go} | 20 ++----- ...ient_test.go => endpoint_supplier_test.go} | 4 +- credentials/u2m/persistent_auth.go | 60 +++++++++++++------ credentials/u2m/persistent_auth_test.go | 20 ++----- 7 files changed, 65 insertions(+), 57 deletions(-) rename credentials/u2m/{client.go => endpoint_supplier.go} (74%) rename credentials/u2m/{client_test.go => endpoint_supplier_test.go} (93%) diff --git a/config/auth_u2m.go b/config/auth_u2m.go index c994be66c..538cb4242 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -21,8 +21,8 @@ import ( // If they have but their access and refresh tokens are both invalid, it returns // a special error message that instructs the user how to reauthenticate. type u2mCredentials struct { - // auth is the persistent auth object that manages the token cache. - auth *u2m.PersistentAuth + // ts supplies the token source for the U2M OAuth flow. + ts oauth2.TokenSource } // Name implements CredentialsStrategy. @@ -43,20 +43,20 @@ func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials return nil, fmt.Errorf("oidc: %w", err) } - if u.auth == nil { + if u.ts == nil { auth, err := u2m.NewPersistentAuth(ctx, u2m.WithOAuthArgument(arg)) if err != nil { logger.Debugf(ctx, "failed to create persistent auth: %v, continuing", err) return nil, nil } - u.auth = auth + u.ts = auth } // Construct the visitor, and try to load the credential from the token // cache. If absent, fall back to the next credentials strategy. If a token // is present but cannot be loaded (e.g. expired), return an error. // Otherwise, fall back to the next credentials strategy. - visitor := u.makeVisitor(u.auth) + visitor := u.makeVisitor(u.ts) r, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil) if err != nil { return nil, fmt.Errorf("http request: %w", err) @@ -65,7 +65,7 @@ func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials return nil, u.errorHandler(ctx, cfg, arg, err) } - return credentials.NewOAuthCredentialsProvider(visitor, u.auth.Token), nil + return credentials.NewOAuthCredentialsProvider(visitor, u.ts.Token), nil } func (u u2mCredentials) makeVisitor(auth oauth2.TokenSource) func(*http.Request) error { diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index b2be8bae7..d8b65b426 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -88,7 +88,7 @@ func TestU2MCredentials(t *testing.T) { }, }, }), - u2m.WithOAuthClient(MockOAuthClient{ + u2m.WithOAuthEndpointSupplier(MockOAuthClient{ Transport: fixtures.SliceTransport{ { Method: "POST", @@ -117,7 +117,7 @@ func TestU2MCredentials(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() strat := u2mCredentials{ - auth: tt.auth, + ts: tt.auth, } provider, err := strat.Configure(ctx, tt.cfg) if tt.expectErr != "" { diff --git a/config/config.go b/config/config.go index b23c4a8ef..6c14f0a27 100644 --- a/config/config.go +++ b/config/config.go @@ -455,7 +455,7 @@ func (c *Config) refreshTokenErrorMapper(ctx context.Context, resp common.Respon // getOidcEndpoints returns the OAuth endpoints for the current configuration. func (c *Config) getOidcEndpoints(ctx context.Context) (*u2m.OAuthAuthorizationServer, error) { c.EnsureResolved() - oauthClient := &u2m.BasicOAuthClient{ + oauthClient := &u2m.BasicOAuthEndpointSupplier{ Client: c.refreshClient, } if c.IsAccountClient() { diff --git a/credentials/u2m/client.go b/credentials/u2m/endpoint_supplier.go similarity index 74% rename from credentials/u2m/client.go rename to credentials/u2m/endpoint_supplier.go index 812a13132..fb5c48c24 100644 --- a/credentials/u2m/client.go +++ b/credentials/u2m/endpoint_supplier.go @@ -4,17 +4,13 @@ import ( "context" "errors" "fmt" - "net/http" "github.com/databricks/databricks-sdk-go/httpclient" ) -// OAuthClient provides the http functionality needed for interacting with the +// OAuthEndpointSupplier provides the http functionality needed for interacting with the // Databricks OAuth APIs. -type OAuthClient interface { - // GetHttpClient returns an HTTP client for OAuth2 requests. - GetHttpClient(context.Context) *http.Client - +type OAuthEndpointSupplier interface { // GetWorkspaceOAuthEndpoints returns the OAuth2 endpoints for the workspace. GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) @@ -22,20 +18,16 @@ type OAuthClient interface { GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*OAuthAuthorizationServer, error) } -// BasicOAuthClient is an implementation of the OAuthClient interface. -type BasicOAuthClient struct { +// BasicOAuthEndpointSupplier is an implementation of the OAuthEndpointSupplier interface. +type BasicOAuthEndpointSupplier struct { // Client is the ApiClient to use for making HTTP requests. Client *httpclient.ApiClient } -func (c *BasicOAuthClient) GetHttpClient(_ context.Context) *http.Client { - return c.Client.ToHttpClient() -} - // GetWorkspaceOAuthEndpoints returns the OAuth endpoints for the given workspace. // It queries the OIDC discovery endpoint to get the OAuth endpoints using the // provided ApiClient. -func (c *BasicOAuthClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) { +func (c *BasicOAuthEndpointSupplier) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) { oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", workspaceHost) var oauthEndpoints OAuthAuthorizationServer if err := c.Client.Do(ctx, "GET", oidc, httpclient.WithResponseUnmarshal(&oauthEndpoints)); err != nil { @@ -46,7 +38,7 @@ func (c *BasicOAuthClient) GetWorkspaceOAuthEndpoints(ctx context.Context, works // GetAccountOAuthEndpoints returns the OAuth2 endpoints for the account. The // account-level OAuth endpoints are fixed based on the account ID and host. -func (c *BasicOAuthClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*OAuthAuthorizationServer, error) { +func (c *BasicOAuthEndpointSupplier) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*OAuthAuthorizationServer, error) { return &OAuthAuthorizationServer{ AuthorizationEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/authorize", accountHost, accountId), TokenEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/token", accountHost, accountId), diff --git a/credentials/u2m/client_test.go b/credentials/u2m/endpoint_supplier_test.go similarity index 93% rename from credentials/u2m/client_test.go rename to credentials/u2m/endpoint_supplier_test.go index 5c656d68e..72106e91b 100644 --- a/credentials/u2m/client_test.go +++ b/credentials/u2m/endpoint_supplier_test.go @@ -10,7 +10,7 @@ import ( ) func TestBasicOAuthClient_GetAccountOAuthEndpoints(t *testing.T) { - c := &BasicOAuthClient{} + c := &BasicOAuthEndpointSupplier{} s, err := c.GetAccountOAuthEndpoints(context.Background(), "https://abc", "xyz") assert.NoError(t, err) assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", s.AuthorizationEndpoint) @@ -29,7 +29,7 @@ func TestGetWorkspaceOAuthEndpoints(t *testing.T) { }, }, }) - c := &BasicOAuthClient{Client: p} + c := &BasicOAuthEndpointSupplier{Client: p} endpoints, err := c.GetWorkspaceOAuthEndpoints(context.Background(), "https://abc") assert.NoError(t, err) assert.Equal(t, "a", endpoints.AuthorizationEndpoint) diff --git a/credentials/u2m/persistent_auth.go b/credentials/u2m/persistent_auth.go index ad7651343..d6befe57a 100644 --- a/credentials/u2m/persistent_auth.go +++ b/credentials/u2m/persistent_auth.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "net" + "net/http" "time" cache "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" @@ -43,7 +44,9 @@ type PersistentAuth struct { // cache is the token cache to store and lookup tokens. cache cache.TokenCache // client is the HTTP client to use for OAuth2 requests. - client OAuthClient + client *http.Client + // endpointSupplier is the HTTP endpointSupplier to use for OAuth2 requests. + endpointSupplier OAuthEndpointSupplier // oAuthArgument defines the workspace or account to authenticate to and the // cache key for the token. oAuthArgument OAuthArgument @@ -65,13 +68,22 @@ func WithTokenCache(c cache.TokenCache) PersistentAuthOption { } } -// WithApiClient sets the HTTP client for the PersistentAuth. -func WithOAuthClient(c OAuthClient) PersistentAuthOption { +// WithHttpClient sets the HTTP client for the PersistentAuth. +func WithHttpClient(c *http.Client) PersistentAuthOption { return func(a *PersistentAuth) { a.client = c } } +// WithOAuthEndpointSupplier sets the OAuth endpoint supplier for the +// PersistentAuth. +func WithOAuthEndpointSupplier(c OAuthEndpointSupplier) PersistentAuthOption { + return func(a *PersistentAuth) { + a.endpointSupplier = c + } +} + +// WithOAuthArgument sets the OAuthArgument for the PersistentAuth. func WithOAuthArgument(arg OAuthArgument) PersistentAuthOption { return func(a *PersistentAuth) { a.oAuthArgument = arg @@ -91,9 +103,22 @@ func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*Pers for _, opt := range opts { opt(p) } + // By default, PersistentAuth uses the default ApiClient to make HTTP + // requests. Furthermore, if the endpointSupplier is not provided, it uses + // this same client to fetch the OAuth endpoints. If the HTTP client is + // provided but the endpointSupplier is not, we construct a default + // ApiClient for use with BasicOAuthClient. + var apiClient *httpclient.ApiClient if p.client == nil { - p.client = &BasicOAuthClient{ - Client: httpclient.NewApiClient(httpclient.ClientConfig{}), + apiClient = httpclient.NewApiClient(httpclient.ClientConfig{}) + p.client = apiClient.ToHttpClient() + } + if p.endpointSupplier == nil { + if apiClient == nil { + apiClient = httpclient.NewApiClient(httpclient.ClientConfig{}) + } + p.endpointSupplier = &BasicOAuthEndpointSupplier{ + Client: apiClient, } } if p.cache == nil { @@ -202,12 +227,11 @@ func (a *PersistentAuth) refresh(oldToken *oauth2.Token) (*oauth2.Token, error) // OAuth2 flow is started by opening the browser to the OAuth2 authorization // URL. The user is redirected to the callback server on appRedirectAddr. The // callback server listens for the redirect from the identity provider and -// exchanges the authorization code for an access token. It returns the OAuth2 -// token on success. -func (a *PersistentAuth) Challenge() (*oauth2.Token, error) { +// exchanges the authorization code for an access token. +func (a *PersistentAuth) Challenge() error { err := a.startListener(a.ctx) if err != nil { - return nil, fmt.Errorf("starting listener: %w", err) + return fmt.Errorf("starting listener: %w", err) } // The listener will be closed by the callback server automatically, but if // the callback server is not created, we need to close the listener manually. @@ -215,31 +239,31 @@ func (a *PersistentAuth) Challenge() (*oauth2.Token, error) { cfg, err := a.oauth2Config() if err != nil { - return nil, fmt.Errorf("fetching oauth config: %w", err) + return fmt.Errorf("fetching oauth config: %w", err) } cb, err := a.newCallbackServer() if err != nil { - return nil, fmt.Errorf("callback server: %w", err) + return fmt.Errorf("callback server: %w", err) } defer cb.Close() state, pkce, err := a.stateAndPKCE() if err != nil { - return nil, fmt.Errorf("state and pkce: %w", err) + return fmt.Errorf("state and pkce: %w", err) } // make OAuth2 library use our client ctx := a.setOAuthContext(a.ctx) ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce) t, err := ts.Token() if err != nil { - return nil, fmt.Errorf("authorize: %w", err) + return fmt.Errorf("authorize: %w", err) } // cache token identified by host (and possibly the account id) err = a.cache.Store(a.oAuthArgument.GetCacheKey(), t) if err != nil { - return nil, fmt.Errorf("store: %w", err) + return fmt.Errorf("store: %w", err) } - return t, nil + return nil } // startListener starts a listener on appRedirectAddr, retrying if the address @@ -290,9 +314,9 @@ func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) { var err error switch argg := a.oAuthArgument.(type) { case WorkspaceOAuthArgument: - endpoints, err = a.client.GetWorkspaceOAuthEndpoints(a.ctx, argg.GetWorkspaceHost()) + endpoints, err = a.endpointSupplier.GetWorkspaceOAuthEndpoints(a.ctx, argg.GetWorkspaceHost()) case AccountOAuthArgument: - endpoints, err = a.client.GetAccountOAuthEndpoints( + endpoints, err = a.endpointSupplier.GetAccountOAuthEndpoints( a.ctx, argg.GetAccountHost(), argg.GetAccountId()) default: return nil, fmt.Errorf("unsupported OAuthArgument type: %T, must implement either WorkspaceOAuthArgument or AccountOAuthArgument interface", a.oAuthArgument) @@ -341,7 +365,7 @@ func (a *PersistentAuth) randomString(size int) (string, error) { } func (a *PersistentAuth) setOAuthContext(ctx context.Context) context.Context { - return context.WithValue(ctx, oauth2.HTTPClient, a.client.GetHttpClient(ctx)) + return context.WithValue(ctx, oauth2.HTTPClient, a.client) } var _ oauth2.TokenSource = (*PersistentAuth)(nil) diff --git a/credentials/u2m/persistent_auth_test.go b/credentials/u2m/persistent_auth_test.go index 7ded254e9..a32ecb31c 100644 --- a/credentials/u2m/persistent_auth_test.go +++ b/credentials/u2m/persistent_auth_test.go @@ -102,7 +102,7 @@ func TestLoadRefresh(t *testing.T) { p, err := u2m.NewPersistentAuth( ctx, u2m.WithTokenCache(cache), - u2m.WithOAuthClient(&MockOAuthClient{ + u2m.WithOAuthEndpointSupplier(&MockOAuthClient{ Transport: fixtures.SliceTransport{ { Method: "POST", @@ -126,7 +126,6 @@ func TestLoadRefresh(t *testing.T) { func TestChallenge(t *testing.T) { ctx := context.Background() - expectedKey := "https://accounts.cloud.databricks.com/oidc/accounts/xyz" browserOpened := make(chan string) browser := func(redirect string) error { @@ -142,7 +141,8 @@ func TestChallenge(t *testing.T) { } cache := &tokenCacheMock{ store: func(key string, tok *oauth2.Token) error { - assert.Equal(t, expectedKey, key) + assert.Equal(t, "https://accounts.cloud.databricks.com/oidc/accounts/xyz", key) + assert.Equal(t, "__THAT__", tok.AccessToken) assert.Equal(t, "__SOMETHING__", tok.RefreshToken) return nil }, @@ -153,7 +153,7 @@ func TestChallenge(t *testing.T) { ctx, u2m.WithTokenCache(cache), u2m.WithBrowser(browser), - u2m.WithOAuthClient(&MockOAuthClient{ + u2m.WithOAuthEndpointSupplier(&MockOAuthClient{ Transport: fixtures.SliceTransport{ { Method: "POST", @@ -170,14 +170,11 @@ func TestChallenge(t *testing.T) { require.NoError(t, err) defer p.Close() - tokenc := make(chan *oauth2.Token) errc := make(chan error) go func() { - token, err := p.Challenge() + err := p.Challenge() errc <- err close(errc) - tokenc <- token - close(tokenc) }() state := <-browserOpened @@ -188,7 +185,6 @@ func TestChallenge(t *testing.T) { err = <-errc assert.NoError(t, err) - assert.Equal(t, "__THAT__", (<-tokenc).AccessToken) } func TestChallengeFailed(t *testing.T) { @@ -211,14 +207,11 @@ func TestChallengeFailed(t *testing.T) { require.NoError(t, err) defer p.Close() - tokenc := make(chan *oauth2.Token) errc := make(chan error) go func() { - token, err := p.Challenge() + err := p.Challenge() errc <- err close(errc) - tokenc <- token - close(tokenc) }() <-browserOpened @@ -230,5 +223,4 @@ func TestChallengeFailed(t *testing.T) { err = <-errc assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request") - assert.Nil(t, <-tokenc) } From de60fcf5562eddf087f6f8ac94c66d5d27c11e29 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 17 Feb 2025 17:30:14 +0100 Subject: [PATCH 40/44] fix --- config/.azure/az.json | 1 + config/.azure/az.sess | 1 + config/.azure/azureProfile.json | 1 + config/.azure/commandIndex.json | 1 + config/.azure/config | 3 +++ config/.azure/versionCheck.json | 1 + config/auth_u2m_test.go | 17 ++++++---------- config/testdata/corrupt/.azure/az.json | 1 + config/testdata/corrupt/.azure/az.sess | 1 + .../testdata/corrupt/.azure/azureProfile.json | 1 + .../testdata/corrupt/.azure/commandIndex.json | 1 + config/testdata/corrupt/.azure/config | 3 +++ .../testdata/corrupt/.azure/versionCheck.json | 1 + credentials/u2m/persistent_auth_test.go | 20 +++++++------------ 14 files changed, 29 insertions(+), 24 deletions(-) create mode 100644 config/.azure/az.json create mode 100644 config/.azure/az.sess create mode 100644 config/.azure/azureProfile.json create mode 100644 config/.azure/commandIndex.json create mode 100644 config/.azure/config create mode 100644 config/.azure/versionCheck.json create mode 100644 config/testdata/corrupt/.azure/az.json create mode 100644 config/testdata/corrupt/.azure/az.sess create mode 100644 config/testdata/corrupt/.azure/azureProfile.json create mode 100644 config/testdata/corrupt/.azure/commandIndex.json create mode 100644 config/testdata/corrupt/.azure/config create mode 100644 config/testdata/corrupt/.azure/versionCheck.json diff --git a/config/.azure/az.json b/config/.azure/az.json new file mode 100644 index 000000000..22fdca1b2 --- /dev/null +++ b/config/.azure/az.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/config/.azure/az.sess b/config/.azure/az.sess new file mode 100644 index 000000000..22fdca1b2 --- /dev/null +++ b/config/.azure/az.sess @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/config/.azure/azureProfile.json b/config/.azure/azureProfile.json new file mode 100644 index 000000000..5072f902c --- /dev/null +++ b/config/.azure/azureProfile.json @@ -0,0 +1 @@ +{"installationId": "d87faa92-ed4a-11ef-a8ec-8203a245c339"} \ No newline at end of file diff --git a/config/.azure/commandIndex.json b/config/.azure/commandIndex.json new file mode 100644 index 000000000..0870156ec --- /dev/null +++ b/config/.azure/commandIndex.json @@ -0,0 +1 @@ +{"version": "2.69.0", "cloudProfile": "latest", "commandIndex": {"acr": ["azure.cli.command_modules.acr"], "aks": ["azure.cli.command_modules.acs", "azure.cli.command_modules.serviceconnector"], "advisor": ["azure.cli.command_modules.advisor"], "ams": ["azure.cli.command_modules.ams"], "apim": ["azure.cli.command_modules.apim"], "appconfig": ["azure.cli.command_modules.appconfig"], "webapp": ["azure.cli.command_modules.appservice", "azure.cli.command_modules.serviceconnector"], "functionapp": ["azure.cli.command_modules.appservice", "azure.cli.command_modules.serviceconnector"], "appservice": ["azure.cli.command_modules.appservice"], "staticwebapp": ["azure.cli.command_modules.appservice"], "logicapp": ["azure.cli.command_modules.appservice"], "aro": ["azure.cli.command_modules.aro"], "backup": ["azure.cli.command_modules.backup"], "batch": ["azure.cli.command_modules.batch"], "batchai": ["azure.cli.command_modules.batchai"], "billing": ["azure.cli.command_modules.billing"], "bot": ["azure.cli.command_modules.botservice"], "afd": ["azure.cli.command_modules.cdn"], "cdn": ["azure.cli.command_modules.cdn"], "cloud": ["azure.cli.command_modules.cloud"], "cognitiveservices": ["azure.cli.command_modules.cognitiveservices"], "compute-recommender": ["azure.cli.command_modules.compute_recommender"], "compute-fleet": ["azure.cli.command_modules.computefleet"], "config": ["azure.cli.command_modules.config"], "configure": ["azure.cli.command_modules.configure"], "cache": ["azure.cli.command_modules.configure"], "consumption": ["azure.cli.command_modules.consumption"], "container": ["azure.cli.command_modules.container"], "containerapp": ["azure.cli.command_modules.containerapp", "azure.cli.command_modules.serviceconnector"], "cosmosdb": ["azure.cli.command_modules.cosmosdb"], "managed-cassandra": ["azure.cli.command_modules.cosmosdb"], "databoxedge": ["azure.cli.command_modules.databoxedge"], "dls": ["azure.cli.command_modules.dls"], "dms": ["azure.cli.command_modules.dms"], "eventgrid": ["azure.cli.command_modules.eventgrid"], "eventhubs": ["azure.cli.command_modules.eventhubs"], "extension": ["azure.cli.command_modules.extension"], "feedback": ["azure.cli.command_modules.feedback"], "survey": ["azure.cli.command_modules.feedback"], "find": ["azure.cli.command_modules.find"], "hdinsight": ["azure.cli.command_modules.hdinsight"], "identity": ["azure.cli.command_modules.identity"], "interactive": ["azure.cli.command_modules.interactive"], "iot": ["azure.cli.command_modules.iot"], "keyvault": ["azure.cli.command_modules.keyvault"], "lab": ["azure.cli.command_modules.lab"], "managedservices": ["azure.cli.command_modules.managedservices"], "maps": ["azure.cli.command_modules.maps"], "term": ["azure.cli.command_modules.marketplaceordering"], "monitor": ["azure.cli.command_modules.monitor"], "mysql": ["azure.cli.command_modules.mysql", "azure.cli.command_modules.rdbms"], "netappfiles": ["azure.cli.command_modules.netappfiles"], "network": ["azure.cli.command_modules.network", "azure.cli.command_modules.privatedns"], "policy": ["azure.cli.command_modules.policyinsights", "azure.cli.command_modules.resource"], "login": ["azure.cli.command_modules.profile"], "logout": ["azure.cli.command_modules.profile"], "self-test": ["azure.cli.command_modules.profile"], "account": ["azure.cli.command_modules.profile", "azure.cli.command_modules.resource"], "mariadb": ["azure.cli.command_modules.rdbms"], "postgres": ["azure.cli.command_modules.rdbms"], "redis": ["azure.cli.command_modules.redis"], "relay": ["azure.cli.command_modules.relay"], "data-boundary": ["azure.cli.command_modules.resource"], "group": ["azure.cli.command_modules.resource"], "resource": ["azure.cli.command_modules.resource"], "provider": ["azure.cli.command_modules.resource"], "feature": ["azure.cli.command_modules.resource"], "tag": ["azure.cli.command_modules.resource"], "deployment": ["azure.cli.command_modules.resource"], "deployment-scripts": ["azure.cli.command_modules.resource"], "ts": ["azure.cli.command_modules.resource"], "stack": ["azure.cli.command_modules.resource"], "lock": ["azure.cli.command_modules.resource"], "managedapp": ["azure.cli.command_modules.resource"], "bicep": ["azure.cli.command_modules.resource"], "resourcemanagement": ["azure.cli.command_modules.resource"], "private-link": ["azure.cli.command_modules.resource"], "role": ["azure.cli.command_modules.role"], "ad": ["azure.cli.command_modules.role"], "search": ["azure.cli.command_modules.search"], "security": ["azure.cli.command_modules.security"], "servicebus": ["azure.cli.command_modules.servicebus"], "connection": ["azure.cli.command_modules.serviceconnector"], "sf": ["azure.cli.command_modules.servicefabric"], "signalr": ["azure.cli.command_modules.signalr"], "sql": ["azure.cli.command_modules.sql", "azure.cli.command_modules.sqlvm"], "storage": ["azure.cli.command_modules.storage"], "synapse": ["azure.cli.command_modules.synapse"], "rest": ["azure.cli.command_modules.util"], "version": ["azure.cli.command_modules.util"], "upgrade": ["azure.cli.command_modules.util"], "demo": ["azure.cli.command_modules.util"], "snapshot": ["azure.cli.command_modules.vm"], "disk-access": ["azure.cli.command_modules.vm"], "sig": ["azure.cli.command_modules.vm"], "vmss": ["azure.cli.command_modules.vm"], "restore-point": ["azure.cli.command_modules.vm"], "image": ["azure.cli.command_modules.vm"], "capacity": ["azure.cli.command_modules.vm"], "vm": ["azure.cli.command_modules.vm"], "disk": ["azure.cli.command_modules.vm"], "ppg": ["azure.cli.command_modules.vm"], "disk-encryption-set": ["azure.cli.command_modules.vm"], "sshkey": ["azure.cli.command_modules.vm"]}} \ No newline at end of file diff --git a/config/.azure/config b/config/.azure/config new file mode 100644 index 000000000..0ed7f34d6 --- /dev/null +++ b/config/.azure/config @@ -0,0 +1,3 @@ +[cloud] +name = AzureCloud + diff --git a/config/.azure/versionCheck.json b/config/.azure/versionCheck.json new file mode 100644 index 000000000..2899d2982 --- /dev/null +++ b/config/.azure/versionCheck.json @@ -0,0 +1 @@ +{"versions": {"azure-cli": {"local": "2.69.0", "pypi": "2.69.0"}, "core": {"local": "2.69.0", "pypi": "2.69.0"}, "telemetry": {"local": "1.1.0", "pypi": "1.1.0"}}, "update_time": "2025-02-17 17:18:34.270836"} \ No newline at end of file diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index d8b65b426..9aef3e4ce 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -15,23 +15,16 @@ import ( "golang.org/x/oauth2" ) -type MockOAuthClient struct { - Transport http.RoundTripper +type MockOAuthEndpointSupplier struct { GetAccountOAuthEndpointsFn func(ctx context.Context, accountHost string, accountId string) (*u2m.OAuthAuthorizationServer, error) GetWorkspaceOAuthEndpointsFn func(ctx context.Context, workspaceHost string) (*u2m.OAuthAuthorizationServer, error) } -func (m MockOAuthClient) GetHttpClient(_ context.Context) *http.Client { - return &http.Client{ - Transport: m.Transport, - } -} - -func (m MockOAuthClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*u2m.OAuthAuthorizationServer, error) { +func (m MockOAuthEndpointSupplier) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*u2m.OAuthAuthorizationServer, error) { return m.GetAccountOAuthEndpointsFn(ctx, accountHost, accountId) } -func (m MockOAuthClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*u2m.OAuthAuthorizationServer, error) { +func (m MockOAuthEndpointSupplier) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*u2m.OAuthAuthorizationServer, error) { return m.GetWorkspaceOAuthEndpointsFn(ctx, workspaceHost) } @@ -88,7 +81,7 @@ func TestU2MCredentials(t *testing.T) { }, }, }), - u2m.WithOAuthEndpointSupplier(MockOAuthClient{ + u2m.WithHttpClient(&http.Client{ Transport: fixtures.SliceTransport{ { Method: "POST", @@ -97,6 +90,8 @@ func TestU2MCredentials(t *testing.T) { Response: `{"error":"invalid_refresh_token","error_description":"Refresh token is invalid"}`, }, }, + }), + u2m.WithOAuthEndpointSupplier(MockOAuthEndpointSupplier{ GetWorkspaceOAuthEndpointsFn: func(ctx context.Context, workspaceHost string) (*u2m.OAuthAuthorizationServer, error) { return &u2m.OAuthAuthorizationServer{ TokenEndpoint: "https://myworkspace.cloud.databricks.com/oidc/v1/token", diff --git a/config/testdata/corrupt/.azure/az.json b/config/testdata/corrupt/.azure/az.json new file mode 100644 index 000000000..22fdca1b2 --- /dev/null +++ b/config/testdata/corrupt/.azure/az.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/config/testdata/corrupt/.azure/az.sess b/config/testdata/corrupt/.azure/az.sess new file mode 100644 index 000000000..22fdca1b2 --- /dev/null +++ b/config/testdata/corrupt/.azure/az.sess @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/config/testdata/corrupt/.azure/azureProfile.json b/config/testdata/corrupt/.azure/azureProfile.json new file mode 100644 index 000000000..6cd55e606 --- /dev/null +++ b/config/testdata/corrupt/.azure/azureProfile.json @@ -0,0 +1 @@ +{"installationId": "d901c22a-ed4a-11ef-be5e-8203a245c339"} \ No newline at end of file diff --git a/config/testdata/corrupt/.azure/commandIndex.json b/config/testdata/corrupt/.azure/commandIndex.json new file mode 100644 index 000000000..0870156ec --- /dev/null +++ b/config/testdata/corrupt/.azure/commandIndex.json @@ -0,0 +1 @@ +{"version": "2.69.0", "cloudProfile": "latest", "commandIndex": {"acr": ["azure.cli.command_modules.acr"], "aks": ["azure.cli.command_modules.acs", "azure.cli.command_modules.serviceconnector"], "advisor": ["azure.cli.command_modules.advisor"], "ams": ["azure.cli.command_modules.ams"], "apim": ["azure.cli.command_modules.apim"], "appconfig": ["azure.cli.command_modules.appconfig"], "webapp": ["azure.cli.command_modules.appservice", "azure.cli.command_modules.serviceconnector"], "functionapp": ["azure.cli.command_modules.appservice", "azure.cli.command_modules.serviceconnector"], "appservice": ["azure.cli.command_modules.appservice"], "staticwebapp": ["azure.cli.command_modules.appservice"], "logicapp": ["azure.cli.command_modules.appservice"], "aro": ["azure.cli.command_modules.aro"], "backup": ["azure.cli.command_modules.backup"], "batch": ["azure.cli.command_modules.batch"], "batchai": ["azure.cli.command_modules.batchai"], "billing": ["azure.cli.command_modules.billing"], "bot": ["azure.cli.command_modules.botservice"], "afd": ["azure.cli.command_modules.cdn"], "cdn": ["azure.cli.command_modules.cdn"], "cloud": ["azure.cli.command_modules.cloud"], "cognitiveservices": ["azure.cli.command_modules.cognitiveservices"], "compute-recommender": ["azure.cli.command_modules.compute_recommender"], "compute-fleet": ["azure.cli.command_modules.computefleet"], "config": ["azure.cli.command_modules.config"], "configure": ["azure.cli.command_modules.configure"], "cache": ["azure.cli.command_modules.configure"], "consumption": ["azure.cli.command_modules.consumption"], "container": ["azure.cli.command_modules.container"], "containerapp": ["azure.cli.command_modules.containerapp", "azure.cli.command_modules.serviceconnector"], "cosmosdb": ["azure.cli.command_modules.cosmosdb"], "managed-cassandra": ["azure.cli.command_modules.cosmosdb"], "databoxedge": ["azure.cli.command_modules.databoxedge"], "dls": ["azure.cli.command_modules.dls"], "dms": ["azure.cli.command_modules.dms"], "eventgrid": ["azure.cli.command_modules.eventgrid"], "eventhubs": ["azure.cli.command_modules.eventhubs"], "extension": ["azure.cli.command_modules.extension"], "feedback": ["azure.cli.command_modules.feedback"], "survey": ["azure.cli.command_modules.feedback"], "find": ["azure.cli.command_modules.find"], "hdinsight": ["azure.cli.command_modules.hdinsight"], "identity": ["azure.cli.command_modules.identity"], "interactive": ["azure.cli.command_modules.interactive"], "iot": ["azure.cli.command_modules.iot"], "keyvault": ["azure.cli.command_modules.keyvault"], "lab": ["azure.cli.command_modules.lab"], "managedservices": ["azure.cli.command_modules.managedservices"], "maps": ["azure.cli.command_modules.maps"], "term": ["azure.cli.command_modules.marketplaceordering"], "monitor": ["azure.cli.command_modules.monitor"], "mysql": ["azure.cli.command_modules.mysql", "azure.cli.command_modules.rdbms"], "netappfiles": ["azure.cli.command_modules.netappfiles"], "network": ["azure.cli.command_modules.network", "azure.cli.command_modules.privatedns"], "policy": ["azure.cli.command_modules.policyinsights", "azure.cli.command_modules.resource"], "login": ["azure.cli.command_modules.profile"], "logout": ["azure.cli.command_modules.profile"], "self-test": ["azure.cli.command_modules.profile"], "account": ["azure.cli.command_modules.profile", "azure.cli.command_modules.resource"], "mariadb": ["azure.cli.command_modules.rdbms"], "postgres": ["azure.cli.command_modules.rdbms"], "redis": ["azure.cli.command_modules.redis"], "relay": ["azure.cli.command_modules.relay"], "data-boundary": ["azure.cli.command_modules.resource"], "group": ["azure.cli.command_modules.resource"], "resource": ["azure.cli.command_modules.resource"], "provider": ["azure.cli.command_modules.resource"], "feature": ["azure.cli.command_modules.resource"], "tag": ["azure.cli.command_modules.resource"], "deployment": ["azure.cli.command_modules.resource"], "deployment-scripts": ["azure.cli.command_modules.resource"], "ts": ["azure.cli.command_modules.resource"], "stack": ["azure.cli.command_modules.resource"], "lock": ["azure.cli.command_modules.resource"], "managedapp": ["azure.cli.command_modules.resource"], "bicep": ["azure.cli.command_modules.resource"], "resourcemanagement": ["azure.cli.command_modules.resource"], "private-link": ["azure.cli.command_modules.resource"], "role": ["azure.cli.command_modules.role"], "ad": ["azure.cli.command_modules.role"], "search": ["azure.cli.command_modules.search"], "security": ["azure.cli.command_modules.security"], "servicebus": ["azure.cli.command_modules.servicebus"], "connection": ["azure.cli.command_modules.serviceconnector"], "sf": ["azure.cli.command_modules.servicefabric"], "signalr": ["azure.cli.command_modules.signalr"], "sql": ["azure.cli.command_modules.sql", "azure.cli.command_modules.sqlvm"], "storage": ["azure.cli.command_modules.storage"], "synapse": ["azure.cli.command_modules.synapse"], "rest": ["azure.cli.command_modules.util"], "version": ["azure.cli.command_modules.util"], "upgrade": ["azure.cli.command_modules.util"], "demo": ["azure.cli.command_modules.util"], "snapshot": ["azure.cli.command_modules.vm"], "disk-access": ["azure.cli.command_modules.vm"], "sig": ["azure.cli.command_modules.vm"], "vmss": ["azure.cli.command_modules.vm"], "restore-point": ["azure.cli.command_modules.vm"], "image": ["azure.cli.command_modules.vm"], "capacity": ["azure.cli.command_modules.vm"], "vm": ["azure.cli.command_modules.vm"], "disk": ["azure.cli.command_modules.vm"], "ppg": ["azure.cli.command_modules.vm"], "disk-encryption-set": ["azure.cli.command_modules.vm"], "sshkey": ["azure.cli.command_modules.vm"]}} \ No newline at end of file diff --git a/config/testdata/corrupt/.azure/config b/config/testdata/corrupt/.azure/config new file mode 100644 index 000000000..0ed7f34d6 --- /dev/null +++ b/config/testdata/corrupt/.azure/config @@ -0,0 +1,3 @@ +[cloud] +name = AzureCloud + diff --git a/config/testdata/corrupt/.azure/versionCheck.json b/config/testdata/corrupt/.azure/versionCheck.json new file mode 100644 index 000000000..3345fa560 --- /dev/null +++ b/config/testdata/corrupt/.azure/versionCheck.json @@ -0,0 +1 @@ +{"versions": {"azure-cli": {"local": "2.69.0", "pypi": "2.69.0"}, "core": {"local": "2.69.0", "pypi": "2.69.0"}, "telemetry": {"local": "1.1.0", "pypi": "1.1.0"}}, "update_time": "2025-02-17 17:18:36.135883"} \ No newline at end of file diff --git a/credentials/u2m/persistent_auth_test.go b/credentials/u2m/persistent_auth_test.go index a32ecb31c..50835b575 100644 --- a/credentials/u2m/persistent_auth_test.go +++ b/credentials/u2m/persistent_auth_test.go @@ -55,24 +55,16 @@ func TestLoad(t *testing.T) { assert.Equal(t, "", tok.RefreshToken) } -type MockOAuthClient struct { - Transport http.RoundTripper -} - -func (m MockOAuthClient) GetHttpClient(_ context.Context) *http.Client { - return &http.Client{ - Transport: m.Transport, - } -} +type MockOAuthEndpointSupplier struct{} -func (m MockOAuthClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*u2m.OAuthAuthorizationServer, error) { +func (m MockOAuthEndpointSupplier) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*u2m.OAuthAuthorizationServer, error) { return &u2m.OAuthAuthorizationServer{ AuthorizationEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/authorize", accountHost, accountId), TokenEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/token", accountHost, accountId), }, nil } -func (m MockOAuthClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*u2m.OAuthAuthorizationServer, error) { +func (m MockOAuthEndpointSupplier) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*u2m.OAuthAuthorizationServer, error) { return &u2m.OAuthAuthorizationServer{ AuthorizationEndpoint: fmt.Sprintf("%s/oidc/v1/authorize", workspaceHost), TokenEndpoint: fmt.Sprintf("%s/oidc/v1/token", workspaceHost), @@ -102,7 +94,7 @@ func TestLoadRefresh(t *testing.T) { p, err := u2m.NewPersistentAuth( ctx, u2m.WithTokenCache(cache), - u2m.WithOAuthEndpointSupplier(&MockOAuthClient{ + u2m.WithHttpClient(&http.Client{ Transport: fixtures.SliceTransport{ { Method: "POST", @@ -114,6 +106,7 @@ func TestLoadRefresh(t *testing.T) { }, }, }), + u2m.WithOAuthEndpointSupplier(MockOAuthEndpointSupplier{}), u2m.WithOAuthArgument(arg), ) require.NoError(t, err) @@ -153,7 +146,7 @@ func TestChallenge(t *testing.T) { ctx, u2m.WithTokenCache(cache), u2m.WithBrowser(browser), - u2m.WithOAuthEndpointSupplier(&MockOAuthClient{ + u2m.WithHttpClient(&http.Client{ Transport: fixtures.SliceTransport{ { Method: "POST", @@ -165,6 +158,7 @@ func TestChallenge(t *testing.T) { }, }, }), + u2m.WithOAuthEndpointSupplier(MockOAuthEndpointSupplier{}), u2m.WithOAuthArgument(arg), ) require.NoError(t, err) From 11380f017145cff6d44d0b4e117ab1d16f9cf647 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 17 Feb 2025 17:30:36 +0100 Subject: [PATCH 41/44] remove --- config/.azure/az.json | 1 - config/.azure/az.sess | 1 - config/.azure/azureProfile.json | 1 - config/.azure/commandIndex.json | 1 - config/.azure/config | 3 --- config/.azure/versionCheck.json | 1 - config/testdata/corrupt/.azure/az.json | 1 - config/testdata/corrupt/.azure/az.sess | 1 - config/testdata/corrupt/.azure/azureProfile.json | 1 - config/testdata/corrupt/.azure/commandIndex.json | 1 - config/testdata/corrupt/.azure/config | 3 --- config/testdata/corrupt/.azure/versionCheck.json | 1 - 12 files changed, 16 deletions(-) delete mode 100644 config/.azure/az.json delete mode 100644 config/.azure/az.sess delete mode 100644 config/.azure/azureProfile.json delete mode 100644 config/.azure/commandIndex.json delete mode 100644 config/.azure/config delete mode 100644 config/.azure/versionCheck.json delete mode 100644 config/testdata/corrupt/.azure/az.json delete mode 100644 config/testdata/corrupt/.azure/az.sess delete mode 100644 config/testdata/corrupt/.azure/azureProfile.json delete mode 100644 config/testdata/corrupt/.azure/commandIndex.json delete mode 100644 config/testdata/corrupt/.azure/config delete mode 100644 config/testdata/corrupt/.azure/versionCheck.json diff --git a/config/.azure/az.json b/config/.azure/az.json deleted file mode 100644 index 22fdca1b2..000000000 --- a/config/.azure/az.json +++ /dev/null @@ -1 +0,0 @@ -{} \ No newline at end of file diff --git a/config/.azure/az.sess b/config/.azure/az.sess deleted file mode 100644 index 22fdca1b2..000000000 --- a/config/.azure/az.sess +++ /dev/null @@ -1 +0,0 @@ -{} \ No newline at end of file diff --git a/config/.azure/azureProfile.json b/config/.azure/azureProfile.json deleted file mode 100644 index 5072f902c..000000000 --- a/config/.azure/azureProfile.json +++ /dev/null @@ -1 +0,0 @@ -{"installationId": "d87faa92-ed4a-11ef-a8ec-8203a245c339"} \ No newline at end of file diff --git a/config/.azure/commandIndex.json b/config/.azure/commandIndex.json deleted file mode 100644 index 0870156ec..000000000 --- a/config/.azure/commandIndex.json +++ /dev/null @@ -1 +0,0 @@ -{"version": "2.69.0", "cloudProfile": "latest", "commandIndex": {"acr": ["azure.cli.command_modules.acr"], "aks": ["azure.cli.command_modules.acs", "azure.cli.command_modules.serviceconnector"], "advisor": ["azure.cli.command_modules.advisor"], "ams": ["azure.cli.command_modules.ams"], "apim": ["azure.cli.command_modules.apim"], "appconfig": ["azure.cli.command_modules.appconfig"], "webapp": ["azure.cli.command_modules.appservice", "azure.cli.command_modules.serviceconnector"], "functionapp": ["azure.cli.command_modules.appservice", "azure.cli.command_modules.serviceconnector"], "appservice": ["azure.cli.command_modules.appservice"], "staticwebapp": ["azure.cli.command_modules.appservice"], "logicapp": ["azure.cli.command_modules.appservice"], "aro": ["azure.cli.command_modules.aro"], "backup": ["azure.cli.command_modules.backup"], "batch": ["azure.cli.command_modules.batch"], "batchai": ["azure.cli.command_modules.batchai"], "billing": ["azure.cli.command_modules.billing"], "bot": ["azure.cli.command_modules.botservice"], "afd": ["azure.cli.command_modules.cdn"], "cdn": ["azure.cli.command_modules.cdn"], "cloud": ["azure.cli.command_modules.cloud"], "cognitiveservices": ["azure.cli.command_modules.cognitiveservices"], "compute-recommender": ["azure.cli.command_modules.compute_recommender"], "compute-fleet": ["azure.cli.command_modules.computefleet"], "config": ["azure.cli.command_modules.config"], "configure": ["azure.cli.command_modules.configure"], "cache": ["azure.cli.command_modules.configure"], "consumption": ["azure.cli.command_modules.consumption"], "container": ["azure.cli.command_modules.container"], "containerapp": ["azure.cli.command_modules.containerapp", "azure.cli.command_modules.serviceconnector"], "cosmosdb": ["azure.cli.command_modules.cosmosdb"], "managed-cassandra": ["azure.cli.command_modules.cosmosdb"], "databoxedge": ["azure.cli.command_modules.databoxedge"], "dls": ["azure.cli.command_modules.dls"], "dms": ["azure.cli.command_modules.dms"], "eventgrid": ["azure.cli.command_modules.eventgrid"], "eventhubs": ["azure.cli.command_modules.eventhubs"], "extension": ["azure.cli.command_modules.extension"], "feedback": ["azure.cli.command_modules.feedback"], "survey": ["azure.cli.command_modules.feedback"], "find": ["azure.cli.command_modules.find"], "hdinsight": ["azure.cli.command_modules.hdinsight"], "identity": ["azure.cli.command_modules.identity"], "interactive": ["azure.cli.command_modules.interactive"], "iot": ["azure.cli.command_modules.iot"], "keyvault": ["azure.cli.command_modules.keyvault"], "lab": ["azure.cli.command_modules.lab"], "managedservices": ["azure.cli.command_modules.managedservices"], "maps": ["azure.cli.command_modules.maps"], "term": ["azure.cli.command_modules.marketplaceordering"], "monitor": ["azure.cli.command_modules.monitor"], "mysql": ["azure.cli.command_modules.mysql", "azure.cli.command_modules.rdbms"], "netappfiles": ["azure.cli.command_modules.netappfiles"], "network": ["azure.cli.command_modules.network", "azure.cli.command_modules.privatedns"], "policy": ["azure.cli.command_modules.policyinsights", "azure.cli.command_modules.resource"], "login": ["azure.cli.command_modules.profile"], "logout": ["azure.cli.command_modules.profile"], "self-test": ["azure.cli.command_modules.profile"], "account": ["azure.cli.command_modules.profile", "azure.cli.command_modules.resource"], "mariadb": ["azure.cli.command_modules.rdbms"], "postgres": ["azure.cli.command_modules.rdbms"], "redis": ["azure.cli.command_modules.redis"], "relay": ["azure.cli.command_modules.relay"], "data-boundary": ["azure.cli.command_modules.resource"], "group": ["azure.cli.command_modules.resource"], "resource": ["azure.cli.command_modules.resource"], "provider": ["azure.cli.command_modules.resource"], "feature": ["azure.cli.command_modules.resource"], "tag": ["azure.cli.command_modules.resource"], "deployment": ["azure.cli.command_modules.resource"], "deployment-scripts": ["azure.cli.command_modules.resource"], "ts": ["azure.cli.command_modules.resource"], "stack": ["azure.cli.command_modules.resource"], "lock": ["azure.cli.command_modules.resource"], "managedapp": ["azure.cli.command_modules.resource"], "bicep": ["azure.cli.command_modules.resource"], "resourcemanagement": ["azure.cli.command_modules.resource"], "private-link": ["azure.cli.command_modules.resource"], "role": ["azure.cli.command_modules.role"], "ad": ["azure.cli.command_modules.role"], "search": ["azure.cli.command_modules.search"], "security": ["azure.cli.command_modules.security"], "servicebus": ["azure.cli.command_modules.servicebus"], "connection": ["azure.cli.command_modules.serviceconnector"], "sf": ["azure.cli.command_modules.servicefabric"], "signalr": ["azure.cli.command_modules.signalr"], "sql": ["azure.cli.command_modules.sql", "azure.cli.command_modules.sqlvm"], "storage": ["azure.cli.command_modules.storage"], "synapse": ["azure.cli.command_modules.synapse"], "rest": ["azure.cli.command_modules.util"], "version": ["azure.cli.command_modules.util"], "upgrade": ["azure.cli.command_modules.util"], "demo": ["azure.cli.command_modules.util"], "snapshot": ["azure.cli.command_modules.vm"], "disk-access": ["azure.cli.command_modules.vm"], "sig": ["azure.cli.command_modules.vm"], "vmss": ["azure.cli.command_modules.vm"], "restore-point": ["azure.cli.command_modules.vm"], "image": ["azure.cli.command_modules.vm"], "capacity": ["azure.cli.command_modules.vm"], "vm": ["azure.cli.command_modules.vm"], "disk": ["azure.cli.command_modules.vm"], "ppg": ["azure.cli.command_modules.vm"], "disk-encryption-set": ["azure.cli.command_modules.vm"], "sshkey": ["azure.cli.command_modules.vm"]}} \ No newline at end of file diff --git a/config/.azure/config b/config/.azure/config deleted file mode 100644 index 0ed7f34d6..000000000 --- a/config/.azure/config +++ /dev/null @@ -1,3 +0,0 @@ -[cloud] -name = AzureCloud - diff --git a/config/.azure/versionCheck.json b/config/.azure/versionCheck.json deleted file mode 100644 index 2899d2982..000000000 --- a/config/.azure/versionCheck.json +++ /dev/null @@ -1 +0,0 @@ -{"versions": {"azure-cli": {"local": "2.69.0", "pypi": "2.69.0"}, "core": {"local": "2.69.0", "pypi": "2.69.0"}, "telemetry": {"local": "1.1.0", "pypi": "1.1.0"}}, "update_time": "2025-02-17 17:18:34.270836"} \ No newline at end of file diff --git a/config/testdata/corrupt/.azure/az.json b/config/testdata/corrupt/.azure/az.json deleted file mode 100644 index 22fdca1b2..000000000 --- a/config/testdata/corrupt/.azure/az.json +++ /dev/null @@ -1 +0,0 @@ -{} \ No newline at end of file diff --git a/config/testdata/corrupt/.azure/az.sess b/config/testdata/corrupt/.azure/az.sess deleted file mode 100644 index 22fdca1b2..000000000 --- a/config/testdata/corrupt/.azure/az.sess +++ /dev/null @@ -1 +0,0 @@ -{} \ No newline at end of file diff --git a/config/testdata/corrupt/.azure/azureProfile.json b/config/testdata/corrupt/.azure/azureProfile.json deleted file mode 100644 index 6cd55e606..000000000 --- a/config/testdata/corrupt/.azure/azureProfile.json +++ /dev/null @@ -1 +0,0 @@ -{"installationId": "d901c22a-ed4a-11ef-be5e-8203a245c339"} \ No newline at end of file diff --git a/config/testdata/corrupt/.azure/commandIndex.json b/config/testdata/corrupt/.azure/commandIndex.json deleted file mode 100644 index 0870156ec..000000000 --- a/config/testdata/corrupt/.azure/commandIndex.json +++ /dev/null @@ -1 +0,0 @@ -{"version": "2.69.0", "cloudProfile": "latest", "commandIndex": {"acr": ["azure.cli.command_modules.acr"], "aks": ["azure.cli.command_modules.acs", "azure.cli.command_modules.serviceconnector"], "advisor": ["azure.cli.command_modules.advisor"], "ams": ["azure.cli.command_modules.ams"], "apim": ["azure.cli.command_modules.apim"], "appconfig": ["azure.cli.command_modules.appconfig"], "webapp": ["azure.cli.command_modules.appservice", "azure.cli.command_modules.serviceconnector"], "functionapp": ["azure.cli.command_modules.appservice", "azure.cli.command_modules.serviceconnector"], "appservice": ["azure.cli.command_modules.appservice"], "staticwebapp": ["azure.cli.command_modules.appservice"], "logicapp": ["azure.cli.command_modules.appservice"], "aro": ["azure.cli.command_modules.aro"], "backup": ["azure.cli.command_modules.backup"], "batch": ["azure.cli.command_modules.batch"], "batchai": ["azure.cli.command_modules.batchai"], "billing": ["azure.cli.command_modules.billing"], "bot": ["azure.cli.command_modules.botservice"], "afd": ["azure.cli.command_modules.cdn"], "cdn": ["azure.cli.command_modules.cdn"], "cloud": ["azure.cli.command_modules.cloud"], "cognitiveservices": ["azure.cli.command_modules.cognitiveservices"], "compute-recommender": ["azure.cli.command_modules.compute_recommender"], "compute-fleet": ["azure.cli.command_modules.computefleet"], "config": ["azure.cli.command_modules.config"], "configure": ["azure.cli.command_modules.configure"], "cache": ["azure.cli.command_modules.configure"], "consumption": ["azure.cli.command_modules.consumption"], "container": ["azure.cli.command_modules.container"], "containerapp": ["azure.cli.command_modules.containerapp", "azure.cli.command_modules.serviceconnector"], "cosmosdb": ["azure.cli.command_modules.cosmosdb"], "managed-cassandra": ["azure.cli.command_modules.cosmosdb"], "databoxedge": ["azure.cli.command_modules.databoxedge"], "dls": ["azure.cli.command_modules.dls"], "dms": ["azure.cli.command_modules.dms"], "eventgrid": ["azure.cli.command_modules.eventgrid"], "eventhubs": ["azure.cli.command_modules.eventhubs"], "extension": ["azure.cli.command_modules.extension"], "feedback": ["azure.cli.command_modules.feedback"], "survey": ["azure.cli.command_modules.feedback"], "find": ["azure.cli.command_modules.find"], "hdinsight": ["azure.cli.command_modules.hdinsight"], "identity": ["azure.cli.command_modules.identity"], "interactive": ["azure.cli.command_modules.interactive"], "iot": ["azure.cli.command_modules.iot"], "keyvault": ["azure.cli.command_modules.keyvault"], "lab": ["azure.cli.command_modules.lab"], "managedservices": ["azure.cli.command_modules.managedservices"], "maps": ["azure.cli.command_modules.maps"], "term": ["azure.cli.command_modules.marketplaceordering"], "monitor": ["azure.cli.command_modules.monitor"], "mysql": ["azure.cli.command_modules.mysql", "azure.cli.command_modules.rdbms"], "netappfiles": ["azure.cli.command_modules.netappfiles"], "network": ["azure.cli.command_modules.network", "azure.cli.command_modules.privatedns"], "policy": ["azure.cli.command_modules.policyinsights", "azure.cli.command_modules.resource"], "login": ["azure.cli.command_modules.profile"], "logout": ["azure.cli.command_modules.profile"], "self-test": ["azure.cli.command_modules.profile"], "account": ["azure.cli.command_modules.profile", "azure.cli.command_modules.resource"], "mariadb": ["azure.cli.command_modules.rdbms"], "postgres": ["azure.cli.command_modules.rdbms"], "redis": ["azure.cli.command_modules.redis"], "relay": ["azure.cli.command_modules.relay"], "data-boundary": ["azure.cli.command_modules.resource"], "group": ["azure.cli.command_modules.resource"], "resource": ["azure.cli.command_modules.resource"], "provider": ["azure.cli.command_modules.resource"], "feature": ["azure.cli.command_modules.resource"], "tag": ["azure.cli.command_modules.resource"], "deployment": ["azure.cli.command_modules.resource"], "deployment-scripts": ["azure.cli.command_modules.resource"], "ts": ["azure.cli.command_modules.resource"], "stack": ["azure.cli.command_modules.resource"], "lock": ["azure.cli.command_modules.resource"], "managedapp": ["azure.cli.command_modules.resource"], "bicep": ["azure.cli.command_modules.resource"], "resourcemanagement": ["azure.cli.command_modules.resource"], "private-link": ["azure.cli.command_modules.resource"], "role": ["azure.cli.command_modules.role"], "ad": ["azure.cli.command_modules.role"], "search": ["azure.cli.command_modules.search"], "security": ["azure.cli.command_modules.security"], "servicebus": ["azure.cli.command_modules.servicebus"], "connection": ["azure.cli.command_modules.serviceconnector"], "sf": ["azure.cli.command_modules.servicefabric"], "signalr": ["azure.cli.command_modules.signalr"], "sql": ["azure.cli.command_modules.sql", "azure.cli.command_modules.sqlvm"], "storage": ["azure.cli.command_modules.storage"], "synapse": ["azure.cli.command_modules.synapse"], "rest": ["azure.cli.command_modules.util"], "version": ["azure.cli.command_modules.util"], "upgrade": ["azure.cli.command_modules.util"], "demo": ["azure.cli.command_modules.util"], "snapshot": ["azure.cli.command_modules.vm"], "disk-access": ["azure.cli.command_modules.vm"], "sig": ["azure.cli.command_modules.vm"], "vmss": ["azure.cli.command_modules.vm"], "restore-point": ["azure.cli.command_modules.vm"], "image": ["azure.cli.command_modules.vm"], "capacity": ["azure.cli.command_modules.vm"], "vm": ["azure.cli.command_modules.vm"], "disk": ["azure.cli.command_modules.vm"], "ppg": ["azure.cli.command_modules.vm"], "disk-encryption-set": ["azure.cli.command_modules.vm"], "sshkey": ["azure.cli.command_modules.vm"]}} \ No newline at end of file diff --git a/config/testdata/corrupt/.azure/config b/config/testdata/corrupt/.azure/config deleted file mode 100644 index 0ed7f34d6..000000000 --- a/config/testdata/corrupt/.azure/config +++ /dev/null @@ -1,3 +0,0 @@ -[cloud] -name = AzureCloud - diff --git a/config/testdata/corrupt/.azure/versionCheck.json b/config/testdata/corrupt/.azure/versionCheck.json deleted file mode 100644 index 3345fa560..000000000 --- a/config/testdata/corrupt/.azure/versionCheck.json +++ /dev/null @@ -1 +0,0 @@ -{"versions": {"azure-cli": {"local": "2.69.0", "pypi": "2.69.0"}, "core": {"local": "2.69.0", "pypi": "2.69.0"}, "telemetry": {"local": "1.1.0", "pypi": "1.1.0"}}, "update_time": "2025-02-17 17:18:36.135883"} \ No newline at end of file From 44c2ca7acea354e7c986538b5202fe566d8182d0 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 17 Feb 2025 17:46:25 +0100 Subject: [PATCH 42/44] fix --- config/auth_u2m_test.go | 80 +++++------------------- credentials/u2m/persistent_auth_test.go | 82 ++++++++++++++++++++++++- 2 files changed, 94 insertions(+), 68 deletions(-) diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index 9aef3e4ce..597e5421b 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -10,36 +10,24 @@ import ( "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" - "github.com/databricks/databricks-sdk-go/httpclient/fixtures" "github.com/stretchr/testify/require" "golang.org/x/oauth2" ) -type MockOAuthEndpointSupplier struct { - GetAccountOAuthEndpointsFn func(ctx context.Context, accountHost string, accountId string) (*u2m.OAuthAuthorizationServer, error) - GetWorkspaceOAuthEndpointsFn func(ctx context.Context, workspaceHost string) (*u2m.OAuthAuthorizationServer, error) +type mockU2mTokenSource struct { + token *oauth2.Token + err error } -func (m MockOAuthEndpointSupplier) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*u2m.OAuthAuthorizationServer, error) { - return m.GetAccountOAuthEndpointsFn(ctx, accountHost, accountId) -} - -func (m MockOAuthEndpointSupplier) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*u2m.OAuthAuthorizationServer, error) { - return m.GetWorkspaceOAuthEndpointsFn(ctx, workspaceHost) -} - -func must[T any](c T, err error) T { - if err != nil { - panic(err) - } - return c +func (m mockU2mTokenSource) Token() (*oauth2.Token, error) { + return m.token, m.err } func TestU2MCredentials(t *testing.T) { tests := []struct { name string cfg *Config - auth *u2m.PersistentAuth + auth oauth2.TokenSource expectErr string expectAuth string }{ @@ -48,20 +36,12 @@ func TestU2MCredentials(t *testing.T) { cfg: &Config{ Host: "https://myworkspace.cloud.databricks.com", }, - auth: must( - u2m.NewPersistentAuth( - context.Background(), - u2m.WithTokenCache(&InMemoryTokenCache{ - Tokens: map[string]*oauth2.Token{ - "https://myworkspace.cloud.databricks.com": { - AccessToken: "dummy_access_token", - Expiry: time.Now().Add(1 * time.Hour), - }, - }, - }), - u2m.WithOAuthArgument(must(u2m.NewBasicWorkspaceOAuthArgument("https://myworkspace.cloud.databricks.com"))), - ), - ), + auth: mockU2mTokenSource{ + token: &oauth2.Token{ + AccessToken: "dummy_access_token", + Expiry: time.Now().Add(1 * time.Hour), + }, + }, expectAuth: "Bearer dummy_access_token", }, { @@ -69,39 +49,9 @@ func TestU2MCredentials(t *testing.T) { cfg: &Config{ Host: "https://myworkspace.cloud.databricks.com", }, - auth: must( - u2m.NewPersistentAuth( - context.Background(), - u2m.WithTokenCache(&InMemoryTokenCache{ - Tokens: map[string]*oauth2.Token{ - "https://myworkspace.cloud.databricks.com": { - AccessToken: "dummy_access_token", - RefreshToken: "dummy_refresh_token", - Expiry: time.Now().Add(-1 * time.Hour), - }, - }, - }), - u2m.WithHttpClient(&http.Client{ - Transport: fixtures.SliceTransport{ - { - Method: "POST", - Resource: "/oidc/v1/token", - Status: 401, - Response: `{"error":"invalid_refresh_token","error_description":"Refresh token is invalid"}`, - }, - }, - }), - u2m.WithOAuthEndpointSupplier(MockOAuthEndpointSupplier{ - GetWorkspaceOAuthEndpointsFn: func(ctx context.Context, workspaceHost string) (*u2m.OAuthAuthorizationServer, error) { - return &u2m.OAuthAuthorizationServer{ - TokenEndpoint: "https://myworkspace.cloud.databricks.com/oidc/v1/token", - AuthorizationEndpoint: "https://myworkspace.cloud.databricks.com/oidc/v1/authorize", - }, nil - }, - }), - u2m.WithOAuthArgument(must(u2m.NewBasicWorkspaceOAuthArgument("https://myworkspace.cloud.databricks.com"))), - ), - ), + auth: mockU2mTokenSource{ + err: &u2m.InvalidRefreshTokenError{}, + }, expectErr: `a new access token could not be retrieved because the refresh token is invalid. If using the CLI, run the following command to reauthenticate: $ databricks auth login --host https://myworkspace.cloud.databricks.com`, diff --git a/credentials/u2m/persistent_auth_test.go b/credentials/u2m/persistent_auth_test.go index 50835b575..5aac4ecda 100644 --- a/credentials/u2m/persistent_auth_test.go +++ b/credentials/u2m/persistent_auth_test.go @@ -2,6 +2,7 @@ package u2m_test import ( "context" + "errors" "fmt" "net/http" "net/url" @@ -34,7 +35,7 @@ func (m *tokenCacheMock) Lookup(key string) (*oauth2.Token, error) { return m.lookup(key) } -func TestLoad(t *testing.T) { +func TestToken(t *testing.T) { cache := &tokenCacheMock{ lookup: func(key string) (*oauth2.Token, error) { assert.Equal(t, "https://abc/oidc/accounts/xyz", key) @@ -71,7 +72,7 @@ func (m MockOAuthEndpointSupplier) GetWorkspaceOAuthEndpoints(ctx context.Contex }, nil } -func TestLoadRefresh(t *testing.T) { +func TestToken_RefreshesExpiredAccessToken(t *testing.T) { ctx := context.Background() expectedKey := "https://accounts.cloud.databricks.com/oidc/accounts/xyz" cache := &tokenCacheMock{ @@ -117,6 +118,81 @@ func TestLoadRefresh(t *testing.T) { assert.Equal(t, "", tok.RefreshToken) } +func TestToken_ReturnsError(t *testing.T) { + ctx := context.Background() + cache := &tokenCacheMock{ + lookup: func(key string) (*oauth2.Token, error) { + assert.Equal(t, "https://accounts.cloud.databricks.com/oidc/accounts/xyz", key) + return &oauth2.Token{ + AccessToken: "expired", + RefreshToken: "cde", + Expiry: time.Now().Add(-1 * time.Minute), + }, nil + }, + } + arg, err := u2m.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + assert.NoError(t, err) + p, err := u2m.NewPersistentAuth( + ctx, + u2m.WithTokenCache(cache), + u2m.WithHttpClient(&http.Client{ + Transport: fixtures.SliceTransport{ + { + Method: "POST", + Resource: "/oidc/accounts/xyz/v1/token", + Response: `{"error": "invalid_grant", "error_description": "Invalid Client"}`, + Status: 401, + }, + }, + }), + u2m.WithOAuthEndpointSupplier(MockOAuthEndpointSupplier{}), + u2m.WithOAuthArgument(arg), + ) + require.NoError(t, err) + defer p.Close() + tok, err := p.Token() + assert.Nil(t, tok) + assert.ErrorContains(t, err, "Invalid Client (error code: invalid_grant)") +} + +func TestToken_ReturnsInvalidRefreshTokenError(t *testing.T) { + ctx := context.Background() + cache := &tokenCacheMock{ + lookup: func(key string) (*oauth2.Token, error) { + assert.Equal(t, "https://accounts.cloud.databricks.com/oidc/accounts/xyz", key) + return &oauth2.Token{ + AccessToken: "expired", + RefreshToken: "cde", + Expiry: time.Now().Add(-1 * time.Minute), + }, nil + }, + } + arg, err := u2m.NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + assert.NoError(t, err) + p, err := u2m.NewPersistentAuth( + ctx, + u2m.WithTokenCache(cache), + u2m.WithHttpClient(&http.Client{ + Transport: fixtures.SliceTransport{ + { + Method: "POST", + Resource: "/oidc/accounts/xyz/v1/token", + Response: `{"error": "invalid_grant", "error_description": "Refresh token is invalid"}`, + Status: 401, + }, + }, + }), + u2m.WithOAuthEndpointSupplier(MockOAuthEndpointSupplier{}), + u2m.WithOAuthArgument(arg), + ) + require.NoError(t, err) + defer p.Close() + tok, err := p.Token() + assert.Nil(t, tok) + target := &u2m.InvalidRefreshTokenError{} + assert.True(t, errors.As(err, &target)) +} + func TestChallenge(t *testing.T) { ctx := context.Background() @@ -181,7 +257,7 @@ func TestChallenge(t *testing.T) { assert.NoError(t, err) } -func TestChallengeFailed(t *testing.T) { +func TestChallenge_ReturnsErrorOnFailure(t *testing.T) { ctx := context.Background() browserOpened := make(chan string) browser := func(redirect string) error { From 45b7e5568eb516c91df97b18892c96833b0f6f1d Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Tue, 18 Feb 2025 11:05:35 +0100 Subject: [PATCH 43/44] changelog --- NEXT_CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 9bb04b137..a484f2b95 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -4,6 +4,8 @@ ### New Features and Improvements +* Support user-to-machine authentication in the SDK ([#1108](https://github.com/databricks/databricks-sdk-go/pull/1108)). + ### Bug Fixes ### Documentation From e37d457b6b9257a191de5aed31a2d34b90b33e2b Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 21 Mar 2025 17:12:44 +0100 Subject: [PATCH 44/44] clean up --- credentials/u2m/persistent_auth.go | 12 ++++++------ httpclient/http.go | 10 ---------- 2 files changed, 6 insertions(+), 16 deletions(-) delete mode 100644 httpclient/http.go diff --git a/credentials/u2m/persistent_auth.go b/credentials/u2m/persistent_auth.go index d6befe57a..7e28a25f5 100644 --- a/credentials/u2m/persistent_auth.go +++ b/credentials/u2m/persistent_auth.go @@ -108,15 +108,15 @@ func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*Pers // this same client to fetch the OAuth endpoints. If the HTTP client is // provided but the endpointSupplier is not, we construct a default // ApiClient for use with BasicOAuthClient. - var apiClient *httpclient.ApiClient + apiClient := httpclient.NewApiClient(httpclient.ClientConfig{}) if p.client == nil { - apiClient = httpclient.NewApiClient(httpclient.ClientConfig{}) - p.client = apiClient.ToHttpClient() + p.client = &http.Client{ + Transport: apiClient, + // 30 seconds matches the default timeout of the ApiClient + Timeout: 30 * time.Second, + } } if p.endpointSupplier == nil { - if apiClient == nil { - apiClient = httpclient.NewApiClient(httpclient.ClientConfig{}) - } p.endpointSupplier = &BasicOAuthEndpointSupplier{ Client: apiClient, } diff --git a/httpclient/http.go b/httpclient/http.go deleted file mode 100644 index b0430b69d..000000000 --- a/httpclient/http.go +++ /dev/null @@ -1,10 +0,0 @@ -package httpclient - -import "net/http" - -func (a *ApiClient) ToHttpClient() *http.Client { - return &http.Client{ - Transport: a, - Timeout: a.config.HTTPTimeout, - } -}