From 456edd15b7dd1eb11396538675328e1564427355 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 3 Jan 2025 14:45:07 +0100 Subject: [PATCH 01/15] Migrate U2M to SDK --- cmd/auth/auth.go | 12 +- cmd/auth/login.go | 68 ++++--- cmd/auth/login_test.go | 43 ++--- cmd/auth/token.go | 102 +++++++---- cmd/auth/token_test.go | 237 ++++++++++++++---------- go.mod | 7 +- go.sum | 19 +- libs/auth/cache/cache.go | 26 --- libs/auth/cache/file.go | 108 ----------- libs/auth/cache/file_test.go | 105 ----------- libs/auth/cache/in_memory.go | 26 --- libs/auth/cache/in_memory_test.go | 44 ----- libs/auth/callback.go | 104 ----------- libs/auth/oauth.go | 289 ------------------------------ libs/auth/oauth_test.go | 267 --------------------------- libs/auth/page.tmpl | 102 ----------- 16 files changed, 296 insertions(+), 1263 deletions(-) delete mode 100644 libs/auth/cache/cache.go delete mode 100644 libs/auth/cache/file.go delete mode 100644 libs/auth/cache/file_test.go delete mode 100644 libs/auth/cache/in_memory.go delete mode 100644 libs/auth/cache/in_memory_test.go delete mode 100644 libs/auth/callback.go delete mode 100644 libs/auth/oauth.go delete mode 100644 libs/auth/oauth_test.go delete mode 100644 libs/auth/page.tmpl diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index ceceae25c5..7493e16a30 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -4,8 +4,8 @@ import ( "context" "fmt" - "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/spf13/cobra" ) @@ -22,14 +22,14 @@ Azure: https://learn.microsoft.com/azure/databricks/dev-tools/auth GCP: https://docs.gcp.databricks.com/dev-tools/auth/index.html`, } - var perisistentAuth auth.PersistentAuth - cmd.PersistentFlags().StringVar(&perisistentAuth.Host, "host", perisistentAuth.Host, "Databricks Host") - cmd.PersistentFlags().StringVar(&perisistentAuth.AccountID, "account-id", perisistentAuth.AccountID, "Databricks Account ID") + var oauthArgument oauth.BasicOAuthArgument + cmd.PersistentFlags().StringVar(&oauthArgument.Host, "host", oauthArgument.Host, "Databricks Host") + cmd.PersistentFlags().StringVar(&oauthArgument.AccountID, "account-id", oauthArgument.AccountID, "Databricks Account ID") cmd.AddCommand(newEnvCommand()) - cmd.AddCommand(newLoginCommand(&perisistentAuth)) + cmd.AddCommand(newLoginCommand(&oauthArgument)) cmd.AddCommand(newProfilesCommand()) - cmd.AddCommand(newTokenCommand(&perisistentAuth)) + cmd.AddCommand(newTokenCommand(&oauthArgument)) cmd.AddCommand(newDescribeCommand()) return cmd } diff --git a/cmd/auth/login.go b/cmd/auth/login.go index c986765994..6cfee2a765 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -5,26 +5,27 @@ import ( "errors" "fmt" "runtime" + "strings" "time" - "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/databrickscfg" "github.com/databricks/cli/libs/databrickscfg/cfgpickers" "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/config" + "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/spf13/cobra" ) -func promptForProfile(ctx context.Context, defaultValue string) (string, error) { +func promptForProfile(ctx context.Context, oauthArgument oauth.OAuthArgument) (string, error) { if !cmdio.IsInTTY(ctx) { return "", nil } prompt := cmdio.Prompt(ctx) prompt.Label = "Databricks profile name" - prompt.Default = defaultValue + prompt.Default = getProfileName(ctx, oauthArgument) prompt.AllowEdit = true return prompt.Run() } @@ -34,7 +35,7 @@ const ( defaultTimeout = 1 * time.Hour ) -func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command { +func newLoginCommand(oauthArgument oauth.OAuthArgument) *cobra.Command { defaultConfigPath := "~/.databrickscfg" if runtime.GOOS == "windows" { defaultConfigPath = "%USERPROFILE%\\.databrickscfg" @@ -98,14 +99,18 @@ depends on the existing profiles you have set in your configuration file // If the user has not specified a profile name, prompt for one. if profileName == "" { var err error - profileName, err = promptForProfile(ctx, persistentAuth.ProfileName()) + profileName, err = promptForProfile(ctx, oauthArgument) if err != nil { return err } } // Set the host and account-id based on the provided arguments and flags. - err := setHostAndAccountId(ctx, profileName, persistentAuth, args) + oauthArgument, err := setHostAndAccountId(ctx, profile.DefaultProfiler, profileName, oauthArgument, args) + if err != nil { + return err + } + persistentAuth, err := oauth.NewPersistentAuth(ctx) if err != nil { return err } @@ -114,15 +119,15 @@ depends on the existing profiles you have set in your configuration file // We need the config without the profile before it's used to initialise new workspace client below. // Otherwise it will complain about non existing profile because it was not yet saved. cfg := config.Config{ - Host: persistentAuth.Host, - AccountID: persistentAuth.AccountID, + Host: oauthArgument.GetHost(ctx), + AccountID: oauthArgument.GetAccountId(ctx), AuthType: "databricks-cli", } ctx, cancel := context.WithTimeout(ctx, loginTimeout) defer cancel() - err = persistentAuth.Challenge(ctx) + err = persistentAuth.Challenge(ctx, oauthArgument) if err != nil { return err } @@ -173,53 +178,66 @@ depends on the existing profiles you have set in your configuration file // 1. --account-id flag. // 2. account-id from the specified profile, if available. // 3. Prompt the user for the account-id. -func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error { +func setHostAndAccountId(ctx context.Context, profiler profile.Profiler, profileName string, oauthArgument oauth.OAuthArgument, args []string) (oauth.OAuthArgument, error) { + res := oauth.BasicOAuthArgument{} // If both [HOST] and --host are provided, return an error. - if len(args) > 0 && persistentAuth.Host != "" { - return fmt.Errorf("please only provide a host as an argument or a flag, not both") + host := oauthArgument.GetHost(ctx) + if len(args) > 0 && host != "" { + return nil, fmt.Errorf("please only provide a host as an argument or a flag, not both") } - profiler := profile.GetProfiler(ctx) // If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile. profiles, err := profiler.LoadProfiles(ctx, profile.WithName(profileName)) // Tolerate ErrNoConfiguration here, as we will write out a configuration as part of the login flow. if err != nil && !errors.Is(err, profile.ErrNoConfiguration) { - return err + return nil, err } - if persistentAuth.Host == "" { + if host == "" { if len(args) > 0 { // If [HOST] is provided, set the host to the provided positional argument. - persistentAuth.Host = args[0] + res.Host = args[0] } else if len(profiles) > 0 && profiles[0].Host != "" { // If neither [HOST] nor --host are provided, and the profile has a host, use it. - persistentAuth.Host = profiles[0].Host + res.Host = profiles[0].Host } else { // If neither [HOST] nor --host are provided, and the profile does not have a host, // then prompt the user for a host. hostName, err := promptForHost(ctx) if err != nil { - return err + return nil, err } - persistentAuth.Host = hostName + res.Host = hostName } } // If the account-id was not provided as a cmd line flag, try to read it from // the specified profile. - isAccountClient := (&config.Config{Host: persistentAuth.Host}).IsAccountClient() - if isAccountClient && persistentAuth.AccountID == "" { + isAccountClient := (&config.Config{Host: res.Host}).IsAccountClient() + accountID := oauthArgument.GetAccountId(ctx) + if isAccountClient && accountID == "" { if len(profiles) > 0 && profiles[0].AccountID != "" { - persistentAuth.AccountID = profiles[0].AccountID + res.AccountID = profiles[0].AccountID } else { // Prompt user for the account-id if it we could not get it from a // profile. accountId, err := promptForAccountID(ctx) if err != nil { - return err + return nil, err } - persistentAuth.AccountID = accountId + res.AccountID = accountId } } - return nil + return res, nil +} + +func getProfileName(ctx context.Context, oauthArgument oauth.OAuthArgument) string { + host := oauthArgument.GetHost(ctx) + accountId := oauthArgument.GetAccountId(ctx) + if accountId != "" { + return fmt.Sprintf("ACCOUNT-%s", accountId) + } + host = strings.TrimPrefix(host, "https://") + split := strings.Split(host, ".") + return split[0] } diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index d0fa5a16b8..3d41550411 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -4,9 +4,10 @@ import ( "context" "testing" - "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/cli/libs/env" + "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -14,72 +15,72 @@ import ( func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) { ctx := context.Background() ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./imaginary-file/databrickscfg") - err := setHostAndAccountId(ctx, "foo", &auth.PersistentAuth{Host: "test"}, []string{}) + _, err := setHostAndAccountId(ctx, profile.DefaultProfiler, "foo", oauth.BasicOAuthArgument{Host: "test"}, []string{}) assert.NoError(t, err) } func TestSetHost(t *testing.T) { - var persistentAuth auth.PersistentAuth + var persistentAuth oauth.BasicOAuthArgument t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg") ctx, _ := cmdio.SetupTest(context.Background()) // Test error when both flag and argument are provided persistentAuth.Host = "val from --host" - err := setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"}) + _, err := setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &persistentAuth, []string{"val from [HOST]"}) assert.EqualError(t, err, "please only provide a host as an argument or a flag, not both") // Test setting host from flag persistentAuth.Host = "val from --host" - err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{}) + res, err := setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &persistentAuth, []string{}) assert.NoError(t, err) - assert.Equal(t, "val from --host", persistentAuth.Host) + assert.Equal(t, "val from --host", res.GetHost(ctx)) // Test setting host from argument persistentAuth.Host = "" - err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"}) + res, err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &persistentAuth, []string{"val from [HOST]"}) assert.NoError(t, err) - assert.Equal(t, "val from [HOST]", persistentAuth.Host) + assert.Equal(t, "val from [HOST]", res.GetHost(ctx)) // Test setting host from profile persistentAuth.Host = "" - err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{}) + res, err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &persistentAuth, []string{}) assert.NoError(t, err) - assert.Equal(t, "https://www.host1.com", persistentAuth.Host) + assert.Equal(t, "https://www.host1.com", res.GetHost(ctx)) // Test setting host from profile persistentAuth.Host = "" - err = setHostAndAccountId(ctx, "profile-2", &persistentAuth, []string{}) + res, err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-2", &persistentAuth, []string{}) assert.NoError(t, err) - assert.Equal(t, "https://www.host2.com", persistentAuth.Host) + assert.Equal(t, "https://www.host2.com", res.GetHost(ctx)) // Test host is not set. Should prompt. persistentAuth.Host = "" - err = setHostAndAccountId(ctx, "", &persistentAuth, []string{}) + _, err = setHostAndAccountId(ctx, profile.DefaultProfiler, "", &persistentAuth, []string{}) assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify a host using --host") } func TestSetAccountId(t *testing.T) { - var persistentAuth auth.PersistentAuth + var persistentAuth oauth.BasicOAuthArgument t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg") ctx, _ := cmdio.SetupTest(context.Background()) // Test setting account-id from flag persistentAuth.AccountID = "val from --account-id" - err := setHostAndAccountId(ctx, "account-profile", &persistentAuth, []string{}) + res, err := setHostAndAccountId(ctx, profile.DefaultProfiler, "account-profile", &persistentAuth, []string{}) assert.NoError(t, err) - assert.Equal(t, "https://accounts.cloud.databricks.com", persistentAuth.Host) - assert.Equal(t, "val from --account-id", persistentAuth.AccountID) + assert.Equal(t, "https://accounts.cloud.databricks.com", res.GetHost(ctx)) + assert.Equal(t, "val from --account-id", res.GetAccountId(ctx)) // Test setting account_id from profile persistentAuth.AccountID = "" - err = setHostAndAccountId(ctx, "account-profile", &persistentAuth, []string{}) + res, err = setHostAndAccountId(ctx, profile.DefaultProfiler, "account-profile", &persistentAuth, []string{}) require.NoError(t, err) - assert.Equal(t, "https://accounts.cloud.databricks.com", persistentAuth.Host) - assert.Equal(t, "id-from-profile", persistentAuth.AccountID) + assert.Equal(t, "https://accounts.cloud.databricks.com", res.GetHost(ctx)) + assert.Equal(t, "id-from-profile", res.GetAccountId(ctx)) // Neither flag nor profile account-id is set, should prompt persistentAuth.AccountID = "" persistentAuth.Host = "https://accounts.cloud.databricks.com" - err = setHostAndAccountId(ctx, "", &persistentAuth, []string{}) + _, err = setHostAndAccountId(ctx, profile.DefaultProfiler, "", &persistentAuth, []string{}) assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify an account ID using --account-id") } diff --git a/cmd/auth/token.go b/cmd/auth/token.go index fbf8b68f6e..12a6237d72 100644 --- a/cmd/auth/token.go +++ b/cmd/auth/token.go @@ -9,9 +9,11 @@ import ( "strings" "time" - "github.com/databricks/cli/libs/auth" + "github.com/databricks/cli/libs/databrickscfg/profile" + "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/databricks/databricks-sdk-go/httpclient" "github.com/spf13/cobra" + "golang.org/x/oauth2" ) type tokenErrorResponse struct { @@ -19,7 +21,7 @@ type tokenErrorResponse struct { ErrorDescription string `json:"error_description"` } -func buildLoginCommand(profile string, persistentAuth *auth.PersistentAuth) string { +func buildLoginCommand(ctx context.Context, profile string, persistentAuth oauth.OAuthArgument) string { executable := os.Args[0] cmd := []string{ executable, @@ -29,20 +31,20 @@ func buildLoginCommand(profile string, persistentAuth *auth.PersistentAuth) stri if profile != "" { cmd = append(cmd, "--profile", profile) } else { - cmd = append(cmd, "--host", persistentAuth.Host) - if persistentAuth.AccountID != "" { - cmd = append(cmd, "--account-id", persistentAuth.AccountID) + cmd = append(cmd, "--host", persistentAuth.GetHost(ctx)) + if accountId := persistentAuth.GetAccountId(ctx); accountId != "" { + cmd = append(cmd, "--account-id", accountId) } } return strings.Join(cmd, " ") } -func helpfulError(profile string, persistentAuth *auth.PersistentAuth) string { - loginMsg := buildLoginCommand(profile, persistentAuth) +func helpfulError(ctx context.Context, profile string, persistentAuth oauth.OAuthArgument) string { + loginMsg := buildLoginCommand(ctx, profile, persistentAuth) return fmt.Sprintf("Try logging in again with `%s` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", loginMsg) } -func newTokenCommand(persistentAuth *auth.PersistentAuth) *cobra.Command { +func newTokenCommand(oauthArgument oauth.OAuthArgument) *cobra.Command { cmd := &cobra.Command{ Use: "token [HOST]", Short: "Get authentication token", @@ -54,42 +56,23 @@ func newTokenCommand(persistentAuth *auth.PersistentAuth) *cobra.Command { cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - - var profileName string + profileName := "" profileFlag := cmd.Flag("profile") if profileFlag != nil { profileName = profileFlag.Value.String() - // If a profile is provided we read the host from the .databrickscfg file - if profileName != "" && len(args) > 0 { - return errors.New("providing both a profile and host is not supported") - } } - err := setHostAndAccountId(ctx, profileName, persistentAuth, args) + t, err := loadToken(ctx, loadTokenArgs{ + oauthArgument: oauthArgument, + profileName: profileName, + args: args, + tokenTimeout: tokenTimeout, + profiler: profile.DefaultProfiler, + persistentAuthOpts: nil, + }) if err != nil { return err } - defer persistentAuth.Close() - - ctx, cancel := context.WithTimeout(ctx, tokenTimeout) - defer cancel() - t, err := persistentAuth.Load(ctx) - var httpErr *httpclient.HttpError - if errors.As(err, &httpErr) { - helpMsg := helpfulError(profileName, persistentAuth) - t := &tokenErrorResponse{} - err = json.Unmarshal([]byte(httpErr.Message), t) - if err != nil { - return fmt.Errorf("unexpected parsing token response: %w. %s", err, helpMsg) - } - if t.ErrorDescription == "Refresh token is invalid" { - return fmt.Errorf("a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run `%s`", buildLoginCommand(profileName, persistentAuth)) - } else { - return fmt.Errorf("unexpected error refreshing token: %s. %s", t.ErrorDescription, helpMsg) - } - } else if err != nil { - return fmt.Errorf("unexpected error refreshing token: %w. %s", err, helpfulError(profileName, persistentAuth)) - } raw, err := json.MarshalIndent(t, "", " ") if err != nil { return err @@ -100,3 +83,50 @@ func newTokenCommand(persistentAuth *auth.PersistentAuth) *cobra.Command { return cmd } + +type loadTokenArgs struct { + oauthArgument oauth.OAuthArgument + profileName string + args []string + tokenTimeout time.Duration + profiler profile.Profiler + persistentAuthOpts []oauth.PersistentAuthOption +} + +func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { + // If a profile is provided we read the host from the .databrickscfg file + if args.profileName != "" && len(args.args) > 0 { + return nil, errors.New("providing both a profile and host is not supported") + } + + oauthArgument, err := setHostAndAccountId(ctx, args.profiler, args.profileName, args.oauthArgument, args.args) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(ctx, args.tokenTimeout) + defer cancel() + persistentAuth, err := oauth.NewPersistentAuth(ctx) + if err != nil { + helpMsg := helpfulError(ctx, args.profileName, oauthArgument) + return nil, fmt.Errorf("unexpected error creating persistent auth: %w. %s", err, helpMsg) + } + t, err := persistentAuth.Load(ctx, oauthArgument) + var httpErr *httpclient.HttpError + if errors.As(err, &httpErr) { + helpMsg := helpfulError(ctx, args.profileName, oauthArgument) + t := &tokenErrorResponse{} + err = json.Unmarshal([]byte(httpErr.Message), t) + if err != nil { + return nil, fmt.Errorf("unexpected parsing token response: %w. %s", err, helpMsg) + } + if t.ErrorDescription == "Refresh token is invalid" { + return nil, fmt.Errorf("a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run `%s`", buildLoginCommand(ctx, args.profileName, oauthArgument)) + } else { + return nil, fmt.Errorf("unexpected error refreshing token: %s. %s", t.ErrorDescription, helpMsg) + } + } else if err != nil { + return nil, fmt.Errorf("unexpected error refreshing token: %w. %s", err, helpfulError(ctx, args.profileName, oauthArgument)) + } + return t, nil +} diff --git a/cmd/auth/token_test.go b/cmd/auth/token_test.go index df98cc151e..d62c18cf1f 100644 --- a/cmd/auth/token_test.go +++ b/cmd/auth/token_test.go @@ -1,19 +1,15 @@ -package auth_test +package auth import ( - "bytes" "context" - "encoding/json" "testing" "time" - "github.com/databricks/cli/cmd" - "github.com/databricks/cli/libs/auth" - "github.com/databricks/cli/libs/auth/cache" "github.com/databricks/cli/libs/databrickscfg/profile" + "github.com/databricks/databricks-sdk-go/credentials/cache" + "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" - "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "golang.org/x/oauth2" ) @@ -52,15 +48,7 @@ var refreshSuccessTokenResponse = fixtures.HTTPFixture{ }, } -func validateToken(t *testing.T, resp string) { - res := map[string]string{} - err := json.Unmarshal([]byte(resp), &res) - assert.NoError(t, err) - assert.Equal(t, "new-access-token", res["access_token"]) - assert.Equal(t, "Bearer", res["token_type"]) -} - -func getContextForTest(f fixtures.HTTPFixture) context.Context { +func TestToken_loadToken(t *testing.T) { profiler := profile.InMemoryProfiler{ Profiles: profile.Profiles{ { @@ -86,83 +74,144 @@ func getContextForTest(f fixtures.HTTPFixture) context.Context { }, }, } - client := httpclient.NewApiClient(httpclient.ClientConfig{ - Transport: fixtures.SliceTransport{f}, - }) - ctx := profile.WithProfiler(context.Background(), profiler) - ctx = cache.WithTokenCache(ctx, tokenCache) - ctx = auth.WithApiClientForOAuth(ctx, client) - return ctx -} - -func getCobraCmdForTest(f fixtures.HTTPFixture) (*cobra.Command, *bytes.Buffer) { - ctx := getContextForTest(f) - c := cmd.New(ctx) - output := &bytes.Buffer{} - c.SetOut(output) - return c, output -} - -func TestTokenCmdWithProfilePrintsHelpfulLoginMessageOnRefreshFailure(t *testing.T) { - cmd, output := getCobraCmdForTest(refreshFailureTokenResponse) - cmd.SetArgs([]string{"auth", "token", "--profile", "expired"}) - err := cmd.Execute() - - out := output.String() - assert.Empty(t, out) - assert.ErrorContains(t, err, "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run ") - assert.ErrorContains(t, err, "auth login --profile expired") -} - -func TestTokenCmdWithHostPrintsHelpfulLoginMessageOnRefreshFailure(t *testing.T) { - cmd, output := getCobraCmdForTest(refreshFailureTokenResponse) - cmd.SetArgs([]string{"auth", "token", "--host", "https://accounts.cloud.databricks.com", "--account-id", "expired"}) - err := cmd.Execute() - - out := output.String() - assert.Empty(t, out) - assert.ErrorContains(t, err, "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run ") - assert.ErrorContains(t, err, "auth login --host https://accounts.cloud.databricks.com --account-id expired") -} - -func TestTokenCmdInvalidResponse(t *testing.T) { - cmd, output := getCobraCmdForTest(refreshFailureInvalidResponse) - cmd.SetArgs([]string{"auth", "token", "--profile", "active"}) - err := cmd.Execute() - - out := output.String() - assert.Empty(t, out) - assert.ErrorContains(t, err, "unexpected parsing token response: invalid character 'N' looking for beginning of value. Try logging in again with ") - assert.ErrorContains(t, err, "auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new") -} - -func TestTokenCmdOtherErrorResponse(t *testing.T) { - cmd, output := getCobraCmdForTest(refreshFailureOtherError) - cmd.SetArgs([]string{"auth", "token", "--profile", "active"}) - err := cmd.Execute() - - out := output.String() - assert.Empty(t, out) - assert.ErrorContains(t, err, "unexpected error refreshing token: Databricks is down. Try logging in again with ") - assert.ErrorContains(t, err, "auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new") -} - -func TestTokenCmdWithProfileSuccess(t *testing.T) { - cmd, output := getCobraCmdForTest(refreshSuccessTokenResponse) - cmd.SetArgs([]string{"auth", "token", "--profile", "active"}) - err := cmd.Execute() - - out := output.String() - validateToken(t, out) - assert.NoError(t, err) -} - -func TestTokenCmdWithHostSuccess(t *testing.T) { - cmd, output := getCobraCmdForTest(refreshSuccessTokenResponse) - cmd.SetArgs([]string{"auth", "token", "--host", "https://accounts.cloud.databricks.com", "--account-id", "expired"}) - err := cmd.Execute() + makeApiClient := func(f fixtures.HTTPFixture) *httpclient.ApiClient { + return httpclient.NewApiClient(httpclient.ClientConfig{ + Transport: fixtures.SliceTransport{f}, + }) + } + wantErrors := func(substrings ...string) func(error) { + return func(err error) { + for _, s := range substrings { + assert.ErrorContains(t, err, s) + } + } + } + validateToken := func(resp *oauth2.Token) { + assert.Equal(t, "new-access-token", resp.AccessToken) + assert.Equal(t, "Bearer", resp.TokenType) + } - out := output.String() - validateToken(t, out) - assert.NoError(t, err) + cases := []struct { + name string + args loadTokenArgs + want func(*oauth2.Token) + wantErr func(error) + }{ + { + name: "prints helpful login message on refresh failure when profile is specified", + args: loadTokenArgs{ + oauthArgument: oauth.BasicOAuthArgument{}, + profileName: "expired", + args: []string{}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []oauth.PersistentAuthOption{ + oauth.WithTokenCache(tokenCache), + oauth.WithApiClient(makeApiClient(refreshFailureTokenResponse)), + }, + }, + wantErr: wantErrors( + "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run ", + "auth login --host https://accounts.cloud.databricks.com --account-id expired", + ), + }, + { + name: "prints helpful login message on refresh failure when host is specified", + args: loadTokenArgs{ + oauthArgument: oauth.BasicOAuthArgument{ + Host: "https://accounts.cloud.databricks.com", + AccountID: "expired", + }, + profileName: "", + args: []string{}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []oauth.PersistentAuthOption{ + oauth.WithTokenCache(tokenCache), + oauth.WithApiClient(makeApiClient(refreshFailureTokenResponse)), + }, + }, + wantErr: wantErrors( + "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run ", + "auth login --host https://accounts.cloud.databricks.com --account-id expired", + ), + }, + { + name: "prints helpful login message on invalid response", + args: loadTokenArgs{ + oauthArgument: oauth.BasicOAuthArgument{}, + profileName: "active", + args: []string{}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []oauth.PersistentAuthOption{ + oauth.WithTokenCache(tokenCache), + oauth.WithApiClient(makeApiClient(refreshFailureInvalidResponse)), + }, + }, + wantErr: wantErrors( + "unexpected parsing token response: invalid character 'N' looking for beginning of value. Try logging in again with ", + "auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", + ), + }, + { + name: "prints helpful login message on other error response", + args: loadTokenArgs{ + oauthArgument: oauth.BasicOAuthArgument{}, + profileName: "active", + args: []string{}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []oauth.PersistentAuthOption{ + oauth.WithTokenCache(tokenCache), + oauth.WithApiClient(makeApiClient(refreshFailureOtherError)), + }, + }, + wantErr: wantErrors( + "unexpected error refreshing token: Databricks is down. Try logging in again with ", + "auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", + ), + }, + { + name: "succeeds with profile", + args: loadTokenArgs{ + oauthArgument: oauth.BasicOAuthArgument{}, + profileName: "active", + args: []string{}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []oauth.PersistentAuthOption{ + oauth.WithTokenCache(tokenCache), + oauth.WithApiClient(makeApiClient(refreshSuccessTokenResponse)), + }, + }, + want: validateToken, + }, + { + name: "succeeds with host", + args: loadTokenArgs{ + oauthArgument: oauth.BasicOAuthArgument{Host: "https://accounts.cloud.databricks.com", AccountID: "active"}, + profileName: "", + args: []string{}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []oauth.PersistentAuthOption{ + oauth.WithTokenCache(tokenCache), + oauth.WithApiClient(makeApiClient(refreshSuccessTokenResponse)), + }, + }, + want: validateToken, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got, err := loadToken(context.Background(), c.args) + if c.wantErr != nil { + c.wantErr(err) + } else { + assert.NoError(t, err) + assert.Equal(t, c.want, got) + } + }) + } } diff --git a/go.mod b/go.mod index 2dda0cd609..baebd80a1a 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ toolchain go1.23.4 require ( github.com/Masterminds/semver/v3 v3.3.1 // MIT github.com/briandowns/spinner v1.23.1 // Apache 2.0 - github.com/databricks/databricks-sdk-go v0.54.0 // Apache 2.0 + github.com/databricks/databricks-sdk-go v0.54.1-0.20250103133740-0688c3b8afac // Apache 2.0 github.com/fatih/color v1.18.0 // MIT github.com/google/uuid v1.6.0 // BSD-3-Clause github.com/hashicorp/go-version v1.7.0 // MPL 2.0 @@ -18,7 +18,7 @@ require ( github.com/manifoldco/promptui v0.9.0 // BSD-3-Clause github.com/mattn/go-isatty v0.0.20 // MIT github.com/nwidger/jsoncolor v0.3.2 // MIT - github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // BSD-2-Clause + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // BSD-2-Clause github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 // MIT github.com/spf13/cobra v1.8.1 // Apache 2.0 github.com/spf13/pflag v1.0.5 // BSD-3-Clause @@ -39,6 +39,7 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect github.com/ProtonMail/go-crypto v1.1.0-alpha.2 // indirect + github.com/alexflint/go-filemutex v1.3.0 // indirect github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect github.com/cloudflare/circl v1.3.7 // indirect @@ -68,7 +69,7 @@ require ( go.opentelemetry.io/otel/metric v1.24.0 // indirect go.opentelemetry.io/otel/trace v1.24.0 // indirect golang.org/x/crypto v0.31.0 // indirect - golang.org/x/net v0.26.0 // indirect + golang.org/x/net v0.33.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/time v0.5.0 // indirect google.golang.org/api v0.182.0 // indirect diff --git a/go.sum b/go.sum index 1e806ea036..186fd1caa4 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migc github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/ProtonMail/go-crypto v1.1.0-alpha.2 h1:bkyFVUP+ROOARdgCiJzNQo2V2kiB97LyUpzH9P6Hrlg= github.com/ProtonMail/go-crypto v1.1.0-alpha.2/go.mod h1:rA3QumHc/FZ8pAHreoekgiAbzpNsfQAosU5td4SnOrE= +github.com/alexflint/go-filemutex v1.3.0 h1:LgE+nTUWnQCyRKbpoceKZsPQbs84LivvgwUymZXdOcM= +github.com/alexflint/go-filemutex v1.3.0/go.mod h1:U0+VA/i30mGBlLCrFPGtTe9y6wGQfNAWPBTekHQ+c8A= github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= github.com/briandowns/spinner v1.23.1 h1:t5fDPmScwUjozhDj4FA46p5acZWIPXYE30qW2Ptu650= @@ -32,8 +34,8 @@ github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGX github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cyphar/filepath-securejoin v0.2.4 h1:Ugdm7cg7i6ZK6x3xDF1oEu1nfkyfH53EtKeQYTC3kyg= github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4= -github.com/databricks/databricks-sdk-go v0.54.0 h1:L8gsA3NXs+uYU3QtW/OUgjxMQxOH24k0MT9JhB3zLlM= -github.com/databricks/databricks-sdk-go v0.54.0/go.mod h1:ds+zbv5mlQG7nFEU5ojLtgN/u0/9YzZmKQES/CfedzU= +github.com/databricks/databricks-sdk-go v0.54.1-0.20250103133740-0688c3b8afac h1:HzS3/zoFUu6bWPEBaeeidfXIjg6Vu8qDphwr8Rg5EVU= +github.com/databricks/databricks-sdk-go v0.54.1-0.20250103133740-0688c3b8afac/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -131,8 +133,8 @@ github.com/nwidger/jsoncolor v0.3.2 h1:rVJJlwAWDJShnbTYOQ5RM7yTA20INyKXlJ/fg4JMh github.com/nwidger/jsoncolor v0.3.2/go.mod h1:Cs34umxLbJvgBMnVNVqhji9BhoT/N/KinHqZptQ7cf4= github.com/pjbgf/sha1cd v0.3.0 h1:4D5XXmUUBUl/xQ6IjCkEAbqXskkq/4O7LmGn0AqMDs4= github.com/pjbgf/sha1cd v0.3.0/go.mod h1:nZ1rrWOcGJ5uZgEEVL1VUM9iRQiZvWdbZjkKyFzPPsI= -github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= -github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -152,6 +154,7 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= @@ -204,8 +207,8 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= @@ -221,9 +224,10 @@ golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= @@ -275,6 +279,7 @@ gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/libs/auth/cache/cache.go b/libs/auth/cache/cache.go deleted file mode 100644 index 097353e74c..0000000000 --- a/libs/auth/cache/cache.go +++ /dev/null @@ -1,26 +0,0 @@ -package cache - -import ( - "context" - - "golang.org/x/oauth2" -) - -type TokenCache interface { - Store(key string, t *oauth2.Token) error - Lookup(key string) (*oauth2.Token, error) -} - -var tokenCache int - -func WithTokenCache(ctx context.Context, c TokenCache) context.Context { - return context.WithValue(ctx, &tokenCache, c) -} - -func GetTokenCache(ctx context.Context) TokenCache { - c, ok := ctx.Value(&tokenCache).(TokenCache) - if !ok { - return &FileTokenCache{} - } - return c -} diff --git a/libs/auth/cache/file.go b/libs/auth/cache/file.go deleted file mode 100644 index 38dfea9f2c..0000000000 --- a/libs/auth/cache/file.go +++ /dev/null @@ -1,108 +0,0 @@ -package cache - -import ( - "encoding/json" - "errors" - "fmt" - "io/fs" - "os" - "path/filepath" - - "golang.org/x/oauth2" -) - -const ( - // where the token cache is stored - tokenCacheFile = ".databricks/token-cache.json" - - // only the owner of the file has full execute, read, and write access - ownerExecReadWrite = 0o700 - - // only the owner of the file has full read and write access - ownerReadWrite = 0o600 - - // format versioning leaves some room for format improvement - tokenCacheVersion = 1 -) - -var ErrNotConfigured = errors.New("databricks OAuth is not configured for this host") - -// this implementation requires the calling code to do a machine-wide lock, -// otherwise the file might get corrupt. -type FileTokenCache struct { - Version int `json:"version"` - Tokens map[string]*oauth2.Token `json:"tokens"` - - fileLocation string -} - -func (c *FileTokenCache) Store(key string, t *oauth2.Token) error { - err := c.load() - if errors.Is(err, fs.ErrNotExist) { - dir := filepath.Dir(c.fileLocation) - err = os.MkdirAll(dir, ownerExecReadWrite) - if err != nil { - return fmt.Errorf("mkdir: %w", err) - } - } else if err != nil { - return fmt.Errorf("load: %w", err) - } - c.Version = tokenCacheVersion - if c.Tokens == nil { - c.Tokens = map[string]*oauth2.Token{} - } - c.Tokens[key] = t - raw, err := json.MarshalIndent(c, "", " ") - if err != nil { - return fmt.Errorf("marshal: %w", err) - } - return os.WriteFile(c.fileLocation, raw, ownerReadWrite) -} - -func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) { - err := c.load() - if errors.Is(err, fs.ErrNotExist) { - return nil, ErrNotConfigured - } else if err != nil { - return nil, fmt.Errorf("load: %w", err) - } - t, ok := c.Tokens[key] - if !ok { - return nil, ErrNotConfigured - } - return t, nil -} - -func (c *FileTokenCache) location() (string, error) { - home, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("home: %w", err) - } - return filepath.Join(home, tokenCacheFile), nil -} - -func (c *FileTokenCache) load() error { - loc, err := c.location() - if err != nil { - return err - } - c.fileLocation = loc - raw, err := os.ReadFile(loc) - if err != nil { - return fmt.Errorf("read: %w", err) - } - err = json.Unmarshal(raw, c) - if err != nil { - return fmt.Errorf("parse: %w", err) - } - if c.Version != tokenCacheVersion { - // in the later iterations we could do state upgraders, - // so that we transform token cache from v1 to v2 without - // losing the tokens and asking the user to re-authenticate. - return fmt.Errorf("needs version %d, got version %d", - tokenCacheVersion, c.Version) - } - return nil -} - -var _ TokenCache = (*FileTokenCache)(nil) diff --git a/libs/auth/cache/file_test.go b/libs/auth/cache/file_test.go deleted file mode 100644 index 3e4aae36f4..0000000000 --- a/libs/auth/cache/file_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package cache - -import ( - "os" - "path/filepath" - "runtime" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/oauth2" -) - -var homeEnvVar = "HOME" - -func init() { - if runtime.GOOS == "windows" { - homeEnvVar = "USERPROFILE" - } -} - -func setup(t *testing.T) string { - tempHomeDir := t.TempDir() - t.Setenv(homeEnvVar, tempHomeDir) - return tempHomeDir -} - -func TestStoreAndLookup(t *testing.T) { - setup(t) - c := &FileTokenCache{} - err := c.Store("x", &oauth2.Token{ - AccessToken: "abc", - }) - require.NoError(t, err) - - err = c.Store("y", &oauth2.Token{ - AccessToken: "bcd", - }) - require.NoError(t, err) - - l := &FileTokenCache{} - tok, err := l.Lookup("x") - require.NoError(t, err) - assert.Equal(t, "abc", tok.AccessToken) - assert.Equal(t, 2, len(l.Tokens)) - - _, err = l.Lookup("z") - assert.Equal(t, ErrNotConfigured, err) -} - -func TestNoCacheFileReturnsErrNotConfigured(t *testing.T) { - setup(t) - l := &FileTokenCache{} - _, err := l.Lookup("x") - assert.Equal(t, ErrNotConfigured, err) -} - -func TestLoadCorruptFile(t *testing.T) { - home := setup(t) - f := filepath.Join(home, tokenCacheFile) - err := os.MkdirAll(filepath.Dir(f), ownerExecReadWrite) - require.NoError(t, err) - err = os.WriteFile(f, []byte("abc"), ownerExecReadWrite) - require.NoError(t, err) - - l := &FileTokenCache{} - _, err = l.Lookup("x") - assert.EqualError(t, err, "load: parse: invalid character 'a' looking for beginning of value") -} - -func TestLoadWrongVersion(t *testing.T) { - home := setup(t) - f := filepath.Join(home, tokenCacheFile) - err := os.MkdirAll(filepath.Dir(f), ownerExecReadWrite) - require.NoError(t, err) - err = os.WriteFile(f, []byte(`{"version": 823, "things": []}`), ownerExecReadWrite) - require.NoError(t, err) - - l := &FileTokenCache{} - _, err = l.Lookup("x") - assert.EqualError(t, err, "load: needs version 1, got version 823") -} - -func TestDevNull(t *testing.T) { - t.Setenv(homeEnvVar, "/dev/null") - l := &FileTokenCache{} - _, err := l.Lookup("x") - // macOS/Linux: load: read: open /dev/null/.databricks/token-cache.json: - // windows: databricks OAuth is not configured for this host - assert.Error(t, err) -} - -func TestStoreOnDev(t *testing.T) { - if runtime.GOOS == "windows" { - t.SkipNow() - } - t.Setenv(homeEnvVar, "/dev") - c := &FileTokenCache{} - err := c.Store("x", &oauth2.Token{ - AccessToken: "abc", - }) - // Linux: permission denied - // macOS: read-only file system - assert.Error(t, err) -} diff --git a/libs/auth/cache/in_memory.go b/libs/auth/cache/in_memory.go deleted file mode 100644 index 469d45575a..0000000000 --- a/libs/auth/cache/in_memory.go +++ /dev/null @@ -1,26 +0,0 @@ -package cache - -import ( - "golang.org/x/oauth2" -) - -type InMemoryTokenCache struct { - Tokens map[string]*oauth2.Token -} - -// Lookup implements TokenCache. -func (i *InMemoryTokenCache) Lookup(key string) (*oauth2.Token, error) { - token, ok := i.Tokens[key] - if !ok { - return nil, ErrNotConfigured - } - return token, nil -} - -// Store implements TokenCache. -func (i *InMemoryTokenCache) Store(key string, t *oauth2.Token) error { - i.Tokens[key] = t - return nil -} - -var _ TokenCache = (*InMemoryTokenCache)(nil) diff --git a/libs/auth/cache/in_memory_test.go b/libs/auth/cache/in_memory_test.go deleted file mode 100644 index d8394d3b26..0000000000 --- a/libs/auth/cache/in_memory_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package cache - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "golang.org/x/oauth2" -) - -func TestInMemoryCacheHit(t *testing.T) { - token := &oauth2.Token{ - AccessToken: "abc", - } - c := &InMemoryTokenCache{ - Tokens: map[string]*oauth2.Token{ - "key": token, - }, - } - res, err := c.Lookup("key") - assert.Equal(t, res, token) - assert.NoError(t, err) -} - -func TestInMemoryCacheMiss(t *testing.T) { - c := &InMemoryTokenCache{ - Tokens: map[string]*oauth2.Token{}, - } - _, err := c.Lookup("key") - assert.ErrorIs(t, err, ErrNotConfigured) -} - -func TestInMemoryCacheStore(t *testing.T) { - token := &oauth2.Token{ - AccessToken: "abc", - } - c := &InMemoryTokenCache{ - Tokens: map[string]*oauth2.Token{}, - } - err := c.Store("key", token) - assert.NoError(t, err) - res, err := c.Lookup("key") - assert.Equal(t, res, token) - assert.NoError(t, err) -} diff --git a/libs/auth/callback.go b/libs/auth/callback.go deleted file mode 100644 index 3893a5041a..0000000000 --- a/libs/auth/callback.go +++ /dev/null @@ -1,104 +0,0 @@ -package auth - -import ( - "context" - _ "embed" - "fmt" - "html/template" - "net" - "net/http" - "strings" - - "golang.org/x/text/cases" - "golang.org/x/text/language" -) - -//go:embed page.tmpl -var pageTmpl string - -type oauthResult struct { - Error string - ErrorDescription string - Host string - State string - Code string -} - -type callbackServer struct { - ln net.Listener - srv http.Server - ctx context.Context - a *PersistentAuth - renderErrCh chan error - feedbackCh chan oauthResult - tmpl *template.Template -} - -func newCallback(ctx context.Context, a *PersistentAuth) (*callbackServer, error) { - tmpl, err := template.New("page").Funcs(template.FuncMap{ - "title": func(in string) string { - title := cases.Title(language.English) - return title.String(strings.ReplaceAll(in, "_", " ")) - }, - }).Parse(pageTmpl) - if err != nil { - return nil, err - } - cb := &callbackServer{ - feedbackCh: make(chan oauthResult), - renderErrCh: make(chan error), - tmpl: tmpl, - ctx: ctx, - ln: a.ln, - a: a, - } - cb.srv.Handler = cb - go func() { - _ = cb.srv.Serve(cb.ln) - }() - return cb, nil -} - -func (cb *callbackServer) Close() error { - return cb.srv.Close() -} - -// ServeHTTP renders page.html template -func (cb *callbackServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - res := oauthResult{ - Error: r.FormValue("error"), - ErrorDescription: r.FormValue("error_description"), - Code: r.FormValue("code"), - State: r.FormValue("state"), - Host: cb.a.Host, - } - if res.Error != "" { - w.WriteHeader(http.StatusBadRequest) - } else { - w.WriteHeader(http.StatusOK) - } - err := cb.tmpl.Execute(w, res) - if err != nil { - cb.renderErrCh <- err - } - cb.feedbackCh <- res -} - -// Handler opens up a browser waits for redirect to come back from the identity provider -func (cb *callbackServer) Handler(authCodeURL string) (string, string, error) { - err := cb.a.browser(authCodeURL) - if err != nil { - fmt.Printf("Please open %s in the browser to continue authentication", authCodeURL) - } - select { - case <-cb.ctx.Done(): - return "", "", cb.ctx.Err() - case renderErr := <-cb.renderErrCh: - return "", "", renderErr - case res := <-cb.feedbackCh: - if res.Error != "" { - return "", "", fmt.Errorf("%s: %s", res.Error, res.ErrorDescription) - } - return res.Code, res.State, nil - } -} diff --git a/libs/auth/oauth.go b/libs/auth/oauth.go deleted file mode 100644 index 026c454682..0000000000 --- a/libs/auth/oauth.go +++ /dev/null @@ -1,289 +0,0 @@ -package auth - -import ( - "context" - "crypto/rand" - "crypto/sha256" - _ "embed" - "encoding/base64" - "errors" - "fmt" - "net" - "net/url" - "strings" - "time" - - "github.com/databricks/cli/libs/auth/cache" - "github.com/databricks/databricks-sdk-go/httpclient" - "github.com/databricks/databricks-sdk-go/retries" - "github.com/pkg/browser" - "golang.org/x/oauth2" - "golang.org/x/oauth2/authhandler" -) - -var apiClientForOauth int - -func WithApiClientForOAuth(ctx context.Context, c *httpclient.ApiClient) context.Context { - return context.WithValue(ctx, &apiClientForOauth, c) -} - -func GetApiClientForOAuth(ctx context.Context) *httpclient.ApiClient { - c, ok := ctx.Value(&apiClientForOauth).(*httpclient.ApiClient) - if !ok { - return httpclient.NewApiClient(httpclient.ClientConfig{}) - } - return c -} - -const ( - // these values are predefined by Databricks as a public client - // and is specific to this application only. Using these values - // for other applications is not allowed. - appClientID = "databricks-cli" - appRedirectAddr = "localhost:8020" - - // maximum amount of time to acquire listener on appRedirectAddr - listenerTimeout = 45 * time.Second -) - -var ( // Databricks SDK API: `databricks OAuth is not` will be checked for presence - ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host") - ErrNotConfigured = errors.New("databricks OAuth is not configured for this host") - ErrFetchCredentials = errors.New("cannot fetch credentials") -) - -type PersistentAuth struct { - Host string - AccountID string - - http *httpclient.ApiClient - cache cache.TokenCache - ln net.Listener - browser func(string) error -} - -func (a *PersistentAuth) SetApiClient(h *httpclient.ApiClient) { - a.http = h -} - -func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) { - err := a.init(ctx) - if err != nil { - return nil, fmt.Errorf("init: %w", err) - } - // lookup token identified by host (and possibly the account id) - key := a.key() - t, err := a.cache.Lookup(key) - if err != nil { - return nil, fmt.Errorf("cache: %w", err) - } - // early return for valid tokens - if t.Valid() { - // do not print refresh token to end-user - t.RefreshToken = "" - return t, nil - } - // OAuth2 config is invoked only for expired tokens to speed up - // the happy path in the token retrieval - cfg, err := a.oauth2Config(ctx) - if err != nil { - return nil, err - } - // make OAuth2 library use our client - ctx = a.http.InContextForOAuth2(ctx) - // eagerly refresh token - refreshed, err := cfg.TokenSource(ctx, t).Token() - if err != nil { - return nil, fmt.Errorf("token refresh: %w", err) - } - err = a.cache.Store(key, refreshed) - if err != nil { - return nil, fmt.Errorf("cache refresh: %w", err) - } - // do not print refresh token to end-user - refreshed.RefreshToken = "" - return refreshed, nil -} - -func (a *PersistentAuth) ProfileName() string { - if a.AccountID != "" { - return fmt.Sprintf("ACCOUNT-%s", a.AccountID) - } - host := strings.TrimPrefix(a.Host, "https://") - split := strings.Split(host, ".") - return split[0] -} - -func (a *PersistentAuth) Challenge(ctx context.Context) error { - err := a.init(ctx) - if err != nil { - return fmt.Errorf("init: %w", err) - } - cfg, err := a.oauth2Config(ctx) - if err != nil { - return err - } - cb, err := newCallback(ctx, a) - if err != nil { - return fmt.Errorf("callback server: %w", err) - } - defer cb.Close() - state, pkce := a.stateAndPKCE() - // make OAuth2 library use our client - ctx = a.http.InContextForOAuth2(ctx) - ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce) - t, err := ts.Token() - if err != nil { - return fmt.Errorf("authorize: %w", err) - } - // cache token identified by host (and possibly the account id) - err = a.cache.Store(a.key(), t) - if err != nil { - return fmt.Errorf("store: %w", err) - } - return nil -} - -// This function cleans up the host URL by only retaining the scheme and the host. -// This function thus removes any path, query arguments, or fragments from the URL. -func (a *PersistentAuth) cleanHost() { - parsedHost, err := url.Parse(a.Host) - if err != nil { - return - } - // when either host or scheme is empty, we don't want to clean it. This is because - // the Go url library parses a raw "abc" string as the path of a URL and cleaning - // it will return thus return an empty string. - if parsedHost.Host == "" || parsedHost.Scheme == "" { - return - } - host := url.URL{ - Scheme: parsedHost.Scheme, - Host: parsedHost.Host, - } - a.Host = host.String() -} - -func (a *PersistentAuth) init(ctx context.Context) error { - if a.Host == "" && a.AccountID == "" { - return ErrFetchCredentials - } - if a.http == nil { - a.http = GetApiClientForOAuth(ctx) - } - if a.cache == nil { - a.cache = cache.GetTokenCache(ctx) - } - if a.browser == nil { - a.browser = browser.OpenURL - } - - a.cleanHost() - - // try acquire listener, which we also use as a machine-local - // exclusive lock to prevent token cache corruption in the scope - // of developer machine, where this command runs. - listener, err := retries.Poll(ctx, listenerTimeout, - func() (*net.Listener, *retries.Err) { - var lc net.ListenConfig - l, err := lc.Listen(ctx, "tcp", appRedirectAddr) - if err != nil { - return nil, retries.Continue(err) - } - return &l, nil - }) - if err != nil { - return fmt.Errorf("listener: %w", err) - } - a.ln = *listener - return nil -} - -func (a *PersistentAuth) Close() error { - if a.ln == nil { - return nil - } - return a.ln.Close() -} - -func (a *PersistentAuth) oidcEndpoints(ctx context.Context) (*oauthAuthorizationServer, error) { - prefix := a.key() - if a.AccountID != "" { - return &oauthAuthorizationServer{ - AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix), - TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix), - }, nil - } - var oauthEndpoints oauthAuthorizationServer - oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", prefix) - err := a.http.Do(ctx, "GET", oidc, httpclient.WithResponseUnmarshal(&oauthEndpoints)) - if err != nil { - return nil, fmt.Errorf("fetch .well-known: %w", err) - } - var httpErr *httpclient.HttpError - if errors.As(err, &httpErr) && httpErr.StatusCode != 200 { - return nil, ErrOAuthNotSupported - } - return &oauthEndpoints, nil -} - -func (a *PersistentAuth) oauth2Config(ctx context.Context) (*oauth2.Config, error) { - // in this iteration of CLI, we're using all scopes by default, - // because tools like CLI and Terraform do use all apis. This - // decision may be reconsidered later, once we have a proper - // taxonomy of all scopes ready and implemented. - scopes := []string{ - "offline_access", - "all-apis", - } - endpoints, err := a.oidcEndpoints(ctx) - if err != nil { - return nil, fmt.Errorf("oidc: %w", err) - } - return &oauth2.Config{ - ClientID: appClientID, - Endpoint: oauth2.Endpoint{ - AuthURL: endpoints.AuthorizationEndpoint, - TokenURL: endpoints.TokenEndpoint, - AuthStyle: oauth2.AuthStyleInParams, - }, - RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr), - Scopes: scopes, - }, nil -} - -// key is currently used for two purposes: OIDC URL prefix and token cache key. -// once we decide to start storing scopes in the token cache, we should change -// this approach. -func (a *PersistentAuth) key() string { - a.Host = strings.TrimSuffix(a.Host, "/") - if !strings.HasPrefix(a.Host, "http") { - a.Host = fmt.Sprintf("https://%s", a.Host) - } - if a.AccountID != "" { - return fmt.Sprintf("%s/oidc/accounts/%s", a.Host, a.AccountID) - } - return a.Host -} - -func (a *PersistentAuth) stateAndPKCE() (string, *authhandler.PKCEParams) { - verifier := a.randomString(64) - verifierSha256 := sha256.Sum256([]byte(verifier)) - challenge := base64.RawURLEncoding.EncodeToString(verifierSha256[:]) - return a.randomString(16), &authhandler.PKCEParams{ - Challenge: challenge, - ChallengeMethod: "S256", - Verifier: verifier, - } -} - -func (a *PersistentAuth) randomString(size int) string { - raw := make([]byte, size) - _, _ = rand.Read(raw) - return base64.RawURLEncoding.EncodeToString(raw) -} - -type oauthAuthorizationServer struct { - AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize - TokenEndpoint string `json:"token_endpoint"` // ../v1/token -} diff --git a/libs/auth/oauth_test.go b/libs/auth/oauth_test.go deleted file mode 100644 index 837ff4fee9..0000000000 --- a/libs/auth/oauth_test.go +++ /dev/null @@ -1,267 +0,0 @@ -package auth - -import ( - "context" - "crypto/tls" - _ "embed" - "fmt" - "net/http" - "net/url" - "testing" - "time" - - "github.com/databricks/databricks-sdk-go/client" - "github.com/databricks/databricks-sdk-go/httpclient" - "github.com/databricks/databricks-sdk-go/httpclient/fixtures" - "github.com/databricks/databricks-sdk-go/qa" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/oauth2" -) - -func TestOidcEndpointsForAccounts(t *testing.T) { - p := &PersistentAuth{ - Host: "abc", - AccountID: "xyz", - } - defer p.Close() - s, err := p.oidcEndpoints(context.Background()) - assert.NoError(t, err) - assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", s.AuthorizationEndpoint) - assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/token", s.TokenEndpoint) -} - -func TestOidcForWorkspace(t *testing.T) { - p := &PersistentAuth{ - Host: "abc", - http: httpclient.NewApiClient(httpclient.ClientConfig{ - Transport: fixtures.MappingTransport{ - "GET /oidc/.well-known/oauth-authorization-server": { - Status: 200, - Response: map[string]string{ - "authorization_endpoint": "a", - "token_endpoint": "b", - }, - }, - }, - }), - } - defer p.Close() - endpoints, err := p.oidcEndpoints(context.Background()) - assert.NoError(t, err) - assert.Equal(t, "a", endpoints.AuthorizationEndpoint) - assert.Equal(t, "b", endpoints.TokenEndpoint) -} - -type tokenCacheMock struct { - store func(key string, t *oauth2.Token) error - lookup func(key string) (*oauth2.Token, error) -} - -func (m *tokenCacheMock) Store(key string, t *oauth2.Token) error { - if m.store == nil { - panic("no store mock") - } - return m.store(key, t) -} - -func (m *tokenCacheMock) Lookup(key string) (*oauth2.Token, error) { - if m.lookup == nil { - panic("no lookup mock") - } - return m.lookup(key) -} - -func TestLoad(t *testing.T) { - p := &PersistentAuth{ - Host: "abc", - AccountID: "xyz", - cache: &tokenCacheMock{ - lookup: func(key string) (*oauth2.Token, error) { - assert.Equal(t, "https://abc/oidc/accounts/xyz", key) - return &oauth2.Token{ - AccessToken: "bcd", - Expiry: time.Now().Add(1 * time.Minute), - }, nil - }, - }, - } - defer p.Close() - tok, err := p.Load(context.Background()) - assert.NoError(t, err) - assert.Equal(t, "bcd", tok.AccessToken) - assert.Equal(t, "", tok.RefreshToken) -} - -func useInsecureOAuthHttpClientForTests(ctx context.Context) context.Context { - return context.WithValue(ctx, oauth2.HTTPClient, &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - }, - }) -} - -func TestLoadRefresh(t *testing.T) { - qa.HTTPFixtures{ - { - Method: "POST", - Resource: "/oidc/accounts/xyz/v1/token", - Response: `access_token=refreshed&refresh_token=def`, - }, - }.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) { - ctx = useInsecureOAuthHttpClientForTests(ctx) - expectedKey := fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host) - p := &PersistentAuth{ - Host: c.Config.Host, - AccountID: "xyz", - cache: &tokenCacheMock{ - lookup: func(key string) (*oauth2.Token, error) { - assert.Equal(t, expectedKey, key) - return &oauth2.Token{ - AccessToken: "expired", - RefreshToken: "cde", - Expiry: time.Now().Add(-1 * time.Minute), - }, nil - }, - store: func(key string, tok *oauth2.Token) error { - assert.Equal(t, expectedKey, key) - assert.Equal(t, "def", tok.RefreshToken) - return nil - }, - }, - } - defer p.Close() - tok, err := p.Load(ctx) - assert.NoError(t, err) - assert.Equal(t, "refreshed", tok.AccessToken) - assert.Equal(t, "", tok.RefreshToken) - }) -} - -func TestChallenge(t *testing.T) { - qa.HTTPFixtures{ - { - Method: "POST", - Resource: "/oidc/accounts/xyz/v1/token", - Response: `access_token=__THAT__&refresh_token=__SOMETHING__`, - }, - }.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) { - ctx = useInsecureOAuthHttpClientForTests(ctx) - expectedKey := fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host) - - browserOpened := make(chan string) - p := &PersistentAuth{ - Host: c.Config.Host, - AccountID: "xyz", - browser: func(redirect string) error { - u, err := url.ParseRequestURI(redirect) - if err != nil { - return err - } - assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path) - // for now we're ignoring asserting the fields of the redirect - query := u.Query() - browserOpened <- query.Get("state") - return nil - }, - cache: &tokenCacheMock{ - store: func(key string, tok *oauth2.Token) error { - assert.Equal(t, expectedKey, key) - assert.Equal(t, "__SOMETHING__", tok.RefreshToken) - return nil - }, - }, - } - defer p.Close() - - errc := make(chan error) - go func() { - errc <- p.Challenge(ctx) - }() - - state := <-browserOpened - resp, err := http.Get(fmt.Sprintf("http://%s?code=__THIS__&state=%s", appRedirectAddr, state)) - require.NoError(t, err) - defer resp.Body.Close() - assert.Equal(t, 200, resp.StatusCode) - - err = <-errc - assert.NoError(t, err) - }) -} - -func TestChallengeFailed(t *testing.T) { - qa.HTTPFixtures{}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) { - ctx = useInsecureOAuthHttpClientForTests(ctx) - - browserOpened := make(chan string) - p := &PersistentAuth{ - Host: c.Config.Host, - AccountID: "xyz", - browser: func(redirect string) error { - u, err := url.ParseRequestURI(redirect) - if err != nil { - return err - } - assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path) - // for now we're ignoring asserting the fields of the redirect - query := u.Query() - browserOpened <- query.Get("state") - return nil - }, - } - defer p.Close() - - errc := make(chan error) - go func() { - errc <- p.Challenge(ctx) - }() - - <-browserOpened - resp, err := http.Get(fmt.Sprintf( - "http://%s?error=access_denied&error_description=Policy%%20evaluation%%20failed%%20for%%20this%%20request", - appRedirectAddr)) - require.NoError(t, err) - defer resp.Body.Close() - assert.Equal(t, 400, resp.StatusCode) - - err = <-errc - assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request") - }) -} - -func TestPersistentAuthCleanHost(t *testing.T) { - for _, tcases := range []struct { - in string - out string - }{ - {"https://example.com", "https://example.com"}, - {"https://example.com/", "https://example.com"}, - {"https://example.com/path", "https://example.com"}, - {"https://example.com/path/subpath", "https://example.com"}, - {"https://example.com/path?query=1", "https://example.com"}, - {"https://example.com/path?query=1&other=2", "https://example.com"}, - {"https://example.com/path#fragment", "https://example.com"}, - {"https://example.com/path?query=1#fragment", "https://example.com"}, - {"https://example.com/path?query=1&other=2#fragment", "https://example.com"}, - {"https://example.com/path/subpath?query=1", "https://example.com"}, - {"https://example.com/path/subpath?query=1&other=2", "https://example.com"}, - {"https://example.com/path/subpath#fragment", "https://example.com"}, - {"https://example.com/path/subpath?query=1#fragment", "https://example.com"}, - {"https://example.com/path/subpath?query=1&other=2#fragment", "https://example.com"}, - {"https://example.com/path?query=1%20value&other=2%20value", "https://example.com"}, - {"http://example.com/path/subpath?query=1%20value&other=2%20value", "http://example.com"}, - - // URLs without scheme should be left as is - {"abc", "abc"}, - {"abc.com/def", "abc.com/def"}, - } { - p := &PersistentAuth{ - Host: tcases.in, - } - p.cleanHost() - assert.Equal(t, tcases.out, p.Host) - } -} diff --git a/libs/auth/page.tmpl b/libs/auth/page.tmpl deleted file mode 100644 index 4642bb3d47..0000000000 --- a/libs/auth/page.tmpl +++ /dev/null @@ -1,102 +0,0 @@ - - - - - {{if .Error }}{{ .Error | title }}{{ else }}Success{{end}} - - - - - - - -
-
- - -
{{ .Error | title }}
-
{{ .ErrorDescription }}
- -
Authenticated
-
Go to {{.Host}}
- -
- You can close this tab. Or go to documentation -
-
-
- - From e8ae285da2ca8733f32f54cc05191785286dad9a Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Tue, 7 Jan 2025 11:13:16 +0100 Subject: [PATCH 02/15] work --- cmd/auth/auth.go | 27 ++++++++++++--- cmd/auth/in_memory_test.go | 27 +++++++++++++++ cmd/auth/login.go | 58 +++++++++++++++---------------- cmd/auth/login_test.go | 61 ++++++++++++++++----------------- cmd/auth/token.go | 49 +++++++++++--------------- cmd/auth/token_test.go | 70 ++++++++++++++++++++++++++------------ go.mod | 2 +- go.sum | 4 +++ 8 files changed, 181 insertions(+), 117 deletions(-) create mode 100644 cmd/auth/in_memory_test.go diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index 7493e16a30..92320d9b48 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -5,10 +5,27 @@ import ( "fmt" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go/config" "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/spf13/cobra" ) +type authArguments struct { + host string + accountId string +} + +func (a authArguments) toOAuthArgument() (oauth.OAuthArgument, error) { + cfg := &config.Config{ + Host: a.host, + AccountID: a.accountId, + } + if cfg.IsAccountClient() { + return oauth.NewBasicAccountOAuthArgument(cfg.Host, cfg.AccountID) + } + return oauth.NewBasicWorkspaceOAuthArgument(cfg.Host) +} + func New() *cobra.Command { cmd := &cobra.Command{ Use: "auth", @@ -22,14 +39,14 @@ Azure: https://learn.microsoft.com/azure/databricks/dev-tools/auth GCP: https://docs.gcp.databricks.com/dev-tools/auth/index.html`, } - var oauthArgument oauth.BasicOAuthArgument - cmd.PersistentFlags().StringVar(&oauthArgument.Host, "host", oauthArgument.Host, "Databricks Host") - cmd.PersistentFlags().StringVar(&oauthArgument.AccountID, "account-id", oauthArgument.AccountID, "Databricks Account ID") + var authArguments authArguments + cmd.PersistentFlags().StringVar(&authArguments.host, "host", "", "Databricks Host") + cmd.PersistentFlags().StringVar(&authArguments.accountId, "account-id", "", "Databricks Account ID") cmd.AddCommand(newEnvCommand()) - cmd.AddCommand(newLoginCommand(&oauthArgument)) + cmd.AddCommand(newLoginCommand(&authArguments)) cmd.AddCommand(newProfilesCommand()) - cmd.AddCommand(newTokenCommand(&oauthArgument)) + cmd.AddCommand(newTokenCommand(&authArguments)) cmd.AddCommand(newDescribeCommand()) return cmd } diff --git a/cmd/auth/in_memory_test.go b/cmd/auth/in_memory_test.go new file mode 100644 index 0000000000..a1714dd749 --- /dev/null +++ b/cmd/auth/in_memory_test.go @@ -0,0 +1,27 @@ +package auth + +import ( + "github.com/databricks/databricks-sdk-go/credentials/cache" + "golang.org/x/oauth2" +) + +type InMemoryTokenCache struct { + Tokens map[string]*oauth2.Token +} + +// Lookup implements TokenCache. +func (i *InMemoryTokenCache) Lookup(key string) (*oauth2.Token, error) { + token, ok := i.Tokens[key] + if !ok { + return nil, cache.ErrNotConfigured + } + return token, nil +} + +// Store implements TokenCache. +func (i *InMemoryTokenCache) Store(key string, t *oauth2.Token) error { + i.Tokens[key] = t + return nil +} + +var _ cache.TokenCache = (*InMemoryTokenCache)(nil) diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 6cfee2a765..cc3a398a8b 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -18,14 +18,14 @@ import ( "github.com/spf13/cobra" ) -func promptForProfile(ctx context.Context, oauthArgument oauth.OAuthArgument) (string, error) { +func promptForProfile(ctx context.Context, authArguments *authArguments) (string, error) { if !cmdio.IsInTTY(ctx) { return "", nil } prompt := cmdio.Prompt(ctx) prompt.Label = "Databricks profile name" - prompt.Default = getProfileName(ctx, oauthArgument) + prompt.Default = getProfileName(authArguments) prompt.AllowEdit = true return prompt.Run() } @@ -35,7 +35,7 @@ const ( defaultTimeout = 1 * time.Hour ) -func newLoginCommand(oauthArgument oauth.OAuthArgument) *cobra.Command { +func newLoginCommand(authArguments *authArguments) *cobra.Command { defaultConfigPath := "~/.databrickscfg" if runtime.GOOS == "windows" { defaultConfigPath = "%USERPROFILE%\\.databrickscfg" @@ -99,14 +99,14 @@ depends on the existing profiles you have set in your configuration file // If the user has not specified a profile name, prompt for one. if profileName == "" { var err error - profileName, err = promptForProfile(ctx, oauthArgument) + profileName, err = promptForProfile(ctx, authArguments) if err != nil { return err } } // Set the host and account-id based on the provided arguments and flags. - oauthArgument, err := setHostAndAccountId(ctx, profile.DefaultProfiler, profileName, oauthArgument, args) + err := setHostAndAccountId(ctx, profile.DefaultProfiler, profileName, authArguments, args) if err != nil { return err } @@ -119,18 +119,21 @@ depends on the existing profiles you have set in your configuration file // We need the config without the profile before it's used to initialise new workspace client below. // Otherwise it will complain about non existing profile because it was not yet saved. cfg := config.Config{ - Host: oauthArgument.GetHost(ctx), - AccountID: oauthArgument.GetAccountId(ctx), + Host: authArguments.host, + AccountID: authArguments.accountId, AuthType: "databricks-cli", } ctx, cancel := context.WithTimeout(ctx, loginTimeout) defer cancel() - err = persistentAuth.Challenge(ctx, oauthArgument) + oauthArgument, err := authArguments.toOAuthArgument() if err != nil { return err } + if err = persistentAuth.Challenge(ctx, oauthArgument); err != nil { + return err + } if configureCluster { w, err := databricks.NewWorkspaceClient((*databricks.Config)(&cfg)) @@ -178,66 +181,63 @@ depends on the existing profiles you have set in your configuration file // 1. --account-id flag. // 2. account-id from the specified profile, if available. // 3. Prompt the user for the account-id. -func setHostAndAccountId(ctx context.Context, profiler profile.Profiler, profileName string, oauthArgument oauth.OAuthArgument, args []string) (oauth.OAuthArgument, error) { - res := oauth.BasicOAuthArgument{} +func setHostAndAccountId(ctx context.Context, profiler profile.Profiler, profileName string, authArguments *authArguments, args []string) error { // If both [HOST] and --host are provided, return an error. - host := oauthArgument.GetHost(ctx) + host := authArguments.host if len(args) > 0 && host != "" { - return nil, fmt.Errorf("please only provide a host as an argument or a flag, not both") + return fmt.Errorf("please only provide a host as an argument or a flag, not both") } // If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile. profiles, err := profiler.LoadProfiles(ctx, profile.WithName(profileName)) // Tolerate ErrNoConfiguration here, as we will write out a configuration as part of the login flow. if err != nil && !errors.Is(err, profile.ErrNoConfiguration) { - return nil, err + return err } if host == "" { if len(args) > 0 { // If [HOST] is provided, set the host to the provided positional argument. - res.Host = args[0] + authArguments.host = args[0] } else if len(profiles) > 0 && profiles[0].Host != "" { // If neither [HOST] nor --host are provided, and the profile has a host, use it. - res.Host = profiles[0].Host + authArguments.host = profiles[0].Host } else { // If neither [HOST] nor --host are provided, and the profile does not have a host, // then prompt the user for a host. hostName, err := promptForHost(ctx) if err != nil { - return nil, err + return err } - res.Host = hostName + authArguments.host = hostName } } // If the account-id was not provided as a cmd line flag, try to read it from // the specified profile. - isAccountClient := (&config.Config{Host: res.Host}).IsAccountClient() - accountID := oauthArgument.GetAccountId(ctx) + isAccountClient := (&config.Config{Host: authArguments.host}).IsAccountClient() + accountID := authArguments.accountId if isAccountClient && accountID == "" { if len(profiles) > 0 && profiles[0].AccountID != "" { - res.AccountID = profiles[0].AccountID + authArguments.accountId = profiles[0].AccountID } else { // Prompt user for the account-id if it we could not get it from a // profile. accountId, err := promptForAccountID(ctx) if err != nil { - return nil, err + return err } - res.AccountID = accountId + authArguments.accountId = accountId } } - return res, nil + return nil } -func getProfileName(ctx context.Context, oauthArgument oauth.OAuthArgument) string { - host := oauthArgument.GetHost(ctx) - accountId := oauthArgument.GetAccountId(ctx) - if accountId != "" { - return fmt.Sprintf("ACCOUNT-%s", accountId) +func getProfileName(authArguments *authArguments) string { + if authArguments.accountId != "" { + return fmt.Sprintf("ACCOUNT-%s", authArguments.accountId) } - host = strings.TrimPrefix(host, "https://") + host := strings.TrimPrefix(authArguments.host, "https://") split := strings.Split(host, ".") return split[0] } diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index 3d41550411..b0a4f9c822 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -7,7 +7,6 @@ import ( "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/cli/libs/env" - "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -15,72 +14,72 @@ import ( func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) { ctx := context.Background() ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./imaginary-file/databrickscfg") - _, err := setHostAndAccountId(ctx, profile.DefaultProfiler, "foo", oauth.BasicOAuthArgument{Host: "test"}, []string{}) + err := setHostAndAccountId(ctx, profile.DefaultProfiler, "foo", &authArguments{host: "test"}, []string{}) assert.NoError(t, err) } func TestSetHost(t *testing.T) { - var persistentAuth oauth.BasicOAuthArgument + authArguments := authArguments{} t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg") ctx, _ := cmdio.SetupTest(context.Background()) // Test error when both flag and argument are provided - persistentAuth.Host = "val from --host" - _, err := setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &persistentAuth, []string{"val from [HOST]"}) + authArguments.host = "val from --host" + err := setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &authArguments, []string{"val from [HOST]"}) assert.EqualError(t, err, "please only provide a host as an argument or a flag, not both") // Test setting host from flag - persistentAuth.Host = "val from --host" - res, err := setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &persistentAuth, []string{}) + authArguments.host = "val from --host" + err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &authArguments, []string{}) assert.NoError(t, err) - assert.Equal(t, "val from --host", res.GetHost(ctx)) + assert.Equal(t, "val from --host", authArguments.host) // Test setting host from argument - persistentAuth.Host = "" - res, err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &persistentAuth, []string{"val from [HOST]"}) + authArguments.host = "" + err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &authArguments, []string{"val from [HOST]"}) assert.NoError(t, err) - assert.Equal(t, "val from [HOST]", res.GetHost(ctx)) + assert.Equal(t, "val from [HOST]", authArguments.host) // Test setting host from profile - persistentAuth.Host = "" - res, err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &persistentAuth, []string{}) + authArguments.host = "" + err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &authArguments, []string{}) assert.NoError(t, err) - assert.Equal(t, "https://www.host1.com", res.GetHost(ctx)) + assert.Equal(t, "https://www.host1.com", authArguments.host) // Test setting host from profile - persistentAuth.Host = "" - res, err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-2", &persistentAuth, []string{}) + authArguments.host = "" + err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-2", &authArguments, []string{}) assert.NoError(t, err) - assert.Equal(t, "https://www.host2.com", res.GetHost(ctx)) + assert.Equal(t, "https://www.host2.com", authArguments.host) // Test host is not set. Should prompt. - persistentAuth.Host = "" - _, err = setHostAndAccountId(ctx, profile.DefaultProfiler, "", &persistentAuth, []string{}) + authArguments.host = "" + err = setHostAndAccountId(ctx, profile.DefaultProfiler, "", &authArguments, []string{}) assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify a host using --host") } func TestSetAccountId(t *testing.T) { - var persistentAuth oauth.BasicOAuthArgument + var authArguments authArguments t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg") ctx, _ := cmdio.SetupTest(context.Background()) // Test setting account-id from flag - persistentAuth.AccountID = "val from --account-id" - res, err := setHostAndAccountId(ctx, profile.DefaultProfiler, "account-profile", &persistentAuth, []string{}) + authArguments.accountId = "val from --account-id" + err := setHostAndAccountId(ctx, profile.DefaultProfiler, "account-profile", &authArguments, []string{}) assert.NoError(t, err) - assert.Equal(t, "https://accounts.cloud.databricks.com", res.GetHost(ctx)) - assert.Equal(t, "val from --account-id", res.GetAccountId(ctx)) + assert.Equal(t, "https://accounts.cloud.databricks.com", authArguments.host) + assert.Equal(t, "val from --account-id", authArguments.accountId) // Test setting account_id from profile - persistentAuth.AccountID = "" - res, err = setHostAndAccountId(ctx, profile.DefaultProfiler, "account-profile", &persistentAuth, []string{}) + authArguments.accountId = "" + err = setHostAndAccountId(ctx, profile.DefaultProfiler, "account-profile", &authArguments, []string{}) require.NoError(t, err) - assert.Equal(t, "https://accounts.cloud.databricks.com", res.GetHost(ctx)) - assert.Equal(t, "id-from-profile", res.GetAccountId(ctx)) + assert.Equal(t, "https://accounts.cloud.databricks.com", authArguments.host) + assert.Equal(t, "id-from-profile", authArguments.accountId) // Neither flag nor profile account-id is set, should prompt - persistentAuth.AccountID = "" - persistentAuth.Host = "https://accounts.cloud.databricks.com" - _, err = setHostAndAccountId(ctx, profile.DefaultProfiler, "", &persistentAuth, []string{}) + authArguments.accountId = "" + authArguments.host = "https://accounts.cloud.databricks.com" + err = setHostAndAccountId(ctx, profile.DefaultProfiler, "", &authArguments, []string{}) assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify an account ID using --account-id") } diff --git a/cmd/auth/token.go b/cmd/auth/token.go index 12a6237d72..2a3d289367 100644 --- a/cmd/auth/token.go +++ b/cmd/auth/token.go @@ -5,35 +5,29 @@ import ( "encoding/json" "errors" "fmt" - "os" "strings" "time" "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/databricks-sdk-go/credentials/oauth" - "github.com/databricks/databricks-sdk-go/httpclient" "github.com/spf13/cobra" "golang.org/x/oauth2" ) -type tokenErrorResponse struct { - Error string `json:"error"` - ErrorDescription string `json:"error_description"` -} - -func buildLoginCommand(ctx context.Context, profile string, persistentAuth oauth.OAuthArgument) string { - executable := os.Args[0] +func buildLoginCommand(ctx context.Context, profile string, arg oauth.OAuthArgument) string { cmd := []string{ - executable, + "databricks", "auth", "login", } if profile != "" { cmd = append(cmd, "--profile", profile) } else { - cmd = append(cmd, "--host", persistentAuth.GetHost(ctx)) - if accountId := persistentAuth.GetAccountId(ctx); accountId != "" { - cmd = append(cmd, "--account-id", accountId) + switch arg := arg.(type) { + case oauth.AccountOAuthArgument: + cmd = append(cmd, "--host", arg.GetAccountHost(ctx), "--account-id", arg.GetAccountId(ctx)) + case oauth.WorkspaceOAuthArgument: + cmd = append(cmd, "--host", arg.GetWorkspaceHost(ctx)) } } return strings.Join(cmd, " ") @@ -44,7 +38,7 @@ func helpfulError(ctx context.Context, profile string, persistentAuth oauth.OAut return fmt.Sprintf("Try logging in again with `%s` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", loginMsg) } -func newTokenCommand(oauthArgument oauth.OAuthArgument) *cobra.Command { +func newTokenCommand(authArguments *authArguments) *cobra.Command { cmd := &cobra.Command{ Use: "token [HOST]", Short: "Get authentication token", @@ -63,7 +57,7 @@ func newTokenCommand(oauthArgument oauth.OAuthArgument) *cobra.Command { } t, err := loadToken(ctx, loadTokenArgs{ - oauthArgument: oauthArgument, + authArguments: authArguments, profileName: profileName, args: args, tokenTimeout: tokenTimeout, @@ -85,7 +79,7 @@ func newTokenCommand(oauthArgument oauth.OAuthArgument) *cobra.Command { } type loadTokenArgs struct { - oauthArgument oauth.OAuthArgument + authArguments *authArguments profileName string args []string tokenTimeout time.Duration @@ -99,34 +93,29 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { return nil, errors.New("providing both a profile and host is not supported") } - oauthArgument, err := setHostAndAccountId(ctx, args.profiler, args.profileName, args.oauthArgument, args.args) + err := setHostAndAccountId(ctx, args.profiler, args.profileName, args.authArguments, args.args) if err != nil { return nil, err } ctx, cancel := context.WithTimeout(ctx, args.tokenTimeout) defer cancel() + oauthArgument, err := args.authArguments.toOAuthArgument() + if err != nil { + return nil, err + } persistentAuth, err := oauth.NewPersistentAuth(ctx) if err != nil { helpMsg := helpfulError(ctx, args.profileName, oauthArgument) return nil, fmt.Errorf("unexpected error creating persistent auth: %w. %s", err, helpMsg) } t, err := persistentAuth.Load(ctx, oauthArgument) - var httpErr *httpclient.HttpError - if errors.As(err, &httpErr) { + if err != nil { helpMsg := helpfulError(ctx, args.profileName, oauthArgument) - t := &tokenErrorResponse{} - err = json.Unmarshal([]byte(httpErr.Message), t) - if err != nil { - return nil, fmt.Errorf("unexpected parsing token response: %w. %s", err, helpMsg) - } - if t.ErrorDescription == "Refresh token is invalid" { - return nil, fmt.Errorf("a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run `%s`", buildLoginCommand(ctx, args.profileName, oauthArgument)) - } else { - return nil, fmt.Errorf("unexpected error refreshing token: %s. %s", t.ErrorDescription, helpMsg) + if errors.Is(err, &oauth.InvalidRefreshTokenError{}) { + return nil, err } - } else if err != nil { - return nil, fmt.Errorf("unexpected error refreshing token: %w. %s", err, helpfulError(ctx, args.profileName, oauthArgument)) + return nil, fmt.Errorf("unexpected error loading token: %w. %s", err, helpMsg) } return t, nil } diff --git a/cmd/auth/token_test.go b/cmd/auth/token_test.go index d62c18cf1f..804e55e582 100644 --- a/cmd/auth/token_test.go +++ b/cmd/auth/token_test.go @@ -2,13 +2,12 @@ package auth import ( "context" + "net/http" "testing" "time" "github.com/databricks/cli/libs/databrickscfg/profile" - "github.com/databricks/databricks-sdk-go/credentials/cache" "github.com/databricks/databricks-sdk-go/credentials/oauth" - "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" "github.com/stretchr/testify/assert" "golang.org/x/oauth2" @@ -48,6 +47,35 @@ var refreshSuccessTokenResponse = fixtures.HTTPFixture{ }, } +type MockApiClient struct { + RefreshTokenResponse http.RoundTripper +} + +// GetAccountOAuthEndpoints implements oauth.OAuthClient. +func (m *MockApiClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*oauth.OAuthAuthorizationServer, error) { + return &oauth.OAuthAuthorizationServer{ + TokenEndpoint: accountHost + "/token", + AuthorizationEndpoint: accountHost + "/authorize", + }, nil +} + +// GetHttpClient implements oauth.OAuthClient. +func (m *MockApiClient) GetHttpClient(context.Context) *http.Client { + return &http.Client{ + Transport: m.RefreshTokenResponse, + } +} + +// GetWorkspaceOAuthEndpoints implements oauth.OAuthClient. +func (m *MockApiClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*oauth.OAuthAuthorizationServer, error) { + return &oauth.OAuthAuthorizationServer{ + TokenEndpoint: workspaceHost + "/token", + AuthorizationEndpoint: workspaceHost + "/authorize", + }, nil +} + +var _ oauth.OAuthClient = (*MockApiClient)(nil) + func TestToken_loadToken(t *testing.T) { profiler := profile.InMemoryProfiler{ Profiles: profile.Profiles{ @@ -63,7 +91,7 @@ func TestToken_loadToken(t *testing.T) { }, }, } - tokenCache := &cache.InMemoryTokenCache{ + tokenCache := &InMemoryTokenCache{ Tokens: map[string]*oauth2.Token{ "https://accounts.cloud.databricks.com/oidc/accounts/expired": { RefreshToken: "expired", @@ -74,10 +102,10 @@ func TestToken_loadToken(t *testing.T) { }, }, } - makeApiClient := func(f fixtures.HTTPFixture) *httpclient.ApiClient { - return httpclient.NewApiClient(httpclient.ClientConfig{ - Transport: fixtures.SliceTransport{f}, - }) + makeApiClient := func(f fixtures.HTTPFixture) *MockApiClient { + return &MockApiClient{ + RefreshTokenResponse: fixtures.SliceTransport{f}, + } } wantErrors := func(substrings ...string) func(error) { return func(err error) { @@ -100,14 +128,14 @@ func TestToken_loadToken(t *testing.T) { { name: "prints helpful login message on refresh failure when profile is specified", args: loadTokenArgs{ - oauthArgument: oauth.BasicOAuthArgument{}, + authArguments: &authArguments{}, profileName: "expired", args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, persistentAuthOpts: []oauth.PersistentAuthOption{ oauth.WithTokenCache(tokenCache), - oauth.WithApiClient(makeApiClient(refreshFailureTokenResponse)), + oauth.WithOAuthClient(makeApiClient(refreshFailureTokenResponse)), }, }, wantErr: wantErrors( @@ -118,9 +146,9 @@ func TestToken_loadToken(t *testing.T) { { name: "prints helpful login message on refresh failure when host is specified", args: loadTokenArgs{ - oauthArgument: oauth.BasicOAuthArgument{ - Host: "https://accounts.cloud.databricks.com", - AccountID: "expired", + authArguments: &authArguments{ + host: "https://accounts.cloud.databricks.com", + accountId: "expired", }, profileName: "", args: []string{}, @@ -128,7 +156,7 @@ func TestToken_loadToken(t *testing.T) { profiler: profiler, persistentAuthOpts: []oauth.PersistentAuthOption{ oauth.WithTokenCache(tokenCache), - oauth.WithApiClient(makeApiClient(refreshFailureTokenResponse)), + oauth.WithOAuthClient(makeApiClient(refreshFailureTokenResponse)), }, }, wantErr: wantErrors( @@ -139,14 +167,14 @@ func TestToken_loadToken(t *testing.T) { { name: "prints helpful login message on invalid response", args: loadTokenArgs{ - oauthArgument: oauth.BasicOAuthArgument{}, + authArguments: &authArguments{}, profileName: "active", args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, persistentAuthOpts: []oauth.PersistentAuthOption{ oauth.WithTokenCache(tokenCache), - oauth.WithApiClient(makeApiClient(refreshFailureInvalidResponse)), + oauth.WithOAuthClient(makeApiClient(refreshFailureInvalidResponse)), }, }, wantErr: wantErrors( @@ -157,14 +185,14 @@ func TestToken_loadToken(t *testing.T) { { name: "prints helpful login message on other error response", args: loadTokenArgs{ - oauthArgument: oauth.BasicOAuthArgument{}, + authArguments: &authArguments{}, profileName: "active", args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, persistentAuthOpts: []oauth.PersistentAuthOption{ oauth.WithTokenCache(tokenCache), - oauth.WithApiClient(makeApiClient(refreshFailureOtherError)), + oauth.WithOAuthClient(makeApiClient(refreshFailureOtherError)), }, }, wantErr: wantErrors( @@ -175,14 +203,14 @@ func TestToken_loadToken(t *testing.T) { { name: "succeeds with profile", args: loadTokenArgs{ - oauthArgument: oauth.BasicOAuthArgument{}, + authArguments: &authArguments{}, profileName: "active", args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, persistentAuthOpts: []oauth.PersistentAuthOption{ oauth.WithTokenCache(tokenCache), - oauth.WithApiClient(makeApiClient(refreshSuccessTokenResponse)), + oauth.WithOAuthClient(makeApiClient(refreshSuccessTokenResponse)), }, }, want: validateToken, @@ -190,14 +218,14 @@ func TestToken_loadToken(t *testing.T) { { name: "succeeds with host", args: loadTokenArgs{ - oauthArgument: oauth.BasicOAuthArgument{Host: "https://accounts.cloud.databricks.com", AccountID: "active"}, + authArguments: &authArguments{host: "https://accounts.cloud.databricks.com", accountId: "active"}, profileName: "", args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, persistentAuthOpts: []oauth.PersistentAuthOption{ oauth.WithTokenCache(tokenCache), - oauth.WithApiClient(makeApiClient(refreshSuccessTokenResponse)), + oauth.WithOAuthClient(makeApiClient(refreshSuccessTokenResponse)), }, }, want: validateToken, diff --git a/go.mod b/go.mod index baebd80a1a..cf0897c92c 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ toolchain go1.23.4 require ( github.com/Masterminds/semver/v3 v3.3.1 // MIT github.com/briandowns/spinner v1.23.1 // Apache 2.0 - github.com/databricks/databricks-sdk-go v0.54.1-0.20250103133740-0688c3b8afac // Apache 2.0 + github.com/databricks/databricks-sdk-go v0.54.1-0.20250106162806-bd7303e3df79 // Apache 2.0 github.com/fatih/color v1.18.0 // MIT github.com/google/uuid v1.6.0 // BSD-3-Clause github.com/hashicorp/go-version v1.7.0 // MPL 2.0 diff --git a/go.sum b/go.sum index 186fd1caa4..0b39199dce 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,10 @@ github.com/cyphar/filepath-securejoin v0.2.4 h1:Ugdm7cg7i6ZK6x3xDF1oEu1nfkyfH53E github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4= github.com/databricks/databricks-sdk-go v0.54.1-0.20250103133740-0688c3b8afac h1:HzS3/zoFUu6bWPEBaeeidfXIjg6Vu8qDphwr8Rg5EVU= github.com/databricks/databricks-sdk-go v0.54.1-0.20250103133740-0688c3b8afac/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= +github.com/databricks/databricks-sdk-go v0.54.1-0.20250106160146-9b5913c7af7a h1:Fzw+C/uhqPgqsndOdbhFYm00uggxy+dmyX9dILU1USQ= +github.com/databricks/databricks-sdk-go v0.54.1-0.20250106160146-9b5913c7af7a/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= +github.com/databricks/databricks-sdk-go v0.54.1-0.20250106162806-bd7303e3df79 h1:8MEDmBCvAMq7A6APGBu/pBC5xxVetr+7MZPYq55upSg= +github.com/databricks/databricks-sdk-go v0.54.1-0.20250106162806-bd7303e3df79/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= From 1149df4287fe2eaa77e8355bae5e5d08d4732b43 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Tue, 7 Jan 2025 11:48:18 +0100 Subject: [PATCH 03/15] work --- cmd/auth/auth.go | 25 +++--------------- cmd/auth/login.go | 37 +++++++++++++------------- cmd/auth/login_test.go | 43 +++++++++++++++--------------- cmd/auth/token.go | 33 ++++++----------------- cmd/auth/token_test.go | 17 ++++++------ cmd/root/auth.go | 9 +++++-- go.mod | 2 +- go.sum | 2 ++ libs/auth/error.go | 59 ++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 131 insertions(+), 96 deletions(-) create mode 100644 libs/auth/error.go diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index 92320d9b48..1b4125657c 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -4,28 +4,11 @@ import ( "context" "fmt" + "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/databricks-sdk-go/config" - "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/spf13/cobra" ) -type authArguments struct { - host string - accountId string -} - -func (a authArguments) toOAuthArgument() (oauth.OAuthArgument, error) { - cfg := &config.Config{ - Host: a.host, - AccountID: a.accountId, - } - if cfg.IsAccountClient() { - return oauth.NewBasicAccountOAuthArgument(cfg.Host, cfg.AccountID) - } - return oauth.NewBasicWorkspaceOAuthArgument(cfg.Host) -} - func New() *cobra.Command { cmd := &cobra.Command{ Use: "auth", @@ -39,9 +22,9 @@ Azure: https://learn.microsoft.com/azure/databricks/dev-tools/auth GCP: https://docs.gcp.databricks.com/dev-tools/auth/index.html`, } - var authArguments authArguments - cmd.PersistentFlags().StringVar(&authArguments.host, "host", "", "Databricks Host") - cmd.PersistentFlags().StringVar(&authArguments.accountId, "account-id", "", "Databricks Account ID") + var authArguments auth.AuthArguments + cmd.PersistentFlags().StringVar(&authArguments.Host, "host", "", "Databricks Host") + cmd.PersistentFlags().StringVar(&authArguments.AccountId, "account-id", "", "Databricks Account ID") cmd.AddCommand(newEnvCommand()) cmd.AddCommand(newLoginCommand(&authArguments)) diff --git a/cmd/auth/login.go b/cmd/auth/login.go index cc3a398a8b..0cb4123527 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/databrickscfg" "github.com/databricks/cli/libs/databrickscfg/cfgpickers" @@ -18,7 +19,7 @@ import ( "github.com/spf13/cobra" ) -func promptForProfile(ctx context.Context, authArguments *authArguments) (string, error) { +func promptForProfile(ctx context.Context, authArguments *auth.AuthArguments) (string, error) { if !cmdio.IsInTTY(ctx) { return "", nil } @@ -35,7 +36,7 @@ const ( defaultTimeout = 1 * time.Hour ) -func newLoginCommand(authArguments *authArguments) *cobra.Command { +func newLoginCommand(authArguments *auth.AuthArguments) *cobra.Command { defaultConfigPath := "~/.databrickscfg" if runtime.GOOS == "windows" { defaultConfigPath = "%USERPROFILE%\\.databrickscfg" @@ -119,15 +120,15 @@ depends on the existing profiles you have set in your configuration file // We need the config without the profile before it's used to initialise new workspace client below. // Otherwise it will complain about non existing profile because it was not yet saved. cfg := config.Config{ - Host: authArguments.host, - AccountID: authArguments.accountId, + Host: authArguments.Host, + AccountID: authArguments.AccountId, AuthType: "databricks-cli", } ctx, cancel := context.WithTimeout(ctx, loginTimeout) defer cancel() - oauthArgument, err := authArguments.toOAuthArgument() + oauthArgument, err := authArguments.ToOAuthArgument() if err != nil { return err } @@ -181,9 +182,9 @@ depends on the existing profiles you have set in your configuration file // 1. --account-id flag. // 2. account-id from the specified profile, if available. // 3. Prompt the user for the account-id. -func setHostAndAccountId(ctx context.Context, profiler profile.Profiler, profileName string, authArguments *authArguments, args []string) error { +func setHostAndAccountId(ctx context.Context, profiler profile.Profiler, profileName string, authArguments *auth.AuthArguments, args []string) error { // If both [HOST] and --host are provided, return an error. - host := authArguments.host + host := authArguments.Host if len(args) > 0 && host != "" { return fmt.Errorf("please only provide a host as an argument or a flag, not both") } @@ -198,10 +199,10 @@ func setHostAndAccountId(ctx context.Context, profiler profile.Profiler, profile if host == "" { if len(args) > 0 { // If [HOST] is provided, set the host to the provided positional argument. - authArguments.host = args[0] + authArguments.Host = args[0] } else if len(profiles) > 0 && profiles[0].Host != "" { // If neither [HOST] nor --host are provided, and the profile has a host, use it. - authArguments.host = profiles[0].Host + authArguments.Host = profiles[0].Host } else { // If neither [HOST] nor --host are provided, and the profile does not have a host, // then prompt the user for a host. @@ -209,17 +210,17 @@ func setHostAndAccountId(ctx context.Context, profiler profile.Profiler, profile if err != nil { return err } - authArguments.host = hostName + authArguments.Host = hostName } } // If the account-id was not provided as a cmd line flag, try to read it from // the specified profile. - isAccountClient := (&config.Config{Host: authArguments.host}).IsAccountClient() - accountID := authArguments.accountId + isAccountClient := (&config.Config{Host: authArguments.Host}).IsAccountClient() + accountID := authArguments.AccountId if isAccountClient && accountID == "" { if len(profiles) > 0 && profiles[0].AccountID != "" { - authArguments.accountId = profiles[0].AccountID + authArguments.AccountId = profiles[0].AccountID } else { // Prompt user for the account-id if it we could not get it from a // profile. @@ -227,17 +228,17 @@ func setHostAndAccountId(ctx context.Context, profiler profile.Profiler, profile if err != nil { return err } - authArguments.accountId = accountId + authArguments.AccountId = accountId } } return nil } -func getProfileName(authArguments *authArguments) string { - if authArguments.accountId != "" { - return fmt.Sprintf("ACCOUNT-%s", authArguments.accountId) +func getProfileName(authArguments *auth.AuthArguments) string { + if authArguments.AccountId != "" { + return fmt.Sprintf("ACCOUNT-%s", authArguments.AccountId) } - host := strings.TrimPrefix(authArguments.host, "https://") + host := strings.TrimPrefix(authArguments.Host, "https://") split := strings.Split(host, ".") return split[0] } diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index b0a4f9c822..0b02b587b6 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/cli/libs/env" @@ -14,72 +15,72 @@ import ( func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) { ctx := context.Background() ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./imaginary-file/databrickscfg") - err := setHostAndAccountId(ctx, profile.DefaultProfiler, "foo", &authArguments{host: "test"}, []string{}) + err := setHostAndAccountId(ctx, profile.DefaultProfiler, "foo", &auth.AuthArguments{Host: "test"}, []string{}) assert.NoError(t, err) } func TestSetHost(t *testing.T) { - authArguments := authArguments{} + authArguments := auth.AuthArguments{} t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg") ctx, _ := cmdio.SetupTest(context.Background()) // Test error when both flag and argument are provided - authArguments.host = "val from --host" + authArguments.Host = "val from --host" err := setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &authArguments, []string{"val from [HOST]"}) assert.EqualError(t, err, "please only provide a host as an argument or a flag, not both") // Test setting host from flag - authArguments.host = "val from --host" + authArguments.Host = "val from --host" err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &authArguments, []string{}) assert.NoError(t, err) - assert.Equal(t, "val from --host", authArguments.host) + assert.Equal(t, "val from --host", authArguments.Host) // Test setting host from argument - authArguments.host = "" + authArguments.Host = "" err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &authArguments, []string{"val from [HOST]"}) assert.NoError(t, err) - assert.Equal(t, "val from [HOST]", authArguments.host) + assert.Equal(t, "val from [HOST]", authArguments.Host) // Test setting host from profile - authArguments.host = "" + authArguments.Host = "" err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &authArguments, []string{}) assert.NoError(t, err) - assert.Equal(t, "https://www.host1.com", authArguments.host) + assert.Equal(t, "https://www.Host1.com", authArguments.Host) // Test setting host from profile - authArguments.host = "" + authArguments.Host = "" err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-2", &authArguments, []string{}) assert.NoError(t, err) - assert.Equal(t, "https://www.host2.com", authArguments.host) + assert.Equal(t, "https://www.Host2.com", authArguments.Host) // Test host is not set. Should prompt. - authArguments.host = "" + authArguments.Host = "" err = setHostAndAccountId(ctx, profile.DefaultProfiler, "", &authArguments, []string{}) assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify a host using --host") } func TestSetAccountId(t *testing.T) { - var authArguments authArguments + var authArguments auth.AuthArguments t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg") ctx, _ := cmdio.SetupTest(context.Background()) // Test setting account-id from flag - authArguments.accountId = "val from --account-id" + authArguments.AccountId = "val from --account-id" err := setHostAndAccountId(ctx, profile.DefaultProfiler, "account-profile", &authArguments, []string{}) assert.NoError(t, err) - assert.Equal(t, "https://accounts.cloud.databricks.com", authArguments.host) - assert.Equal(t, "val from --account-id", authArguments.accountId) + assert.Equal(t, "https://accounts.cloud.databricks.com", authArguments.Host) + assert.Equal(t, "val from --account-id", authArguments.AccountId) // Test setting account_id from profile - authArguments.accountId = "" + authArguments.AccountId = "" err = setHostAndAccountId(ctx, profile.DefaultProfiler, "account-profile", &authArguments, []string{}) require.NoError(t, err) - assert.Equal(t, "https://accounts.cloud.databricks.com", authArguments.host) - assert.Equal(t, "id-from-profile", authArguments.accountId) + assert.Equal(t, "https://accounts.cloud.databricks.com", authArguments.Host) + assert.Equal(t, "id-from-profile", authArguments.AccountId) // Neither flag nor profile account-id is set, should prompt - authArguments.accountId = "" - authArguments.host = "https://accounts.cloud.databricks.com" + authArguments.AccountId = "" + authArguments.Host = "https://accounts.cloud.databricks.com" err = setHostAndAccountId(ctx, profile.DefaultProfiler, "", &authArguments, []string{}) assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify an account ID using --account-id") } diff --git a/cmd/auth/token.go b/cmd/auth/token.go index 2a3d289367..fd3af2b8ac 100644 --- a/cmd/auth/token.go +++ b/cmd/auth/token.go @@ -5,40 +5,21 @@ import ( "encoding/json" "errors" "fmt" - "strings" "time" + "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/spf13/cobra" "golang.org/x/oauth2" ) -func buildLoginCommand(ctx context.Context, profile string, arg oauth.OAuthArgument) string { - cmd := []string{ - "databricks", - "auth", - "login", - } - if profile != "" { - cmd = append(cmd, "--profile", profile) - } else { - switch arg := arg.(type) { - case oauth.AccountOAuthArgument: - cmd = append(cmd, "--host", arg.GetAccountHost(ctx), "--account-id", arg.GetAccountId(ctx)) - case oauth.WorkspaceOAuthArgument: - cmd = append(cmd, "--host", arg.GetWorkspaceHost(ctx)) - } - } - return strings.Join(cmd, " ") -} - func helpfulError(ctx context.Context, profile string, persistentAuth oauth.OAuthArgument) string { - loginMsg := buildLoginCommand(ctx, profile, persistentAuth) + loginMsg := auth.BuildLoginCommand(ctx, profile, persistentAuth) return fmt.Sprintf("Try logging in again with `%s` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", loginMsg) } -func newTokenCommand(authArguments *authArguments) *cobra.Command { +func newTokenCommand(authArguments *auth.AuthArguments) *cobra.Command { cmd := &cobra.Command{ Use: "token [HOST]", Short: "Get authentication token", @@ -79,7 +60,7 @@ func newTokenCommand(authArguments *authArguments) *cobra.Command { } type loadTokenArgs struct { - authArguments *authArguments + authArguments *auth.AuthArguments profileName string args []string tokenTimeout time.Duration @@ -100,7 +81,7 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { ctx, cancel := context.WithTimeout(ctx, args.tokenTimeout) defer cancel() - oauthArgument, err := args.authArguments.toOAuthArgument() + oauthArgument, err := args.authArguments.ToOAuthArgument() if err != nil { return nil, err } @@ -113,7 +94,9 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { if err != nil { helpMsg := helpfulError(ctx, args.profileName, oauthArgument) if errors.Is(err, &oauth.InvalidRefreshTokenError{}) { - return nil, err + msg := "a new access token could not be retrieved because the refresh token is invalid." + msg += fmt.Sprintf(" To reauthenticate, run `%s`", auth.BuildLoginCommand(ctx, args.profileName, oauthArgument)) + return nil, errors.New(msg) } return nil, fmt.Errorf("unexpected error loading token: %w. %s", err, helpMsg) } diff --git a/cmd/auth/token_test.go b/cmd/auth/token_test.go index 804e55e582..63d5f3b51b 100644 --- a/cmd/auth/token_test.go +++ b/cmd/auth/token_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/databricks-sdk-go/credentials/oauth" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" @@ -128,7 +129,7 @@ func TestToken_loadToken(t *testing.T) { { name: "prints helpful login message on refresh failure when profile is specified", args: loadTokenArgs{ - authArguments: &authArguments{}, + authArguments: &auth.AuthArguments{}, profileName: "expired", args: []string{}, tokenTimeout: 1 * time.Hour, @@ -146,9 +147,9 @@ func TestToken_loadToken(t *testing.T) { { name: "prints helpful login message on refresh failure when host is specified", args: loadTokenArgs{ - authArguments: &authArguments{ - host: "https://accounts.cloud.databricks.com", - accountId: "expired", + authArguments: &auth.AuthArguments{ + Host: "https://accounts.cloud.databricks.com", + AccountId: "expired", }, profileName: "", args: []string{}, @@ -167,7 +168,7 @@ func TestToken_loadToken(t *testing.T) { { name: "prints helpful login message on invalid response", args: loadTokenArgs{ - authArguments: &authArguments{}, + authArguments: &auth.AuthArguments{}, profileName: "active", args: []string{}, tokenTimeout: 1 * time.Hour, @@ -185,7 +186,7 @@ func TestToken_loadToken(t *testing.T) { { name: "prints helpful login message on other error response", args: loadTokenArgs{ - authArguments: &authArguments{}, + authArguments: &auth.AuthArguments{}, profileName: "active", args: []string{}, tokenTimeout: 1 * time.Hour, @@ -203,7 +204,7 @@ func TestToken_loadToken(t *testing.T) { { name: "succeeds with profile", args: loadTokenArgs{ - authArguments: &authArguments{}, + authArguments: &auth.AuthArguments{}, profileName: "active", args: []string{}, tokenTimeout: 1 * time.Hour, @@ -218,7 +219,7 @@ func TestToken_loadToken(t *testing.T) { { name: "succeeds with host", args: loadTokenArgs{ - authArguments: &authArguments{host: "https://accounts.cloud.databricks.com", accountId: "active"}, + authArguments: &auth.AuthArguments{Host: "https://accounts.cloud.databricks.com", AccountId: "active"}, profileName: "", args: []string{}, tokenTimeout: 1 * time.Hour, diff --git a/cmd/root/auth.go b/cmd/root/auth.go index 07ab483990..4dc78926b0 100644 --- a/cmd/root/auth.go +++ b/cmd/root/auth.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" + "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/databricks-sdk-go" @@ -143,7 +144,7 @@ func MustAccountClient(cmd *cobra.Command, args []string) error { allowPrompt := !hasProfileFlag && !shouldSkipPrompt(cmd.Context()) a, err := accountClientOrPrompt(cmd.Context(), cfg, allowPrompt) if err != nil { - return err + return renderError(ctx, cfg, err) } ctx = context.WithValue(ctx, &accountClient, a) @@ -220,7 +221,7 @@ func MustWorkspaceClient(cmd *cobra.Command, args []string) error { allowPrompt := !hasProfileFlag && !shouldSkipPrompt(cmd.Context()) w, err := workspaceClientOrPrompt(cmd.Context(), cfg, allowPrompt) if err != nil { - return err + return renderError(ctx, cfg, err) } ctx = context.WithValue(ctx, &workspaceClient, w) @@ -338,3 +339,7 @@ func ConfigUsed(ctx context.Context) *config.Config { } return cfg } + +func renderError(ctx context.Context, cfg *config.Config, err error) error { + return auth.RewriteAuthError(ctx, cfg, err) +} diff --git a/go.mod b/go.mod index cf0897c92c..11e1b59163 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ toolchain go1.23.4 require ( github.com/Masterminds/semver/v3 v3.3.1 // MIT github.com/briandowns/spinner v1.23.1 // Apache 2.0 - github.com/databricks/databricks-sdk-go v0.54.1-0.20250106162806-bd7303e3df79 // Apache 2.0 + github.com/databricks/databricks-sdk-go v0.54.1-0.20250107101710-323fb61606d1 // Apache 2.0 github.com/fatih/color v1.18.0 // MIT github.com/google/uuid v1.6.0 // BSD-3-Clause github.com/hashicorp/go-version v1.7.0 // MPL 2.0 diff --git a/go.sum b/go.sum index 0b39199dce..ac7d7096da 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,8 @@ github.com/databricks/databricks-sdk-go v0.54.1-0.20250106160146-9b5913c7af7a h1 github.com/databricks/databricks-sdk-go v0.54.1-0.20250106160146-9b5913c7af7a/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= github.com/databricks/databricks-sdk-go v0.54.1-0.20250106162806-bd7303e3df79 h1:8MEDmBCvAMq7A6APGBu/pBC5xxVetr+7MZPYq55upSg= github.com/databricks/databricks-sdk-go v0.54.1-0.20250106162806-bd7303e3df79/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= +github.com/databricks/databricks-sdk-go v0.54.1-0.20250107101710-323fb61606d1 h1:XxAliYM4PNoiTp9hEr82G6zGgTSttTkIXF8asJQu64M= +github.com/databricks/databricks-sdk-go v0.54.1-0.20250107101710-323fb61606d1/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/libs/auth/error.go b/libs/auth/error.go new file mode 100644 index 0000000000..d659a345be --- /dev/null +++ b/libs/auth/error.go @@ -0,0 +1,59 @@ +package auth + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/databricks/databricks-sdk-go/config" + "github.com/databricks/databricks-sdk-go/credentials/oauth" +) + +func RewriteAuthError(ctx context.Context, cfg *config.Config, err error) error { + if errors.Is(err, &oauth.InvalidRefreshTokenError{}) { + oauthArgument, err := AuthArguments{cfg.Host, cfg.AccountID}.ToOAuthArgument() + if err != nil { + return err + } + msg := "a new access token could not be retrieved because the refresh token is invalid." + msg += fmt.Sprintf(" To reauthenticate, run `%s`", BuildLoginCommand(ctx, cfg.Profile, oauthArgument)) + return errors.New(msg) + } + return err +} + +func BuildLoginCommand(ctx context.Context, profile string, arg oauth.OAuthArgument) string { + cmd := []string{ + "databricks", + "auth", + "login", + } + if profile != "" { + cmd = append(cmd, "--profile", profile) + } else { + switch arg := arg.(type) { + case oauth.AccountOAuthArgument: + cmd = append(cmd, "--host", arg.GetAccountHost(ctx), "--account-id", arg.GetAccountId(ctx)) + case oauth.WorkspaceOAuthArgument: + cmd = append(cmd, "--host", arg.GetWorkspaceHost(ctx)) + } + } + return strings.Join(cmd, " ") +} + +type AuthArguments struct { + Host string + AccountId string +} + +func (a AuthArguments) ToOAuthArgument() (oauth.OAuthArgument, error) { + cfg := &config.Config{ + Host: a.Host, + AccountID: a.AccountId, + } + if cfg.IsAccountClient() { + return oauth.NewBasicAccountOAuthArgument(cfg.Host, cfg.AccountID) + } + return oauth.NewBasicWorkspaceOAuthArgument(cfg.Host) +} From 7205ab2cbc3e4264bb8b9e4deda68dcaa3579753 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Tue, 7 Jan 2025 13:57:48 +0100 Subject: [PATCH 04/15] small fix --- cmd/auth/login.go | 2 +- go.mod | 2 +- go.sum | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 0cb4123527..3952942d3d 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -132,7 +132,7 @@ depends on the existing profiles you have set in your configuration file if err != nil { return err } - if err = persistentAuth.Challenge(ctx, oauthArgument); err != nil { + if _, err = persistentAuth.Challenge(ctx, oauthArgument); err != nil { return err } diff --git a/go.mod b/go.mod index 11e1b59163..ee891cdd0a 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ toolchain go1.23.4 require ( github.com/Masterminds/semver/v3 v3.3.1 // MIT github.com/briandowns/spinner v1.23.1 // Apache 2.0 - github.com/databricks/databricks-sdk-go v0.54.1-0.20250107101710-323fb61606d1 // Apache 2.0 + github.com/databricks/databricks-sdk-go v0.54.1-0.20250107125441-f53fb8492c42 // Apache 2.0 github.com/fatih/color v1.18.0 // MIT github.com/google/uuid v1.6.0 // BSD-3-Clause github.com/hashicorp/go-version v1.7.0 // MPL 2.0 diff --git a/go.sum b/go.sum index ac7d7096da..0e716ab6fb 100644 --- a/go.sum +++ b/go.sum @@ -42,6 +42,8 @@ github.com/databricks/databricks-sdk-go v0.54.1-0.20250106162806-bd7303e3df79 h1 github.com/databricks/databricks-sdk-go v0.54.1-0.20250106162806-bd7303e3df79/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= github.com/databricks/databricks-sdk-go v0.54.1-0.20250107101710-323fb61606d1 h1:XxAliYM4PNoiTp9hEr82G6zGgTSttTkIXF8asJQu64M= github.com/databricks/databricks-sdk-go v0.54.1-0.20250107101710-323fb61606d1/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= +github.com/databricks/databricks-sdk-go v0.54.1-0.20250107125441-f53fb8492c42 h1:kj6GEDBiYRWRDoCm0SW0N7qagmS9JuG5rgDUeBMuR2M= +github.com/databricks/databricks-sdk-go v0.54.1-0.20250107125441-f53fb8492c42/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= From 22e02e25c228ce0afd11c017b2077b9af64a14fd Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Tue, 7 Jan 2025 15:53:00 +0100 Subject: [PATCH 05/15] fix --- go.mod | 2 +- go.sum | 2 ++ libs/auth/error.go | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index ee891cdd0a..706ade9cff 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ toolchain go1.23.4 require ( github.com/Masterminds/semver/v3 v3.3.1 // MIT github.com/briandowns/spinner v1.23.1 // Apache 2.0 - github.com/databricks/databricks-sdk-go v0.54.1-0.20250107125441-f53fb8492c42 // Apache 2.0 + github.com/databricks/databricks-sdk-go v0.54.1-0.20250107144233-ef8c3f356afc // Apache 2.0 github.com/fatih/color v1.18.0 // MIT github.com/google/uuid v1.6.0 // BSD-3-Clause github.com/hashicorp/go-version v1.7.0 // MPL 2.0 diff --git a/go.sum b/go.sum index 0e716ab6fb..76d74b5044 100644 --- a/go.sum +++ b/go.sum @@ -44,6 +44,8 @@ github.com/databricks/databricks-sdk-go v0.54.1-0.20250107101710-323fb61606d1 h1 github.com/databricks/databricks-sdk-go v0.54.1-0.20250107101710-323fb61606d1/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= github.com/databricks/databricks-sdk-go v0.54.1-0.20250107125441-f53fb8492c42 h1:kj6GEDBiYRWRDoCm0SW0N7qagmS9JuG5rgDUeBMuR2M= github.com/databricks/databricks-sdk-go v0.54.1-0.20250107125441-f53fb8492c42/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= +github.com/databricks/databricks-sdk-go v0.54.1-0.20250107144233-ef8c3f356afc h1:fVNcDO+hDt0oc9jZ1oFiF+ycjdBmp+hXdt699qHU+hQ= +github.com/databricks/databricks-sdk-go v0.54.1-0.20250107144233-ef8c3f356afc/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/libs/auth/error.go b/libs/auth/error.go index d659a345be..4952d1309c 100644 --- a/libs/auth/error.go +++ b/libs/auth/error.go @@ -11,7 +11,8 @@ import ( ) func RewriteAuthError(ctx context.Context, cfg *config.Config, err error) error { - if errors.Is(err, &oauth.InvalidRefreshTokenError{}) { + target := &oauth.InvalidRefreshTokenError{} + if errors.As(err, &target) { oauthArgument, err := AuthArguments{cfg.Host, cfg.AccountID}.ToOAuthArgument() if err != nil { return err From c9a2c130412c00a294443e40d98de3b44e6550b7 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Tue, 7 Jan 2025 18:03:49 +0100 Subject: [PATCH 06/15] lint --- cmd/auth/login.go | 2 +- cmd/auth/token_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 44fcd81cfa..407a56abda 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -236,7 +236,7 @@ func setHostAndAccountId(ctx context.Context, profiler profile.Profiler, profile func getProfileName(authArguments *auth.AuthArguments) string { if authArguments.AccountId != "" { - return fmt.Sprintf("ACCOUNT-%s", authArguments.AccountId) + return "ACCOUNT-" + authArguments.AccountId } host := strings.TrimPrefix(authArguments.Host, "https://") split := strings.Split(host, ".") diff --git a/cmd/auth/token_test.go b/cmd/auth/token_test.go index 63d5f3b51b..a90d602be7 100644 --- a/cmd/auth/token_test.go +++ b/cmd/auth/token_test.go @@ -53,7 +53,7 @@ type MockApiClient struct { } // GetAccountOAuthEndpoints implements oauth.OAuthClient. -func (m *MockApiClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*oauth.OAuthAuthorizationServer, error) { +func (m *MockApiClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost, accountId string) (*oauth.OAuthAuthorizationServer, error) { return &oauth.OAuthAuthorizationServer{ TokenEndpoint: accountHost + "/token", AuthorizationEndpoint: accountHost + "/authorize", From b4c5625adbef1493b99e4799ee79d7a0f3756fbd Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Tue, 7 Jan 2025 18:08:34 +0100 Subject: [PATCH 07/15] work --- cmd/auth/token.go | 4 +--- cmd/root/auth.go | 2 +- libs/auth/error.go | 6 +++--- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/cmd/auth/token.go b/cmd/auth/token.go index fd3af2b8ac..e71939eefa 100644 --- a/cmd/auth/token.go +++ b/cmd/auth/token.go @@ -94,9 +94,7 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { if err != nil { helpMsg := helpfulError(ctx, args.profileName, oauthArgument) if errors.Is(err, &oauth.InvalidRefreshTokenError{}) { - msg := "a new access token could not be retrieved because the refresh token is invalid." - msg += fmt.Sprintf(" To reauthenticate, run `%s`", auth.BuildLoginCommand(ctx, args.profileName, oauthArgument)) - return nil, errors.New(msg) + return nil, auth.RewriteAuthError(ctx, args.authArguments.Host, args.authArguments.AccountId, args.profileName, err) } return nil, fmt.Errorf("unexpected error loading token: %w. %s", err, helpMsg) } diff --git a/cmd/root/auth.go b/cmd/root/auth.go index b36ca2667c..b604b4c95b 100644 --- a/cmd/root/auth.go +++ b/cmd/root/auth.go @@ -341,5 +341,5 @@ func ConfigUsed(ctx context.Context) *config.Config { } func renderError(ctx context.Context, cfg *config.Config, err error) error { - return auth.RewriteAuthError(ctx, cfg, err) + return auth.RewriteAuthError(ctx, cfg.Host, cfg.AccountID, cfg.Profile, err) } diff --git a/libs/auth/error.go b/libs/auth/error.go index 4952d1309c..f07a4993a8 100644 --- a/libs/auth/error.go +++ b/libs/auth/error.go @@ -10,15 +10,15 @@ import ( "github.com/databricks/databricks-sdk-go/credentials/oauth" ) -func RewriteAuthError(ctx context.Context, cfg *config.Config, err error) error { +func RewriteAuthError(ctx context.Context, host, accountId, profile string, err error) error { target := &oauth.InvalidRefreshTokenError{} if errors.As(err, &target) { - oauthArgument, err := AuthArguments{cfg.Host, cfg.AccountID}.ToOAuthArgument() + oauthArgument, err := AuthArguments{host, accountId}.ToOAuthArgument() if err != nil { return err } msg := "a new access token could not be retrieved because the refresh token is invalid." - msg += fmt.Sprintf(" To reauthenticate, run `%s`", BuildLoginCommand(ctx, cfg.Profile, oauthArgument)) + msg += fmt.Sprintf(" To reauthenticate, run `%s`", BuildLoginCommand(ctx, profile, oauthArgument)) return errors.New(msg) } return err From 61c748573759ecdb234dd5334692039f41b6000a Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Wed, 8 Jan 2025 12:02:31 +0100 Subject: [PATCH 08/15] work --- cmd/auth/login_test.go | 4 ++-- cmd/auth/token.go | 5 +++-- cmd/auth/token_test.go | 51 +++++++++++++++--------------------------- go.mod | 2 +- go.sum | 14 ++---------- 5 files changed, 26 insertions(+), 50 deletions(-) diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index 0b02b587b6..1db9602697 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -45,13 +45,13 @@ func TestSetHost(t *testing.T) { authArguments.Host = "" err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &authArguments, []string{}) assert.NoError(t, err) - assert.Equal(t, "https://www.Host1.com", authArguments.Host) + assert.Equal(t, "https://www.host1.com", authArguments.Host) // Test setting host from profile authArguments.Host = "" err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-2", &authArguments, []string{}) assert.NoError(t, err) - assert.Equal(t, "https://www.Host2.com", authArguments.Host) + assert.Equal(t, "https://www.host2.com", authArguments.Host) // Test host is not set. Should prompt. authArguments.Host = "" diff --git a/cmd/auth/token.go b/cmd/auth/token.go index e71939eefa..e3e770e81a 100644 --- a/cmd/auth/token.go +++ b/cmd/auth/token.go @@ -85,7 +85,7 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { if err != nil { return nil, err } - persistentAuth, err := oauth.NewPersistentAuth(ctx) + persistentAuth, err := oauth.NewPersistentAuth(ctx, args.persistentAuthOpts...) if err != nil { helpMsg := helpfulError(ctx, args.profileName, oauthArgument) return nil, fmt.Errorf("unexpected error creating persistent auth: %w. %s", err, helpMsg) @@ -93,7 +93,8 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { t, err := persistentAuth.Load(ctx, oauthArgument) if err != nil { helpMsg := helpfulError(ctx, args.profileName, oauthArgument) - if errors.Is(err, &oauth.InvalidRefreshTokenError{}) { + target := &oauth.InvalidRefreshTokenError{} + if errors.As(err, &target) { return nil, auth.RewriteAuthError(ctx, args.authArguments.Host, args.authArguments.AccountId, args.profileName, err) } return nil, fmt.Errorf("unexpected error loading token: %w. %s", err, helpMsg) diff --git a/cmd/auth/token_test.go b/cmd/auth/token_test.go index a90d602be7..e635511a90 100644 --- a/cmd/auth/token_test.go +++ b/cmd/auth/token_test.go @@ -25,7 +25,7 @@ var refreshFailureTokenResponse = fixtures.HTTPFixture{ var refreshFailureInvalidResponse = fixtures.HTTPFixture{ MatchAny: true, - Status: 401, + Status: 200, Response: "Not json", } @@ -108,23 +108,16 @@ func TestToken_loadToken(t *testing.T) { RefreshTokenResponse: fixtures.SliceTransport{f}, } } - wantErrors := func(substrings ...string) func(error) { - return func(err error) { - for _, s := range substrings { - assert.ErrorContains(t, err, s) - } - } - } validateToken := func(resp *oauth2.Token) { assert.Equal(t, "new-access-token", resp.AccessToken) assert.Equal(t, "Bearer", resp.TokenType) } cases := []struct { - name string - args loadTokenArgs - want func(*oauth2.Token) - wantErr func(error) + name string + args loadTokenArgs + validateToken func(*oauth2.Token) + wantErr string }{ { name: "prints helpful login message on refresh failure when profile is specified", @@ -139,10 +132,8 @@ func TestToken_loadToken(t *testing.T) { oauth.WithOAuthClient(makeApiClient(refreshFailureTokenResponse)), }, }, - wantErr: wantErrors( - "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run ", - "auth login --host https://accounts.cloud.databricks.com --account-id expired", - ), + wantErr: "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run " + + "`databricks auth login --profile expired`", }, { name: "prints helpful login message on refresh failure when host is specified", @@ -160,10 +151,8 @@ func TestToken_loadToken(t *testing.T) { oauth.WithOAuthClient(makeApiClient(refreshFailureTokenResponse)), }, }, - wantErr: wantErrors( - "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run ", - "auth login --host https://accounts.cloud.databricks.com --account-id expired", - ), + wantErr: "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run " + + "`databricks auth login --host https://accounts.cloud.databricks.com --account-id expired`", }, { name: "prints helpful login message on invalid response", @@ -178,10 +167,8 @@ func TestToken_loadToken(t *testing.T) { oauth.WithOAuthClient(makeApiClient(refreshFailureInvalidResponse)), }, }, - wantErr: wantErrors( - "unexpected parsing token response: invalid character 'N' looking for beginning of value. Try logging in again with ", - "auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", - ), + wantErr: "unexpected error loading token: token refresh: oauth2: cannot parse json: invalid character 'N' looking for beginning of value. Try logging in again with " + + "`databricks auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", }, { name: "prints helpful login message on other error response", @@ -196,10 +183,8 @@ func TestToken_loadToken(t *testing.T) { oauth.WithOAuthClient(makeApiClient(refreshFailureOtherError)), }, }, - wantErr: wantErrors( - "unexpected error refreshing token: Databricks is down. Try logging in again with ", - "auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", - ), + wantErr: "unexpected error loading token: token refresh: Databricks is down (error code: other_error). Try logging in again with " + + "`databricks auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", }, { name: "succeeds with profile", @@ -214,7 +199,7 @@ func TestToken_loadToken(t *testing.T) { oauth.WithOAuthClient(makeApiClient(refreshSuccessTokenResponse)), }, }, - want: validateToken, + validateToken: validateToken, }, { name: "succeeds with host", @@ -229,17 +214,17 @@ func TestToken_loadToken(t *testing.T) { oauth.WithOAuthClient(makeApiClient(refreshSuccessTokenResponse)), }, }, - want: validateToken, + validateToken: validateToken, }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { got, err := loadToken(context.Background(), c.args) - if c.wantErr != nil { - c.wantErr(err) + if c.wantErr != "" { + assert.Equal(t, c.wantErr, err.Error()) } else { assert.NoError(t, err) - assert.Equal(t, c.want, got) + c.validateToken(got) } }) } diff --git a/go.mod b/go.mod index 706ade9cff..746794ada1 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ toolchain go1.23.4 require ( github.com/Masterminds/semver/v3 v3.3.1 // MIT github.com/briandowns/spinner v1.23.1 // Apache 2.0 - github.com/databricks/databricks-sdk-go v0.54.1-0.20250107144233-ef8c3f356afc // Apache 2.0 + github.com/databricks/databricks-sdk-go v0.54.1-0.20250108110122-e9f3732ef745 // Apache 2.0 github.com/fatih/color v1.18.0 // MIT github.com/google/uuid v1.6.0 // BSD-3-Clause github.com/hashicorp/go-version v1.7.0 // MPL 2.0 diff --git a/go.sum b/go.sum index 76d74b5044..5e44d50cb7 100644 --- a/go.sum +++ b/go.sum @@ -34,18 +34,8 @@ github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGX github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cyphar/filepath-securejoin v0.2.4 h1:Ugdm7cg7i6ZK6x3xDF1oEu1nfkyfH53EtKeQYTC3kyg= github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4= -github.com/databricks/databricks-sdk-go v0.54.1-0.20250103133740-0688c3b8afac h1:HzS3/zoFUu6bWPEBaeeidfXIjg6Vu8qDphwr8Rg5EVU= -github.com/databricks/databricks-sdk-go v0.54.1-0.20250103133740-0688c3b8afac/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= -github.com/databricks/databricks-sdk-go v0.54.1-0.20250106160146-9b5913c7af7a h1:Fzw+C/uhqPgqsndOdbhFYm00uggxy+dmyX9dILU1USQ= -github.com/databricks/databricks-sdk-go v0.54.1-0.20250106160146-9b5913c7af7a/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= -github.com/databricks/databricks-sdk-go v0.54.1-0.20250106162806-bd7303e3df79 h1:8MEDmBCvAMq7A6APGBu/pBC5xxVetr+7MZPYq55upSg= -github.com/databricks/databricks-sdk-go v0.54.1-0.20250106162806-bd7303e3df79/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= -github.com/databricks/databricks-sdk-go v0.54.1-0.20250107101710-323fb61606d1 h1:XxAliYM4PNoiTp9hEr82G6zGgTSttTkIXF8asJQu64M= -github.com/databricks/databricks-sdk-go v0.54.1-0.20250107101710-323fb61606d1/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= -github.com/databricks/databricks-sdk-go v0.54.1-0.20250107125441-f53fb8492c42 h1:kj6GEDBiYRWRDoCm0SW0N7qagmS9JuG5rgDUeBMuR2M= -github.com/databricks/databricks-sdk-go v0.54.1-0.20250107125441-f53fb8492c42/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= -github.com/databricks/databricks-sdk-go v0.54.1-0.20250107144233-ef8c3f356afc h1:fVNcDO+hDt0oc9jZ1oFiF+ycjdBmp+hXdt699qHU+hQ= -github.com/databricks/databricks-sdk-go v0.54.1-0.20250107144233-ef8c3f356afc/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= +github.com/databricks/databricks-sdk-go v0.54.1-0.20250108110122-e9f3732ef745 h1:sxVHnM0Wluxwz5PO/qNt/OE/g7iEorlH8FfQoFznZ/I= +github.com/databricks/databricks-sdk-go v0.54.1-0.20250108110122-e9f3732ef745/go.mod h1:IMSyEl8eEBwXQNljpYxs5cp31i8nyAx6KPj8U81X5Zo= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= From da5ef6122fac4b6d4d719319a4e10c89401adb14 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Wed, 8 Jan 2025 12:10:54 +0100 Subject: [PATCH 09/15] fixes --- cmd/auth/token.go | 11 +++++------ cmd/auth/token_test.go | 4 ++-- cmd/root/auth.go | 3 ++- libs/auth/error.go | 10 ++++++---- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/cmd/auth/token.go b/cmd/auth/token.go index e3e770e81a..add2390c66 100644 --- a/cmd/auth/token.go +++ b/cmd/auth/token.go @@ -88,16 +88,15 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { persistentAuth, err := oauth.NewPersistentAuth(ctx, args.persistentAuthOpts...) if err != nil { helpMsg := helpfulError(ctx, args.profileName, oauthArgument) - return nil, fmt.Errorf("unexpected error creating persistent auth: %w. %s", err, helpMsg) + return nil, fmt.Errorf("%w. %s", err, helpMsg) } t, err := persistentAuth.Load(ctx, oauthArgument) if err != nil { - helpMsg := helpfulError(ctx, args.profileName, oauthArgument) - target := &oauth.InvalidRefreshTokenError{} - if errors.As(err, &target) { - return nil, auth.RewriteAuthError(ctx, args.authArguments.Host, args.authArguments.AccountId, args.profileName, err) + if err, ok := auth.RewriteAuthError(ctx, args.authArguments.Host, args.authArguments.AccountId, args.profileName, err); ok { + return nil, err } - return nil, fmt.Errorf("unexpected error loading token: %w. %s", err, helpMsg) + helpMsg := helpfulError(ctx, args.profileName, oauthArgument) + return nil, fmt.Errorf("%w. %s", err, helpMsg) } return t, nil } diff --git a/cmd/auth/token_test.go b/cmd/auth/token_test.go index e635511a90..fb01494883 100644 --- a/cmd/auth/token_test.go +++ b/cmd/auth/token_test.go @@ -167,7 +167,7 @@ func TestToken_loadToken(t *testing.T) { oauth.WithOAuthClient(makeApiClient(refreshFailureInvalidResponse)), }, }, - wantErr: "unexpected error loading token: token refresh: oauth2: cannot parse json: invalid character 'N' looking for beginning of value. Try logging in again with " + + wantErr: "token refresh: oauth2: cannot parse json: invalid character 'N' looking for beginning of value. Try logging in again with " + "`databricks auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", }, { @@ -183,7 +183,7 @@ func TestToken_loadToken(t *testing.T) { oauth.WithOAuthClient(makeApiClient(refreshFailureOtherError)), }, }, - wantErr: "unexpected error loading token: token refresh: Databricks is down (error code: other_error). Try logging in again with " + + wantErr: "token refresh: Databricks is down (error code: other_error). Try logging in again with " + "`databricks auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", }, { diff --git a/cmd/root/auth.go b/cmd/root/auth.go index b604b4c95b..efc24f8dc4 100644 --- a/cmd/root/auth.go +++ b/cmd/root/auth.go @@ -341,5 +341,6 @@ func ConfigUsed(ctx context.Context) *config.Config { } func renderError(ctx context.Context, cfg *config.Config, err error) error { - return auth.RewriteAuthError(ctx, cfg.Host, cfg.AccountID, cfg.Profile, err) + err, _ = auth.RewriteAuthError(ctx, cfg.Host, cfg.AccountID, cfg.Profile, err) + return err } diff --git a/libs/auth/error.go b/libs/auth/error.go index f07a4993a8..2ad18781e4 100644 --- a/libs/auth/error.go +++ b/libs/auth/error.go @@ -10,18 +10,20 @@ import ( "github.com/databricks/databricks-sdk-go/credentials/oauth" ) -func RewriteAuthError(ctx context.Context, host, accountId, profile string, err error) error { +// RewriteAuthError rewrites the error message for invalid refresh token error. +// It returns the rewritten error and a boolean indicating whether the error was rewritten. +func RewriteAuthError(ctx context.Context, host, accountId, profile string, err error) (error, bool) { target := &oauth.InvalidRefreshTokenError{} if errors.As(err, &target) { oauthArgument, err := AuthArguments{host, accountId}.ToOAuthArgument() if err != nil { - return err + return err, false } msg := "a new access token could not be retrieved because the refresh token is invalid." msg += fmt.Sprintf(" To reauthenticate, run `%s`", BuildLoginCommand(ctx, profile, oauthArgument)) - return errors.New(msg) + return errors.New(msg), true } - return err + return err, false } func BuildLoginCommand(ctx context.Context, profile string, arg oauth.OAuthArgument) string { From 4465ef5c7beca8b188048ee3ad0fa9dc5ce85b1e Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 20 Jan 2025 13:24:10 +0100 Subject: [PATCH 10/15] work --- go.mod | 2 ++ libs/auth/error.go | 5 ++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 746794ada1..33109e6ba4 100644 --- a/go.mod +++ b/go.mod @@ -77,3 +77,5 @@ require ( google.golang.org/grpc v1.64.1 // indirect google.golang.org/protobuf v1.34.1 // indirect ) + +replace github.com/databricks/databricks-sdk-go => /Users/miles/databricks-sdk-go diff --git a/libs/auth/error.go b/libs/auth/error.go index 2ad18781e4..36c4f34a52 100644 --- a/libs/auth/error.go +++ b/libs/auth/error.go @@ -3,7 +3,6 @@ package auth import ( "context" "errors" - "fmt" "strings" "github.com/databricks/databricks-sdk-go/config" @@ -19,8 +18,8 @@ func RewriteAuthError(ctx context.Context, host, accountId, profile string, err if err != nil { return err, false } - msg := "a new access token could not be retrieved because the refresh token is invalid." - msg += fmt.Sprintf(" To reauthenticate, run `%s`", BuildLoginCommand(ctx, profile, oauthArgument)) + msg := `A new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run the following command: + $ ` + BuildLoginCommand(ctx, profile, oauthArgument) return errors.New(msg), true } return err, false From 11ec4dbbd4fe9fdaafd59daa7ab1d1c828b5c493 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 20 Jan 2025 13:35:26 +0100 Subject: [PATCH 11/15] workg --- cmd/auth/auth.go | 2 +- cmd/auth/login.go | 21 ++++++++++++--------- cmd/auth/login_test.go | 10 +++++----- cmd/auth/token.go | 26 ++++++++++++++++++++------ cmd/auth/token_test.go | 4 ++-- libs/auth/arguments.go | 25 +++++++++++++++++++++++++ libs/auth/error.go | 18 +----------------- 7 files changed, 66 insertions(+), 40 deletions(-) create mode 100644 libs/auth/arguments.go diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index 35f838cb34..9096573a0c 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -24,7 +24,7 @@ GCP: https://docs.gcp.databricks.com/dev-tools/auth/index.html`, var authArguments auth.AuthArguments cmd.PersistentFlags().StringVar(&authArguments.Host, "host", "", "Databricks Host") - cmd.PersistentFlags().StringVar(&authArguments.AccountId, "account-id", "", "Databricks Account ID") + cmd.PersistentFlags().StringVar(&authArguments.AccountID, "account-id", "", "Databricks Account ID") cmd.AddCommand(newEnvCommand()) cmd.AddCommand(newLoginCommand(&authArguments)) diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 407a56abda..e0609f2b9a 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -19,14 +19,14 @@ import ( "github.com/spf13/cobra" ) -func promptForProfile(ctx context.Context, authArguments *auth.AuthArguments) (string, error) { +func promptForProfile(ctx context.Context, defaultValue string) (string, error) { if !cmdio.IsInTTY(ctx) { return "", nil } prompt := cmdio.Prompt(ctx) prompt.Label = "Databricks profile name" - prompt.Default = getProfileName(authArguments) + prompt.Default = defaultValue prompt.AllowEdit = true return prompt.Run() } @@ -100,7 +100,7 @@ depends on the existing profiles you have set in your configuration file // If the user has not specified a profile name, prompt for one. if profileName == "" { var err error - profileName, err = promptForProfile(ctx, authArguments) + profileName, err = promptForProfile(ctx, getProfileName(authArguments)) if err != nil { return err } @@ -121,7 +121,7 @@ depends on the existing profiles you have set in your configuration file // Otherwise it will complain about non existing profile because it was not yet saved. cfg := config.Config{ Host: authArguments.Host, - AccountID: authArguments.AccountId, + AccountID: authArguments.AccountID, AuthType: "databricks-cli", } @@ -217,10 +217,10 @@ func setHostAndAccountId(ctx context.Context, profiler profile.Profiler, profile // If the account-id was not provided as a cmd line flag, try to read it from // the specified profile. isAccountClient := (&config.Config{Host: authArguments.Host}).IsAccountClient() - accountID := authArguments.AccountId + accountID := authArguments.AccountID if isAccountClient && accountID == "" { if len(profiles) > 0 && profiles[0].AccountID != "" { - authArguments.AccountId = profiles[0].AccountID + authArguments.AccountID = profiles[0].AccountID } else { // Prompt user for the account-id if it we could not get it from a // profile. @@ -228,15 +228,18 @@ func setHostAndAccountId(ctx context.Context, profiler profile.Profiler, profile if err != nil { return err } - authArguments.AccountId = accountId + authArguments.AccountID = accountId } } return nil } +// getProfileName returns the default profile name for a given host/account ID. +// If the account ID is provided, the profile name is "ACCOUNT-". +// Otherwise, the profile name is the first part of the host URL. func getProfileName(authArguments *auth.AuthArguments) string { - if authArguments.AccountId != "" { - return "ACCOUNT-" + authArguments.AccountId + if authArguments.AccountID != "" { + return "ACCOUNT-" + authArguments.AccountID } host := strings.TrimPrefix(authArguments.Host, "https://") split := strings.Split(host, ".") diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index 1db9602697..8412dab2ce 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -65,21 +65,21 @@ func TestSetAccountId(t *testing.T) { ctx, _ := cmdio.SetupTest(context.Background()) // Test setting account-id from flag - authArguments.AccountId = "val from --account-id" + authArguments.AccountID = "val from --account-id" err := setHostAndAccountId(ctx, profile.DefaultProfiler, "account-profile", &authArguments, []string{}) assert.NoError(t, err) assert.Equal(t, "https://accounts.cloud.databricks.com", authArguments.Host) - assert.Equal(t, "val from --account-id", authArguments.AccountId) + assert.Equal(t, "val from --account-id", authArguments.AccountID) // Test setting account_id from profile - authArguments.AccountId = "" + authArguments.AccountID = "" err = setHostAndAccountId(ctx, profile.DefaultProfiler, "account-profile", &authArguments, []string{}) require.NoError(t, err) assert.Equal(t, "https://accounts.cloud.databricks.com", authArguments.Host) - assert.Equal(t, "id-from-profile", authArguments.AccountId) + assert.Equal(t, "id-from-profile", authArguments.AccountID) // Neither flag nor profile account-id is set, should prompt - authArguments.AccountId = "" + authArguments.AccountID = "" authArguments.Host = "https://accounts.cloud.databricks.com" err = setHostAndAccountId(ctx, profile.DefaultProfiler, "", &authArguments, []string{}) assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify an account ID using --account-id") diff --git a/cmd/auth/token.go b/cmd/auth/token.go index add2390c66..8d4df2c29e 100644 --- a/cmd/auth/token.go +++ b/cmd/auth/token.go @@ -60,14 +60,28 @@ func newTokenCommand(authArguments *auth.AuthArguments) *cobra.Command { } type loadTokenArgs struct { - authArguments *auth.AuthArguments - profileName string - args []string - tokenTimeout time.Duration - profiler profile.Profiler + // authArguments is the parsed auth arguments, including the host and optionally the account ID. + authArguments *auth.AuthArguments + + // profileName is the name of the specified profile. If no profile is specified, this is an empty string. + profileName string + + // args is the list of arguments passed to the command. + args []string + + // tokenTimeout is the timeout for retrieving (and potentially refreshing) an OAuth token. + tokenTimeout time.Duration + + // profiler is the profiler to use for reading the host and account ID from the .databrickscfg file. + profiler profile.Profiler + + // persistentAuthOpts are the options to pass to the persistent auth client. persistentAuthOpts []oauth.PersistentAuthOption } +// loadToken loads an OAuth token from the persistent auth store. The host and account ID are read from +// the provided profiler if not explicitly provided. If the token cannot be refreshed, a helpful error message +// is printed to the user with steps to reauthenticate. func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { // If a profile is provided we read the host from the .databrickscfg file if args.profileName != "" && len(args.args) > 0 { @@ -92,7 +106,7 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { } t, err := persistentAuth.Load(ctx, oauthArgument) if err != nil { - if err, ok := auth.RewriteAuthError(ctx, args.authArguments.Host, args.authArguments.AccountId, args.profileName, err); ok { + if err, ok := auth.RewriteAuthError(ctx, args.authArguments.Host, args.authArguments.AccountID, args.profileName, err); ok { return nil, err } helpMsg := helpfulError(ctx, args.profileName, oauthArgument) diff --git a/cmd/auth/token_test.go b/cmd/auth/token_test.go index fb01494883..04abb0717d 100644 --- a/cmd/auth/token_test.go +++ b/cmd/auth/token_test.go @@ -140,7 +140,7 @@ func TestToken_loadToken(t *testing.T) { args: loadTokenArgs{ authArguments: &auth.AuthArguments{ Host: "https://accounts.cloud.databricks.com", - AccountId: "expired", + AccountID: "expired", }, profileName: "", args: []string{}, @@ -204,7 +204,7 @@ func TestToken_loadToken(t *testing.T) { { name: "succeeds with host", args: loadTokenArgs{ - authArguments: &auth.AuthArguments{Host: "https://accounts.cloud.databricks.com", AccountId: "active"}, + authArguments: &auth.AuthArguments{Host: "https://accounts.cloud.databricks.com", AccountID: "active"}, profileName: "", args: []string{}, tokenTimeout: 1 * time.Hour, diff --git a/libs/auth/arguments.go b/libs/auth/arguments.go new file mode 100644 index 0000000000..90f8948402 --- /dev/null +++ b/libs/auth/arguments.go @@ -0,0 +1,25 @@ +package auth + +import ( + "github.com/databricks/databricks-sdk-go/config" + "github.com/databricks/databricks-sdk-go/credentials/oauth" +) + +// AuthArguments is a struct that contains the common arguments passed to +// `databricks auth` commands. +type AuthArguments struct { + Host string + AccountID string +} + +// ToOAuthArgument converts the AuthArguments to an OAuthArgument from the Go SDK. +func (a AuthArguments) ToOAuthArgument() (oauth.OAuthArgument, error) { + cfg := &config.Config{ + Host: a.Host, + AccountID: a.AccountID, + } + if cfg.IsAccountClient() { + return oauth.NewBasicAccountOAuthArgument(cfg.Host, cfg.AccountID) + } + return oauth.NewBasicWorkspaceOAuthArgument(cfg.Host) +} diff --git a/libs/auth/error.go b/libs/auth/error.go index 36c4f34a52..f432fa0ddb 100644 --- a/libs/auth/error.go +++ b/libs/auth/error.go @@ -5,7 +5,6 @@ import ( "errors" "strings" - "github.com/databricks/databricks-sdk-go/config" "github.com/databricks/databricks-sdk-go/credentials/oauth" ) @@ -25,6 +24,7 @@ func RewriteAuthError(ctx context.Context, host, accountId, profile string, err return err, false } +// BuildLoginCommand builds the login command for the given OAuth argument or profile. func BuildLoginCommand(ctx context.Context, profile string, arg oauth.OAuthArgument) string { cmd := []string{ "databricks", @@ -43,19 +43,3 @@ func BuildLoginCommand(ctx context.Context, profile string, arg oauth.OAuthArgum } return strings.Join(cmd, " ") } - -type AuthArguments struct { - Host string - AccountId string -} - -func (a AuthArguments) ToOAuthArgument() (oauth.OAuthArgument, error) { - cfg := &config.Config{ - Host: a.Host, - AccountID: a.AccountId, - } - if cfg.IsAccountClient() { - return oauth.NewBasicAccountOAuthArgument(cfg.Host, cfg.AccountID) - } - return oauth.NewBasicWorkspaceOAuthArgument(cfg.Host) -} From cd64e338fcb5e697e3aa10644cf3515cf4446a37 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 24 Mar 2025 13:48:19 +0100 Subject: [PATCH 12/15] merge --- cmd/auth/in_memory_test.go | 2 +- cmd/auth/login.go | 14 ++++----- cmd/auth/token.go | 10 +++---- cmd/auth/token_test.go | 61 ++++++++++++++++---------------------- cmd/root/auth.go | 27 ----------------- libs/auth/arguments.go | 8 ++--- libs/auth/error.go | 14 ++++----- 7 files changed, 50 insertions(+), 86 deletions(-) diff --git a/cmd/auth/in_memory_test.go b/cmd/auth/in_memory_test.go index a1714dd749..3733c6fe41 100644 --- a/cmd/auth/in_memory_test.go +++ b/cmd/auth/in_memory_test.go @@ -1,7 +1,7 @@ package auth import ( - "github.com/databricks/databricks-sdk-go/credentials/cache" + "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" "golang.org/x/oauth2" ) diff --git a/cmd/auth/login.go b/cmd/auth/login.go index e0609f2b9a..9fd1ee519e 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -15,7 +15,7 @@ import ( "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/config" - "github.com/databricks/databricks-sdk-go/credentials/oauth" + "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/spf13/cobra" ) @@ -111,7 +111,11 @@ depends on the existing profiles you have set in your configuration file if err != nil { return err } - persistentAuth, err := oauth.NewPersistentAuth(ctx) + oauthArgument, err := authArguments.ToOAuthArgument() + if err != nil { + return err + } + persistentAuth, err := u2m.NewPersistentAuth(ctx, u2m.WithOAuthArgument(oauthArgument)) if err != nil { return err } @@ -128,11 +132,7 @@ depends on the existing profiles you have set in your configuration file ctx, cancel := context.WithTimeout(ctx, loginTimeout) defer cancel() - oauthArgument, err := authArguments.ToOAuthArgument() - if err != nil { - return err - } - if _, err = persistentAuth.Challenge(ctx, oauthArgument); err != nil { + if err = persistentAuth.Challenge(); err != nil { return err } diff --git a/cmd/auth/token.go b/cmd/auth/token.go index fd02a23750..aafda37221 100644 --- a/cmd/auth/token.go +++ b/cmd/auth/token.go @@ -9,12 +9,12 @@ import ( "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/databrickscfg/profile" - "github.com/databricks/databricks-sdk-go/credentials/oauth" + "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/spf13/cobra" "golang.org/x/oauth2" ) -func helpfulError(ctx context.Context, profile string, persistentAuth oauth.OAuthArgument) string { +func helpfulError(ctx context.Context, profile string, persistentAuth u2m.OAuthArgument) string { loginMsg := auth.BuildLoginCommand(ctx, profile, persistentAuth) return fmt.Sprintf("Try logging in again with `%s` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", loginMsg) } @@ -80,7 +80,7 @@ type loadTokenArgs struct { profiler profile.Profiler // persistentAuthOpts are the options to pass to the persistent auth client. - persistentAuthOpts []oauth.PersistentAuthOption + persistentAuthOpts []u2m.PersistentAuthOption } // loadToken loads an OAuth token from the persistent auth store. The host and account ID are read from @@ -103,12 +103,12 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { if err != nil { return nil, err } - persistentAuth, err := oauth.NewPersistentAuth(ctx, args.persistentAuthOpts...) + persistentAuth, err := u2m.NewPersistentAuth(ctx, args.persistentAuthOpts...) if err != nil { helpMsg := helpfulError(ctx, args.profileName, oauthArgument) return nil, fmt.Errorf("%w. %s", err, helpMsg) } - t, err := persistentAuth.Load(ctx, oauthArgument) + t, err := persistentAuth.Token() if err != nil { if err, ok := auth.RewriteAuthError(ctx, args.authArguments.Host, args.authArguments.AccountID, args.profileName, err); ok { return nil, err diff --git a/cmd/auth/token_test.go b/cmd/auth/token_test.go index 77686a81b4..c476229de5 100644 --- a/cmd/auth/token_test.go +++ b/cmd/auth/token_test.go @@ -6,11 +6,9 @@ import ( "testing" "time" - "github.com/databricks/cli/cmd" - "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/databrickscfg/profile" - "github.com/databricks/databricks-sdk-go/credentials/oauth" + "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" "github.com/stretchr/testify/assert" "golang.org/x/oauth2" @@ -54,30 +52,23 @@ type MockApiClient struct { RefreshTokenResponse http.RoundTripper } -// GetAccountOAuthEndpoints implements oauth.OAuthClient. -func (m *MockApiClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost, accountId string) (*oauth.OAuthAuthorizationServer, error) { - return &oauth.OAuthAuthorizationServer{ +// GetAccountOAuthEndpoints implements u2m.OAuthEndpointSupplier. +func (m *MockApiClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost, accountId string) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ TokenEndpoint: accountHost + "/token", AuthorizationEndpoint: accountHost + "/authorize", }, nil } -// GetHttpClient implements oauth.OAuthClient. -func (m *MockApiClient) GetHttpClient(context.Context) *http.Client { - return &http.Client{ - Transport: m.RefreshTokenResponse, - } -} - -// GetWorkspaceOAuthEndpoints implements oauth.OAuthClient. -func (m *MockApiClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*oauth.OAuthAuthorizationServer, error) { - return &oauth.OAuthAuthorizationServer{ +// GetWorkspaceOAuthEndpoints implements u2m.OAuthEndpointSupplier. +func (m *MockApiClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ TokenEndpoint: workspaceHost + "/token", AuthorizationEndpoint: workspaceHost + "/authorize", }, nil } -var _ oauth.OAuthClient = (*MockApiClient)(nil) +var _ u2m.OAuthEndpointSupplier = (*MockApiClient)(nil) func TestToken_loadToken(t *testing.T) { profiler := profile.InMemoryProfiler{ @@ -129,9 +120,9 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, - persistentAuthOpts: []oauth.PersistentAuthOption{ - oauth.WithTokenCache(tokenCache), - oauth.WithOAuthClient(makeApiClient(refreshFailureTokenResponse)), + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(makeApiClient(refreshFailureTokenResponse)), }, }, wantErr: "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run " + @@ -148,9 +139,9 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, - persistentAuthOpts: []oauth.PersistentAuthOption{ - oauth.WithTokenCache(tokenCache), - oauth.WithOAuthClient(makeApiClient(refreshFailureTokenResponse)), + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(makeApiClient(refreshFailureTokenResponse)), }, }, wantErr: "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run " + @@ -164,9 +155,9 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, - persistentAuthOpts: []oauth.PersistentAuthOption{ - oauth.WithTokenCache(tokenCache), - oauth.WithOAuthClient(makeApiClient(refreshFailureInvalidResponse)), + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(makeApiClient(refreshFailureInvalidResponse)), }, }, wantErr: "token refresh: oauth2: cannot parse json: invalid character 'N' looking for beginning of value. Try logging in again with " + @@ -180,9 +171,9 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, - persistentAuthOpts: []oauth.PersistentAuthOption{ - oauth.WithTokenCache(tokenCache), - oauth.WithOAuthClient(makeApiClient(refreshFailureOtherError)), + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(makeApiClient(refreshFailureOtherError)), }, }, wantErr: "token refresh: Databricks is down (error code: other_error). Try logging in again with " + @@ -196,9 +187,9 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, - persistentAuthOpts: []oauth.PersistentAuthOption{ - oauth.WithTokenCache(tokenCache), - oauth.WithOAuthClient(makeApiClient(refreshSuccessTokenResponse)), + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(makeApiClient(refreshSuccessTokenResponse)), }, }, validateToken: validateToken, @@ -211,9 +202,9 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, - persistentAuthOpts: []oauth.PersistentAuthOption{ - oauth.WithTokenCache(tokenCache), - oauth.WithOAuthClient(makeApiClient(refreshSuccessTokenResponse)), + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(makeApiClient(refreshSuccessTokenResponse)), }, }, validateToken: validateToken, diff --git a/cmd/root/auth.go b/cmd/root/auth.go index 4f943dd88d..406631d24b 100644 --- a/cmd/root/auth.go +++ b/cmd/root/auth.go @@ -6,11 +6,8 @@ import ( "fmt" "net/http" -<<<<<<< HEAD "github.com/databricks/cli/libs/auth" -======= "github.com/databricks/cli/libs/cmdctx" ->>>>>>> main "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/databricks-sdk-go" @@ -311,30 +308,6 @@ func emptyHttpRequest(ctx context.Context) *http.Request { return req } -func WorkspaceClient(ctx context.Context) *databricks.WorkspaceClient { - w, ok := ctx.Value(&workspaceClient).(*databricks.WorkspaceClient) - if !ok { - panic("cannot get *databricks.WorkspaceClient. Please report it as a bug") - } - return w -} - -func AccountClient(ctx context.Context) *databricks.AccountClient { - a, ok := ctx.Value(&accountClient).(*databricks.AccountClient) - if !ok { - panic("cannot get *databricks.AccountClient. Please report it as a bug") - } - return a -} - -func ConfigUsed(ctx context.Context) *config.Config { - cfg, ok := ctx.Value(&configUsed).(*config.Config) - if !ok { - panic("cannot get *config.Config. Please report it as a bug") - } - return cfg -} - func renderError(ctx context.Context, cfg *config.Config, err error) error { err, _ = auth.RewriteAuthError(ctx, cfg.Host, cfg.AccountID, cfg.Profile, err) return err diff --git a/libs/auth/arguments.go b/libs/auth/arguments.go index 90f8948402..17957ec511 100644 --- a/libs/auth/arguments.go +++ b/libs/auth/arguments.go @@ -2,7 +2,7 @@ package auth import ( "github.com/databricks/databricks-sdk-go/config" - "github.com/databricks/databricks-sdk-go/credentials/oauth" + "github.com/databricks/databricks-sdk-go/credentials/u2m" ) // AuthArguments is a struct that contains the common arguments passed to @@ -13,13 +13,13 @@ type AuthArguments struct { } // ToOAuthArgument converts the AuthArguments to an OAuthArgument from the Go SDK. -func (a AuthArguments) ToOAuthArgument() (oauth.OAuthArgument, error) { +func (a AuthArguments) ToOAuthArgument() (u2m.OAuthArgument, error) { cfg := &config.Config{ Host: a.Host, AccountID: a.AccountID, } if cfg.IsAccountClient() { - return oauth.NewBasicAccountOAuthArgument(cfg.Host, cfg.AccountID) + return u2m.NewBasicAccountOAuthArgument(cfg.Host, cfg.AccountID) } - return oauth.NewBasicWorkspaceOAuthArgument(cfg.Host) + return u2m.NewBasicWorkspaceOAuthArgument(cfg.Host) } diff --git a/libs/auth/error.go b/libs/auth/error.go index f432fa0ddb..1bf0e9b519 100644 --- a/libs/auth/error.go +++ b/libs/auth/error.go @@ -5,13 +5,13 @@ import ( "errors" "strings" - "github.com/databricks/databricks-sdk-go/credentials/oauth" + "github.com/databricks/databricks-sdk-go/credentials/u2m" ) // RewriteAuthError rewrites the error message for invalid refresh token error. // It returns the rewritten error and a boolean indicating whether the error was rewritten. func RewriteAuthError(ctx context.Context, host, accountId, profile string, err error) (error, bool) { - target := &oauth.InvalidRefreshTokenError{} + target := &u2m.InvalidRefreshTokenError{} if errors.As(err, &target) { oauthArgument, err := AuthArguments{host, accountId}.ToOAuthArgument() if err != nil { @@ -25,7 +25,7 @@ func RewriteAuthError(ctx context.Context, host, accountId, profile string, err } // BuildLoginCommand builds the login command for the given OAuth argument or profile. -func BuildLoginCommand(ctx context.Context, profile string, arg oauth.OAuthArgument) string { +func BuildLoginCommand(ctx context.Context, profile string, arg u2m.OAuthArgument) string { cmd := []string{ "databricks", "auth", @@ -35,10 +35,10 @@ func BuildLoginCommand(ctx context.Context, profile string, arg oauth.OAuthArgum cmd = append(cmd, "--profile", profile) } else { switch arg := arg.(type) { - case oauth.AccountOAuthArgument: - cmd = append(cmd, "--host", arg.GetAccountHost(ctx), "--account-id", arg.GetAccountId(ctx)) - case oauth.WorkspaceOAuthArgument: - cmd = append(cmd, "--host", arg.GetWorkspaceHost(ctx)) + case u2m.AccountOAuthArgument: + cmd = append(cmd, "--host", arg.GetAccountHost(), "--account-id", arg.GetAccountId()) + case u2m.WorkspaceOAuthArgument: + cmd = append(cmd, "--host", arg.GetWorkspaceHost()) } } return strings.Join(cmd, " ") From cac81a1824174adec49dd87f5721f62727d96ed3 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Wed, 26 Mar 2025 11:55:21 +0100 Subject: [PATCH 13/15] work --- cmd/auth/token.go | 3 ++- cmd/auth/token_test.go | 35 +++++++++++++++++------------------ 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/cmd/auth/token.go b/cmd/auth/token.go index aafda37221..dd0f95732e 100644 --- a/cmd/auth/token.go +++ b/cmd/auth/token.go @@ -103,7 +103,8 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { if err != nil { return nil, err } - persistentAuth, err := u2m.NewPersistentAuth(ctx, args.persistentAuthOpts...) + allArgs := append(args.persistentAuthOpts, u2m.WithOAuthArgument(oauthArgument)) + persistentAuth, err := u2m.NewPersistentAuth(ctx, allArgs...) if err != nil { helpMsg := helpfulError(ctx, args.profileName, oauthArgument) return nil, fmt.Errorf("%w. %s", err, helpMsg) diff --git a/cmd/auth/token_test.go b/cmd/auth/token_test.go index c476229de5..63d56b2b28 100644 --- a/cmd/auth/token_test.go +++ b/cmd/auth/token_test.go @@ -48,9 +48,7 @@ var refreshSuccessTokenResponse = fixtures.HTTPFixture{ }, } -type MockApiClient struct { - RefreshTokenResponse http.RoundTripper -} +type MockApiClient struct{} // GetAccountOAuthEndpoints implements u2m.OAuthEndpointSupplier. func (m *MockApiClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost, accountId string) (*u2m.OAuthAuthorizationServer, error) { @@ -96,11 +94,6 @@ func TestToken_loadToken(t *testing.T) { }, }, } - makeApiClient := func(f fixtures.HTTPFixture) *MockApiClient { - return &MockApiClient{ - RefreshTokenResponse: fixtures.SliceTransport{f}, - } - } validateToken := func(resp *oauth2.Token) { assert.Equal(t, "new-access-token", resp.AccessToken) assert.Equal(t, "Bearer", resp.TokenType) @@ -122,11 +115,12 @@ func TestToken_loadToken(t *testing.T) { profiler: profiler, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), - u2m.WithOAuthEndpointSupplier(makeApiClient(refreshFailureTokenResponse)), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshFailureTokenResponse}}), }, }, - wantErr: "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run " + - "`databricks auth login --profile expired`", + wantErr: `A new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run the following command: + $ databricks auth login --profile expired`, }, { name: "prints helpful login message on refresh failure when host is specified", @@ -141,11 +135,12 @@ func TestToken_loadToken(t *testing.T) { profiler: profiler, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), - u2m.WithOAuthEndpointSupplier(makeApiClient(refreshFailureTokenResponse)), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshFailureTokenResponse}}), }, }, - wantErr: "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run " + - "`databricks auth login --host https://accounts.cloud.databricks.com --account-id expired`", + wantErr: `A new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run the following command: + $ databricks auth login --host https://accounts.cloud.databricks.com --account-id expired`, }, { name: "prints helpful login message on invalid response", @@ -157,7 +152,8 @@ func TestToken_loadToken(t *testing.T) { profiler: profiler, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), - u2m.WithOAuthEndpointSupplier(makeApiClient(refreshFailureInvalidResponse)), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshFailureInvalidResponse}}), }, }, wantErr: "token refresh: oauth2: cannot parse json: invalid character 'N' looking for beginning of value. Try logging in again with " + @@ -173,7 +169,8 @@ func TestToken_loadToken(t *testing.T) { profiler: profiler, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), - u2m.WithOAuthEndpointSupplier(makeApiClient(refreshFailureOtherError)), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshFailureOtherError}}), }, }, wantErr: "token refresh: Databricks is down (error code: other_error). Try logging in again with " + @@ -189,7 +186,8 @@ func TestToken_loadToken(t *testing.T) { profiler: profiler, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), - u2m.WithOAuthEndpointSupplier(makeApiClient(refreshSuccessTokenResponse)), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}), }, }, validateToken: validateToken, @@ -204,7 +202,8 @@ func TestToken_loadToken(t *testing.T) { profiler: profiler, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), - u2m.WithOAuthEndpointSupplier(makeApiClient(refreshSuccessTokenResponse)), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}), }, }, validateToken: validateToken, From e7eab57cadbec58a93f89906071b10732edbf031 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 3 Apr 2025 09:46:40 +0200 Subject: [PATCH 14/15] work --- go.mod | 6 ++---- go.sum | 10 ++++------ 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/go.mod b/go.mod index eb5aa27c9d..8324f96ac5 100644 --- a/go.mod +++ b/go.mod @@ -6,10 +6,10 @@ toolchain go1.23.7 require ( dario.cat/mergo v1.0.1 // BSD 3-Clause - github.com/BurntSushi/toml v1.4.0 // MIT + github.com/BurntSushi/toml v1.5.0 // MIT github.com/Masterminds/semver/v3 v3.3.1 // MIT github.com/briandowns/spinner v1.23.1 // Apache 2.0 - github.com/databricks/databricks-sdk-go v0.60.0 // Apache 2.0 + github.com/databricks/databricks-sdk-go v0.61.0 // Apache 2.0 github.com/fatih/color v1.18.0 // MIT github.com/google/uuid v1.6.0 // BSD-3-Clause github.com/gorilla/mux v1.8.1 // BSD 3-Clause @@ -80,5 +80,3 @@ require ( google.golang.org/grpc v1.64.1 // indirect google.golang.org/protobuf v1.34.1 // indirect ) - -replace github.com/databricks/databricks-sdk-go => /Users/miles/databricks-sdk-go diff --git a/go.sum b/go.sum index 597c138aa6..12ad43650f 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,8 @@ cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1h dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0= -github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= +github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7rj+4nv4= github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= @@ -34,8 +34,8 @@ github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGX github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/cyphar/filepath-securejoin v0.2.5 h1:6iR5tXJ/e6tJZzzdMc1km3Sa7RRIVBKAK32O2s7AYfo= github.com/cyphar/filepath-securejoin v0.2.5/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4= -github.com/databricks/databricks-sdk-go v0.60.0 h1:mCnPsK7gLxF6ps9WihQkh3OwOTTLq/JEzsBzDq1yYbc= -github.com/databricks/databricks-sdk-go v0.60.0/go.mod h1:JpLizplEs+up9/Z4Xf2x++o3sM9eTTWFGzIXAptKJzI= +github.com/databricks/databricks-sdk-go v0.61.0 h1:rRshNJxGoTOyRf4783YZLcd5JTH3hhaZyxHNRmAcVwU= +github.com/databricks/databricks-sdk-go v0.61.0/go.mod h1:xBtjeP9nq+6MgTewZW1EcbRkD7aDY9gZvcRPcwPhZjw= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -158,7 +158,6 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= @@ -282,7 +281,6 @@ gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 229d35c71e58a2dc28a1cb42fc0fe7623744fe39 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 11 Apr 2025 10:25:26 +0200 Subject: [PATCH 15/15] fix --- cmd/auth/in_memory_test.go | 8 ++++---- cmd/auth/token_test.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cmd/auth/in_memory_test.go b/cmd/auth/in_memory_test.go index 3733c6fe41..212b2ed91f 100644 --- a/cmd/auth/in_memory_test.go +++ b/cmd/auth/in_memory_test.go @@ -5,12 +5,12 @@ import ( "golang.org/x/oauth2" ) -type InMemoryTokenCache struct { +type inMemoryTokenCache struct { Tokens map[string]*oauth2.Token } // Lookup implements TokenCache. -func (i *InMemoryTokenCache) Lookup(key string) (*oauth2.Token, error) { +func (i *inMemoryTokenCache) Lookup(key string) (*oauth2.Token, error) { token, ok := i.Tokens[key] if !ok { return nil, cache.ErrNotConfigured @@ -19,9 +19,9 @@ func (i *InMemoryTokenCache) Lookup(key string) (*oauth2.Token, error) { } // Store implements TokenCache. -func (i *InMemoryTokenCache) Store(key string, t *oauth2.Token) error { +func (i *inMemoryTokenCache) Store(key string, t *oauth2.Token) error { i.Tokens[key] = t return nil } -var _ cache.TokenCache = (*InMemoryTokenCache)(nil) +var _ cache.TokenCache = (*inMemoryTokenCache)(nil) diff --git a/cmd/auth/token_test.go b/cmd/auth/token_test.go index 63d56b2b28..feb0b0ae59 100644 --- a/cmd/auth/token_test.go +++ b/cmd/auth/token_test.go @@ -83,7 +83,7 @@ func TestToken_loadToken(t *testing.T) { }, }, } - tokenCache := &InMemoryTokenCache{ + tokenCache := &inMemoryTokenCache{ Tokens: map[string]*oauth2.Token{ "https://accounts.cloud.databricks.com/oidc/accounts/expired": { RefreshToken: "expired",