Skip to content

Commit 1500ff5

Browse files
authored
implement waitAll (#5)
* attempt to implement waitAll * WaitAll with options * cover context cancel case in WaitAll * update test comments * waitAll * prevent unittest from stuck * limit all test to finish in 3 seconds * tweaks * tweaks * tweak * - distinguash softClose and hardClose - waitOne as a function
1 parent fc5fa4c commit 1500ff5

File tree

4 files changed

+233
-25
lines changed

4 files changed

+233
-25
lines changed

async_task_test.go

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,18 @@ func newTestContext(t *testing.T) context.Context {
1717
return context.WithValue(context.TODO(), testContextKey, t)
1818
}
1919

20-
func getCountingTask(sleepDuration time.Duration) asynctask.AsyncFunc {
20+
func newTestContextWithTimeout(t *testing.T, timeout time.Duration) (context.Context, context.CancelFunc) {
21+
return context.WithTimeout(context.WithValue(context.TODO(), testContextKey, t), timeout)
22+
}
23+
24+
func getCountingTask(countTo int, sleepInterval time.Duration) asynctask.AsyncFunc {
2125
return func(ctx context.Context) (interface{}, error) {
2226
t := ctx.Value(testContextKey).(*testing.T)
2327

2428
result := 0
25-
for i := 0; i < 10; i++ {
29+
for i := 0; i < countTo; i++ {
2630
select {
27-
case <-time.After(sleepDuration):
31+
case <-time.After(sleepInterval):
2832
t.Logf(" working %d", i)
2933
result = i
3034
case <-ctx.Done():
@@ -38,9 +42,10 @@ func getCountingTask(sleepDuration time.Duration) asynctask.AsyncFunc {
3842

3943
func TestEasyCase(t *testing.T) {
4044
t.Parallel()
41-
ctx := newTestContext(t)
42-
t1 := asynctask.Start(ctx, getCountingTask(200*time.Millisecond))
45+
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
46+
defer cancelFunc()
4347

48+
t1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
4449
assert.Equal(t, asynctask.StateRunning, t1.State(), "Task should queued to Running")
4550

4651
rawResult, err := t1.Wait(ctx)
@@ -61,14 +66,15 @@ func TestEasyCase(t *testing.T) {
6166
result = rawResult.(int)
6267
assert.Equal(t, result, 9)
6368

64-
assert.True(t, elapsed.Microseconds() < 2, "Second wait should return immediately")
69+
assert.True(t, elapsed.Microseconds() < 3, "Second wait should return immediately")
6570
}
6671

6772
func TestCancelFunc(t *testing.T) {
6873
t.Parallel()
69-
ctx := newTestContext(t)
70-
t1 := asynctask.Start(ctx, getCountingTask(200*time.Millisecond))
74+
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
75+
defer cancelFunc()
7176

77+
t1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
7278
assert.Equal(t, asynctask.StateRunning, t1.State(), "Task should queued to Running")
7379

7480
time.Sleep(time.Second * 1)
@@ -98,10 +104,11 @@ func TestCancelFunc(t *testing.T) {
98104

99105
func TestConsistentResultAfterCancel(t *testing.T) {
100106
t.Parallel()
101-
ctx := newTestContext(t)
102-
t1 := asynctask.Start(ctx, getCountingTask(200*time.Millisecond))
103-
t2 := asynctask.Start(ctx, getCountingTask(200*time.Millisecond))
107+
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
108+
defer cancelFunc()
104109

110+
t1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
111+
t2 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
105112
assert.Equal(t, asynctask.StateRunning, t1.State(), "Task should queued to Running")
106113

107114
time.Sleep(time.Second * 1)
@@ -126,6 +133,8 @@ func TestConsistentResultAfterCancel(t *testing.T) {
126133

127134
func TestCompletedTask(t *testing.T) {
128135
t.Parallel()
136+
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
137+
defer cancelFunc()
129138

130139
tsk := asynctask.NewCompletedTask()
131140
assert.Equal(t, asynctask.StateCompleted, tsk.State(), "Task should in CompletedState")
@@ -135,19 +144,21 @@ func TestCompletedTask(t *testing.T) {
135144
assert.Equal(t, asynctask.StateCompleted, tsk.State(), "Task should still in CompletedState")
136145

137146
// you get nil result and nil error
138-
result, err := tsk.Wait(context.TODO())
147+
result, err := tsk.Wait(ctx)
139148
assert.Equal(t, asynctask.StateCompleted, tsk.State(), "Task should still in CompletedState")
140149
assert.NoError(t, err)
141150
assert.Nil(t, result)
142151
}
143152

144153
func TestCrazyCase(t *testing.T) {
145154
t.Parallel()
146-
ctx := newTestContext(t)
147-
numOfTasks := 10000
155+
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
156+
defer cancelFunc()
157+
158+
numOfTasks := 8000 // if you have --race switch on: limit on 8128 simultaneously alive goroutines is exceeded, dying
148159
tasks := map[int]*asynctask.TaskStatus{}
149160
for i := 0; i < numOfTasks; i++ {
150-
tasks[i] = asynctask.Start(ctx, getCountingTask(200*time.Millisecond))
161+
tasks[i] = asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
151162
}
152163

153164
time.Sleep(200 * time.Millisecond)

error_test.go

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,19 @@ func getPanicTask(sleepDuration time.Duration) asynctask.AsyncFunc {
2929
}
3030
}
3131

32-
func getErrorTask(sleepDuration time.Duration) asynctask.AsyncFunc {
32+
func getErrorTask(errorString string, sleepDuration time.Duration) asynctask.AsyncFunc {
3333
return func(ctx context.Context) (interface{}, error) {
3434
time.Sleep(sleepDuration)
35-
return nil, errors.New("not found")
35+
return nil, errors.New(errorString)
3636
}
3737
}
3838

3939
func TestTimeoutCase(t *testing.T) {
4040
t.Parallel()
41-
ctx := newTestContext(t)
42-
tsk := asynctask.Start(ctx, getCountingTask(200*time.Millisecond))
41+
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
42+
defer cancelFunc()
43+
44+
tsk := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
4345
_, err := tsk.WaitWithTimeout(ctx, 300*time.Millisecond)
4446
assert.True(t, errors.Is(err, context.DeadlineExceeded), "expecting DeadlineExceeded")
4547

@@ -57,25 +59,31 @@ func TestTimeoutCase(t *testing.T) {
5759

5860
func TestPanicCase(t *testing.T) {
5961
t.Parallel()
60-
ctx := newTestContext(t)
62+
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
63+
defer cancelFunc()
64+
6165
tsk := asynctask.Start(ctx, getPanicTask(200*time.Millisecond))
6266
_, err := tsk.WaitWithTimeout(ctx, 300*time.Millisecond)
6367
assert.True(t, errors.Is(err, asynctask.ErrPanic), "expecting ErrPanic")
6468
}
6569

6670
func TestErrorCase(t *testing.T) {
6771
t.Parallel()
68-
ctx := newTestContext(t)
69-
tsk := asynctask.Start(ctx, getErrorTask(200*time.Millisecond))
72+
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
73+
defer cancelFunc()
74+
75+
tsk := asynctask.Start(ctx, getErrorTask("dummy error", 200*time.Millisecond))
7076
_, err := tsk.WaitWithTimeout(ctx, 300*time.Millisecond)
7177
assert.Error(t, err)
7278
assert.False(t, errors.Is(err, asynctask.ErrPanic), "not expecting ErrPanic")
7379
assert.False(t, errors.Is(err, context.DeadlineExceeded), "not expecting DeadlineExceeded")
74-
assert.Equal(t, "not found", err.Error())
80+
assert.Equal(t, "dummy error", err.Error())
7581
}
7682

7783
func TestPointerErrorCase(t *testing.T) {
7884
t.Parallel()
85+
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
86+
defer cancelFunc()
7987

8088
// nil point of a type that implement error
8189
var pe *pointerError = nil
@@ -84,7 +92,6 @@ func TestPointerErrorCase(t *testing.T) {
8492
// now you get a non-nil error
8593
assert.False(t, err == nil, "reason this test is needed")
8694

87-
ctx := newTestContext(t)
8895
tsk := asynctask.Start(ctx, func(ctx context.Context) (interface{}, error) {
8996
time.Sleep(100 * time.Millisecond)
9097
var pe *pointerError = nil
@@ -98,6 +105,8 @@ func TestPointerErrorCase(t *testing.T) {
98105

99106
func TestStructErrorCase(t *testing.T) {
100107
t.Parallel()
108+
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
109+
defer cancelFunc()
101110

102111
// nil point of a type that implement error
103112
var se structError
@@ -106,7 +115,6 @@ func TestStructErrorCase(t *testing.T) {
106115
// now you get a non-nil error
107116
assert.False(t, err == nil, "reason this test is needed")
108117

109-
ctx := newTestContext(t)
110118
tsk := asynctask.Start(ctx, func(ctx context.Context) (interface{}, error) {
111119
time.Sleep(100 * time.Millisecond)
112120
var se structError

wait_all.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package asynctask
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
)
8+
9+
// WaitAllOptions defines options for WaitAll function
10+
type WaitAllOptions struct {
11+
// FailFast set to true will indicate WaitAll to return on first error it sees.
12+
FailFast bool
13+
}
14+
15+
// WaitAll block current thread til all task finished.
16+
// first error from any tasks passed in will be returned.
17+
func WaitAll(ctx context.Context, options *WaitAllOptions, tasks ...*TaskStatus) error {
18+
tasksCount := len(tasks)
19+
20+
mutex := sync.Mutex{}
21+
errorChClosed := false
22+
errorCh := make(chan error, tasksCount)
23+
// hard close channel
24+
defer close(errorCh)
25+
26+
for _, tsk := range tasks {
27+
go waitOne(ctx, tsk, errorCh, &errorChClosed, &mutex)
28+
}
29+
30+
runningTasks := tasksCount
31+
var errList []error
32+
for {
33+
select {
34+
case err := <-errorCh:
35+
runningTasks--
36+
if err != nil {
37+
// return immediately after receive first error.
38+
if options.FailFast {
39+
softCloseChannel(&mutex, &errorChClosed)
40+
return err
41+
}
42+
43+
errList = append(errList, err)
44+
}
45+
case <-ctx.Done():
46+
softCloseChannel(&mutex, &errorChClosed)
47+
return fmt.Errorf("WaitAll context canceled: %w", ctx.Err())
48+
}
49+
50+
// are we finished yet?
51+
if runningTasks == 0 {
52+
softCloseChannel(&mutex, &errorChClosed)
53+
break
54+
}
55+
}
56+
57+
// we have at least 1 error, return first one.
58+
// caller can get error for individual task by using Wait(),
59+
// it would return immediately after this WaitAll()
60+
if len(errList) > 0 {
61+
return errList[0]
62+
}
63+
64+
// no error at all.
65+
return nil
66+
}
67+
68+
func waitOne(ctx context.Context, tsk *TaskStatus, errorCh chan<- error, errorChClosed *bool, mutex *sync.Mutex) {
69+
_, err := tsk.Wait(ctx)
70+
71+
// why mutex?
72+
// if all tasks start using same context (unittest is good example)
73+
// and that context got canceled, all task fail at same time.
74+
// first one went in and close the channel, while another one already went through gate check.
75+
// raise a panic with send to closed channel.
76+
mutex.Lock()
77+
defer mutex.Unlock()
78+
if !*errorChClosed {
79+
errorCh <- err
80+
}
81+
}
82+
83+
func softCloseChannel(mutex *sync.Mutex, closed *bool) {
84+
mutex.Lock()
85+
defer mutex.Unlock()
86+
*closed = true
87+
}

wait_all_test.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package asynctask_test
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
"time"
8+
9+
"github.com/Azure/go-asynctask"
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
func TestWaitAll(t *testing.T) {
14+
t.Parallel()
15+
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
16+
defer cancelFunc()
17+
18+
countingTsk1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
19+
countingTsk2 := asynctask.Start(ctx, getCountingTask(10, 20*time.Millisecond))
20+
countingTsk3 := asynctask.Start(ctx, getCountingTask(10, 2*time.Millisecond))
21+
completedTsk := asynctask.NewCompletedTask()
22+
23+
start := time.Now()
24+
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk1, countingTsk2, countingTsk3, completedTsk)
25+
elapsed := time.Since(start)
26+
assert.NoError(t, err)
27+
// should only finish after longest task.
28+
assert.True(t, elapsed > 10*200*time.Millisecond)
29+
}
30+
31+
func TestWaitAllFailFastCase(t *testing.T) {
32+
t.Parallel()
33+
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
34+
defer cancelFunc()
35+
36+
countingTsk := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
37+
errorTsk := asynctask.Start(ctx, getErrorTask("expected error", 10*time.Millisecond))
38+
panicTsk := asynctask.Start(ctx, getPanicTask(20*time.Millisecond))
39+
completedTsk := asynctask.NewCompletedTask()
40+
41+
start := time.Now()
42+
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk, errorTsk, panicTsk, completedTsk)
43+
countingTskState := countingTsk.State()
44+
panicTskState := countingTsk.State()
45+
elapsed := time.Since(start)
46+
assert.Error(t, err)
47+
assert.Equal(t, "expected error", err.Error())
48+
// should fail before we finish panic task
49+
assert.True(t, elapsed.Milliseconds() < 15)
50+
51+
// since we pass FailFast, countingTsk and panicTsk should be still running
52+
assert.Equal(t, asynctask.StateRunning, countingTskState)
53+
assert.Equal(t, asynctask.StateRunning, panicTskState)
54+
}
55+
56+
func TestWaitAllErrorCase(t *testing.T) {
57+
t.Parallel()
58+
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
59+
defer cancelFunc()
60+
61+
countingTsk := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
62+
errorTsk := asynctask.Start(ctx, getErrorTask("expected error", 10*time.Millisecond))
63+
panicTsk := asynctask.Start(ctx, getPanicTask(20*time.Millisecond))
64+
completedTsk := asynctask.NewCompletedTask()
65+
66+
start := time.Now()
67+
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: false}, countingTsk, errorTsk, panicTsk, completedTsk)
68+
countingTskState := countingTsk.State()
69+
panicTskState := panicTsk.State()
70+
elapsed := time.Since(start)
71+
assert.Error(t, err)
72+
assert.Equal(t, "expected error", err.Error())
73+
// should only finish after longest task.
74+
assert.True(t, elapsed > 10*200*time.Millisecond)
75+
76+
// since we pass FailFast, countingTsk and panicTsk should be still running
77+
assert.Equal(t, asynctask.StateCompleted, countingTskState, "countingTask should finished")
78+
assert.Equal(t, asynctask.StateFailed, panicTskState, "panic task should failed")
79+
}
80+
81+
func TestWaitAllCanceled(t *testing.T) {
82+
t.Parallel()
83+
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
84+
defer cancelFunc()
85+
86+
countingTsk1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
87+
countingTsk2 := asynctask.Start(ctx, getCountingTask(10, 20*time.Millisecond))
88+
countingTsk3 := asynctask.Start(ctx, getCountingTask(10, 2*time.Millisecond))
89+
completedTsk := asynctask.NewCompletedTask()
90+
91+
waitCtx, cancelFunc1 := context.WithTimeout(ctx, 5*time.Millisecond)
92+
defer cancelFunc1()
93+
94+
start := time.Now()
95+
err := asynctask.WaitAll(waitCtx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk1, countingTsk2, countingTsk3, completedTsk)
96+
elapsed := time.Since(start)
97+
assert.Error(t, err)
98+
t.Log(err.Error())
99+
assert.True(t, errors.Is(err, context.DeadlineExceeded))
100+
// should return before first task
101+
assert.True(t, elapsed < 10*2*time.Millisecond)
102+
}

0 commit comments

Comments
 (0)