Skip to content

Commit 07c4668

Browse files
committed
fn/ContextGuard: use context.AfterFunc to wait
Simplifies context cancellation handling by using context.AfterFunc instead of a goroutine to wait for context cancellation. This approach avoids the overhead of a goroutine during the waiting period. For ctxQuitUnsafe, since g.quit is closed only in the Quit method (which also cancels all associated contexts), waiting on context cancellation ensures the same behavior without unnecessary dependency on g.quit. Added a test to ensure that the Create method does not launch any goroutines.
1 parent e9ab603 commit 07c4668

File tree

2 files changed

+53
-20
lines changed

2 files changed

+53
-20
lines changed

fn/context_guard.go

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,10 @@ func (g *ContextGuard) Create(ctx context.Context,
173173
return ctx, cancel
174174
}
175175

176-
// ctxQuitUnsafe spins off a goroutine that will block until the passed context
177-
// is cancelled or until the quit channel has been signaled after which it will
178-
// call the passed cancel function and decrement the wait group.
176+
// ctxQuitUnsafe increases the wait group counter, waits until the context is
177+
// cancelled and decreases the wait group counter. It stores the passed cancel
178+
// function and returns a wrapped version, which removed the stored one and
179+
// calls it. The Quit method calls all the stored cancel functions.
179180
//
180181
// NOTE: the caller must hold the ContextGuard's mutex before calling this
181182
// function.
@@ -185,31 +186,27 @@ func (g *ContextGuard) ctxQuitUnsafe(ctx context.Context,
185186
cancel = g.addCancelFnUnsafe(cancel)
186187

187188
g.wg.Add(1)
188-
go func() {
189-
defer cancel()
190-
defer g.wg.Done()
191189

192-
select {
193-
case <-g.quit:
194-
195-
case <-ctx.Done():
196-
}
197-
}()
190+
// We don't have to wait on g.quit here: g.quit can be closed only in
191+
// the Quit method, which also closes the context we are waiting for.
192+
context.AfterFunc(ctx, func() {
193+
g.wg.Done()
194+
})
198195

199196
return cancel
200197
}
201198

202-
// ctxBlocking spins off a goroutine that will block until the passed context
203-
// is cancelled after which it will decrement the wait group.
199+
// ctxBlocking increases the wait group counter, waits until the context is
200+
// cancelled and decreases the wait group counter.
201+
//
202+
// NOTE: the caller must hold the ContextGuard's mutex before calling this
203+
// function.
204204
func (g *ContextGuard) ctxBlocking(ctx context.Context) {
205205
g.wg.Add(1)
206-
go func() {
207-
defer g.wg.Done()
208206

209-
select {
210-
case <-ctx.Done():
211-
}
212-
}()
207+
context.AfterFunc(ctx, func() {
208+
g.wg.Done()
209+
})
213210
}
214211

215212
// addCancelFnUnsafe adds a context cancel function to the manager and returns a

fn/context_guard_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ package fn
22

33
import (
44
"context"
5+
"runtime"
56
"testing"
67
"time"
8+
9+
"github.com/stretchr/testify/require"
710
)
811

912
// TestContextGuard tests the behaviour of the ContextGuard.
@@ -439,3 +442,36 @@ func TestContextGuard(t *testing.T) {
439442
}
440443
})
441444
}
445+
446+
// TestContextGuardCountGoroutines makes sure that ContextGuard doesn't create
447+
// any goroutines while waiting for contexts.
448+
func TestContextGuardCountGoroutines(t *testing.T) {
449+
// NOTE: t.Parallel() is not called in this test because it relies on an
450+
// accurate count of active goroutines. Running other tests in parallel
451+
// would introduce additional goroutines, leading to unreliable results.
452+
453+
g := NewContextGuard()
454+
455+
ctx, cancel := context.WithCancel(context.Background())
456+
457+
// Count goroutines before contexts are created.
458+
count1 := runtime.NumGoroutine()
459+
460+
// Create 1000 contexts of each type.
461+
for i := 0; i < 1000; i++ {
462+
_, _ = g.Create(ctx)
463+
_, _ = g.Create(ctx, WithBlockingCG())
464+
_, _ = g.Create(ctx, WithTimeoutCG())
465+
_, _ = g.Create(ctx, WithBlockingCG(), WithTimeoutCG())
466+
}
467+
468+
// Make sure no new goroutine was launched.
469+
count2 := runtime.NumGoroutine()
470+
require.LessOrEqual(t, count2, count1)
471+
472+
// Cancel root context.
473+
cancel()
474+
475+
// Make sure wg's counter gets to 0 eventually.
476+
g.WgWait()
477+
}

0 commit comments

Comments
 (0)