Skip to content

Commit dd25e6e

Browse files
authored
Merge pull request #9361 from starius/optimize-context-guard
fn: optimize context guard
2 parents 70e7b56 + 07c4668 commit dd25e6e

File tree

2 files changed

+65
-26
lines changed

2 files changed

+65
-26
lines changed

fn/context_guard.go

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ func (g *ContextGuard) Quit() {
5151
cancel()
5252
}
5353

54+
// Clear cancelFns. It is safe to use nil, because no write
55+
// operations to it can happen after g.quit is closed.
56+
g.cancelFns = nil
57+
5458
close(g.quit)
5559
})
5660
}
@@ -149,7 +153,7 @@ func (g *ContextGuard) Create(ctx context.Context,
149153
}
150154

151155
if opts.blocking {
152-
g.ctxBlocking(ctx, cancel)
156+
g.ctxBlocking(ctx)
153157

154158
return ctx, cancel
155159
}
@@ -169,9 +173,10 @@ func (g *ContextGuard) Create(ctx context.Context,
169173
return ctx, cancel
170174
}
171175

172-
// ctxQuitUnsafe spins off a goroutine that will block until the passed context
173-
// is cancelled or until the quit channel has been signaled after which it will
174-
// call the passed cancel function and decrement the wait group.
176+
// ctxQuitUnsafe increases the wait group counter, waits until the context is
177+
// cancelled and decreases the wait group counter. It stores the passed cancel
178+
// function and returns a wrapped version, which removed the stored one and
179+
// calls it. The Quit method calls all the stored cancel functions.
175180
//
176181
// NOTE: the caller must hold the ContextGuard's mutex before calling this
177182
// function.
@@ -181,35 +186,27 @@ func (g *ContextGuard) ctxQuitUnsafe(ctx context.Context,
181186
cancel = g.addCancelFnUnsafe(cancel)
182187

183188
g.wg.Add(1)
184-
go func() {
185-
defer cancel()
186-
defer g.wg.Done()
187-
188-
select {
189-
case <-g.quit:
190189

191-
case <-ctx.Done():
192-
}
193-
}()
190+
// We don't have to wait on g.quit here: g.quit can be closed only in
191+
// the Quit method, which also closes the context we are waiting for.
192+
context.AfterFunc(ctx, func() {
193+
g.wg.Done()
194+
})
194195

195196
return cancel
196197
}
197198

198-
// ctxBlocking spins off a goroutine that will block until the passed context
199-
// is cancelled after which it will call the passed cancel function and
200-
// decrement the wait group.
201-
func (g *ContextGuard) ctxBlocking(ctx context.Context,
202-
cancel context.CancelFunc) {
203-
199+
// ctxBlocking increases the wait group counter, waits until the context is
200+
// cancelled and decreases the wait group counter.
201+
//
202+
// NOTE: the caller must hold the ContextGuard's mutex before calling this
203+
// function.
204+
func (g *ContextGuard) ctxBlocking(ctx context.Context) {
204205
g.wg.Add(1)
205-
go func() {
206-
defer cancel()
207-
defer g.wg.Done()
208206

209-
select {
210-
case <-ctx.Done():
211-
}
212-
}()
207+
context.AfterFunc(ctx, func() {
208+
g.wg.Done()
209+
})
213210
}
214211

215212
// addCancelFnUnsafe adds a context cancel function to the manager and returns a

fn/context_guard_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ package fn
22

33
import (
44
"context"
5+
"runtime"
56
"testing"
67
"time"
8+
9+
"github.com/stretchr/testify/require"
710
)
811

912
// TestContextGuard tests the behaviour of the ContextGuard.
@@ -298,6 +301,12 @@ func TestContextGuard(t *testing.T) {
298301
case <-time.After(time.Second):
299302
t.Fatalf("timeout")
300303
}
304+
305+
// Cancel the context.
306+
cancel()
307+
308+
// Make sure wg's counter gets to 0 eventually.
309+
g.WgWait()
301310
})
302311

303312
// Test that if we add the CustomTimeoutCGOpt option, then the context
@@ -433,3 +442,36 @@ func TestContextGuard(t *testing.T) {
433442
}
434443
})
435444
}
445+
446+
// TestContextGuardCountGoroutines makes sure that ContextGuard doesn't create
447+
// any goroutines while waiting for contexts.
448+
func TestContextGuardCountGoroutines(t *testing.T) {
449+
// NOTE: t.Parallel() is not called in this test because it relies on an
450+
// accurate count of active goroutines. Running other tests in parallel
451+
// would introduce additional goroutines, leading to unreliable results.
452+
453+
g := NewContextGuard()
454+
455+
ctx, cancel := context.WithCancel(context.Background())
456+
457+
// Count goroutines before contexts are created.
458+
count1 := runtime.NumGoroutine()
459+
460+
// Create 1000 contexts of each type.
461+
for i := 0; i < 1000; i++ {
462+
_, _ = g.Create(ctx)
463+
_, _ = g.Create(ctx, WithBlockingCG())
464+
_, _ = g.Create(ctx, WithTimeoutCG())
465+
_, _ = g.Create(ctx, WithBlockingCG(), WithTimeoutCG())
466+
}
467+
468+
// Make sure no new goroutine was launched.
469+
count2 := runtime.NumGoroutine()
470+
require.LessOrEqual(t, count2, count1)
471+
472+
// Cancel root context.
473+
cancel()
474+
475+
// Make sure wg's counter gets to 0 eventually.
476+
g.WgWait()
477+
}

0 commit comments

Comments
 (0)