From ac57af32deed57fd74edc124c2b09d301efc1846 Mon Sep 17 00:00:00 2001 From: Renaud Hartert Date: Wed, 22 Jan 2025 21:56:57 +0100 Subject: [PATCH 01/10] Add async token cache --- config/auth/auth.go | 177 ++++++++++++++++++++++++++ config/auth/auth_test.go | 261 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 438 insertions(+) create mode 100644 config/auth/auth.go create mode 100644 config/auth/auth_test.go diff --git a/config/auth/auth.go b/config/auth/auth.go new file mode 100644 index 000000000..9e11b8d07 --- /dev/null +++ b/config/auth/auth.go @@ -0,0 +1,177 @@ +package auth + +import ( + "sync" + "time" + + "golang.org/x/oauth2" +) + +const ( + // Default duration for the stale period. The number as been set arbitrarily + // and might be changed in the future. + defaultStaleDuration = 3 * time.Minute + + // Disable the asynchronous token refresh by default. This is meant to + // change in the future once the feature is stable. + defaultDisableAsyncRefresh = true +) + +type CachedTokenSourceOptions struct { + // DisableAsyncRefresh disables the asynchronous token refresh. + DisableAsyncRefresh bool + + // StaleDuration is the duration before the token expires. If unset, the + // default duration of 3 minutes is used. + StaleDuration time.Duration +} + +func (ctso *CachedTokenSourceOptions) disableAsyncRefresh() bool { + if ctso == nil { + return defaultDisableAsyncRefresh + } + return ctso.DisableAsyncRefresh +} + +func (ctso *CachedTokenSourceOptions) staleDuration() time.Duration { + if ctso == nil || ctso.StaleDuration == 0 { + return defaultStaleDuration + } + return ctso.StaleDuration +} + +// NewCachedTokenProvider returns a new token source that caches the token. +// +// The TokenSource is expected to take care of potential retries on its own. +// +// If the TokenSource is already a cached token source, it is returned as is. +func NewCachedTokenSource(ts oauth2.TokenSource, opts *CachedTokenSourceOptions) oauth2.TokenSource { + if cts, ok := ts.(*cachedTokenSource); ok { + return cts + } + + return &cachedTokenSource{ + tokenSource: ts, + staleDuration: opts.staleDuration(), + disableAsync: opts.disableAsyncRefresh(), + timeNow: time.Now, + } +} + +type cachedTokenSource struct { + tokenSource oauth2.TokenSource + staleDuration time.Duration + disableAsync bool + + mu sync.Mutex + cachedToken *oauth2.Token + isRefreshing bool + refreshErr error + + // timeNow is a function that returns the current time. It is used to + // determine the current time in tests. + timeNow func() time.Time +} + +// Token returns a token from the cache or fetches a new one if the current +// token is expired. +func (cts *cachedTokenSource) Token() (*oauth2.Token, error) { + if cts.disableAsync { + return cts.blockingToken() + } + return cts.asyncToken() +} + +// tokenState represents the state of the token. +type tokenState int + +const ( + fresh tokenState = iota // The token is valid. + stale // The token is valid but will expire soon. + expired // The token has expired and cannot be used. +) + +// tokenState returns the state of the token. The function is not thread-safe +// and should be called with the lock held. +func (c *cachedTokenSource) tokenState() tokenState { + if c.cachedToken == nil { + return expired + } + switch lifeSpan := c.cachedToken.Expiry.Sub(c.timeNow()); { + case lifeSpan <= 0: + return expired + case lifeSpan <= c.staleDuration: + return stale + default: + return fresh + } +} + +func (cts *cachedTokenSource) asyncToken() (*oauth2.Token, error) { + cts.mu.Lock() + ts := cts.tokenState() + cts.mu.Unlock() + + switch ts { + case fresh: + cts.mu.Lock() + defer cts.mu.Unlock() + return cts.cachedToken, nil + case stale: + cts.triggerAsyncRefresh() + cts.mu.Lock() + defer cts.mu.Unlock() + return cts.cachedToken, nil + default: // expired + return cts.blockingToken() + } +} + +// blockingToken returns a token from the cache or fetches a new one if the +// current token is expired. The function guarantees that only one refresh call +// we be made if several goroutines are calling it concurrently. +func (cts *cachedTokenSource) blockingToken() (*oauth2.Token, error) { + cts.mu.Lock() + + // The lock is kept for the entire operation to ensure that only one + // blockingToken operation is running at a time. + defer cts.mu.Unlock() + + cts.isRefreshing = false + if ts := cts.tokenState(); ts != expired { // fresh or stale + return cts.cachedToken, nil + } + + t, err := cts.tokenSource.Token() + if err != nil { + return nil, err + } + cts.cachedToken = t + return t, nil +} + +// triggerAsyncRefresh +func (cts *cachedTokenSource) triggerAsyncRefresh() { + cts.mu.Lock() + defer cts.mu.Unlock() + if !cts.isRefreshing && cts.refreshErr == nil { + go cts.asyncRefresh() + } +} + +func (cts *cachedTokenSource) asyncRefresh() { + cts.mu.Lock() + cts.isRefreshing = true + cts.mu.Unlock() + + t, err := cts.tokenSource.Token() + + cts.mu.Lock() + defer cts.mu.Unlock() + cts.isRefreshing = false + if err != nil { + cts.refreshErr = err + return + } + cts.cachedToken = t +} diff --git a/config/auth/auth_test.go b/config/auth/auth_test.go new file mode 100644 index 000000000..a9957f6d3 --- /dev/null +++ b/config/auth/auth_test.go @@ -0,0 +1,261 @@ +package auth + +import ( + "fmt" + "reflect" + "sync" + "sync/atomic" + "testing" + "time" + + "golang.org/x/oauth2" +) + +type mockTokenSource func() (*oauth2.Token, error) + +func (m mockTokenSource) Token() (*oauth2.Token, error) { + return m() +} + +func TestNewCachedTokenSource_noCaching(t *testing.T) { + want := &cachedTokenSource{} + got := NewCachedTokenSource(want, nil) + if got != want { + t.Errorf("NewCachedTokenSource() = %v, want %v", got, want) + } +} + +func TestNewCachedTokenSource(t *testing.T) { + ts := mockTokenSource(func() (*oauth2.Token, error) { + return nil, nil + }) + + testCases := []struct { + options *CachedTokenSourceOptions + want *cachedTokenSource + }{ + { + options: nil, + want: &cachedTokenSource{ + tokenSource: ts, + staleDuration: defaultStaleDuration, + disableAsync: defaultDisableAsyncRefresh, + }, + }, + { + options: &CachedTokenSourceOptions{}, + want: &cachedTokenSource{ + tokenSource: ts, + staleDuration: defaultStaleDuration, + disableAsync: false, + }, + }, + { + options: &CachedTokenSourceOptions{ + DisableAsyncRefresh: true, + }, + want: &cachedTokenSource{ + tokenSource: ts, + staleDuration: defaultStaleDuration, + disableAsync: true, + }, + }, + { + options: &CachedTokenSourceOptions{ + StaleDuration: 5 * time.Minute, + }, + want: &cachedTokenSource{ + tokenSource: ts, + staleDuration: 5 * time.Minute, + disableAsync: false, + }, + }, + } + + for _, tc := range testCases { + got, ok := NewCachedTokenSource(ts, tc.options).(*cachedTokenSource) + if !ok { + t.Fatalf("NewCachedTokenSource() = %T, want *cachedTokenSource", got) + } + + if got.staleDuration != tc.want.staleDuration { + t.Errorf("NewCachedTokenSource() staleDuration = %v, want %v", got.staleDuration, tc.want.staleDuration) + } + if got.disableAsync != tc.want.disableAsync { + t.Errorf("NewCachedTokenSource() disableAsync = %v, want %v", got.disableAsync, tc.want.disableAsync) + } + } +} + +func TestCachedTokenSource_tokenState(t *testing.T) { + now := time.Unix(1337, 0) // mock value for time.Now() + + testCases := []struct { + token *oauth2.Token + staleDuration time.Duration + want tokenState + }{ + { + token: nil, + staleDuration: 10 * time.Minute, + want: expired, + }, + { + token: &oauth2.Token{ + Expiry: now.Add(-1 * time.Second), + }, + staleDuration: 10 * time.Minute, + want: expired, + }, + { + token: &oauth2.Token{ + Expiry: now.Add(1 * time.Hour), + }, + staleDuration: 10 * time.Minute, + want: fresh, + }, + { + token: &oauth2.Token{ + Expiry: now.Add(5 * time.Minute), + }, + staleDuration: 10 * time.Minute, + want: stale, + }, + } + + for _, tc := range testCases { + cts := &cachedTokenSource{ + cachedToken: tc.token, + staleDuration: tc.staleDuration, + disableAsync: false, + timeNow: func() time.Time { return now }, + } + + got := cts.tokenState() + + if got != tc.want { + t.Errorf("tokenState() = %v, want %v", got, tc.want) + } + } +} + +func TestCachedTokenSource_Token(t *testing.T) { + now := time.Unix(1337, 0) // mock value for time.Now() + nTokenCalls := 10 // number of goroutines calling Token() + testCases := []struct { + desc string // description of the test case + cachedToken *oauth2.Token // token cached before calling Token() + disableAsync bool // whether are disabled or not + + returnedToken *oauth2.Token // token returned by the token source + returnedError error // error returned by the token source + + wantCalls int // expected number of calls to the token source + wantToken *oauth2.Token // expected token in the cache + }{ + { + desc: "[Blocking] no cached token", + disableAsync: true, + returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + wantCalls: 1, + }, + { + desc: "[Blocking] expired cached token", + disableAsync: true, + cachedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Second)}, + returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + wantCalls: 1, + }, + { + desc: "[Blocking] fresh cached token", + disableAsync: true, + cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + wantCalls: 0, + }, + { + desc: "[Blocking] stale cached token", + disableAsync: true, + cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)}, + wantCalls: 0, + }, + { + desc: "[Blocking] refresh error", + disableAsync: true, + returnedError: fmt.Errorf("test error"), + wantCalls: 10, + }, + { + desc: "[Async] no cached token", + returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + wantCalls: 1, + }, + { + desc: "[Async] no cached token", + returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + wantCalls: 1, + }, + { + desc: "[Async] expired cached token", + cachedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Second)}, + returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + wantCalls: 1, + }, + { + desc: "[Async] fresh cached token", + cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + wantCalls: 0, + }, + { + desc: "[Async] stale cached token", + cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)}, + returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + wantCalls: 1, + }, + { + desc: "[Async] refresh error", + cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)}, + returnedError: fmt.Errorf("test error"), + wantCalls: 1, + }, + { + desc: "[Async] stale cached token, expired token returned", + cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)}, + returnedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Second)}, + wantCalls: 10, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + gotCalls := int32(0) + cts := &cachedTokenSource{ + disableAsync: tc.disableAsync, + staleDuration: 10 * time.Minute, + cachedToken: tc.cachedToken, + timeNow: func() time.Time { return now }, + tokenSource: mockTokenSource(func() (*oauth2.Token, error) { + atomic.AddInt32(&gotCalls, 1) + return tc.returnedToken, tc.returnedError + }), + } + + wg := sync.WaitGroup{} + for i := 0; i < nTokenCalls; i++ { + wg.Add(1) + go func() { + defer wg.Done() + cts.Token() + }() + } + wg.Wait() + + if int(gotCalls) != tc.wantCalls { + t.Errorf("want %d calls to cts.tokenSource.Token(), got %d", tc.wantCalls, gotCalls) + } + if !reflect.DeepEqual(tc.wantToken, cts.cachedToken) { + t.Errorf("want cached token %v, got %v", tc.wantToken, cts.cachedToken) + } + }) + } + +} From 356b4335b5ceb0915dd7ba6e964f250bfeb4d2c2 Mon Sep 17 00:00:00 2001 From: Renaud Hartert Date: Thu, 23 Jan 2025 08:22:01 +0100 Subject: [PATCH 02/10] Fix issue in trigger async --- config/auth/auth.go | 25 +++++++++++-------------- config/auth/auth_test.go | 11 +++++++++++ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/config/auth/auth.go b/config/auth/auth.go index 9e11b8d07..9f755f955 100644 --- a/config/auth/auth.go +++ b/config/auth/auth.go @@ -63,14 +63,18 @@ type cachedTokenSource struct { staleDuration time.Duration disableAsync bool - mu sync.Mutex - cachedToken *oauth2.Token + mu sync.Mutex + cachedToken *oauth2.Token + + // Indicates that an async refresh is in progress. This is used to prevent + // multiple async refreshes from being triggered at the same time. isRefreshing bool - refreshErr error - // timeNow is a function that returns the current time. It is used to - // determine the current time in tests. - timeNow func() time.Time + // Error returned by the last async refresh. This is used to prevent + // multiple async refreshes from being triggered. + refreshErr error + + timeNow func() time.Time // for testing } // Token returns a token from the cache or fetches a new one if the current @@ -127,9 +131,6 @@ func (cts *cachedTokenSource) asyncToken() (*oauth2.Token, error) { } } -// blockingToken returns a token from the cache or fetches a new one if the -// current token is expired. The function guarantees that only one refresh call -// we be made if several goroutines are calling it concurrently. func (cts *cachedTokenSource) blockingToken() (*oauth2.Token, error) { cts.mu.Lock() @@ -150,20 +151,16 @@ func (cts *cachedTokenSource) blockingToken() (*oauth2.Token, error) { return t, nil } -// triggerAsyncRefresh func (cts *cachedTokenSource) triggerAsyncRefresh() { cts.mu.Lock() defer cts.mu.Unlock() if !cts.isRefreshing && cts.refreshErr == nil { + cts.isRefreshing = true go cts.asyncRefresh() } } func (cts *cachedTokenSource) asyncRefresh() { - cts.mu.Lock() - cts.isRefreshing = true - cts.mu.Unlock() - t, err := cts.tokenSource.Token() cts.mu.Lock() diff --git a/config/auth/auth_test.go b/config/auth/auth_test.go index a9957f6d3..c4627d487 100644 --- a/config/auth/auth_test.go +++ b/config/auth/auth_test.go @@ -158,6 +158,7 @@ func TestCachedTokenSource_Token(t *testing.T) { disableAsync: true, returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, wantCalls: 1, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, }, { desc: "[Blocking] expired cached token", @@ -165,18 +166,21 @@ func TestCachedTokenSource_Token(t *testing.T) { cachedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Second)}, returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, wantCalls: 1, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, }, { desc: "[Blocking] fresh cached token", disableAsync: true, cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, wantCalls: 0, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, }, { desc: "[Blocking] stale cached token", disableAsync: true, cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)}, wantCalls: 0, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)}, }, { desc: "[Blocking] refresh error", @@ -188,40 +192,47 @@ func TestCachedTokenSource_Token(t *testing.T) { desc: "[Async] no cached token", returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, wantCalls: 1, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, }, { desc: "[Async] no cached token", returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, wantCalls: 1, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, }, { desc: "[Async] expired cached token", cachedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Second)}, returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, wantCalls: 1, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, }, { desc: "[Async] fresh cached token", cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, wantCalls: 0, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, }, { desc: "[Async] stale cached token", cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)}, returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, wantCalls: 1, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, }, { desc: "[Async] refresh error", cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)}, returnedError: fmt.Errorf("test error"), wantCalls: 1, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)}, }, { desc: "[Async] stale cached token, expired token returned", cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)}, returnedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Second)}, wantCalls: 10, + wantToken: &oauth2.Token{Expiry: now.Add(-1 * time.Second)}, }, } From 799aae7f927e4a6c3d456fd79ca0e2f84aeb5bc9 Mon Sep 17 00:00:00 2001 From: Renaud Hartert Date: Thu, 23 Jan 2025 09:03:37 +0100 Subject: [PATCH 03/10] Use CTS --- config/auth/auth.go | 53 +++++++++++++--------- config/auth/auth_test.go | 94 ++++++++++++++++++---------------------- config/oauth_visitors.go | 48 +++++++++++++++----- 3 files changed, 111 insertions(+), 84 deletions(-) diff --git a/config/auth/auth.go b/config/auth/auth.go index 9f755f955..dc7d892d3 100644 --- a/config/auth/auth.go +++ b/config/auth/auth.go @@ -17,45 +17,58 @@ const ( defaultDisableAsyncRefresh = true ) -type CachedTokenSourceOptions struct { - // DisableAsyncRefresh disables the asynchronous token refresh. - DisableAsyncRefresh bool +type Option func(*cachedTokenSource) - // StaleDuration is the duration before the token expires. If unset, the - // default duration of 3 minutes is used. - StaleDuration time.Duration +// WithCachedToken sets the initial token to be used by a cached token source. +func WithCachedToken(t *oauth2.Token) Option { + return func(cts *cachedTokenSource) { + cts.cachedToken = t + } } -func (ctso *CachedTokenSourceOptions) disableAsyncRefresh() bool { - if ctso == nil { - return defaultDisableAsyncRefresh +// WithStaleDuration sets the duration for which a token is considered stale. +// Stale tokens are still valid but will trigger an asynchronous refresh if +// async refresh is enabled. The default value is 3 minutes. +func WithStaleDuration(d time.Duration) Option { + return func(cts *cachedTokenSource) { + cts.staleDuration = d } - return ctso.DisableAsyncRefresh } -func (ctso *CachedTokenSourceOptions) staleDuration() time.Duration { - if ctso == nil || ctso.StaleDuration == 0 { - return defaultStaleDuration +// WithAsyncRefresh enables or disables the asynchronous token refresh. +func WithAsyncRefresh(b bool) Option { + return func(cts *cachedTokenSource) { + cts.disableAsync = !b } - return ctso.StaleDuration } -// NewCachedTokenProvider returns a new token source that caches the token. +// NewCachedTokenProvider returns a new token source that caches the token. The +// token is refreshed when it is expired or about to expire. The token is +// refreshed asynchronously if the async refresh is enabled. // -// The TokenSource is expected to take care of potential retries on its own. +// The token cache does not take care of retries in case the token source +// returns and error; it is the responsibility of the provided token source to +// handle retries appropriately. // // If the TokenSource is already a cached token source, it is returned as is. -func NewCachedTokenSource(ts oauth2.TokenSource, opts *CachedTokenSourceOptions) oauth2.TokenSource { +func NewCachedTokenSource(ts oauth2.TokenSource, opts ...Option) oauth2.TokenSource { if cts, ok := ts.(*cachedTokenSource); ok { return cts } - return &cachedTokenSource{ + cts := &cachedTokenSource{ tokenSource: ts, - staleDuration: opts.staleDuration(), - disableAsync: opts.disableAsyncRefresh(), + staleDuration: defaultStaleDuration, + disableAsync: defaultDisableAsyncRefresh, + cachedToken: nil, timeNow: time.Now, } + + for _, opt := range opts { + opt(cts) + } + + return cts } type cachedTokenSource struct { diff --git a/config/auth/auth_test.go b/config/auth/auth_test.go index c4627d487..a7d5d71d2 100644 --- a/config/auth/auth_test.go +++ b/config/auth/auth_test.go @@ -25,65 +25,55 @@ func TestNewCachedTokenSource_noCaching(t *testing.T) { } } -func TestNewCachedTokenSource(t *testing.T) { +func TestNewCachedTokenSource_default(t *testing.T) { ts := mockTokenSource(func() (*oauth2.Token, error) { return nil, nil }) - testCases := []struct { - options *CachedTokenSourceOptions - want *cachedTokenSource - }{ - { - options: nil, - want: &cachedTokenSource{ - tokenSource: ts, - staleDuration: defaultStaleDuration, - disableAsync: defaultDisableAsyncRefresh, - }, - }, - { - options: &CachedTokenSourceOptions{}, - want: &cachedTokenSource{ - tokenSource: ts, - staleDuration: defaultStaleDuration, - disableAsync: false, - }, - }, - { - options: &CachedTokenSourceOptions{ - DisableAsyncRefresh: true, - }, - want: &cachedTokenSource{ - tokenSource: ts, - staleDuration: defaultStaleDuration, - disableAsync: true, - }, - }, - { - options: &CachedTokenSourceOptions{ - StaleDuration: 5 * time.Minute, - }, - want: &cachedTokenSource{ - tokenSource: ts, - staleDuration: 5 * time.Minute, - disableAsync: false, - }, - }, + got, ok := NewCachedTokenSource(ts).(*cachedTokenSource) + if !ok { + t.Fatalf("NewCachedTokenSource() = %T, want *cachedTokenSource", got) } - for _, tc := range testCases { - got, ok := NewCachedTokenSource(ts, tc.options).(*cachedTokenSource) - if !ok { - t.Fatalf("NewCachedTokenSource() = %T, want *cachedTokenSource", got) - } + if got.staleDuration != defaultStaleDuration { + t.Errorf("NewCachedTokenSource() staleDuration = %v, want %v", got.staleDuration, defaultStaleDuration) + } + if got.disableAsync != defaultDisableAsyncRefresh { + t.Errorf("NewCachedTokenSource() disableAsync = %v, want %v", got.disableAsync, defaultDisableAsyncRefresh) + } + if got.cachedToken != nil { + t.Errorf("NewCachedTokenSource() cachedToken = %v, want nil", got.cachedToken) + } +} - if got.staleDuration != tc.want.staleDuration { - t.Errorf("NewCachedTokenSource() staleDuration = %v, want %v", got.staleDuration, tc.want.staleDuration) - } - if got.disableAsync != tc.want.disableAsync { - t.Errorf("NewCachedTokenSource() disableAsync = %v, want %v", got.disableAsync, tc.want.disableAsync) - } +func TestNewCachedTokenSource_options(t *testing.T) { + ts := mockTokenSource(func() (*oauth2.Token, error) { + return nil, nil + }) + + wantStaleDuration := 10 * time.Minute + wantDisableAsync := false + wantCachedToken := &oauth2.Token{Expiry: time.Unix(42, 0)} + + opts := []Option{ + WithStaleDuration(wantStaleDuration), + WithAsyncRefresh(!wantDisableAsync), + WithCachedToken(wantCachedToken), + } + + got, ok := NewCachedTokenSource(ts, opts...).(*cachedTokenSource) + if !ok { + t.Fatalf("NewCachedTokenSource() = %T, want *cachedTokenSource", got) + } + + if got.staleDuration != wantStaleDuration { + t.Errorf("NewCachedTokenSource(): staleDuration = %v, want %v", got.staleDuration, wantStaleDuration) + } + if got.disableAsync != wantDisableAsync { + t.Errorf("NewCachedTokenSource(): disableAsync = %v, want %v", got.disableAsync, wantDisableAsync) + } + if got.cachedToken != wantCachedToken { + t.Errorf("NewCachedTokenSource(): cachedToken = %v, want %v", got.cachedToken, wantCachedToken) } } diff --git a/config/oauth_visitors.go b/config/oauth_visitors.go index 2b172bf1e..6ae7620bb 100644 --- a/config/oauth_visitors.go +++ b/config/oauth_visitors.go @@ -5,14 +5,16 @@ import ( "net/http" "time" + "github.com/databricks/databricks-sdk-go/config/auth" "golang.org/x/oauth2" ) -// serviceToServiceVisitor returns a visitor that sets the Authorization header to the token from the auth token source -// and the provided secondary header to the token from the secondary token source. -func serviceToServiceVisitor(auth, secondary oauth2.TokenSource, secondaryHeader string) func(r *http.Request) error { - refreshableAuth := oauth2.ReuseTokenSource(nil, auth) - refreshableSecondary := oauth2.ReuseTokenSource(nil, secondary) +// serviceToServiceVisitor returns a visitor that sets the Authorization header +// to the token from the auth token sourcevand the provided secondary header to +// the token from the secondary token source. +func serviceToServiceVisitor(primary, secondary oauth2.TokenSource, secondaryHeader string) func(r *http.Request) error { + refreshableAuth := auth.NewCachedTokenSource(primary) + refreshableSecondary := auth.NewCachedTokenSource(secondary) return func(r *http.Request) error { inner, err := refreshableAuth.Token() if err != nil { @@ -31,9 +33,9 @@ func serviceToServiceVisitor(auth, secondary oauth2.TokenSource, secondaryHeader // The same as serviceToServiceVisitor, but without a secondary token source. func refreshableVisitor(inner oauth2.TokenSource) func(r *http.Request) error { - refreshableAuth := oauth2.ReuseTokenSource(nil, inner) + cts := auth.NewCachedTokenSource(inner) return func(r *http.Request) error { - inner, err := refreshableAuth.Token() + inner, err := cts.Token() if err != nil { return fmt.Errorf("inner token: %w", err) } @@ -51,10 +53,32 @@ func azureVisitor(cfg *Config, inner func(*http.Request) error) func(*http.Reque } } -// azureReuseTokenSource calls into oauth2.ReuseTokenSourceWithExpiry with a 40 second expiry window. -// By default, the oauth2 library refreshes a token 10 seconds before it expires. -// Azure Databricks rejects tokens that expire in 30 seconds or less. -// We combine these and refresh the token 40 seconds before it expires. +// azureReuseTokenSource returns a cached token source that refreshes token 40 +// seconds before they expire. The reason for this is that Azure Databricks +// rejects tokens that expire in 30 seconds or less and we want to give a 10 +// second buffer. func azureReuseTokenSource(t *oauth2.Token, ts oauth2.TokenSource) oauth2.TokenSource { - return oauth2.ReuseTokenSourceWithExpiry(t, ts, 40*time.Second) + early := wrap(ts, func(t *oauth2.Token) *oauth2.Token { + t.Expiry = t.Expiry.Add(-40 * time.Second) + return t + }) + + return auth.NewCachedTokenSource(early, auth.WithCachedToken(t)) +} + +func wrap(ts oauth2.TokenSource, fn func(*oauth2.Token) *oauth2.Token) oauth2.TokenSource { + return &tokenSourceWrapper{fn: fn, inner: ts} +} + +type tokenSourceWrapper struct { + fn func(*oauth2.Token) *oauth2.Token + inner oauth2.TokenSource +} + +func (w *tokenSourceWrapper) Token() (*oauth2.Token, error) { + t, err := w.inner.Token() + if err != nil { + return nil, err + } + return w.fn(t), nil } From dc40cbb3b1c7fb810cb6aa7473654a8ae3432a44 Mon Sep 17 00:00:00 2001 From: Renaud Hartert Date: Thu, 23 Jan 2025 10:04:47 +0100 Subject: [PATCH 04/10] Add more comments --- config/auth/auth.go | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/config/auth/auth.go b/config/auth/auth.go index dc7d892d3..c8203df82 100644 --- a/config/auth/auth.go +++ b/config/auth/auth.go @@ -42,15 +42,19 @@ func WithAsyncRefresh(b bool) Option { } } -// NewCachedTokenProvider returns a new token source that caches the token. The -// token is refreshed when it is expired or about to expire. The token is -// refreshed asynchronously if the async refresh is enabled. +// NewCachedTokenProvider wraps a [oauth2.TokenSource] to cache the tokens +// it returns. By default, the cache will refresh tokens asynchronously a few +// minutes before they expire. +// +// The token cache is safe for concurrent use by multiple goroutines and will +// guarantee that only one token refresh is triggered at a time. // // The token cache does not take care of retries in case the token source // returns and error; it is the responsibility of the provided token source to // handle retries appropriately. // -// If the TokenSource is already a cached token source, it is returned as is. +// If the TokenSource is already a cached token source (obtained by calling this +// function), it is returned as is. func NewCachedTokenSource(ts oauth2.TokenSource, opts ...Option) oauth2.TokenSource { if cts, ok := ts.(*cachedTokenSource); ok { return cts @@ -83,8 +87,11 @@ type cachedTokenSource struct { // multiple async refreshes from being triggered at the same time. isRefreshing bool - // Error returned by the last async refresh. This is used to prevent - // multiple async refreshes from being triggered. + // Error returned by the last refresh. Async refreshes are disabled if this + // value is not nil so that the cache does not continue sending request to + // a potentially failing server. The next blocking call will re-enable async + // refreshes by setting this value to nil if it succeeds, or return the + // error if it fails. refreshErr error timeNow func() time.Time // for testing @@ -151,7 +158,11 @@ func (cts *cachedTokenSource) blockingToken() (*oauth2.Token, error) { // blockingToken operation is running at a time. defer cts.mu.Unlock() + // This is important to recover from potential previous failed attempts + // to refresh the token asynchronously, see declaration of refreshErr for + // more information. cts.isRefreshing = false + if ts := cts.tokenState(); ts != expired { // fresh or stale return cts.cachedToken, nil } From 1f30f68afae29d4b99346723c507e73ff77209f7 Mon Sep 17 00:00:00 2001 From: Renaud Hartert Date: Sun, 26 Jan 2025 17:09:16 +0100 Subject: [PATCH 05/10] Clarify experimental status --- config/auth/auth.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/config/auth/auth.go b/config/auth/auth.go index c8203df82..496bd22a7 100644 --- a/config/auth/auth.go +++ b/config/auth/auth.go @@ -1,3 +1,7 @@ +// Package auth is an internal package that provides authentication utilities. +// +// IMPORTANT: This package is not meant to be used directly by consumers of the +// SDK and is subject to change without notice. package auth import ( From 03cb267f5598e17a13fb7ac8c9cb2baa46e0db9d Mon Sep 17 00:00:00 2001 From: Renaud Hartert Date: Sun, 26 Jan 2025 18:52:33 +0100 Subject: [PATCH 06/10] Clarify experimental status --- config/{ => experimental}/auth/auth.go | 0 config/{ => experimental}/auth/auth_test.go | 0 config/oauth_visitors.go | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) rename config/{ => experimental}/auth/auth.go (100%) rename config/{ => experimental}/auth/auth_test.go (100%) diff --git a/config/auth/auth.go b/config/experimental/auth/auth.go similarity index 100% rename from config/auth/auth.go rename to config/experimental/auth/auth.go diff --git a/config/auth/auth_test.go b/config/experimental/auth/auth_test.go similarity index 100% rename from config/auth/auth_test.go rename to config/experimental/auth/auth_test.go diff --git a/config/oauth_visitors.go b/config/oauth_visitors.go index 6ae7620bb..e9d3277c2 100644 --- a/config/oauth_visitors.go +++ b/config/oauth_visitors.go @@ -5,7 +5,7 @@ import ( "net/http" "time" - "github.com/databricks/databricks-sdk-go/config/auth" + "github.com/databricks/databricks-sdk-go/config/experimental/auth" "golang.org/x/oauth2" ) From 89b972d967464f1b423b643747e96d18091d4667 Mon Sep 17 00:00:00 2001 From: Renaud Hartert Date: Mon, 27 Jan 2025 13:06:41 +0100 Subject: [PATCH 07/10] Add comments --- config/experimental/auth/auth.go | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/config/experimental/auth/auth.go b/config/experimental/auth/auth.go index 496bd22a7..a48dd465b 100644 --- a/config/experimental/auth/auth.go +++ b/config/experimental/auth/auth.go @@ -80,9 +80,14 @@ func NewCachedTokenSource(ts oauth2.TokenSource, opts ...Option) oauth2.TokenSou } type cachedTokenSource struct { - tokenSource oauth2.TokenSource + // The token source to obtain tokens from. + tokenSource oauth2.TokenSource + + // If true, only refresh the token with a blocking call when it is expired. + disableAsync bool + + // Duration during which a token is considered stale, see tokenState. staleDuration time.Duration - disableAsync bool mu sync.Mutex cachedToken *oauth2.Token @@ -110,7 +115,18 @@ func (cts *cachedTokenSource) Token() (*oauth2.Token, error) { return cts.asyncToken() } -// tokenState represents the state of the token. +// tokenState represents the state of the token. Each token can be in one of +// the following three states: +// - fresh: The token is valid. +// - stale: The token is valid but will expire soon. +// - expired: The token has expired and cannot be used. +// +// Token state through time: +// +// issue time expiry time +// v v +// | fresh | stale | expired -> time +// | valid | type tokenState int const ( From 7280fd535c939ad74e342b00ffb83301f92da90f Mon Sep 17 00:00:00 2001 From: Renaud Hartert Date: Mon, 3 Feb 2025 09:52:01 +0100 Subject: [PATCH 08/10] Address review comments --- config/experimental/auth/auth.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/config/experimental/auth/auth.go b/config/experimental/auth/auth.go index a48dd465b..108d65d6d 100644 --- a/config/experimental/auth/auth.go +++ b/config/experimental/auth/auth.go @@ -154,18 +154,15 @@ func (c *cachedTokenSource) tokenState() tokenState { func (cts *cachedTokenSource) asyncToken() (*oauth2.Token, error) { cts.mu.Lock() ts := cts.tokenState() + t := cts.cachedToken cts.mu.Unlock() switch ts { case fresh: - cts.mu.Lock() - defer cts.mu.Unlock() - return cts.cachedToken, nil + return t, nil case stale: cts.triggerAsyncRefresh() - cts.mu.Lock() - defer cts.mu.Unlock() - return cts.cachedToken, nil + return t, nil default: // expired return cts.blockingToken() } @@ -183,6 +180,9 @@ func (cts *cachedTokenSource) blockingToken() (*oauth2.Token, error) { // more information. cts.isRefreshing = false + // It's possible that the token got refreshed (either by a blockingToken or + // an asyncRefresh call) while this particular call was waiting to acquire + // the mutex. This check avoids refreshing the token again in such cases. if ts := cts.tokenState(); ts != expired { // fresh or stale return cts.cachedToken, nil } From e287bdf7b8c015f10560b99602df2d78bb6555ac Mon Sep 17 00:00:00 2001 From: Renaud Hartert Date: Mon, 3 Feb 2025 09:55:25 +0100 Subject: [PATCH 09/10] Address review comments --- config/experimental/auth/auth.go | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/config/experimental/auth/auth.go b/config/experimental/auth/auth.go index 108d65d6d..21000083d 100644 --- a/config/experimental/auth/auth.go +++ b/config/experimental/auth/auth.go @@ -179,6 +179,7 @@ func (cts *cachedTokenSource) blockingToken() (*oauth2.Token, error) { // to refresh the token asynchronously, see declaration of refreshErr for // more information. cts.isRefreshing = false + cts.refreshErr = nil // It's possible that the token got refreshed (either by a blockingToken or // an asyncRefresh call) while this particular call was waiting to acquire @@ -200,19 +201,18 @@ func (cts *cachedTokenSource) triggerAsyncRefresh() { defer cts.mu.Unlock() if !cts.isRefreshing && cts.refreshErr == nil { cts.isRefreshing = true - go cts.asyncRefresh() - } -} -func (cts *cachedTokenSource) asyncRefresh() { - t, err := cts.tokenSource.Token() - - cts.mu.Lock() - defer cts.mu.Unlock() - cts.isRefreshing = false - if err != nil { - cts.refreshErr = err - return + go func() { + t, err := cts.tokenSource.Token() + + cts.mu.Lock() + defer cts.mu.Unlock() + cts.isRefreshing = false + if err != nil { + cts.refreshErr = err + return + } + cts.cachedToken = t + }() } - cts.cachedToken = t } From f1159eaab87010aa4ab2d534f7ac47f034970691 Mon Sep 17 00:00:00 2001 From: Renaud Hartert Date: Mon, 3 Feb 2025 12:54:02 +0100 Subject: [PATCH 10/10] Address review comments --- config/experimental/auth/auth.go | 12 +++-------- config/experimental/auth/auth_test.go | 30 +++++++++++++++++++++------ 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/config/experimental/auth/auth.go b/config/experimental/auth/auth.go index 21000083d..2f560498b 100644 --- a/config/experimental/auth/auth.go +++ b/config/experimental/auth/auth.go @@ -30,15 +30,6 @@ func WithCachedToken(t *oauth2.Token) Option { } } -// WithStaleDuration sets the duration for which a token is considered stale. -// Stale tokens are still valid but will trigger an asynchronous refresh if -// async refresh is enabled. The default value is 3 minutes. -func WithStaleDuration(d time.Duration) Option { - return func(cts *cachedTokenSource) { - cts.staleDuration = d - } -} - // WithAsyncRefresh enables or disables the asynchronous token refresh. func WithAsyncRefresh(b bool) Option { return func(cts *cachedTokenSource) { @@ -60,6 +51,9 @@ func WithAsyncRefresh(b bool) Option { // If the TokenSource is already a cached token source (obtained by calling this // function), it is returned as is. func NewCachedTokenSource(ts oauth2.TokenSource, opts ...Option) oauth2.TokenSource { + // This is meant as a niche optimization to avoid double caching of the + // token source in situations where the user calls needs caching guarantees + // but does not know if the token source is already cached. if cts, ok := ts.(*cachedTokenSource); ok { return cts } diff --git a/config/experimental/auth/auth_test.go b/config/experimental/auth/auth_test.go index a7d5d71d2..035ebe42d 100644 --- a/config/experimental/auth/auth_test.go +++ b/config/experimental/auth/auth_test.go @@ -51,12 +51,10 @@ func TestNewCachedTokenSource_options(t *testing.T) { return nil, nil }) - wantStaleDuration := 10 * time.Minute wantDisableAsync := false wantCachedToken := &oauth2.Token{Expiry: time.Unix(42, 0)} opts := []Option{ - WithStaleDuration(wantStaleDuration), WithAsyncRefresh(!wantDisableAsync), WithCachedToken(wantCachedToken), } @@ -66,9 +64,6 @@ func TestNewCachedTokenSource_options(t *testing.T) { t.Fatalf("NewCachedTokenSource() = %T, want *cachedTokenSource", got) } - if got.staleDuration != wantStaleDuration { - t.Errorf("NewCachedTokenSource(): staleDuration = %v, want %v", got.staleDuration, wantStaleDuration) - } if got.disableAsync != wantDisableAsync { t.Errorf("NewCachedTokenSource(): disableAsync = %v, want %v", got.disableAsync, wantDisableAsync) } @@ -136,6 +131,7 @@ func TestCachedTokenSource_Token(t *testing.T) { desc string // description of the test case cachedToken *oauth2.Token // token cached before calling Token() disableAsync bool // whether are disabled or not + refreshErr error // whether the cache was in error state returnedToken *oauth2.Token // token returned by the token source returnedError error // error returned by the token source @@ -178,6 +174,15 @@ func TestCachedTokenSource_Token(t *testing.T) { returnedError: fmt.Errorf("test error"), wantCalls: 10, }, + { + desc: "[Blocking] recover from error", + disableAsync: true, + refreshErr: fmt.Errorf("refresh error"), + cachedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Minute)}, + returnedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Hour)}, + wantCalls: 10, + wantToken: &oauth2.Token{Expiry: now.Add(-1 * time.Hour)}, + }, { desc: "[Async] no cached token", returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, @@ -224,6 +229,14 @@ func TestCachedTokenSource_Token(t *testing.T) { wantCalls: 10, wantToken: &oauth2.Token{Expiry: now.Add(-1 * time.Second)}, }, + { + desc: "[Async] recover from error", + refreshErr: fmt.Errorf("refresh error"), + cachedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Minute)}, + returnedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Hour)}, + wantCalls: 10, + wantToken: &oauth2.Token{Expiry: now.Add(-1 * time.Hour)}, + }, } for _, tc := range testCases { @@ -248,8 +261,14 @@ func TestCachedTokenSource_Token(t *testing.T) { cts.Token() }() } + wg.Wait() + // Wait for async refreshes to finish. This part is a little brittle + // but necessary to ensure that the async refresh is done before + // checking the results. + time.Sleep(10 * time.Millisecond) + if int(gotCalls) != tc.wantCalls { t.Errorf("want %d calls to cts.tokenSource.Token(), got %d", tc.wantCalls, gotCalls) } @@ -258,5 +277,4 @@ func TestCachedTokenSource_Token(t *testing.T) { } }) } - }