diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index 4261e93e7a..9096573a0c 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -22,14 +22,14 @@ Azure: https://learn.microsoft.com/azure/databricks/dev-tools/auth GCP: https://docs.gcp.databricks.com/dev-tools/auth/index.html`, } - var perisistentAuth auth.PersistentAuth - cmd.PersistentFlags().StringVar(&perisistentAuth.Host, "host", perisistentAuth.Host, "Databricks Host") - cmd.PersistentFlags().StringVar(&perisistentAuth.AccountID, "account-id", perisistentAuth.AccountID, "Databricks Account ID") + var authArguments auth.AuthArguments + cmd.PersistentFlags().StringVar(&authArguments.Host, "host", "", "Databricks Host") + cmd.PersistentFlags().StringVar(&authArguments.AccountID, "account-id", "", "Databricks Account ID") cmd.AddCommand(newEnvCommand()) - cmd.AddCommand(newLoginCommand(&perisistentAuth)) + cmd.AddCommand(newLoginCommand(&authArguments)) cmd.AddCommand(newProfilesCommand()) - cmd.AddCommand(newTokenCommand(&perisistentAuth)) + cmd.AddCommand(newTokenCommand(&authArguments)) cmd.AddCommand(newDescribeCommand()) return cmd } diff --git a/cmd/auth/in_memory_test.go b/cmd/auth/in_memory_test.go new file mode 100644 index 0000000000..212b2ed91f --- /dev/null +++ b/cmd/auth/in_memory_test.go @@ -0,0 +1,27 @@ +package auth + +import ( + "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" + "golang.org/x/oauth2" +) + +type inMemoryTokenCache struct { + Tokens map[string]*oauth2.Token +} + +// Lookup implements TokenCache. +func (i *inMemoryTokenCache) Lookup(key string) (*oauth2.Token, error) { + token, ok := i.Tokens[key] + if !ok { + return nil, cache.ErrNotConfigured + } + return token, nil +} + +// Store implements TokenCache. +func (i *inMemoryTokenCache) Store(key string, t *oauth2.Token) error { + i.Tokens[key] = t + return nil +} + +var _ cache.TokenCache = (*inMemoryTokenCache)(nil) diff --git a/cmd/auth/login.go b/cmd/auth/login.go index a6d0bf4cc7..9fd1ee519e 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "runtime" + "strings" "time" "github.com/databricks/cli/libs/auth" @@ -14,6 +15,7 @@ import ( "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/config" + "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/spf13/cobra" ) @@ -34,7 +36,7 @@ const ( defaultTimeout = 1 * time.Hour ) -func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command { +func newLoginCommand(authArguments *auth.AuthArguments) *cobra.Command { defaultConfigPath := "~/.databrickscfg" if runtime.GOOS == "windows" { defaultConfigPath = "%USERPROFILE%\\.databrickscfg" @@ -98,14 +100,22 @@ depends on the existing profiles you have set in your configuration file // If the user has not specified a profile name, prompt for one. if profileName == "" { var err error - profileName, err = promptForProfile(ctx, persistentAuth.ProfileName()) + profileName, err = promptForProfile(ctx, getProfileName(authArguments)) if err != nil { return err } } // Set the host and account-id based on the provided arguments and flags. - err := setHostAndAccountId(ctx, profileName, persistentAuth, args) + err := setHostAndAccountId(ctx, profile.DefaultProfiler, profileName, authArguments, args) + if err != nil { + return err + } + oauthArgument, err := authArguments.ToOAuthArgument() + if err != nil { + return err + } + persistentAuth, err := u2m.NewPersistentAuth(ctx, u2m.WithOAuthArgument(oauthArgument)) if err != nil { return err } @@ -114,16 +124,15 @@ depends on the existing profiles you have set in your configuration file // We need the config without the profile before it's used to initialise new workspace client below. // Otherwise it will complain about non existing profile because it was not yet saved. cfg := config.Config{ - Host: persistentAuth.Host, - AccountID: persistentAuth.AccountID, + Host: authArguments.Host, + AccountID: authArguments.AccountID, AuthType: "databricks-cli", } ctx, cancel := context.WithTimeout(ctx, loginTimeout) defer cancel() - err = persistentAuth.Challenge(ctx) - if err != nil { + if err = persistentAuth.Challenge(); err != nil { return err } @@ -173,13 +182,13 @@ depends on the existing profiles you have set in your configuration file // 1. --account-id flag. // 2. account-id from the specified profile, if available. // 3. Prompt the user for the account-id. -func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error { +func setHostAndAccountId(ctx context.Context, profiler profile.Profiler, profileName string, authArguments *auth.AuthArguments, args []string) error { // If both [HOST] and --host are provided, return an error. - if len(args) > 0 && persistentAuth.Host != "" { + host := authArguments.Host + if len(args) > 0 && host != "" { return errors.New("please only provide a host as an argument or a flag, not both") } - profiler := profile.GetProfiler(ctx) // If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile. profiles, err := profiler.LoadProfiles(ctx, profile.WithName(profileName)) // Tolerate ErrNoConfiguration here, as we will write out a configuration as part of the login flow. @@ -187,13 +196,13 @@ func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth return err } - if persistentAuth.Host == "" { + if host == "" { if len(args) > 0 { // If [HOST] is provided, set the host to the provided positional argument. - persistentAuth.Host = args[0] + authArguments.Host = args[0] } else if len(profiles) > 0 && profiles[0].Host != "" { // If neither [HOST] nor --host are provided, and the profile has a host, use it. - persistentAuth.Host = profiles[0].Host + authArguments.Host = profiles[0].Host } else { // If neither [HOST] nor --host are provided, and the profile does not have a host, // then prompt the user for a host. @@ -201,16 +210,17 @@ func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth if err != nil { return err } - persistentAuth.Host = hostName + authArguments.Host = hostName } } // If the account-id was not provided as a cmd line flag, try to read it from // the specified profile. - isAccountClient := (&config.Config{Host: persistentAuth.Host}).IsAccountClient() - if isAccountClient && persistentAuth.AccountID == "" { + isAccountClient := (&config.Config{Host: authArguments.Host}).IsAccountClient() + accountID := authArguments.AccountID + if isAccountClient && accountID == "" { if len(profiles) > 0 && profiles[0].AccountID != "" { - persistentAuth.AccountID = profiles[0].AccountID + authArguments.AccountID = profiles[0].AccountID } else { // Prompt user for the account-id if it we could not get it from a // profile. @@ -218,8 +228,20 @@ func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth if err != nil { return err } - persistentAuth.AccountID = accountId + authArguments.AccountID = accountId } } return nil } + +// getProfileName returns the default profile name for a given host/account ID. +// If the account ID is provided, the profile name is "ACCOUNT-". +// Otherwise, the profile name is the first part of the host URL. +func getProfileName(authArguments *auth.AuthArguments) string { + if authArguments.AccountID != "" { + return "ACCOUNT-" + authArguments.AccountID + } + host := strings.TrimPrefix(authArguments.Host, "https://") + split := strings.Split(host, ".") + return split[0] +} diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index d0fa5a16b8..8412dab2ce 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -6,6 +6,7 @@ import ( "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/cli/libs/env" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -14,72 +15,72 @@ import ( func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) { ctx := context.Background() ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./imaginary-file/databrickscfg") - err := setHostAndAccountId(ctx, "foo", &auth.PersistentAuth{Host: "test"}, []string{}) + err := setHostAndAccountId(ctx, profile.DefaultProfiler, "foo", &auth.AuthArguments{Host: "test"}, []string{}) assert.NoError(t, err) } func TestSetHost(t *testing.T) { - var persistentAuth auth.PersistentAuth + authArguments := auth.AuthArguments{} t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg") ctx, _ := cmdio.SetupTest(context.Background()) // Test error when both flag and argument are provided - persistentAuth.Host = "val from --host" - err := setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"}) + authArguments.Host = "val from --host" + err := setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &authArguments, []string{"val from [HOST]"}) assert.EqualError(t, err, "please only provide a host as an argument or a flag, not both") // Test setting host from flag - persistentAuth.Host = "val from --host" - err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{}) + authArguments.Host = "val from --host" + err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &authArguments, []string{}) assert.NoError(t, err) - assert.Equal(t, "val from --host", persistentAuth.Host) + assert.Equal(t, "val from --host", authArguments.Host) // Test setting host from argument - persistentAuth.Host = "" - err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"}) + authArguments.Host = "" + err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &authArguments, []string{"val from [HOST]"}) assert.NoError(t, err) - assert.Equal(t, "val from [HOST]", persistentAuth.Host) + assert.Equal(t, "val from [HOST]", authArguments.Host) // Test setting host from profile - persistentAuth.Host = "" - err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{}) + authArguments.Host = "" + err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &authArguments, []string{}) assert.NoError(t, err) - assert.Equal(t, "https://www.host1.com", persistentAuth.Host) + assert.Equal(t, "https://www.host1.com", authArguments.Host) // Test setting host from profile - persistentAuth.Host = "" - err = setHostAndAccountId(ctx, "profile-2", &persistentAuth, []string{}) + authArguments.Host = "" + err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-2", &authArguments, []string{}) assert.NoError(t, err) - assert.Equal(t, "https://www.host2.com", persistentAuth.Host) + assert.Equal(t, "https://www.host2.com", authArguments.Host) // Test host is not set. Should prompt. - persistentAuth.Host = "" - err = setHostAndAccountId(ctx, "", &persistentAuth, []string{}) + authArguments.Host = "" + err = setHostAndAccountId(ctx, profile.DefaultProfiler, "", &authArguments, []string{}) assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify a host using --host") } func TestSetAccountId(t *testing.T) { - var persistentAuth auth.PersistentAuth + var authArguments auth.AuthArguments t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg") ctx, _ := cmdio.SetupTest(context.Background()) // Test setting account-id from flag - persistentAuth.AccountID = "val from --account-id" - err := setHostAndAccountId(ctx, "account-profile", &persistentAuth, []string{}) + authArguments.AccountID = "val from --account-id" + err := setHostAndAccountId(ctx, profile.DefaultProfiler, "account-profile", &authArguments, []string{}) assert.NoError(t, err) - assert.Equal(t, "https://accounts.cloud.databricks.com", persistentAuth.Host) - assert.Equal(t, "val from --account-id", persistentAuth.AccountID) + assert.Equal(t, "https://accounts.cloud.databricks.com", authArguments.Host) + assert.Equal(t, "val from --account-id", authArguments.AccountID) // Test setting account_id from profile - persistentAuth.AccountID = "" - err = setHostAndAccountId(ctx, "account-profile", &persistentAuth, []string{}) + authArguments.AccountID = "" + err = setHostAndAccountId(ctx, profile.DefaultProfiler, "account-profile", &authArguments, []string{}) require.NoError(t, err) - assert.Equal(t, "https://accounts.cloud.databricks.com", persistentAuth.Host) - assert.Equal(t, "id-from-profile", persistentAuth.AccountID) + assert.Equal(t, "https://accounts.cloud.databricks.com", authArguments.Host) + assert.Equal(t, "id-from-profile", authArguments.AccountID) // Neither flag nor profile account-id is set, should prompt - persistentAuth.AccountID = "" - persistentAuth.Host = "https://accounts.cloud.databricks.com" - err = setHostAndAccountId(ctx, "", &persistentAuth, []string{}) + authArguments.AccountID = "" + authArguments.Host = "https://accounts.cloud.databricks.com" + err = setHostAndAccountId(ctx, profile.DefaultProfiler, "", &authArguments, []string{}) assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify an account ID using --account-id") } diff --git a/cmd/auth/token.go b/cmd/auth/token.go index f3468df402..dd0f95732e 100644 --- a/cmd/auth/token.go +++ b/cmd/auth/token.go @@ -5,44 +5,21 @@ import ( "encoding/json" "errors" "fmt" - "os" - "strings" "time" "github.com/databricks/cli/libs/auth" - "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/databricks/cli/libs/databrickscfg/profile" + "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/spf13/cobra" + "golang.org/x/oauth2" ) -type tokenErrorResponse struct { - Error string `json:"error"` - ErrorDescription string `json:"error_description"` -} - -func buildLoginCommand(profile string, persistentAuth *auth.PersistentAuth) string { - executable := os.Args[0] - cmd := []string{ - executable, - "auth", - "login", - } - if profile != "" { - cmd = append(cmd, "--profile", profile) - } else { - cmd = append(cmd, "--host", persistentAuth.Host) - if persistentAuth.AccountID != "" { - cmd = append(cmd, "--account-id", persistentAuth.AccountID) - } - } - return strings.Join(cmd, " ") -} - -func helpfulError(profile string, persistentAuth *auth.PersistentAuth) string { - loginMsg := buildLoginCommand(profile, persistentAuth) +func helpfulError(ctx context.Context, profile string, persistentAuth u2m.OAuthArgument) string { + loginMsg := auth.BuildLoginCommand(ctx, profile, persistentAuth) return fmt.Sprintf("Try logging in again with `%s` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", loginMsg) } -func newTokenCommand(persistentAuth *auth.PersistentAuth) *cobra.Command { +func newTokenCommand(authArguments *auth.AuthArguments) *cobra.Command { cmd := &cobra.Command{ Use: "token [HOST]", Short: "Get authentication token", @@ -58,42 +35,23 @@ using a client ID and secret is not supported.`, cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - - var profileName string + profileName := "" profileFlag := cmd.Flag("profile") if profileFlag != nil { profileName = profileFlag.Value.String() - // If a profile is provided we read the host from the .databrickscfg file - if profileName != "" && len(args) > 0 { - return errors.New("providing both a profile and host is not supported") - } } - err := setHostAndAccountId(ctx, profileName, persistentAuth, args) + t, err := loadToken(ctx, loadTokenArgs{ + authArguments: authArguments, + profileName: profileName, + args: args, + tokenTimeout: tokenTimeout, + profiler: profile.DefaultProfiler, + persistentAuthOpts: nil, + }) if err != nil { return err } - defer persistentAuth.Close() - - ctx, cancel := context.WithTimeout(ctx, tokenTimeout) - defer cancel() - t, err := persistentAuth.Load(ctx) - var httpErr *httpclient.HttpError - if errors.As(err, &httpErr) { - helpMsg := helpfulError(profileName, persistentAuth) - t := &tokenErrorResponse{} - err = json.Unmarshal([]byte(httpErr.Message), t) - if err != nil { - return fmt.Errorf("unexpected parsing token response: %w. %s", err, helpMsg) - } - if t.ErrorDescription == "Refresh token is invalid" { - return fmt.Errorf("a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run `%s`", buildLoginCommand(profileName, persistentAuth)) - } else { - return fmt.Errorf("unexpected error refreshing token: %s. %s", t.ErrorDescription, helpMsg) - } - } else if err != nil { - return fmt.Errorf("unexpected error refreshing token: %w. %s", err, helpfulError(profileName, persistentAuth)) - } raw, err := json.MarshalIndent(t, "", " ") if err != nil { return err @@ -104,3 +62,60 @@ using a client ID and secret is not supported.`, return cmd } + +type loadTokenArgs struct { + // authArguments is the parsed auth arguments, including the host and optionally the account ID. + authArguments *auth.AuthArguments + + // profileName is the name of the specified profile. If no profile is specified, this is an empty string. + profileName string + + // args is the list of arguments passed to the command. + args []string + + // tokenTimeout is the timeout for retrieving (and potentially refreshing) an OAuth token. + tokenTimeout time.Duration + + // profiler is the profiler to use for reading the host and account ID from the .databrickscfg file. + profiler profile.Profiler + + // persistentAuthOpts are the options to pass to the persistent auth client. + persistentAuthOpts []u2m.PersistentAuthOption +} + +// loadToken loads an OAuth token from the persistent auth store. The host and account ID are read from +// the provided profiler if not explicitly provided. If the token cannot be refreshed, a helpful error message +// is printed to the user with steps to reauthenticate. +func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { + // If a profile is provided we read the host from the .databrickscfg file + if args.profileName != "" && len(args.args) > 0 { + return nil, errors.New("providing both a profile and host is not supported") + } + + err := setHostAndAccountId(ctx, args.profiler, args.profileName, args.authArguments, args.args) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(ctx, args.tokenTimeout) + defer cancel() + oauthArgument, err := args.authArguments.ToOAuthArgument() + if err != nil { + return nil, err + } + allArgs := append(args.persistentAuthOpts, u2m.WithOAuthArgument(oauthArgument)) + persistentAuth, err := u2m.NewPersistentAuth(ctx, allArgs...) + if err != nil { + helpMsg := helpfulError(ctx, args.profileName, oauthArgument) + return nil, fmt.Errorf("%w. %s", err, helpMsg) + } + t, err := persistentAuth.Token() + if err != nil { + if err, ok := auth.RewriteAuthError(ctx, args.authArguments.Host, args.authArguments.AccountID, args.profileName, err); ok { + return nil, err + } + helpMsg := helpfulError(ctx, args.profileName, oauthArgument) + return nil, fmt.Errorf("%w. %s", err, helpMsg) + } + return t, nil +} diff --git a/cmd/auth/token_test.go b/cmd/auth/token_test.go index f47b419905..feb0b0ae59 100644 --- a/cmd/auth/token_test.go +++ b/cmd/auth/token_test.go @@ -1,20 +1,15 @@ -package auth_test +package auth import ( - "bytes" "context" - "encoding/json" + "net/http" "testing" "time" - "github.com/databricks/cli/cmd" - "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/auth" - "github.com/databricks/cli/libs/auth/cache" "github.com/databricks/cli/libs/databrickscfg/profile" - "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" - "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "golang.org/x/oauth2" ) @@ -30,7 +25,7 @@ var refreshFailureTokenResponse = fixtures.HTTPFixture{ var refreshFailureInvalidResponse = fixtures.HTTPFixture{ MatchAny: true, - Status: 401, + Status: 200, Response: "Not json", } @@ -53,15 +48,27 @@ var refreshSuccessTokenResponse = fixtures.HTTPFixture{ }, } -func validateToken(t *testing.T, resp string) { - res := map[string]string{} - err := json.Unmarshal([]byte(resp), &res) - assert.NoError(t, err) - assert.Equal(t, "new-access-token", res["access_token"]) - assert.Equal(t, "Bearer", res["token_type"]) +type MockApiClient struct{} + +// GetAccountOAuthEndpoints implements u2m.OAuthEndpointSupplier. +func (m *MockApiClient) GetAccountOAuthEndpoints(ctx context.Context, accountHost, accountId string) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ + TokenEndpoint: accountHost + "/token", + AuthorizationEndpoint: accountHost + "/authorize", + }, nil +} + +// GetWorkspaceOAuthEndpoints implements u2m.OAuthEndpointSupplier. +func (m *MockApiClient) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ + TokenEndpoint: workspaceHost + "/token", + AuthorizationEndpoint: workspaceHost + "/authorize", + }, nil } -func getContextForTest(f fixtures.HTTPFixture) context.Context { +var _ u2m.OAuthEndpointSupplier = (*MockApiClient)(nil) + +func TestToken_loadToken(t *testing.T) { profiler := profile.InMemoryProfiler{ Profiles: profile.Profiles{ { @@ -76,7 +83,7 @@ func getContextForTest(f fixtures.HTTPFixture) context.Context { }, }, } - tokenCache := &cache.InMemoryTokenCache{ + tokenCache := &inMemoryTokenCache{ Tokens: map[string]*oauth2.Token{ "https://accounts.cloud.databricks.com/oidc/accounts/expired": { RefreshToken: "expired", @@ -87,83 +94,130 @@ func getContextForTest(f fixtures.HTTPFixture) context.Context { }, }, } - client := httpclient.NewApiClient(httpclient.ClientConfig{ - Transport: fixtures.SliceTransport{f}, - }) - ctx := profile.WithProfiler(context.Background(), profiler) - ctx = cache.WithTokenCache(ctx, tokenCache) - ctx = auth.WithApiClientForOAuth(ctx, client) - return ctx -} - -func getCobraCmdForTest(f fixtures.HTTPFixture) (*cobra.Command, *bytes.Buffer) { - ctx := getContextForTest(f) - c := cmd.New(ctx) - output := &bytes.Buffer{} - c.SetOut(output) - return c, output -} - -func TestTokenCmdWithProfilePrintsHelpfulLoginMessageOnRefreshFailure(t *testing.T) { - cmd, output := getCobraCmdForTest(refreshFailureTokenResponse) - cmd.SetArgs([]string{"auth", "token", "--profile", "expired"}) - err := root.Execute(cmd.Context(), cmd) - - out := output.String() - assert.Empty(t, out) - assert.ErrorContains(t, err, "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run ") - assert.ErrorContains(t, err, "auth login --profile expired") -} - -func TestTokenCmdWithHostPrintsHelpfulLoginMessageOnRefreshFailure(t *testing.T) { - cmd, output := getCobraCmdForTest(refreshFailureTokenResponse) - cmd.SetArgs([]string{"auth", "token", "--host", "https://accounts.cloud.databricks.com", "--account-id", "expired"}) - err := root.Execute(cmd.Context(), cmd) - - out := output.String() - assert.Empty(t, out) - assert.ErrorContains(t, err, "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run ") - assert.ErrorContains(t, err, "auth login --host https://accounts.cloud.databricks.com --account-id expired") -} - -func TestTokenCmdInvalidResponse(t *testing.T) { - cmd, output := getCobraCmdForTest(refreshFailureInvalidResponse) - cmd.SetArgs([]string{"auth", "token", "--profile", "active"}) - err := root.Execute(cmd.Context(), cmd) - - out := output.String() - assert.Empty(t, out) - assert.ErrorContains(t, err, "unexpected parsing token response: invalid character 'N' looking for beginning of value. Try logging in again with ") - assert.ErrorContains(t, err, "auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new") -} - -func TestTokenCmdOtherErrorResponse(t *testing.T) { - cmd, output := getCobraCmdForTest(refreshFailureOtherError) - cmd.SetArgs([]string{"auth", "token", "--profile", "active"}) - err := root.Execute(cmd.Context(), cmd) - - out := output.String() - assert.Empty(t, out) - assert.ErrorContains(t, err, "unexpected error refreshing token: Databricks is down. Try logging in again with ") - assert.ErrorContains(t, err, "auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new") -} - -func TestTokenCmdWithProfileSuccess(t *testing.T) { - cmd, output := getCobraCmdForTest(refreshSuccessTokenResponse) - cmd.SetArgs([]string{"auth", "token", "--profile", "active"}) - err := root.Execute(cmd.Context(), cmd) - - out := output.String() - validateToken(t, out) - assert.NoError(t, err) -} - -func TestTokenCmdWithHostSuccess(t *testing.T) { - cmd, output := getCobraCmdForTest(refreshSuccessTokenResponse) - cmd.SetArgs([]string{"auth", "token", "--host", "https://accounts.cloud.databricks.com", "--account-id", "expired"}) - err := root.Execute(cmd.Context(), cmd) + validateToken := func(resp *oauth2.Token) { + assert.Equal(t, "new-access-token", resp.AccessToken) + assert.Equal(t, "Bearer", resp.TokenType) + } - out := output.String() - validateToken(t, out) - assert.NoError(t, err) + cases := []struct { + name string + args loadTokenArgs + validateToken func(*oauth2.Token) + wantErr string + }{ + { + name: "prints helpful login message on refresh failure when profile is specified", + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{}, + profileName: "expired", + args: []string{}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshFailureTokenResponse}}), + }, + }, + wantErr: `A new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run the following command: + $ databricks auth login --profile expired`, + }, + { + name: "prints helpful login message on refresh failure when host is specified", + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{ + Host: "https://accounts.cloud.databricks.com", + AccountID: "expired", + }, + profileName: "", + args: []string{}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshFailureTokenResponse}}), + }, + }, + wantErr: `A new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run the following command: + $ databricks auth login --host https://accounts.cloud.databricks.com --account-id expired`, + }, + { + name: "prints helpful login message on invalid response", + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{}, + profileName: "active", + args: []string{}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshFailureInvalidResponse}}), + }, + }, + wantErr: "token refresh: oauth2: cannot parse json: invalid character 'N' looking for beginning of value. Try logging in again with " + + "`databricks auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", + }, + { + name: "prints helpful login message on other error response", + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{}, + profileName: "active", + args: []string{}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshFailureOtherError}}), + }, + }, + wantErr: "token refresh: Databricks is down (error code: other_error). Try logging in again with " + + "`databricks auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", + }, + { + name: "succeeds with profile", + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{}, + profileName: "active", + args: []string{}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}), + }, + }, + validateToken: validateToken, + }, + { + name: "succeeds with host", + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{Host: "https://accounts.cloud.databricks.com", AccountID: "active"}, + profileName: "", + args: []string{}, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}), + }, + }, + validateToken: validateToken, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got, err := loadToken(context.Background(), c.args) + if c.wantErr != "" { + assert.Equal(t, c.wantErr, err.Error()) + } else { + assert.NoError(t, err) + c.validateToken(got) + } + }) + } } diff --git a/cmd/root/auth.go b/cmd/root/auth.go index 15e32b9cd4..406631d24b 100644 --- a/cmd/root/auth.go +++ b/cmd/root/auth.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" + "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/databrickscfg/profile" @@ -137,7 +138,7 @@ func MustAccountClient(cmd *cobra.Command, args []string) error { allowPrompt := !hasProfileFlag && !shouldSkipPrompt(cmd.Context()) a, err := accountClientOrPrompt(cmd.Context(), cfg, allowPrompt) if err != nil { - return err + return renderError(ctx, cfg, err) } ctx = cmdctx.SetAccountClient(ctx, a) @@ -220,7 +221,7 @@ func MustWorkspaceClient(cmd *cobra.Command, args []string) error { allowPrompt := !hasProfileFlag && !shouldSkipPrompt(cmd.Context()) w, err := workspaceClientOrPrompt(cmd.Context(), cfg, allowPrompt) if err != nil { - return err + return renderError(ctx, cfg, err) } ctx = cmdctx.SetWorkspaceClient(ctx, w) @@ -306,3 +307,8 @@ func emptyHttpRequest(ctx context.Context) *http.Request { } return req } + +func renderError(ctx context.Context, cfg *config.Config, err error) error { + err, _ = auth.RewriteAuthError(ctx, cfg.Host, cfg.AccountID, cfg.Profile, err) + return err +} diff --git a/libs/auth/arguments.go b/libs/auth/arguments.go new file mode 100644 index 0000000000..17957ec511 --- /dev/null +++ b/libs/auth/arguments.go @@ -0,0 +1,25 @@ +package auth + +import ( + "github.com/databricks/databricks-sdk-go/config" + "github.com/databricks/databricks-sdk-go/credentials/u2m" +) + +// AuthArguments is a struct that contains the common arguments passed to +// `databricks auth` commands. +type AuthArguments struct { + Host string + AccountID string +} + +// ToOAuthArgument converts the AuthArguments to an OAuthArgument from the Go SDK. +func (a AuthArguments) ToOAuthArgument() (u2m.OAuthArgument, error) { + cfg := &config.Config{ + Host: a.Host, + AccountID: a.AccountID, + } + if cfg.IsAccountClient() { + return u2m.NewBasicAccountOAuthArgument(cfg.Host, cfg.AccountID) + } + return u2m.NewBasicWorkspaceOAuthArgument(cfg.Host) +} diff --git a/libs/auth/cache/cache.go b/libs/auth/cache/cache.go deleted file mode 100644 index 097353e74c..0000000000 --- a/libs/auth/cache/cache.go +++ /dev/null @@ -1,26 +0,0 @@ -package cache - -import ( - "context" - - "golang.org/x/oauth2" -) - -type TokenCache interface { - Store(key string, t *oauth2.Token) error - Lookup(key string) (*oauth2.Token, error) -} - -var tokenCache int - -func WithTokenCache(ctx context.Context, c TokenCache) context.Context { - return context.WithValue(ctx, &tokenCache, c) -} - -func GetTokenCache(ctx context.Context) TokenCache { - c, ok := ctx.Value(&tokenCache).(TokenCache) - if !ok { - return &FileTokenCache{} - } - return c -} diff --git a/libs/auth/cache/file.go b/libs/auth/cache/file.go deleted file mode 100644 index 38dfea9f2c..0000000000 --- a/libs/auth/cache/file.go +++ /dev/null @@ -1,108 +0,0 @@ -package cache - -import ( - "encoding/json" - "errors" - "fmt" - "io/fs" - "os" - "path/filepath" - - "golang.org/x/oauth2" -) - -const ( - // where the token cache is stored - tokenCacheFile = ".databricks/token-cache.json" - - // only the owner of the file has full execute, read, and write access - ownerExecReadWrite = 0o700 - - // only the owner of the file has full read and write access - ownerReadWrite = 0o600 - - // format versioning leaves some room for format improvement - tokenCacheVersion = 1 -) - -var ErrNotConfigured = errors.New("databricks OAuth is not configured for this host") - -// this implementation requires the calling code to do a machine-wide lock, -// otherwise the file might get corrupt. -type FileTokenCache struct { - Version int `json:"version"` - Tokens map[string]*oauth2.Token `json:"tokens"` - - fileLocation string -} - -func (c *FileTokenCache) Store(key string, t *oauth2.Token) error { - err := c.load() - if errors.Is(err, fs.ErrNotExist) { - dir := filepath.Dir(c.fileLocation) - err = os.MkdirAll(dir, ownerExecReadWrite) - if err != nil { - return fmt.Errorf("mkdir: %w", err) - } - } else if err != nil { - return fmt.Errorf("load: %w", err) - } - c.Version = tokenCacheVersion - if c.Tokens == nil { - c.Tokens = map[string]*oauth2.Token{} - } - c.Tokens[key] = t - raw, err := json.MarshalIndent(c, "", " ") - if err != nil { - return fmt.Errorf("marshal: %w", err) - } - return os.WriteFile(c.fileLocation, raw, ownerReadWrite) -} - -func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) { - err := c.load() - if errors.Is(err, fs.ErrNotExist) { - return nil, ErrNotConfigured - } else if err != nil { - return nil, fmt.Errorf("load: %w", err) - } - t, ok := c.Tokens[key] - if !ok { - return nil, ErrNotConfigured - } - return t, nil -} - -func (c *FileTokenCache) location() (string, error) { - home, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("home: %w", err) - } - return filepath.Join(home, tokenCacheFile), nil -} - -func (c *FileTokenCache) load() error { - loc, err := c.location() - if err != nil { - return err - } - c.fileLocation = loc - raw, err := os.ReadFile(loc) - if err != nil { - return fmt.Errorf("read: %w", err) - } - err = json.Unmarshal(raw, c) - if err != nil { - return fmt.Errorf("parse: %w", err) - } - if c.Version != tokenCacheVersion { - // in the later iterations we could do state upgraders, - // so that we transform token cache from v1 to v2 without - // losing the tokens and asking the user to re-authenticate. - return fmt.Errorf("needs version %d, got version %d", - tokenCacheVersion, c.Version) - } - return nil -} - -var _ TokenCache = (*FileTokenCache)(nil) diff --git a/libs/auth/cache/file_test.go b/libs/auth/cache/file_test.go deleted file mode 100644 index 54964bed3c..0000000000 --- a/libs/auth/cache/file_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package cache - -import ( - "os" - "path/filepath" - "runtime" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/oauth2" -) - -var homeEnvVar = "HOME" - -func init() { - if runtime.GOOS == "windows" { - homeEnvVar = "USERPROFILE" - } -} - -func setup(t *testing.T) string { - tempHomeDir := t.TempDir() - t.Setenv(homeEnvVar, tempHomeDir) - return tempHomeDir -} - -func TestStoreAndLookup(t *testing.T) { - setup(t) - c := &FileTokenCache{} - err := c.Store("x", &oauth2.Token{ - AccessToken: "abc", - }) - require.NoError(t, err) - - err = c.Store("y", &oauth2.Token{ - AccessToken: "bcd", - }) - require.NoError(t, err) - - l := &FileTokenCache{} - tok, err := l.Lookup("x") - require.NoError(t, err) - assert.Equal(t, "abc", tok.AccessToken) - assert.Len(t, l.Tokens, 2) - - _, err = l.Lookup("z") - assert.Equal(t, ErrNotConfigured, err) -} - -func TestNoCacheFileReturnsErrNotConfigured(t *testing.T) { - setup(t) - l := &FileTokenCache{} - _, err := l.Lookup("x") - assert.Equal(t, ErrNotConfigured, err) -} - -func TestLoadCorruptFile(t *testing.T) { - home := setup(t) - f := filepath.Join(home, tokenCacheFile) - err := os.MkdirAll(filepath.Dir(f), ownerExecReadWrite) - require.NoError(t, err) - err = os.WriteFile(f, []byte("abc"), ownerExecReadWrite) - require.NoError(t, err) - - l := &FileTokenCache{} - _, err = l.Lookup("x") - assert.EqualError(t, err, "load: parse: invalid character 'a' looking for beginning of value") -} - -func TestLoadWrongVersion(t *testing.T) { - home := setup(t) - f := filepath.Join(home, tokenCacheFile) - err := os.MkdirAll(filepath.Dir(f), ownerExecReadWrite) - require.NoError(t, err) - err = os.WriteFile(f, []byte(`{"version": 823, "things": []}`), ownerExecReadWrite) - require.NoError(t, err) - - l := &FileTokenCache{} - _, err = l.Lookup("x") - assert.EqualError(t, err, "load: needs version 1, got version 823") -} - -func TestDevNull(t *testing.T) { - t.Setenv(homeEnvVar, "/dev/null") - l := &FileTokenCache{} - _, err := l.Lookup("x") - // macOS/Linux: load: read: open /dev/null/.databricks/token-cache.json: - // windows: databricks OAuth is not configured for this host - assert.Error(t, err) -} - -func TestStoreOnDev(t *testing.T) { - if runtime.GOOS == "windows" { - t.SkipNow() - } - t.Setenv(homeEnvVar, "/dev") - c := &FileTokenCache{} - err := c.Store("x", &oauth2.Token{ - AccessToken: "abc", - }) - // Linux: permission denied - // macOS: read-only file system - assert.Error(t, err) -} diff --git a/libs/auth/cache/in_memory.go b/libs/auth/cache/in_memory.go deleted file mode 100644 index 469d45575a..0000000000 --- a/libs/auth/cache/in_memory.go +++ /dev/null @@ -1,26 +0,0 @@ -package cache - -import ( - "golang.org/x/oauth2" -) - -type InMemoryTokenCache struct { - Tokens map[string]*oauth2.Token -} - -// Lookup implements TokenCache. -func (i *InMemoryTokenCache) Lookup(key string) (*oauth2.Token, error) { - token, ok := i.Tokens[key] - if !ok { - return nil, ErrNotConfigured - } - return token, nil -} - -// Store implements TokenCache. -func (i *InMemoryTokenCache) Store(key string, t *oauth2.Token) error { - i.Tokens[key] = t - return nil -} - -var _ TokenCache = (*InMemoryTokenCache)(nil) diff --git a/libs/auth/cache/in_memory_test.go b/libs/auth/cache/in_memory_test.go deleted file mode 100644 index d8394d3b26..0000000000 --- a/libs/auth/cache/in_memory_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package cache - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "golang.org/x/oauth2" -) - -func TestInMemoryCacheHit(t *testing.T) { - token := &oauth2.Token{ - AccessToken: "abc", - } - c := &InMemoryTokenCache{ - Tokens: map[string]*oauth2.Token{ - "key": token, - }, - } - res, err := c.Lookup("key") - assert.Equal(t, res, token) - assert.NoError(t, err) -} - -func TestInMemoryCacheMiss(t *testing.T) { - c := &InMemoryTokenCache{ - Tokens: map[string]*oauth2.Token{}, - } - _, err := c.Lookup("key") - assert.ErrorIs(t, err, ErrNotConfigured) -} - -func TestInMemoryCacheStore(t *testing.T) { - token := &oauth2.Token{ - AccessToken: "abc", - } - c := &InMemoryTokenCache{ - Tokens: map[string]*oauth2.Token{}, - } - err := c.Store("key", token) - assert.NoError(t, err) - res, err := c.Lookup("key") - assert.Equal(t, res, token) - assert.NoError(t, err) -} diff --git a/libs/auth/callback.go b/libs/auth/callback.go deleted file mode 100644 index 3893a5041a..0000000000 --- a/libs/auth/callback.go +++ /dev/null @@ -1,104 +0,0 @@ -package auth - -import ( - "context" - _ "embed" - "fmt" - "html/template" - "net" - "net/http" - "strings" - - "golang.org/x/text/cases" - "golang.org/x/text/language" -) - -//go:embed page.tmpl -var pageTmpl string - -type oauthResult struct { - Error string - ErrorDescription string - Host string - State string - Code string -} - -type callbackServer struct { - ln net.Listener - srv http.Server - ctx context.Context - a *PersistentAuth - renderErrCh chan error - feedbackCh chan oauthResult - tmpl *template.Template -} - -func newCallback(ctx context.Context, a *PersistentAuth) (*callbackServer, error) { - tmpl, err := template.New("page").Funcs(template.FuncMap{ - "title": func(in string) string { - title := cases.Title(language.English) - return title.String(strings.ReplaceAll(in, "_", " ")) - }, - }).Parse(pageTmpl) - if err != nil { - return nil, err - } - cb := &callbackServer{ - feedbackCh: make(chan oauthResult), - renderErrCh: make(chan error), - tmpl: tmpl, - ctx: ctx, - ln: a.ln, - a: a, - } - cb.srv.Handler = cb - go func() { - _ = cb.srv.Serve(cb.ln) - }() - return cb, nil -} - -func (cb *callbackServer) Close() error { - return cb.srv.Close() -} - -// ServeHTTP renders page.html template -func (cb *callbackServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - res := oauthResult{ - Error: r.FormValue("error"), - ErrorDescription: r.FormValue("error_description"), - Code: r.FormValue("code"), - State: r.FormValue("state"), - Host: cb.a.Host, - } - if res.Error != "" { - w.WriteHeader(http.StatusBadRequest) - } else { - w.WriteHeader(http.StatusOK) - } - err := cb.tmpl.Execute(w, res) - if err != nil { - cb.renderErrCh <- err - } - cb.feedbackCh <- res -} - -// Handler opens up a browser waits for redirect to come back from the identity provider -func (cb *callbackServer) Handler(authCodeURL string) (string, string, error) { - err := cb.a.browser(authCodeURL) - if err != nil { - fmt.Printf("Please open %s in the browser to continue authentication", authCodeURL) - } - select { - case <-cb.ctx.Done(): - return "", "", cb.ctx.Err() - case renderErr := <-cb.renderErrCh: - return "", "", renderErr - case res := <-cb.feedbackCh: - if res.Error != "" { - return "", "", fmt.Errorf("%s: %s", res.Error, res.ErrorDescription) - } - return res.Code, res.State, nil - } -} diff --git a/libs/auth/error.go b/libs/auth/error.go new file mode 100644 index 0000000000..1bf0e9b519 --- /dev/null +++ b/libs/auth/error.go @@ -0,0 +1,45 @@ +package auth + +import ( + "context" + "errors" + "strings" + + "github.com/databricks/databricks-sdk-go/credentials/u2m" +) + +// RewriteAuthError rewrites the error message for invalid refresh token error. +// It returns the rewritten error and a boolean indicating whether the error was rewritten. +func RewriteAuthError(ctx context.Context, host, accountId, profile string, err error) (error, bool) { + target := &u2m.InvalidRefreshTokenError{} + if errors.As(err, &target) { + oauthArgument, err := AuthArguments{host, accountId}.ToOAuthArgument() + if err != nil { + return err, false + } + msg := `A new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run the following command: + $ ` + BuildLoginCommand(ctx, profile, oauthArgument) + return errors.New(msg), true + } + return err, false +} + +// BuildLoginCommand builds the login command for the given OAuth argument or profile. +func BuildLoginCommand(ctx context.Context, profile string, arg u2m.OAuthArgument) string { + cmd := []string{ + "databricks", + "auth", + "login", + } + if profile != "" { + cmd = append(cmd, "--profile", profile) + } else { + switch arg := arg.(type) { + case u2m.AccountOAuthArgument: + cmd = append(cmd, "--host", arg.GetAccountHost(), "--account-id", arg.GetAccountId()) + case u2m.WorkspaceOAuthArgument: + cmd = append(cmd, "--host", arg.GetWorkspaceHost()) + } + } + return strings.Join(cmd, " ") +} diff --git a/libs/auth/oauth.go b/libs/auth/oauth.go deleted file mode 100644 index 1037a5a852..0000000000 --- a/libs/auth/oauth.go +++ /dev/null @@ -1,289 +0,0 @@ -package auth - -import ( - "context" - "crypto/rand" - "crypto/sha256" - _ "embed" - "encoding/base64" - "errors" - "fmt" - "net" - "net/url" - "strings" - "time" - - "github.com/databricks/cli/libs/auth/cache" - "github.com/databricks/databricks-sdk-go/httpclient" - "github.com/databricks/databricks-sdk-go/retries" - "github.com/pkg/browser" - "golang.org/x/oauth2" - "golang.org/x/oauth2/authhandler" -) - -var apiClientForOauth int - -func WithApiClientForOAuth(ctx context.Context, c *httpclient.ApiClient) context.Context { - return context.WithValue(ctx, &apiClientForOauth, c) -} - -func GetApiClientForOAuth(ctx context.Context) *httpclient.ApiClient { - c, ok := ctx.Value(&apiClientForOauth).(*httpclient.ApiClient) - if !ok { - return httpclient.NewApiClient(httpclient.ClientConfig{}) - } - return c -} - -const ( - // these values are predefined by Databricks as a public client - // and is specific to this application only. Using these values - // for other applications is not allowed. - appClientID = "databricks-cli" - appRedirectAddr = "localhost:8020" - - // maximum amount of time to acquire listener on appRedirectAddr - listenerTimeout = 45 * time.Second -) - -var ( // Databricks SDK API: `databricks OAuth is not` will be checked for presence - ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host") - ErrNotConfigured = errors.New("databricks OAuth is not configured for this host") - ErrFetchCredentials = errors.New("cannot fetch credentials") -) - -type PersistentAuth struct { - Host string - AccountID string - - http *httpclient.ApiClient - cache cache.TokenCache - ln net.Listener - browser func(string) error -} - -func (a *PersistentAuth) SetApiClient(h *httpclient.ApiClient) { - a.http = h -} - -func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) { - err := a.init(ctx) - if err != nil { - return nil, fmt.Errorf("init: %w", err) - } - // lookup token identified by host (and possibly the account id) - key := a.key() - t, err := a.cache.Lookup(key) - if err != nil { - return nil, fmt.Errorf("cache: %w", err) - } - // early return for valid tokens - if t.Valid() { - // do not print refresh token to end-user - t.RefreshToken = "" - return t, nil - } - // OAuth2 config is invoked only for expired tokens to speed up - // the happy path in the token retrieval - cfg, err := a.oauth2Config(ctx) - if err != nil { - return nil, err - } - // make OAuth2 library use our client - ctx = a.http.InContextForOAuth2(ctx) - // eagerly refresh token - refreshed, err := cfg.TokenSource(ctx, t).Token() - if err != nil { - return nil, fmt.Errorf("token refresh: %w", err) - } - err = a.cache.Store(key, refreshed) - if err != nil { - return nil, fmt.Errorf("cache refresh: %w", err) - } - // do not print refresh token to end-user - refreshed.RefreshToken = "" - return refreshed, nil -} - -func (a *PersistentAuth) ProfileName() string { - if a.AccountID != "" { - return "ACCOUNT-" + a.AccountID - } - host := strings.TrimPrefix(a.Host, "https://") - split := strings.Split(host, ".") - return split[0] -} - -func (a *PersistentAuth) Challenge(ctx context.Context) error { - err := a.init(ctx) - if err != nil { - return fmt.Errorf("init: %w", err) - } - cfg, err := a.oauth2Config(ctx) - if err != nil { - return err - } - cb, err := newCallback(ctx, a) - if err != nil { - return fmt.Errorf("callback server: %w", err) - } - defer cb.Close() - state, pkce := a.stateAndPKCE() - // make OAuth2 library use our client - ctx = a.http.InContextForOAuth2(ctx) - ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce) - t, err := ts.Token() - if err != nil { - return fmt.Errorf("authorize: %w", err) - } - // cache token identified by host (and possibly the account id) - err = a.cache.Store(a.key(), t) - if err != nil { - return fmt.Errorf("store: %w", err) - } - return nil -} - -// This function cleans up the host URL by only retaining the scheme and the host. -// This function thus removes any path, query arguments, or fragments from the URL. -func (a *PersistentAuth) cleanHost() { - parsedHost, err := url.Parse(a.Host) - if err != nil { - return - } - // when either host or scheme is empty, we don't want to clean it. This is because - // the Go url library parses a raw "abc" string as the path of a URL and cleaning - // it will return thus return an empty string. - if parsedHost.Host == "" || parsedHost.Scheme == "" { - return - } - host := url.URL{ - Scheme: parsedHost.Scheme, - Host: parsedHost.Host, - } - a.Host = host.String() -} - -func (a *PersistentAuth) init(ctx context.Context) error { - if a.Host == "" && a.AccountID == "" { - return ErrFetchCredentials - } - if a.http == nil { - a.http = GetApiClientForOAuth(ctx) - } - if a.cache == nil { - a.cache = cache.GetTokenCache(ctx) - } - if a.browser == nil { - a.browser = browser.OpenURL - } - - a.cleanHost() - - // try acquire listener, which we also use as a machine-local - // exclusive lock to prevent token cache corruption in the scope - // of developer machine, where this command runs. - listener, err := retries.Poll(ctx, listenerTimeout, - func() (*net.Listener, *retries.Err) { - var lc net.ListenConfig - l, err := lc.Listen(ctx, "tcp", appRedirectAddr) - if err != nil { - return nil, retries.Continue(err) - } - return &l, nil - }) - if err != nil { - return fmt.Errorf("listener: %w", err) - } - a.ln = *listener - return nil -} - -func (a *PersistentAuth) Close() error { - if a.ln == nil { - return nil - } - return a.ln.Close() -} - -func (a *PersistentAuth) oidcEndpoints(ctx context.Context) (*oauthAuthorizationServer, error) { - prefix := a.key() - if a.AccountID != "" { - return &oauthAuthorizationServer{ - AuthorizationEndpoint: prefix + "/v1/authorize", - TokenEndpoint: prefix + "/v1/token", - }, nil - } - var oauthEndpoints oauthAuthorizationServer - oidc := prefix + "/oidc/.well-known/oauth-authorization-server" - err := a.http.Do(ctx, "GET", oidc, httpclient.WithResponseUnmarshal(&oauthEndpoints)) - if err != nil { - return nil, fmt.Errorf("fetch .well-known: %w", err) - } - var httpErr *httpclient.HttpError - if errors.As(err, &httpErr) && httpErr.StatusCode != 200 { - return nil, ErrOAuthNotSupported - } - return &oauthEndpoints, nil -} - -func (a *PersistentAuth) oauth2Config(ctx context.Context) (*oauth2.Config, error) { - // in this iteration of CLI, we're using all scopes by default, - // because tools like CLI and Terraform do use all apis. This - // decision may be reconsidered later, once we have a proper - // taxonomy of all scopes ready and implemented. - scopes := []string{ - "offline_access", - "all-apis", - } - endpoints, err := a.oidcEndpoints(ctx) - if err != nil { - return nil, fmt.Errorf("oidc: %w", err) - } - return &oauth2.Config{ - ClientID: appClientID, - Endpoint: oauth2.Endpoint{ - AuthURL: endpoints.AuthorizationEndpoint, - TokenURL: endpoints.TokenEndpoint, - AuthStyle: oauth2.AuthStyleInParams, - }, - RedirectURL: "http://" + appRedirectAddr, - Scopes: scopes, - }, nil -} - -// key is currently used for two purposes: OIDC URL prefix and token cache key. -// once we decide to start storing scopes in the token cache, we should change -// this approach. -func (a *PersistentAuth) key() string { - a.Host = strings.TrimSuffix(a.Host, "/") - if !strings.HasPrefix(a.Host, "http") { - a.Host = "https://" + a.Host - } - if a.AccountID != "" { - return fmt.Sprintf("%s/oidc/accounts/%s", a.Host, a.AccountID) - } - return a.Host -} - -func (a *PersistentAuth) stateAndPKCE() (string, *authhandler.PKCEParams) { - verifier := a.randomString(64) - verifierSha256 := sha256.Sum256([]byte(verifier)) - challenge := base64.RawURLEncoding.EncodeToString(verifierSha256[:]) - return a.randomString(16), &authhandler.PKCEParams{ - Challenge: challenge, - ChallengeMethod: "S256", - Verifier: verifier, - } -} - -func (a *PersistentAuth) randomString(size int) string { - raw := make([]byte, size) - _, _ = rand.Read(raw) - return base64.RawURLEncoding.EncodeToString(raw) -} - -type oauthAuthorizationServer struct { - AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize - TokenEndpoint string `json:"token_endpoint"` // ../v1/token -} diff --git a/libs/auth/oauth_test.go b/libs/auth/oauth_test.go deleted file mode 100644 index 6c3b9bf477..0000000000 --- a/libs/auth/oauth_test.go +++ /dev/null @@ -1,267 +0,0 @@ -package auth - -import ( - "context" - "crypto/tls" - _ "embed" - "fmt" - "net/http" - "net/url" - "testing" - "time" - - "github.com/databricks/databricks-sdk-go/client" - "github.com/databricks/databricks-sdk-go/httpclient" - "github.com/databricks/databricks-sdk-go/httpclient/fixtures" - "github.com/databricks/databricks-sdk-go/qa" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/oauth2" -) - -func TestOidcEndpointsForAccounts(t *testing.T) { - p := &PersistentAuth{ - Host: "abc", - AccountID: "xyz", - } - defer p.Close() - s, err := p.oidcEndpoints(context.Background()) - assert.NoError(t, err) - assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", s.AuthorizationEndpoint) - assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/token", s.TokenEndpoint) -} - -func TestOidcForWorkspace(t *testing.T) { - p := &PersistentAuth{ - Host: "abc", - http: httpclient.NewApiClient(httpclient.ClientConfig{ - Transport: fixtures.MappingTransport{ - "GET /oidc/.well-known/oauth-authorization-server": { - Status: 200, - Response: map[string]string{ - "authorization_endpoint": "a", - "token_endpoint": "b", - }, - }, - }, - }), - } - defer p.Close() - endpoints, err := p.oidcEndpoints(context.Background()) - assert.NoError(t, err) - assert.Equal(t, "a", endpoints.AuthorizationEndpoint) - assert.Equal(t, "b", endpoints.TokenEndpoint) -} - -type tokenCacheMock struct { - store func(key string, t *oauth2.Token) error - lookup func(key string) (*oauth2.Token, error) -} - -func (m *tokenCacheMock) Store(key string, t *oauth2.Token) error { - if m.store == nil { - panic("no store mock") - } - return m.store(key, t) -} - -func (m *tokenCacheMock) Lookup(key string) (*oauth2.Token, error) { - if m.lookup == nil { - panic("no lookup mock") - } - return m.lookup(key) -} - -func TestLoad(t *testing.T) { - p := &PersistentAuth{ - Host: "abc", - AccountID: "xyz", - cache: &tokenCacheMock{ - lookup: func(key string) (*oauth2.Token, error) { - assert.Equal(t, "https://abc/oidc/accounts/xyz", key) - return &oauth2.Token{ - AccessToken: "bcd", - Expiry: time.Now().Add(1 * time.Minute), - }, nil - }, - }, - } - defer p.Close() - tok, err := p.Load(context.Background()) - assert.NoError(t, err) - assert.Equal(t, "bcd", tok.AccessToken) - assert.Equal(t, "", tok.RefreshToken) -} - -func useInsecureOAuthHttpClientForTests(ctx context.Context) context.Context { - return context.WithValue(ctx, oauth2.HTTPClient, &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - }, - }) -} - -func TestLoadRefresh(t *testing.T) { - qa.HTTPFixtures{ - { - Method: "POST", - Resource: "/oidc/accounts/xyz/v1/token", - Response: `access_token=refreshed&refresh_token=def`, - }, - }.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) { - ctx = useInsecureOAuthHttpClientForTests(ctx) - expectedKey := c.Config.Host + "/oidc/accounts/xyz" - p := &PersistentAuth{ - Host: c.Config.Host, - AccountID: "xyz", - cache: &tokenCacheMock{ - lookup: func(key string) (*oauth2.Token, error) { - assert.Equal(t, expectedKey, key) - return &oauth2.Token{ - AccessToken: "expired", - RefreshToken: "cde", - Expiry: time.Now().Add(-1 * time.Minute), - }, nil - }, - store: func(key string, tok *oauth2.Token) error { - assert.Equal(t, expectedKey, key) - assert.Equal(t, "def", tok.RefreshToken) - return nil - }, - }, - } - defer p.Close() - tok, err := p.Load(ctx) - assert.NoError(t, err) - assert.Equal(t, "refreshed", tok.AccessToken) - assert.Equal(t, "", tok.RefreshToken) - }) -} - -func TestChallenge(t *testing.T) { - qa.HTTPFixtures{ - { - Method: "POST", - Resource: "/oidc/accounts/xyz/v1/token", - Response: `access_token=__THAT__&refresh_token=__SOMETHING__`, - }, - }.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) { - ctx = useInsecureOAuthHttpClientForTests(ctx) - expectedKey := c.Config.Host + "/oidc/accounts/xyz" - - browserOpened := make(chan string) - p := &PersistentAuth{ - Host: c.Config.Host, - AccountID: "xyz", - browser: func(redirect string) error { - u, err := url.ParseRequestURI(redirect) - if err != nil { - return err - } - assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path) - // for now we're ignoring asserting the fields of the redirect - query := u.Query() - browserOpened <- query.Get("state") - return nil - }, - cache: &tokenCacheMock{ - store: func(key string, tok *oauth2.Token) error { - assert.Equal(t, expectedKey, key) - assert.Equal(t, "__SOMETHING__", tok.RefreshToken) - return nil - }, - }, - } - defer p.Close() - - errc := make(chan error) - go func() { - errc <- p.Challenge(ctx) - }() - - state := <-browserOpened - resp, err := http.Get(fmt.Sprintf("http://%s?code=__THIS__&state=%s", appRedirectAddr, state)) - require.NoError(t, err) - defer resp.Body.Close() - assert.Equal(t, 200, resp.StatusCode) - - err = <-errc - assert.NoError(t, err) - }) -} - -func TestChallengeFailed(t *testing.T) { - qa.HTTPFixtures{}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) { - ctx = useInsecureOAuthHttpClientForTests(ctx) - - browserOpened := make(chan string) - p := &PersistentAuth{ - Host: c.Config.Host, - AccountID: "xyz", - browser: func(redirect string) error { - u, err := url.ParseRequestURI(redirect) - if err != nil { - return err - } - assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path) - // for now we're ignoring asserting the fields of the redirect - query := u.Query() - browserOpened <- query.Get("state") - return nil - }, - } - defer p.Close() - - errc := make(chan error) - go func() { - errc <- p.Challenge(ctx) - }() - - <-browserOpened - resp, err := http.Get(fmt.Sprintf( - "http://%s?error=access_denied&error_description=Policy%%20evaluation%%20failed%%20for%%20this%%20request", - appRedirectAddr)) - require.NoError(t, err) - defer resp.Body.Close() - assert.Equal(t, 400, resp.StatusCode) - - err = <-errc - assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request") - }) -} - -func TestPersistentAuthCleanHost(t *testing.T) { - for _, tcases := range []struct { - in string - out string - }{ - {"https://example.com", "https://example.com"}, - {"https://example.com/", "https://example.com"}, - {"https://example.com/path", "https://example.com"}, - {"https://example.com/path/subpath", "https://example.com"}, - {"https://example.com/path?query=1", "https://example.com"}, - {"https://example.com/path?query=1&other=2", "https://example.com"}, - {"https://example.com/path#fragment", "https://example.com"}, - {"https://example.com/path?query=1#fragment", "https://example.com"}, - {"https://example.com/path?query=1&other=2#fragment", "https://example.com"}, - {"https://example.com/path/subpath?query=1", "https://example.com"}, - {"https://example.com/path/subpath?query=1&other=2", "https://example.com"}, - {"https://example.com/path/subpath#fragment", "https://example.com"}, - {"https://example.com/path/subpath?query=1#fragment", "https://example.com"}, - {"https://example.com/path/subpath?query=1&other=2#fragment", "https://example.com"}, - {"https://example.com/path?query=1%20value&other=2%20value", "https://example.com"}, - {"http://example.com/path/subpath?query=1%20value&other=2%20value", "http://example.com"}, - - // URLs without scheme should be left as is - {"abc", "abc"}, - {"abc.com/def", "abc.com/def"}, - } { - p := &PersistentAuth{ - Host: tcases.in, - } - p.cleanHost() - assert.Equal(t, tcases.out, p.Host) - } -} diff --git a/libs/auth/page.tmpl b/libs/auth/page.tmpl deleted file mode 100644 index 4642bb3d47..0000000000 --- a/libs/auth/page.tmpl +++ /dev/null @@ -1,102 +0,0 @@ - - - - - {{if .Error }}{{ .Error | title }}{{ else }}Success{{end}} - - - - - - - -
-
- - -
{{ .Error | title }}
-
{{ .ErrorDescription }}
- -
Authenticated
-
Go to {{.Host}}
- -
- You can close this tab. Or go to documentation -
-
-
- -