Skip to content

Commit 33d1369

Browse files
authored
fix: prevent repeated context expired errors (#228)
This is a port of GoogleCloudPlatform/cloud-sql-go-connector#458.
1 parent f8c584a commit 33d1369

File tree

6 files changed

+65
-71
lines changed

6 files changed

+65
-71
lines changed

dialer.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ type Dialer struct {
9494
// RSA keypair is generated will be faster.
9595
func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
9696
cfg := &dialerConfig{
97-
refreshTimeout: 30 * time.Second,
97+
refreshTimeout: alloydb.RefreshTimeout,
9898
dialFunc: proxy.Dial,
9999
useragents: []string{userAgent},
100100
}

internal/alloydb/instance.go

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,26 @@ import (
2525

2626
"cloud.google.com/go/alloydbconn/errtype"
2727
"cloud.google.com/go/alloydbconn/internal/alloydbapi"
28+
"golang.org/x/time/rate"
2829
)
2930

30-
// the refresh buffer is the amount of time before a refresh's result expires
31-
// that a new refresh operation begins.
32-
const refreshBuffer = 4 * time.Minute
31+
const (
32+
// the refresh buffer is the amount of time before a refresh's result
33+
// expires that a new refresh operation begins.
34+
refreshBuffer = 4 * time.Minute
35+
36+
// refreshInterval is the amount of time between refresh attempts as
37+
// enforced by the rate limiter.
38+
refreshInterval = 30 * time.Second
39+
40+
// RefreshTimeout is the maximum amount of time to wait for a refresh
41+
// cycle to complete. This value should be greater than the
42+
// refreshInterval.
43+
RefreshTimeout = 60 * time.Second
44+
45+
// refreshBurst is the initial burst allowed by the rate limiter.
46+
refreshBurst = 2
47+
)
3348

3449
var (
3550
// Instance URI is in the format:
@@ -117,7 +132,12 @@ type Instance struct {
117132

118133
instanceURI
119134
key *rsa.PrivateKey
120-
r refresher
135+
// refreshTimeout sets the maximum duration a refresh cycle can run
136+
// for.
137+
refreshTimeout time.Duration
138+
// l controls the rate at which refresh cycles are run.
139+
l *rate.Limiter
140+
r refresher
121141

122142
resultGuard sync.RWMutex
123143
// cur represents the current refreshOperation that will be used to
@@ -148,17 +168,13 @@ func NewInstance(
148168
}
149169
ctx, cancel := context.WithCancel(context.Background())
150170
i := &Instance{
151-
instanceURI: cn,
152-
key: key,
153-
r: newRefresher(
154-
client,
155-
refreshTimeout,
156-
30*time.Second,
157-
2,
158-
dialerID,
159-
),
160-
ctx: ctx,
161-
cancel: cancel,
171+
instanceURI: cn,
172+
key: key,
173+
l: rate.NewLimiter(rate.Every(refreshInterval), refreshBurst),
174+
r: newRefresher(client, dialerID),
175+
refreshTimeout: refreshTimeout,
176+
ctx: ctx,
177+
cancel: cancel,
162178
}
163179
// For the initial refresh operation, set cur = next so that connection
164180
// requests block until the first refresh is complete.
@@ -234,20 +250,33 @@ func refreshDuration(now, certExpiry time.Time) time.Duration {
234250

235251
// scheduleRefresh schedules a refresh operation to be triggered after a given
236252
// duration. The returned refreshOperation can be used to either Cancel or Wait
237-
// for the operations result.
253+
// for the operation's result.
238254
func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation {
239-
res := &refreshOperation{}
240-
res.ready = make(chan struct{})
241-
res.timer = time.AfterFunc(d, func() {
242-
res.result, res.err = i.r.performRefresh(i.ctx, i.instanceURI, i.key)
243-
close(res.ready)
255+
r := &refreshOperation{}
256+
r.ready = make(chan struct{})
257+
r.timer = time.AfterFunc(d, func() {
258+
ctx, cancel := context.WithTimeout(i.ctx, i.refreshTimeout)
259+
defer cancel()
260+
261+
err := i.l.Wait(ctx)
262+
if err != nil {
263+
r.err = errtype.NewDialError(
264+
"context was canceled or expired before refresh completed",
265+
i.instanceURI.String(),
266+
nil,
267+
)
268+
} else {
269+
r.result, r.err = i.r.performRefresh(i.ctx, i.instanceURI, i.key)
270+
}
271+
272+
close(r.ready)
244273

245274
// Once the refresh is complete, update "current" with working
246275
// result and schedule a new refresh
247276
i.resultGuard.Lock()
248277
defer i.resultGuard.Unlock()
249278
// if failed, scheduled the next refresh immediately
250-
if res.err != nil {
279+
if r.err != nil {
251280
i.next = i.scheduleRefresh(0)
252281
// If the latest result is bad, avoid replacing the
253282
// used result while it's still valid and potentially
@@ -256,13 +285,13 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation {
256285
// valid are surpressed. We should try to surface
257286
// errors in a more meaningful way.
258287
if !i.cur.isValid() {
259-
i.cur = res
288+
i.cur = r
260289
}
261290
return
262291
}
263292
// Update the current results, and schedule the next refresh in
264293
// the future
265-
i.cur = res
294+
i.cur = r
266295
select {
267296
case <-i.ctx.Done():
268297
// instance has been closed, don't schedule anything
@@ -272,7 +301,7 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation {
272301
t := refreshDuration(time.Now(), i.cur.result.expiry)
273302
i.next = i.scheduleRefresh(t)
274303
})
275-
return res
304+
return r
276305
}
277306

278307
// String returns the instance's URI.

internal/alloydb/instance_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"crypto/rand"
2020
"crypto/rsa"
2121
"errors"
22+
"strings"
2223
"testing"
2324
"time"
2425

@@ -210,7 +211,7 @@ func TestClose(t *testing.T) {
210211
im.Close()
211212

212213
_, _, err = im.ConnectInfo(ctx)
213-
if !errors.Is(err, context.Canceled) {
214+
if !strings.Contains(err.Error(), "context was canceled or expired") {
214215
t.Fatalf("failed to retrieve connect info: %v", err)
215216
}
216217
}

internal/alloydb/refresh.go

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ import (
3030
"cloud.google.com/go/alloydbconn/errtype"
3131
"cloud.google.com/go/alloydbconn/internal/alloydbapi"
3232
"cloud.google.com/go/alloydbconn/internal/trace"
33-
"golang.org/x/time/rate"
3433
)
3534

3635
type connectInfo struct {
@@ -196,16 +195,11 @@ func createTLSConfig(inst instanceURI, cc certChain, info connectInfo, k *rsa.Pr
196195
// newRefresher creates a Refresher.
197196
func newRefresher(
198197
client *alloydbapi.Client,
199-
timeout time.Duration,
200-
interval time.Duration,
201-
burst int,
202198
dialerID string,
203199
) refresher {
204200
return refresher{
205-
client: client,
206-
timeout: timeout,
207-
clientLimiter: rate.NewLimiter(rate.Every(interval), burst),
208-
dialerID: dialerID,
201+
client: client,
202+
dialerID: dialerID,
209203
}
210204
}
211205

@@ -215,14 +209,8 @@ type refresher struct {
215209
// client provides access to the AlloyDB Admin API
216210
client *alloydbapi.Client
217211

218-
// timeout is the maximum amount of time a refresh operation should be allowed to take.
219-
timeout time.Duration
220-
221212
// dialerID is the unique ID of the associated dialer.
222213
dialerID string
223-
224-
// clientLimiter limits the number of refreshes.
225-
clientLimiter *rate.Limiter
226214
}
227215

228216
type refreshResult struct {
@@ -247,22 +235,6 @@ func (r refresher) performRefresh(ctx context.Context, cn instanceURI, k *rsa.Pr
247235
refreshEnd(err)
248236
}()
249237

250-
ctx, cancel := context.WithTimeout(ctx, r.timeout)
251-
defer cancel()
252-
if ctx.Err() == context.Canceled {
253-
return refreshResult{}, ctx.Err()
254-
}
255-
256-
// avoid refreshing too often to try not to tax the AlloyDB Admin API quotas
257-
err = r.clientLimiter.Wait(ctx)
258-
if err != nil {
259-
return refreshResult{}, errtype.NewDialError(
260-
"refresh was throttled until context expired",
261-
cn.String(),
262-
nil,
263-
)
264-
}
265-
266238
type mdRes struct {
267239
info connectInfo
268240
err error

internal/alloydb/refresh_test.go

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ import (
2020
"testing"
2121
"time"
2222

23-
"cloud.google.com/go/alloydbconn/errtype"
2423
"cloud.google.com/go/alloydbconn/internal/alloydbapi"
2524
"cloud.google.com/go/alloydbconn/internal/mock"
2625
"google.golang.org/api/option"
2726
)
2827

28+
const testDialerID = "some-dialer-id"
29+
2930
func TestRefresh(t *testing.T) {
3031
wantIP := "10.0.0.1"
3132
wantExpiry := time.Now().Add(time.Hour).UTC().Round(time.Second)
@@ -57,7 +58,7 @@ func TestRefresh(t *testing.T) {
5758
if err != nil {
5859
t.Fatalf("admin API client error: %v", err)
5960
}
60-
r := newRefresher(cl, time.Hour, 30*time.Second, 2, "some-id")
61+
r := newRefresher(cl, testDialerID)
6162
res, err := r.performRefresh(context.Background(), cn, RSAKey)
6263
if err != nil {
6364
t.Fatalf("performRefresh unexpectedly failed with error: %v", err)
@@ -98,7 +99,7 @@ func TestRefreshFailsFast(t *testing.T) {
9899
if err != nil {
99100
t.Fatalf("admin API client error: %v", err)
100101
}
101-
r := newRefresher(cl, time.Hour, 30*time.Second, 1, "some-id")
102+
r := newRefresher(cl, testDialerID)
102103

103104
_, err = r.performRefresh(context.Background(), cn, RSAKey)
104105
if err != nil {
@@ -112,14 +113,4 @@ func TestRefreshFailsFast(t *testing.T) {
112113
if !errors.Is(err, context.Canceled) {
113114
t.Fatalf("expected context.Canceled error, got = %v", err)
114115
}
115-
116-
// force the rate limiter to throttle with a timed out context
117-
ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond)
118-
defer cancel()
119-
_, err = r.performRefresh(ctx, cn, RSAKey)
120-
121-
var wantErr *errtype.DialError
122-
if !errors.As(err, &wantErr) {
123-
t.Fatalf("when refresh is throttled, want = %T, got = %v", wantErr, err)
124-
}
125116
}

options.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ func WithRSAKey(k *rsa.PrivateKey) Option {
116116
}
117117
}
118118

119-
// WithRefreshTimeout returns an Option that sets a timeout on refresh operations. Defaults to 30s.
119+
// WithRefreshTimeout returns an Option that sets a timeout on refresh
120+
// operations. Defaults to 60s.
120121
func WithRefreshTimeout(t time.Duration) Option {
121122
return func(d *dialerConfig) {
122123
d.refreshTimeout = t

0 commit comments

Comments
 (0)