Skip to content

Commit 6b433f4

Browse files
committed
Enhance token handling by adding copyToken function and comprehensive tests for token copying behavior
1 parent 58162b6 commit 6b433f4

File tree

2 files changed

+140
-20
lines changed

2 files changed

+140
-20
lines changed

tokensource.go

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ import (
99
// TokenRefreshCallback is called whenever a token is automatically refreshed.
1010
// The callback receives the new token and should persist it to storage.
1111
//
12+
// The token passed to the callback contains the essential fields needed for
13+
// persistence (AccessToken, RefreshToken, TokenType, Expiry, ExpiresIn).
14+
// Extra fields stored via WithExtra() are intentionally not included to avoid
15+
// potential race conditions from sharing references to the original token's
16+
// internal data structures.
17+
//
1218
// IMPORTANT: The callback is called synchronously and should return quickly.
1319
// If heavy processing is needed, consider sending the token to a channel
1420
// or queue for asynchronous processing.
@@ -56,31 +62,13 @@ func (cts *callbackTokenSource) Token() (*oauth2.Token, error) {
5662

5763
// Update lastToken if token changed
5864
if tokenChanged {
59-
cts.lastToken = &oauth2.Token{
60-
AccessToken: newToken.AccessToken,
61-
TokenType: newToken.TokenType,
62-
RefreshToken: newToken.RefreshToken,
63-
Expiry: newToken.Expiry,
64-
}
65+
cts.lastToken = copyToken(newToken)
6566

6667
// Invoke callback if provided
6768
if cts.callback != nil {
6869
// Make a copy of the token to pass to the callback
6970
// This prevents the callback from modifying the token
70-
tokenCopy := &oauth2.Token{
71-
AccessToken: newToken.AccessToken,
72-
TokenType: newToken.TokenType,
73-
RefreshToken: newToken.RefreshToken,
74-
Expiry: newToken.Expiry,
75-
}
76-
// Copy Extra field if present
77-
if newToken.Extra != nil {
78-
extraCopy := make(map[string]interface{}, len(newToken.Extra))
79-
for k, v := range newToken.Extra {
80-
extraCopy[k] = v
81-
}
82-
tokenCopy.Extra = extraCopy
83-
}
71+
tokenCopy := copyToken(newToken)
8472
cts.callback(tokenCopy)
8573
}
8674
}
@@ -114,3 +102,25 @@ func NewCallbackTokenSource(src oauth2.TokenSource, callback TokenRefreshCallbac
114102
callback: callback,
115103
}
116104
}
105+
106+
// copyToken creates a copy of an oauth2.Token.
107+
// This creates a new token with the same field values as the source token.
108+
//
109+
// Note: Extra fields stored via WithExtra() are intentionally not copied.
110+
// The purpose of the callback is to persist the access and refresh tokens
111+
// for later use. Extra fields are not needed for token storage and excluding
112+
// them avoids potential race conditions from sharing references to the
113+
// original token's internal data structures.
114+
func copyToken(src *oauth2.Token) *oauth2.Token {
115+
if src == nil {
116+
return nil
117+
}
118+
119+
return &oauth2.Token{
120+
AccessToken: src.AccessToken,
121+
TokenType: src.TokenType,
122+
RefreshToken: src.RefreshToken,
123+
Expiry: src.Expiry,
124+
ExpiresIn: src.ExpiresIn,
125+
}
126+
}

tokensource_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package gotado
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"golang.org/x/oauth2"
8+
)
9+
10+
func TestCopyToken(t *testing.T) {
11+
t.Run("NilToken", func(t *testing.T) {
12+
copied := copyToken(nil)
13+
if copied != nil {
14+
t.Error("Expected nil for nil input")
15+
}
16+
})
17+
18+
t.Run("BasicFields", func(t *testing.T) {
19+
expiry := time.Now().Add(1 * time.Hour)
20+
original := &oauth2.Token{
21+
AccessToken: "test_access_token",
22+
TokenType: "Bearer",
23+
RefreshToken: "test_refresh_token",
24+
Expiry: expiry,
25+
}
26+
27+
copied := copyToken(original)
28+
29+
// Verify it's a different object
30+
if original == copied {
31+
t.Error("Expected different pointer, got same object")
32+
}
33+
34+
// Verify all fields are copied
35+
if copied.AccessToken != original.AccessToken {
36+
t.Errorf("AccessToken mismatch: got %v, want %v", copied.AccessToken, original.AccessToken)
37+
}
38+
if copied.TokenType != original.TokenType {
39+
t.Errorf("TokenType mismatch: got %v, want %v", copied.TokenType, original.TokenType)
40+
}
41+
if copied.RefreshToken != original.RefreshToken {
42+
t.Errorf("RefreshToken mismatch: got %v, want %v", copied.RefreshToken, original.RefreshToken)
43+
}
44+
if !copied.Expiry.Equal(original.Expiry) {
45+
t.Errorf("Expiry mismatch: got %v, want %v", copied.Expiry, original.Expiry)
46+
}
47+
})
48+
49+
t.Run("ExpiresInField", func(t *testing.T) {
50+
original := &oauth2.Token{
51+
AccessToken: "test_access_token",
52+
ExpiresIn: 600,
53+
}
54+
55+
copied := copyToken(original)
56+
57+
if copied.ExpiresIn != original.ExpiresIn {
58+
t.Errorf("ExpiresIn mismatch: got %v, want %v", copied.ExpiresIn, original.ExpiresIn)
59+
}
60+
})
61+
62+
t.Run("WithExtraFields", func(t *testing.T) {
63+
extraMap := map[string]interface{}{
64+
"scope": "read write",
65+
"custom_field": "custom_value",
66+
"number": float64(123),
67+
}
68+
69+
original := (&oauth2.Token{
70+
AccessToken: "test_access_token",
71+
TokenType: "Bearer",
72+
RefreshToken: "test_refresh_token",
73+
}).WithExtra(extraMap)
74+
75+
copied := copyToken(original)
76+
77+
// Verify extra fields are NOT copied (limitation of simple copy)
78+
if copied.Extra("scope") != nil {
79+
t.Errorf("Extra 'scope' should be nil (not copied), got %v", copied.Extra("scope"))
80+
}
81+
if copied.Extra("custom_field") != nil {
82+
t.Errorf("Extra 'custom_field' should be nil (not copied), got %v", copied.Extra("custom_field"))
83+
}
84+
if copied.Extra("number") != nil {
85+
t.Errorf("Extra 'number' should be nil (not copied), got %v", copied.Extra("number"))
86+
}
87+
})
88+
89+
t.Run("IndependentCopy", func(t *testing.T) {
90+
original := &oauth2.Token{
91+
AccessToken: "original_access",
92+
TokenType: "Bearer",
93+
RefreshToken: "original_refresh",
94+
}
95+
96+
copied := copyToken(original)
97+
98+
// Modify the original
99+
original.AccessToken = "modified_access"
100+
original.RefreshToken = "modified_refresh"
101+
102+
// Verify the copy is not affected
103+
if copied.AccessToken != "original_access" {
104+
t.Errorf("Copy was affected by original modification: got %v, want %v", copied.AccessToken, "original_access")
105+
}
106+
if copied.RefreshToken != "original_refresh" {
107+
t.Errorf("Copy was affected by original modification: got %v, want %v", copied.RefreshToken, "original_refresh")
108+
}
109+
})
110+
}

0 commit comments

Comments
 (0)