Skip to content

Commit 4125874

Browse files
authored
Merge pull request #201 from scalarion/feature/token-refresh-callback
Feature/token refresh callback
2 parents 35ce209 + 6b433f4 commit 4125874

File tree

4 files changed

+275
-0
lines changed

4 files changed

+275
-0
lines changed

client.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@ func newClient(ctx context.Context, config *oauth2.Config, token *oauth2.Token)
2828
}
2929
}
3030

31+
// newClientWithCallback creates a new tado° client with token refresh callback.
32+
// The callback will be invoked whenever the OAuth2 token is automatically refreshed.
33+
func newClientWithCallback(ctx context.Context, config *oauth2.Config, token *oauth2.Token, callback TokenRefreshCallback) *client {
34+
tokenSrc := config.TokenSource(ctx, token)
35+
callbackTokenSrc := NewCallbackTokenSource(tokenSrc, callback)
36+
37+
return &client{
38+
http: oauth2.NewClient(ctx, callbackTokenSrc),
39+
}
40+
}
41+
3142
// WithHTTPClient configures the http client to use for tado° API interactions
3243
func (c *client) WithHTTPClient(httpClient *http.Client) *client {
3344
c.http = httpClient

tado.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,34 @@ func New(ctx context.Context, config *oauth2.Config, token *oauth2.Token) *Tado
3333
}
3434
}
3535

