Skip to content

Commit a7094e4

Browse files
authored
waitall fix write on closed channel (#26)
* try fixing * fix * giveup on close channel * drop variable * minify diff * minify changes * minify changes * try go leak * tweaks on tests * more tweaks on unittest * merge master and tidy * codecov improve
1 parent e936ab2 commit a7094e4

File tree

5 files changed

+101
-25
lines changed

5 files changed

+101
-25
lines changed

go.mod

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ require github.com/stretchr/testify v1.8.2
66

77
require (
88
github.com/davecgh/go-spew v1.1.1 // indirect
9+
github.com/kr/pretty v0.1.0 // indirect
10+
github.com/kr/text v0.2.0 // indirect
911
github.com/pmezard/go-difflib v1.0.0 // indirect
12+
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
1013
gopkg.in/yaml.v3 v3.0.1 // indirect
1114
)

go.sum

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1+
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
12
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
23
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
34
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
5+
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
6+
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
7+
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
8+
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
9+
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
10+
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
411
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
512
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
613
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@@ -10,8 +17,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
1017
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
1118
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
1219
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
13-
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
1420
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
21+
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
22+
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
1523
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
1624
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
1725
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

task.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ import (
1212
type AsyncFunc[T any] func(context.Context) (*T, error)
1313

1414
// ActionToFunc convert a Action to Func (C# term), to satisfy the AsyncFunc interface.
15-
// Action is function that runs without return anything
16-
// Func is function that runs and return something
15+
// - Action is function that runs without return anything
16+
// - Func is function that runs and return something
1717
func ActionToFunc(action func(context.Context) error) func(context.Context) (*interface{}, error) {
1818
return func(ctx context.Context) (*interface{}, error) {
1919
return nil, action(ctx)
@@ -134,7 +134,7 @@ func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task fu
134134
defer record.waitGroup.Done()
135135
defer func() {
136136
if r := recover(); r != nil {
137-
err := fmt.Errorf("Panic cought: %v, StackTrace: %s, %w", r, debug.Stack(), ErrPanic)
137+
err := fmt.Errorf("panic cought: %v, stackTrace: %s, %w", r, debug.Stack(), ErrPanic)
138138
record.finish(StateFailed, nil, err)
139139
}
140140
}()

wait_all.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ func WaitAll(ctx context.Context, options *WaitAllOptions, tasks ...Waitable) er
2727
options = &WaitAllOptions{}
2828
}
2929

30+
// tried to close channel before exit this func,
31+
// but it's complicated with routines, and we don't want to delay the return.
32+
// per https://stackoverflow.com/questions/8593645/is-it-ok-to-leave-a-channel-open, its ok to leave channel open, eventually it will be garbage collected.
33+
// this assumes the tasks eventually finish, otherwise we will have a routine leak.
3034
errorCh := make(chan error, tasksCount)
31-
// when failFast enabled, we return on first error we see, while other task may still post error in this channel.
32-
if !options.FailFast {
33-
defer close(errorCh)
34-
}
3535

3636
for _, tsk := range tasks {
3737
go waitOne(ctx, tsk, errorCh)

wait_all_test.go

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ import (
1313

1414
func TestWaitAll(t *testing.T) {
1515
t.Parallel()
16-
ctx, cancelFunc := newTestContextWithTimeout(t, 2*time.Second)
17-
defer cancelFunc()
16+
ctx, cancelTaskExecution := newTestContextWithTimeout(t, 2*time.Second)
1817

1918
start := time.Now()
2019
countingTsk1 := asynctask.Start(ctx, getCountingTask(10, "countingPer40ms", 40*time.Millisecond))
@@ -26,13 +25,15 @@ func TestWaitAll(t *testing.T) {
2625
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk1, countingTsk2, countingTsk3, completedTsk)
2726
elapsed := time.Since(start)
2827
assert.NoError(t, err)
28+
cancelTaskExecution()
29+
2930
// should only finish after longest task.
3031
assert.True(t, elapsed > 10*40*time.Millisecond, fmt.Sprintf("actually elapsed: %v", elapsed))
3132
}
3233

3334
func TestWaitAllFailFastCase(t *testing.T) {
3435
t.Parallel()
35-
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
36+
ctx, cancelTaskExecution := newTestContextWithTimeout(t, 3*time.Second)
3637

3738
start := time.Now()
3839
countingTsk := asynctask.Start(ctx, getCountingTask(10, "countingPer40ms", 40*time.Millisecond))
@@ -44,9 +45,10 @@ func TestWaitAllFailFastCase(t *testing.T) {
4445
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk, errorTsk, panicTsk, completedTsk)
4546
countingTskState := countingTsk.State()
4647
panicTskState := countingTsk.State()
48+
errTskState := errorTsk.State()
4749
elapsed := time.Since(start)
4850

49-
cancelFunc() // all assertion variable captured, cancel counting task
51+
cancelTaskExecution() // all assertion variable captured, cancel counting task
5052

5153
assert.Error(t, err)
5254
assert.Equal(t, "expected error", err.Error())
@@ -56,6 +58,7 @@ func TestWaitAllFailFastCase(t *testing.T) {
5658
// since we pass FailFast, countingTsk and panicTsk should be still running
5759
assert.Equal(t, asynctask.StateRunning, countingTskState)
5860
assert.Equal(t, asynctask.StateRunning, panicTskState)
61+
assert.Equal(t, asynctask.StateFailed, errTskState, "error task should the one failed the waitAll.")
5962

6063
// counting task do testing.Logf in another go routine
6164
// while testing.Logf would cause DataRace error when test is already finished: https://github.com/golang/go/issues/40343
@@ -65,8 +68,7 @@ func TestWaitAllFailFastCase(t *testing.T) {
6568

6669
func TestWaitAllErrorCase(t *testing.T) {
6770
t.Parallel()
68-
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
69-
defer cancelFunc()
71+
ctx, cancelTaskExecution := newTestContextWithTimeout(t, 3*time.Second)
7072

7173
start := time.Now()
7274
countingTsk := asynctask.Start(ctx, getCountingTask(10, "countingPer40ms", 40*time.Millisecond))
@@ -75,31 +77,34 @@ func TestWaitAllErrorCase(t *testing.T) {
7577
result := "something"
7678
completedTsk := asynctask.NewCompletedTask(&result)
7779

78-
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: false}, countingTsk, errorTsk, panicTsk, completedTsk)
80+
err := asynctask.WaitAll(ctx, nil, countingTsk, errorTsk, panicTsk, completedTsk)
7981
countingTskState := countingTsk.State()
8082
panicTskState := panicTsk.State()
83+
errTskState := errorTsk.State()
84+
completedTskState := completedTsk.State()
8185
elapsed := time.Since(start)
8286

83-
cancelFunc() // all assertion variable captured, cancel counting task
87+
cancelTaskExecution() // all assertion variable captured, cancel counting task
8488

8589
assert.Error(t, err)
86-
assert.Equal(t, "expected error", err.Error())
90+
assert.Equal(t, "expected error", err.Error(), "expecting first error")
8791
// should only finish after longest task.
8892
assert.True(t, elapsed > 10*40*time.Millisecond, fmt.Sprintf("actually elapsed: %v", elapsed))
8993

9094
assert.Equal(t, asynctask.StateCompleted, countingTskState, "countingTask should finished")
95+
assert.Equal(t, asynctask.StateFailed, errTskState, "error task should failed")
9196
assert.Equal(t, asynctask.StateFailed, panicTskState, "panic task should failed")
97+
assert.Equal(t, asynctask.StateCompleted, completedTskState, "completed task should finished")
9298

9399
// counting task do testing.Logf in another go routine
94100
// while testing.Logf would cause DataRace error when test is already finished: https://github.com/golang/go/issues/40343
95101
// wait minor time for the go routine to finish.
96102
time.Sleep(1 * time.Millisecond)
97103
}
98104

99-
func TestWaitAllCanceled(t *testing.T) {
105+
func TestWaitAllFailFastCancelingWait(t *testing.T) {
100106
t.Parallel()
101-
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
102-
defer cancelFunc()
107+
ctx, cancelTaskExecution := newTestContextWithTimeout(t, 3*time.Second)
103108

104109
start := time.Now()
105110
countingTsk1 := asynctask.Start(ctx, getCountingTask(10, "countingPer40ms", 40*time.Millisecond))
@@ -108,25 +113,68 @@ func TestWaitAllCanceled(t *testing.T) {
108113
result := "something"
109114
completedTsk := asynctask.NewCompletedTask(&result)
110115

111-
waitCtx, cancelFunc1 := context.WithTimeout(ctx, 5*time.Millisecond)
112-
defer cancelFunc1()
116+
waitCtx, cancelWait := context.WithTimeout(ctx, 5*time.Millisecond)
117+
defer cancelWait()
113118

114-
elapsed := time.Since(start)
115119
err := asynctask.WaitAll(waitCtx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk1, countingTsk2, countingTsk3, completedTsk)
116-
117-
cancelFunc() // all assertion variable captured, cancel counting task
120+
elapsed := time.Since(start)
121+
countingTsk1State := countingTsk1.State()
122+
countingTsk2State := countingTsk2.State()
123+
countingTsk3State := countingTsk3.State()
124+
completedTskState := completedTsk.State()
125+
cancelTaskExecution() // all assertion variable captured, cancel task execution
118126

119127
assert.Error(t, err)
120128
assert.True(t, errors.Is(err, context.DeadlineExceeded))
121129
// should return before first task
122130
assert.True(t, elapsed < 10*2*time.Millisecond)
131+
assert.Equal(t, countingTsk1State, asynctask.StateRunning)
132+
assert.Equal(t, countingTsk2State, asynctask.StateRunning)
133+
assert.Equal(t, countingTsk3State, asynctask.StateRunning)
134+
assert.Equal(t, completedTskState, asynctask.StateCompleted)
123135

124136
// counting task do testing.Logf in another go routine
125137
// while testing.Logf would cause DataRace error when test is already finished: https://github.com/golang/go/issues/40343
126138
// wait minor time for the go routine to finish.
127139
time.Sleep(1 * time.Millisecond)
128140
}
129141

142+
func TestWaitAllCancelingWait(t *testing.T) {
143+
t.Parallel()
144+
145+
ctx, cancelTaskExecution := newTestContextWithTimeout(t, 4*time.Millisecond)
146+
147+
start := time.Now()
148+
rcCtx, rcCancel := context.WithCancel(context.Background())
149+
uncontrollableTask := asynctask.Start(ctx, getUncontrollableTask(rcCtx, t))
150+
countingTsk1 := asynctask.Start(ctx, getCountingTask(10, "countingPer40ms", 40*time.Millisecond))
151+
countingTsk2 := asynctask.Start(ctx, getCountingTask(10, "countingPer20ms", 20*time.Millisecond))
152+
countingTsk3 := asynctask.Start(ctx, getCountingTask(10, "countingPer2ms", 2*time.Millisecond))
153+
result := "something"
154+
completedTsk := asynctask.NewCompletedTask(&result)
155+
156+
waitCtx, cancelWait := context.WithTimeout(ctx, 5*time.Millisecond)
157+
defer cancelWait()
158+
159+
err := asynctask.WaitAll(waitCtx, nil, countingTsk1, countingTsk2, countingTsk3, completedTsk, uncontrollableTask)
160+
elapsed := time.Since(start)
161+
t.Logf("WaitAll finished, elapsed: %v", elapsed)
162+
cancelTaskExecution() // all assertion variable captured, cancel counting task
163+
164+
assert.Error(t, err)
165+
assert.True(t, errors.Is(err, context.DeadlineExceeded))
166+
// should return before first task
167+
assert.True(t, elapsed < 10*2*time.Millisecond)
168+
169+
// cancel the remote control context to stop the uncontrollable task, or goleak.VerifyNone will fail.
170+
rcCancel()
171+
172+
// counting task do testing.Logf in another go routine
173+
// while testing.Logf would cause DataRace error when test is already finished: https://github.com/golang/go/issues/40343
174+
// wait minor time for the go routine to finish.
175+
time.Sleep(50 * time.Millisecond)
176+
}
177+
130178
func TestWaitAllWithNoTasks(t *testing.T) {
131179
t.Parallel()
132180
ctx, cancelFunc := newTestContextWithTimeout(t, 1*time.Millisecond)
@@ -135,3 +183,20 @@ func TestWaitAllWithNoTasks(t *testing.T) {
135183
err := asynctask.WaitAll(ctx, nil)
136184
assert.NoError(t, err)
137185
}
186+
187+
// getUncontrollableTask return a task that is not honor context, it only hornor the remoteControl context.
188+
func getUncontrollableTask(rcCtx context.Context, t *testing.T) asynctask.AsyncFunc[int] {
189+
return func(ctx context.Context) (*int, error) {
190+
for {
191+
select {
192+
case <-time.After(1 * time.Millisecond):
193+
if err := ctx.Err(); err != nil {
194+
t.Logf("[UncontrollableTask]: context %s, but not honoring it.", err)
195+
}
196+
case <-rcCtx.Done():
197+
t.Logf("[UncontrollableTask]: cancelled by remote control")
198+
return nil, rcCtx.Err()
199+
}
200+
}
201+
}
202+
}

0 commit comments

Comments
 (0)