From e07b32b91b1ea1fb7fb1f84ee749ce7f6cc4df8f Mon Sep 17 00:00:00 2001 From: Renaud Hartert Date: Thu, 6 Feb 2025 14:48:48 +0100 Subject: [PATCH] TokenSource --- config/experimental/auth/auth.go | 32 ++++++++++------- config/experimental/auth/auth_test.go | 5 +-- config/experimental/auth/authconv/authconv.go | 34 +++++++++++++++++++ .../auth/authconv/authconv_test.go | 31 +++++++++++++++++ config/oauth_visitors.go | 19 +++++++---- 5 files changed, 100 insertions(+), 21 deletions(-) create mode 100644 config/experimental/auth/authconv/authconv.go create mode 100644 config/experimental/auth/authconv/authconv_test.go diff --git a/config/experimental/auth/auth.go b/config/experimental/auth/auth.go index 2f560498b..abea1aefd 100644 --- a/config/experimental/auth/auth.go +++ b/config/experimental/auth/auth.go @@ -5,6 +5,7 @@ package auth import ( + "context" "sync" "time" @@ -21,6 +22,13 @@ const ( defaultDisableAsyncRefresh = true ) +// A TokenSource is anything that can return a token. +type TokenSource interface { + // Token returns a token or an error. Token must be safe for concurrent use + // by multiple goroutines. The returned Token must not be modified. + Token(context.Context) (*oauth2.Token, error) +} + type Option func(*cachedTokenSource) // WithCachedToken sets the initial token to be used by a cached token source. @@ -50,7 +58,7 @@ func WithAsyncRefresh(b bool) Option { // // If the TokenSource is already a cached token source (obtained by calling this // function), it is returned as is. -func NewCachedTokenSource(ts oauth2.TokenSource, opts ...Option) oauth2.TokenSource { +func NewCachedTokenSource(ts TokenSource, opts ...Option) TokenSource { // This is meant as a niche optimization to avoid double caching of the // token source in situations where the user calls needs caching guarantees // but does not know if the token source is already cached. @@ -75,7 +83,7 @@ func NewCachedTokenSource(ts oauth2.TokenSource, opts ...Option) oauth2.TokenSou type cachedTokenSource struct { // The token source to obtain tokens from. - tokenSource oauth2.TokenSource + tokenSource TokenSource // If true, only refresh the token with a blocking call when it is expired. disableAsync bool @@ -102,11 +110,11 @@ type cachedTokenSource struct { // Token returns a token from the cache or fetches a new one if the current // token is expired. -func (cts *cachedTokenSource) Token() (*oauth2.Token, error) { +func (cts *cachedTokenSource) Token(ctx context.Context) (*oauth2.Token, error) { if cts.disableAsync { - return cts.blockingToken() + return cts.blockingToken(ctx) } - return cts.asyncToken() + return cts.asyncToken(ctx) } // tokenState represents the state of the token. Each token can be in one of @@ -145,7 +153,7 @@ func (c *cachedTokenSource) tokenState() tokenState { } } -func (cts *cachedTokenSource) asyncToken() (*oauth2.Token, error) { +func (cts *cachedTokenSource) asyncToken(ctx context.Context) (*oauth2.Token, error) { cts.mu.Lock() ts := cts.tokenState() t := cts.cachedToken @@ -155,14 +163,14 @@ func (cts *cachedTokenSource) asyncToken() (*oauth2.Token, error) { case fresh: return t, nil case stale: - cts.triggerAsyncRefresh() + cts.triggerAsyncRefresh(ctx) return t, nil default: // expired - return cts.blockingToken() + return cts.blockingToken(ctx) } } -func (cts *cachedTokenSource) blockingToken() (*oauth2.Token, error) { +func (cts *cachedTokenSource) blockingToken(ctx context.Context) (*oauth2.Token, error) { cts.mu.Lock() // The lock is kept for the entire operation to ensure that only one @@ -182,7 +190,7 @@ func (cts *cachedTokenSource) blockingToken() (*oauth2.Token, error) { return cts.cachedToken, nil } - t, err := cts.tokenSource.Token() + t, err := cts.tokenSource.Token(ctx) if err != nil { return nil, err } @@ -190,14 +198,14 @@ func (cts *cachedTokenSource) blockingToken() (*oauth2.Token, error) { return t, nil } -func (cts *cachedTokenSource) triggerAsyncRefresh() { +func (cts *cachedTokenSource) triggerAsyncRefresh(ctx context.Context) { cts.mu.Lock() defer cts.mu.Unlock() if !cts.isRefreshing && cts.refreshErr == nil { cts.isRefreshing = true go func() { - t, err := cts.tokenSource.Token() + t, err := cts.tokenSource.Token(ctx) cts.mu.Lock() defer cts.mu.Unlock() diff --git a/config/experimental/auth/auth_test.go b/config/experimental/auth/auth_test.go index 035ebe42d..24a0d13bf 100644 --- a/config/experimental/auth/auth_test.go +++ b/config/experimental/auth/auth_test.go @@ -1,6 +1,7 @@ package auth import ( + "context" "fmt" "reflect" "sync" @@ -13,7 +14,7 @@ import ( type mockTokenSource func() (*oauth2.Token, error) -func (m mockTokenSource) Token() (*oauth2.Token, error) { +func (m mockTokenSource) Token(_ context.Context) (*oauth2.Token, error) { return m() } @@ -258,7 +259,7 @@ func TestCachedTokenSource_Token(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - cts.Token() + cts.Token(context.Background()) }() } diff --git a/config/experimental/auth/authconv/authconv.go b/config/experimental/auth/authconv/authconv.go new file mode 100644 index 000000000..0f8a1b5d3 --- /dev/null +++ b/config/experimental/auth/authconv/authconv.go @@ -0,0 +1,34 @@ +package authconv + +import ( + "context" + + "github.com/databricks/databricks-sdk-go/config/experimental/auth" + "golang.org/x/oauth2" +) + +// AuthTokenSource converts an oauth2.TokenSource to an auth.TokenSource. +func AuthTokenSource(ts oauth2.TokenSource) auth.TokenSource { + return &authTokenSource{ts: ts} +} + +type authTokenSource struct { + ts oauth2.TokenSource +} + +func (t *authTokenSource) Token(_ context.Context) (*oauth2.Token, error) { + return t.ts.Token() +} + +// OAuth2TokenSource converts an auth.TokenSource to an oauth2.TokenSource. +func OAuth2TokenSource(ts auth.TokenSource) oauth2.TokenSource { + return &oauth2TokenSource{ts: ts} +} + +type oauth2TokenSource struct { + ts auth.TokenSource +} + +func (t *oauth2TokenSource) Token() (*oauth2.Token, error) { + return t.ts.Token(context.Background()) +} diff --git a/config/experimental/auth/authconv/authconv_test.go b/config/experimental/auth/authconv/authconv_test.go new file mode 100644 index 000000000..5ef223b38 --- /dev/null +++ b/config/experimental/auth/authconv/authconv_test.go @@ -0,0 +1,31 @@ +package authconv + +import ( + "fmt" + "testing" + + "golang.org/x/oauth2" +) + +type mockOauth2TokenSource func() (*oauth2.Token, error) + +func (t mockOauth2TokenSource) Token() (*oauth2.Token, error) { + return t() +} + +func TestIndepotency(t *testing.T) { + wantErr := fmt.Errorf("test error") + wantToken := &oauth2.Token{AccessToken: "test token"} + ts := mockOauth2TokenSource(func() (*oauth2.Token, error) { + return wantToken, wantErr + }) + + gotToken, gotErr := OAuth2TokenSource(AuthTokenSource(ts)).Token() + + if gotErr != wantErr { + t.Errorf("Token() = %v, want %v", gotErr, wantErr) + } + if gotToken != wantToken { + t.Errorf("Token() = %v, want %v", gotToken, wantToken) + } +} diff --git a/config/oauth_visitors.go b/config/oauth_visitors.go index e9d3277c2..fc7a3d153 100644 --- a/config/oauth_visitors.go +++ b/config/oauth_visitors.go @@ -1,11 +1,13 @@ package config import ( + "context" "fmt" "net/http" "time" "github.com/databricks/databricks-sdk-go/config/experimental/auth" + "github.com/databricks/databricks-sdk-go/config/experimental/auth/authconv" "golang.org/x/oauth2" ) @@ -13,16 +15,16 @@ import ( // to the token from the auth token sourcevand the provided secondary header to // the token from the secondary token source. func serviceToServiceVisitor(primary, secondary oauth2.TokenSource, secondaryHeader string) func(r *http.Request) error { - refreshableAuth := auth.NewCachedTokenSource(primary) - refreshableSecondary := auth.NewCachedTokenSource(secondary) + refreshableAuth := auth.NewCachedTokenSource(authconv.AuthTokenSource(primary)) + refreshableSecondary := auth.NewCachedTokenSource(authconv.AuthTokenSource(secondary)) return func(r *http.Request) error { - inner, err := refreshableAuth.Token() + inner, err := refreshableAuth.Token(context.Background()) if err != nil { return fmt.Errorf("inner token: %w", err) } inner.SetAuthHeader(r) - cloud, err := refreshableSecondary.Token() + cloud, err := refreshableSecondary.Token(context.Background()) if err != nil { return fmt.Errorf("cloud token: %w", err) } @@ -33,9 +35,9 @@ func serviceToServiceVisitor(primary, secondary oauth2.TokenSource, secondaryHea // The same as serviceToServiceVisitor, but without a secondary token source. func refreshableVisitor(inner oauth2.TokenSource) func(r *http.Request) error { - cts := auth.NewCachedTokenSource(inner) + cts := auth.NewCachedTokenSource(authconv.AuthTokenSource(inner)) return func(r *http.Request) error { - inner, err := cts.Token() + inner, err := cts.Token(context.Background()) if err != nil { return fmt.Errorf("inner token: %w", err) } @@ -63,7 +65,10 @@ func azureReuseTokenSource(t *oauth2.Token, ts oauth2.TokenSource) oauth2.TokenS return t }) - return auth.NewCachedTokenSource(early, auth.WithCachedToken(t)) + return authconv.OAuth2TokenSource(auth.NewCachedTokenSource( + authconv.AuthTokenSource(early), + auth.WithCachedToken(t), + )) } func wrap(ts oauth2.TokenSource, fn func(*oauth2.Token) *oauth2.Token) oauth2.TokenSource {