36+
// NewWithTokenRefreshCallback creates a new tado client with a callback
37+
// that is invoked whenever OAuth2 tokens are automatically refreshed.
38+
// This allows applications to persist refreshed tokens to storage.
39+
//
40+
// The tado° API uses refresh token rotation, meaning the old refresh token
41+
// is invalidated when a new one is issued. This makes it critical to save
42+
// refreshed tokens to prevent re-authentication.
43+
//
44+
// Example:
45+
//
46+
// config := gotado.AuthConfig(clientID, "offline_access")
47+
// token, _ := config.DeviceAccessToken(ctx, deviceAuth)
48+
//
49+
// callback := func(newToken *oauth2.Token) {
50+
// log.Println("Token refreshed, saving to disk")
51+
// }
52+
//
53+
// tado := gotado.NewWithTokenRefreshCallback(ctx, config, token, callback)
54+
//
55+
// Note: The callback is called synchronously. If you need to perform
56+
// heavy processing, consider sending the token to a channel for
57+
// asynchronous handling.
58+
func NewWithTokenRefreshCallback(ctx context.Context, config *oauth2.Config, token *oauth2.Token, callback TokenRefreshCallback) *Tado {
59+
return &Tado{
60+
client: newClientWithCallback(ctx, config, token, callback),
61+
}
62+
}
63+
3664
// Me returns information about the authenticated user.
3765
func (t *Tado) Me(ctx context.Context) (*User, error) {
3866
me := &User{client: t.client}

tokensource.go

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
package gotado
2+
3+
import (
4+
"sync"
5+
6+
"golang.org/x/oauth2"
7+
)
8+
9+
// TokenRefreshCallback is called whenever a token is automatically refreshed.
10+
// The callback receives the new token and should persist it to storage.
11+
//
12+
// The token passed to the callback contains the essential fields needed for
13+
// persistence (AccessToken, RefreshToken, TokenType, Expiry, ExpiresIn).
14+
// Extra fields stored via WithExtra() are intentionally not included to avoid
15+
// potential race conditions from sharing references to the original token's
16+
// internal data structures.
17+
//
18+
// IMPORTANT: The callback is called synchronously and should return quickly.
19+
// If heavy processing is needed, consider sending the token to a channel
20+
// or queue for asynchronous processing.
21+
type TokenRefreshCallback func(token *oauth2.Token)
22+
23+
// callbackTokenSource wraps an oauth2.TokenSource and calls a callback
24+
// whenever Token() returns a different token than the previous call.
25+
// This is useful for persisting refreshed OAuth2 tokens to disk or other storage.
26+
type callbackTokenSource struct {
27+
src oauth2.TokenSource
28+
callback TokenRefreshCallback
29+
mu sync.Mutex
30+
lastToken *oauth2.Token
31+
}
32+
33+
// Token implements the oauth2.TokenSource interface.
34+
// It retrieves a token from the underlying source and invokes the callback
35+
// if the token has changed (either access token or refresh token is different).
36+
func (cts *callbackTokenSource) Token() (*oauth2.Token, error) {
37+
cts.mu.Lock()
38+
defer cts.mu.Unlock()
39+
40+
newToken, err := cts.src.Token()
41+
if err != nil {
42+
return nil, err
43+
}
44+
45+
// Check if token has changed (different access token or refresh token)
46+
// We check both because:
47+
// - Access tokens expire frequently (every 10 minutes for tado°)
48+
// - Refresh tokens may rotate (tado° uses refresh token rotation)
49+
tokenChanged := false
50+
if cts.lastToken == nil {
51+
tokenChanged = true
52+
} else {
53+
// Compare access tokens
54+
if cts.lastToken.AccessToken != newToken.AccessToken {
55+
tokenChanged = true
56+
}
57+
// Compare refresh tokens (important for token rotation)
58+
if cts.lastToken.RefreshToken != newToken.RefreshToken {
59+
tokenChanged = true
60+
}
61+
}
62+
63+
// Update lastToken if token changed
64+
if tokenChanged {
65+
cts.lastToken = copyToken(newToken)
66+
67+
// Invoke callback if provided
68+
if cts.callback != nil {
69+
// Make a copy of the token to pass to the callback
70+
// This prevents the callback from modifying the token
71+
tokenCopy := copyToken(newToken)
72+
cts.callback(tokenCopy)
73+
}
74+
}
75+
76+
return newToken, nil
77+
}
78+
79+
// NewCallbackTokenSource creates a TokenSource that invokes the provided
80+
// callback whenever the underlying token is refreshed.
81+
//
82+
// This is particularly useful for the tado° API which:
83+
// - Has short-lived access tokens (10 minutes)
84+
// - Uses refresh token rotation (old refresh token is invalidated when new one is issued)
85+
// - Requires offline_access scope for refresh tokens
86+
//
87+
// Example usage:
88+
//
89+
// config := gotado.AuthConfig(clientID, "offline_access")
90+
// token, _ := config.DeviceAccessToken(ctx, deviceAuth)
91+
//
92+
// callback := func(newToken *oauth2.Token) {
93+
// // Save token to encrypted storage
94+
// log.Println("Token refreshed, saving to disk")
95+
// SaveTokenToFile(newToken)
96+
// }
97+
//
98+
// tado := gotado.NewWithTokenRefreshCallback(ctx, config, token, callback)
99+
func NewCallbackTokenSource(src oauth2.TokenSource, callback TokenRefreshCallback) oauth2.TokenSource {
100+
return &callbackTokenSource{
101+
src: src,
102+
callback: callback,
103+
}
104+
}
105+
106+
// copyToken creates a copy of an oauth2.Token.
107+
// This creates a new token with the same field values as the source token.
108+
//
109+
// Note: Extra fields stored via WithExtra() are intentionally not copied.
110+
// The purpose of the callback is to persist the access and refresh tokens
111+
// for later use. Extra fields are not needed for token storage and excluding
112+
// them avoids potential race conditions from sharing references to the
113+
// original token's internal data structures.
114+
func copyToken(src *oauth2.Token) *oauth2.Token {
115+
if src == nil {
116+
return nil
117+
}
118+
119+
return &oauth2.Token{
120+
AccessToken: src.AccessToken,
121+
TokenType: src.TokenType,
122+
RefreshToken: src.RefreshToken,
123+
Expiry: src.Expiry,
124+
ExpiresIn: src.ExpiresIn,
125+
}
126+
}

tokensource_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package gotado
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"golang.org/x/oauth2"
8+
)
9+
10+
func TestCopyToken(t *testing.T) {
11+
t.Run("NilToken", func(t *testing.T) {
12+
copied := copyToken(nil)
13+
if copied != nil {
14+
t.Error("Expected nil for nil input")
15+
}
16+
})
17+
18+
t.Run("BasicFields", func(t *testing.T) {
19+
expiry := time.Now().Add(1 * time.Hour)
20+
original := &oauth2.Token{
21+
AccessToken: "test_access_token",
22+
TokenType: "Bearer",
23+
RefreshToken: "test_refresh_token",
24+
Expiry: expiry,
25+
}
26+
27+
copied := copyToken(original)
28+
29+
// Verify it's a different object
30+
if original == copied {
31+
t.Error("Expected different pointer, got same object")
32+
}
33+
34+
// Verify all fields are copied
35+
if copied.AccessToken != original.AccessToken {
36+
t.Errorf("AccessToken mismatch: got %v, want %v", copied.AccessToken, original.AccessToken)
37+
}
38+
if copied.TokenType != original.TokenType {
39+
t.Errorf("TokenType mismatch: got %v, want %v", copied.TokenType, original.TokenType)
40+
}
41+
if copied.RefreshToken != original.RefreshToken {
42+
t.Errorf("RefreshToken mismatch: got %v, want %v", copied.RefreshToken, original.RefreshToken)
43+
}
44+
if !copied.Expiry.Equal(original.Expiry) {
45+
t.Errorf("Expiry mismatch: got %v, want %v", copied.Expiry, original.Expiry)
46+
}
47+
})
48+
49+
t.Run("ExpiresInField", func(t *testing.T) {
50+
original := &oauth2.Token{
51+
AccessToken: "test_access_token",
52+
ExpiresIn: 600,
53+
}
54+
55+
copied := copyToken(original)
56+
57+
if copied.ExpiresIn != original.ExpiresIn {
58+
t.Errorf("ExpiresIn mismatch: got %v, want %v", copied.ExpiresIn, original.ExpiresIn)
59+
}
60+
})
61+
62+
t.Run("WithExtraFields", func(t *testing.T) {
63+
extraMap := map[string]interface{}{
64+
"scope": "read write",
65+
"custom_field": "custom_value",
66+
"number": float64(123),
67+
}
68+
69+
original := (&oauth2.Token{
70+
AccessToken: "test_access_token",
71+
TokenType: "Bearer",
72+
RefreshToken: "test_refresh_token",
73+
}).WithExtra(extraMap)
74+
75+
copied := copyToken(original)
76+
77+
// Verify extra fields are NOT copied (limitation of simple copy)
78+
if copied.Extra("scope") != nil {
79+
t.Errorf("Extra 'scope' should be nil (not copied), got %v", copied.Extra("scope"))
80+
}
81+
if copied.Extra("custom_field") != nil {
82+
t.Errorf("Extra 'custom_field' should be nil (not copied), got %v", copied.Extra("custom_field"))
83+
}
84+
if copied.Extra("number") != nil {
85+
t.Errorf("Extra 'number' should be nil (not copied), got %v", copied.Extra("number"))
86+
}
87+
})
88+
89+
t.Run("IndependentCopy", func(t *testing.T) {
90+
original := &oauth2.Token{
91+
AccessToken: "original_access",
92+
TokenType: "Bearer",
93+
RefreshToken: "original_refresh",
94+
}
95+
96+
copied := copyToken(original)
97+
98+
// Modify the original
99+
original.AccessToken = "modified_access"
100+
original.RefreshToken = "modified_refresh"
101+
102+
// Verify the copy is not affected
103+
if copied.AccessToken != "original_access" {
104+
t.Errorf("Copy was affected by original modification: got %v, want %v", copied.AccessToken, "original_access")
105+
}
106+
if copied.RefreshToken != "original_refresh" {
107+
t.Errorf("Copy was affected by original modification: got %v, want %v", copied.RefreshToken, "original_refresh")
108+
}
109+
})
110+
}

0 commit comments

Comments
 (0)