Skip to content

Commit 442a02a

Browse files
authored
remove restriction on return pointer value of TypeParameter (#33)
* remove pointer bind * Change sync.Mutex to sync.RWMutex in Task struct * Update go.mod to use go 1.20 * tweaks
1 parent 4c78f32 commit 442a02a

10 files changed

+83
-73
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77

88
Simple mimik of async/await for those come from C# world, so you don't need to dealing with waitGroup/channel in golang.
99

10+
also the result is strongTyped with go generics, no type assertion is needed.
11+
12+
few chaining method provided:
13+
- ContinueWith: send task1's output to task2 as input, return reference to task2.
14+
- AfterBoth : send output of taskA, taskB to taskC as input, return reference to taskC.
15+
- WaitAll: all of the task have to finish to end the wait (with an option to fail early if any task failed)
16+
- WaitAny: any of the task finish would end the wait
17+
1018
```golang
1119
// start task
1220
task := asynctask.Start(ctx, countingTask)

after_both.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,33 @@ package asynctask
33
import "context"
44

55
// AfterBothFunc is a function that has 2 input.
6-
type AfterBothFunc[T, S, R any] func(context.Context, *T, *S) (*R, error)
6+
type AfterBothFunc[T, S, R any] func(context.Context, T, S) (R, error)
77

88
// AfterBoth runs the function after both 2 input task finished, and will be fed with result from 2 input task.
9-
// if one of the input task failed, the AfterBoth task will be failed and returned, even other one are still running.
9+
//
10+
// if one of the input task failed, the AfterBoth task will be failed and returned, even other one are still running.
1011
func AfterBoth[T, S, R any](ctx context.Context, tskT *Task[T], tskS *Task[S], next AfterBothFunc[T, S, R]) *Task[R] {
11-
return Start(ctx, func(fCtx context.Context) (*R, error) {
12+
return Start(ctx, func(fCtx context.Context) (R, error) {
1213
t, err := tskT.Result(fCtx)
1314
if err != nil {
14-
return nil, err
15+
return *new(R), err
1516
}
1617

1718
s, err := tskS.Result(fCtx)
1819
if err != nil {
19-
return nil, err
20+
return *new(R), err
2021
}
2122

2223
return next(fCtx, t, s)
2324
})
2425
}
2526

2627
// AfterBothActionToFunc convert a Action to Func (C# term), to satisfy the AfterBothFunc interface.
27-
// Action is function that runs without return anything
28-
// Func is function that runs and return something
29-
func AfterBothActionToFunc[T, S any](action func(context.Context, *T, *S) error) func(context.Context, *T, *S) (*interface{}, error) {
30-
return func(ctx context.Context, t *T, s *S) (*interface{}, error) {
28+
//
29+
// Action is function that runs without return anything
30+
// Func is function that runs and return something
31+
func AfterBothActionToFunc[T, S any](action func(context.Context, T, S) error) func(context.Context, T, S) (interface{}, error) {
32+
return func(ctx context.Context, t T, s S) (interface{}, error) {
3133
return nil, action(ctx, t, s)
3234
}
3335
}

after_both_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ import (
99
"github.com/stretchr/testify/assert"
1010
)
1111

12-
func summarize2CountingTask(ctx context.Context, result1, result2 *int) (*int, error) {
12+
func summarize2CountingTask(ctx context.Context, result1, result2 int) (int, error) {
1313
t := ctx.Value(testContextKey).(*testing.T)
1414
t.Logf("result1: %d", result1)
1515
t.Logf("result2: %d", result2)
16-
sum := *result1 + *result2
16+
sum := result1 + result2
1717
t.Logf("sum: %d", sum)
18-
return &sum, nil
18+
return sum, nil
1919
}
2020

2121
func TestAfterBoth(t *testing.T) {
@@ -28,7 +28,7 @@ func TestAfterBoth(t *testing.T) {
2828
sum, err := t3.Result(ctx)
2929
assert.NoError(t, err)
3030
assert.Equal(t, asynctask.StateCompleted, t3.State(), "Task should complete with no error")
31-
assert.Equal(t, *sum, 18, "Sum should be 18")
31+
assert.Equal(t, sum, 18, "Sum should be 18")
3232
}
3333

3434
func TestAfterBothFailureCase(t *testing.T) {
@@ -56,11 +56,11 @@ func TestAfterBothActionToFunc(t *testing.T) {
5656

5757
countingTask1 := asynctask.Start(ctx, getCountingTask(10, "afterboth.P1", 20*time.Millisecond))
5858
countingTask2 := asynctask.Start(ctx, getCountingTask(10, "afterboth.P2", 20*time.Millisecond))
59-
t2 := asynctask.AfterBoth(ctx, countingTask1, countingTask2, asynctask.AfterBothActionToFunc(func(ctx context.Context, result1, result2 *int) error {
59+
t2 := asynctask.AfterBoth(ctx, countingTask1, countingTask2, asynctask.AfterBothActionToFunc(func(ctx context.Context, result1, result2 int) error {
6060
t := ctx.Value(testContextKey).(*testing.T)
6161
t.Logf("result1: %d", result1)
6262
t.Logf("result2: %d", result2)
63-
sum := *result1 + *result2
63+
sum := result1 + result2
6464
t.Logf("sum: %d", sum)
6565
return nil
6666
}))

continue_with.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,24 @@ package asynctask
33
import "context"
44

55
// ContinueFunc is a function that can be connected to previous task with ContinueWith
6-
type ContinueFunc[T any, S any] func(context.Context, *T) (*S, error)
6+
type ContinueFunc[T any, S any] func(context.Context, T) (S, error)
77

88
func ContinueWith[T any, S any](ctx context.Context, tsk *Task[T], next ContinueFunc[T, S]) *Task[S] {
9-
return Start(ctx, func(fCtx context.Context) (*S, error) {
9+
return Start(ctx, func(fCtx context.Context) (S, error) {
1010
result, err := tsk.Result(fCtx)
1111
if err != nil {
12-
return nil, err
12+
return *new(S), err
1313
}
1414
return next(fCtx, result)
1515
})
1616
}
1717

1818
// ContinueActionToFunc convert a Action to Func (C# term), to satisfy the AsyncFunc interface.
19-
// Action is function that runs without return anything
20-
// Func is function that runs and return something
21-
func ContinueActionToFunc[T any](action func(context.Context, *T) error) func(context.Context, *T) (*interface{}, error) {
22-
return func(ctx context.Context, t *T) (*interface{}, error) {
19+
//
20+
// Action is function that runs without return anything
21+
// Func is function that runs and return something
22+
func ContinueActionToFunc[T any](action func(context.Context, T) error) func(context.Context, T) (interface{}, error) {
23+
return func(ctx context.Context, t T) (interface{}, error) {
2324
return nil, action(ctx, t)
2425
}
2526
}

continue_with_test.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
)
1212

1313
func getAdvancedCountingTask(countFrom int, step int, sleepInterval time.Duration) asynctask.AsyncFunc[int] {
14-
return func(ctx context.Context) (*int, error) {
14+
return func(ctx context.Context) (int, error) {
1515
t := ctx.Value(testContextKey).(*testing.T)
1616

1717
result := countFrom
@@ -22,47 +22,47 @@ func getAdvancedCountingTask(countFrom int, step int, sleepInterval time.Duratio
2222
result++
2323
case <-ctx.Done():
2424
t.Log("work canceled")
25-
return &result, nil
25+
return result, nil
2626
}
2727
}
28-
return &result, nil
28+
return result, nil
2929
}
3030
}
3131

3232
func TestContinueWith(t *testing.T) {
3333
t.Parallel()
3434
ctx := newTestContext(t)
3535
t1 := asynctask.Start(ctx, getAdvancedCountingTask(0, 10, 20*time.Millisecond))
36-
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
37-
fromPrevTsk := *input
36+
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) {
37+
fromPrevTsk := input
3838
return getAdvancedCountingTask(fromPrevTsk, 10, 20*time.Millisecond)(fCtx)
3939
})
40-
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
41-
fromPrevTsk := *input
40+
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) {
41+
fromPrevTsk := input
4242
return getAdvancedCountingTask(fromPrevTsk, 12, 20*time.Millisecond)(fCtx)
4343
})
4444

4545
result, err := t2.Result(ctx)
4646
assert.NoError(t, err)
4747
assert.Equal(t, asynctask.StateCompleted, t2.State(), "Task should complete with no error")
48-
assert.Equal(t, *result, 20)
48+
assert.Equal(t, result, 20)
4949

5050
result, err = t3.Result(ctx)
5151
assert.NoError(t, err)
5252
assert.Equal(t, asynctask.StateCompleted, t3.State(), "Task should complete with no error")
53-
assert.Equal(t, *result, 22)
53+
assert.Equal(t, result, 22)
5454
}
5555

5656
func TestContinueWithFailureCase(t *testing.T) {
5757
t.Parallel()
5858
ctx := newTestContext(t)
5959
t1 := asynctask.Start(ctx, getErrorTask("devide by 0", 10*time.Millisecond))
60-
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
61-
fromPrevTsk := *input
60+
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) {
61+
fromPrevTsk := input
6262
return getAdvancedCountingTask(fromPrevTsk, 10, 20*time.Millisecond)(fCtx)
6363
})
64-
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
65-
fromPrevTsk := *input
64+
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) {
65+
fromPrevTsk := input
6666
return getAdvancedCountingTask(fromPrevTsk, 12, 20*time.Millisecond)(fCtx)
6767
})
6868

error_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@ import (
1111
)
1212

1313
func getPanicTask(sleepDuration time.Duration) asynctask.AsyncFunc[string] {
14-
return func(ctx context.Context) (*string, error) {
14+
return func(ctx context.Context) (string, error) {
1515
time.Sleep(sleepDuration)
1616
panic("yo")
1717
}
1818
}
1919

2020
func getErrorTask(errorString string, sleepDuration time.Duration) asynctask.AsyncFunc[int] {
21-
return func(ctx context.Context) (*int, error) {
21+
return func(ctx context.Context) (int, error) {
2222
time.Sleep(sleepDuration)
23-
return nil, errors.New(errorString)
23+
return 0, errors.New(errorString)
2424
}
2525
}
2626

@@ -37,12 +37,12 @@ func TestTimeoutCase(t *testing.T) {
3737
// I can continue wait with longer time
3838
rawResult, err := tsk.WaitWithTimeout(ctx, 2*time.Second)
3939
assert.NoError(t, err)
40-
assert.Equal(t, 9, *rawResult)
40+
assert.Equal(t, 9, rawResult)
4141

4242
// any following Wait should complete immediately
4343
rawResult, err = tsk.WaitWithTimeout(ctx, 2*time.Nanosecond)
4444
assert.NoError(t, err)
45-
assert.Equal(t, 9, *rawResult)
45+
assert.Equal(t, 9, rawResult)
4646
}
4747

4848
func TestPanicCase(t *testing.T) {

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module github.com/Azure/go-asynctask
22

3-
go 1.19
3+
go 1.20
44

55
require github.com/stretchr/testify v1.8.4
66

task.go

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ import (
99
)
1010

1111
// AsyncFunc is a function interface this asyncTask accepts.
12-
type AsyncFunc[T any] func(context.Context) (*T, error)
12+
type AsyncFunc[T any] func(context.Context) (T, error)
1313

1414
// ActionToFunc convert a Action to Func (C# term), to satisfy the AsyncFunc interface.
1515
// - Action is function that runs without return anything
1616
// - Func is function that runs and return something
17-
func ActionToFunc(action func(context.Context) error) func(context.Context) (*interface{}, error) {
18-
return func(ctx context.Context) (*interface{}, error) {
17+
func ActionToFunc(action func(context.Context) error) func(context.Context) (interface{}, error) {
18+
return func(ctx context.Context) (interface{}, error) {
1919
return nil, action(ctx)
2020
}
2121
}
@@ -24,25 +24,25 @@ func ActionToFunc(action func(context.Context) error) func(context.Context) (*in
2424
// which you can use to wait, cancel, get the result.
2525
type Task[T any] struct {
2626
state State
27-
result *T
27+
result T
2828
err error
2929
cancelFunc context.CancelFunc
3030
waitGroup *sync.WaitGroup
31-
mutex *sync.Mutex
31+
mutex *sync.RWMutex
3232
}
3333

3434
// State return state of the task.
3535
func (t *Task[T]) State() State {
36-
t.mutex.Lock()
37-
defer t.mutex.Unlock()
36+
t.mutex.RLock()
37+
defer t.mutex.RUnlock()
3838
return t.state
3939
}
4040

4141
// Cancel the task by cancel the context.
4242
// !! this rely on the task function to check context cancellation and proper context handling.
4343
func (t *Task[T]) Cancel() bool {
4444
if !t.finished() {
45-
t.finish(StateCanceled, nil, ErrCanceled)
45+
t.finish(StateCanceled, *new(T), ErrCanceled)
4646
return true
4747
}
4848

@@ -74,7 +74,7 @@ func (t *Task[T]) Wait(ctx context.Context) error {
7474

7575
// WaitWithTimeout block current thread/routine until task finished or failed, or exceed the duration specified.
7676
// timeout only stop waiting, taks will remain running.
77-
func (t *Task[T]) WaitWithTimeout(ctx context.Context, timeout time.Duration) (*T, error) {
77+
func (t *Task[T]) WaitWithTimeout(ctx context.Context, timeout time.Duration) (T, error) {
7878
// return immediately if task already in terminal state.
7979
if t.finished() {
8080
return t.result, t.err
@@ -86,11 +86,10 @@ func (t *Task[T]) WaitWithTimeout(ctx context.Context, timeout time.Duration) (*
8686
return t.Result(ctx)
8787
}
8888

89-
func (t *Task[T]) Result(ctx context.Context) (*T, error) {
89+
func (t *Task[T]) Result(ctx context.Context) (T, error) {
9090
err := t.Wait(ctx)
9191
if err != nil {
92-
var result T
93-
return &result, err
92+
return *new(T), err
9493
}
9594

9695
return t.result, t.err
@@ -102,11 +101,11 @@ func Start[T any](ctx context.Context, task AsyncFunc[T]) *Task[T] {
102101
ctx, cancel := context.WithCancel(ctx)
103102
wg := &sync.WaitGroup{}
104103
wg.Add(1)
105-
mutex := &sync.Mutex{}
104+
mutex := &sync.RWMutex{}
106105

107106
record := &Task[T]{
108107
state: StateRunning,
109-
result: nil,
108+
result: *new(T),
110109
cancelFunc: cancel,
111110
waitGroup: wg,
112111
mutex: mutex,
@@ -118,24 +117,24 @@ func Start[T any](ctx context.Context, task AsyncFunc[T]) *Task[T] {
118117
}
119118

120119
// NewCompletedTask returns a Completed task, with result=nil, error=nil
121-
func NewCompletedTask[T any](value *T) *Task[T] {
120+
func NewCompletedTask[T any](value T) *Task[T] {
122121
return &Task[T]{
123122
state: StateCompleted,
124123
result: value,
125124
err: nil,
126125
// nil cancelFunc and waitGroup should be protected with IsTerminalState()
127126
cancelFunc: nil,
128127
waitGroup: nil,
129-
mutex: &sync.Mutex{},
128+
mutex: &sync.RWMutex{},
130129
}
131130
}
132131

133-
func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task func(ctx context.Context) (*T, error)) {
132+
func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task func(ctx context.Context) (T, error)) {
134133
defer record.waitGroup.Done()
135134
defer func() {
136135
if r := recover(); r != nil {
137136
err := fmt.Errorf("panic cought: %v, stackTrace: %s, %w", r, debug.Stack(), ErrPanic)
138-
record.finish(StateFailed, nil, err)
137+
record.finish(StateFailed, *new(T), err)
139138
}
140139
}()
141140

@@ -150,7 +149,7 @@ func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task fu
150149
record.finish(StateFailed, result, err)
151150
}
152151

153-
func (t *Task[T]) finish(state State, result *T, err error) {
152+
func (t *Task[T]) finish(state State, result T, err error) {
154153
// only update state and result if not yet canceled
155154
t.mutex.Lock()
156155
defer t.mutex.Unlock()
@@ -163,7 +162,7 @@ func (t *Task[T]) finish(state State, result *T, err error) {
163162
}
164163

165164
func (t *Task[T]) finished() bool {
166-
t.mutex.Lock()
167-
defer t.mutex.Unlock()
165+
t.mutex.RLock()
166+
defer t.mutex.RUnlock()
168167
return t.state.IsTerminalState()
169168
}

0 commit comments

Comments
 (0)