Skip to content

Commit 2ecae96

Browse files
committed
Fix #238 - Add support to fork task
Signed-off-by: Ricardo Zanini <ricardozanini@gmail.com>
1 parent 46481f6 commit 2ecae96

16 files changed

+373
-101
lines changed

impl/ctx/context.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"encoding/json"
2020
"errors"
2121
"fmt"
22+
"github.com/serverlessworkflow/sdk-go/v3/impl/utils"
2223
"sync"
2324
"time"
2425

@@ -71,6 +72,7 @@ type WorkflowContext interface {
7172
SetLocalExprVars(vars map[string]interface{})
7273
AddLocalExprVars(vars map[string]interface{})
7374
RemoveLocalExprVars(keys ...string)
75+
Clone() WorkflowContext
7476
}
7577

7678
// workflowContext holds the necessary data for the workflow execution within the instance.
@@ -118,6 +120,38 @@ func GetWorkflowContext(ctx context.Context) (WorkflowContext, error) {
118120
return wfCtx, nil
119121
}
120122

123+
func (ctx *workflowContext) Clone() WorkflowContext {
124+
ctx.mu.Lock()
125+
defer ctx.mu.Unlock()
126+
127+
newInput := utils.DeepCloneValue(ctx.input)
128+
newOutput := utils.DeepCloneValue(ctx.output)
129+
130+
// deep clone each of the maps
131+
newContextMap := utils.DeepClone(ctx.context)
132+
newWorkflowDesc := utils.DeepClone(ctx.workflowDescriptor)
133+
newTaskDesc := utils.DeepClone(ctx.taskDescriptor)
134+
newLocalExprVars := utils.DeepClone(ctx.localExprVars)
135+
136+
newStatusPhase := append([]StatusPhaseLog(nil), ctx.StatusPhase...)
137+
138+
newTasksStatusPhase := make(map[string][]StatusPhaseLog, len(ctx.TasksStatusPhase))
139+
for taskName, logs := range ctx.TasksStatusPhase {
140+
newTasksStatusPhase[taskName] = append([]StatusPhaseLog(nil), logs...)
141+
}
142+
143+
return &workflowContext{
144+
input: newInput,
145+
output: newOutput,
146+
context: newContextMap,
147+
workflowDescriptor: newWorkflowDesc,
148+
taskDescriptor: newTaskDesc,
149+
localExprVars: newLocalExprVars,
150+
StatusPhase: newStatusPhase,
151+
TasksStatusPhase: newTasksStatusPhase,
152+
}
153+
}
154+
121155
func (ctx *workflowContext) SetStartedAt(t time.Time) {
122156
ctx.mu.Lock()
123157
defer ctx.mu.Unlock()

impl/expr/expr.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,28 @@ func mergeContextInVars(nodeCtx context.Context, variables map[string]interface{
132132

133133
return nil
134134
}
135+
136+
func TraverseAndEvaluateObj(runtimeExpr *model.ObjectOrRuntimeExpr, input interface{}, taskName string, wfCtx context.Context) (output interface{}, err error) {
137+
if runtimeExpr == nil {
138+
return input, nil
139+
}
140+
output, err = TraverseAndEvaluate(runtimeExpr.AsStringOrMap(), input, wfCtx)
141+
if err != nil {
142+
return nil, model.NewErrExpression(err, taskName)
143+
}
144+
return output, nil
145+
}
146+
147+
func TraverseAndEvaluateBool(runtimeExpr string, input interface{}, wfCtx context.Context) (bool, error) {
148+
if len(runtimeExpr) == 0 {
149+
return false, nil
150+
}
151+
output, err := TraverseAndEvaluate(runtimeExpr, input, wfCtx)
152+
if err != nil {
153+
return false, nil
154+
}
155+
if result, ok := output.(bool); ok {
156+
return result, nil
157+
}
158+
return false, nil
159+
}

impl/runner.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ package impl
1717
import (
1818
"context"
1919
"fmt"
20+
"github.com/serverlessworkflow/sdk-go/v3/impl/expr"
21+
"github.com/serverlessworkflow/sdk-go/v3/impl/utils"
2022
"time"
2123

2224
"github.com/serverlessworkflow/sdk-go/v3/impl/ctx"
@@ -53,6 +55,18 @@ type workflowRunnerImpl struct {
5355
RunnerCtx ctx.WorkflowContext
5456
}
5557

58+
func (wr *workflowRunnerImpl) CloneWithContext(newCtx context.Context) TaskSupport {
59+
clonedWfCtx := wr.RunnerCtx.Clone()
60+
61+
ctxWithWf := ctx.WithWorkflowContext(newCtx, clonedWfCtx)
62+
63+
return &workflowRunnerImpl{
64+
Workflow: wr.Workflow,
65+
Context: ctxWithWf,
66+
RunnerCtx: clonedWfCtx,
67+
}
68+
}
69+
5670
func (wr *workflowRunnerImpl) RemoveLocalExprVars(keys ...string) {
5771
wr.RunnerCtx.RemoveLocalExprVars(keys...)
5872
}
@@ -175,13 +189,13 @@ func (wr *workflowRunnerImpl) wrapWorkflowError(err error) error {
175189
func (wr *workflowRunnerImpl) processInput(input interface{}) (output interface{}, err error) {
176190
if wr.Workflow.Input != nil {
177191
if wr.Workflow.Input.Schema != nil {
178-
if err = validateSchema(input, wr.Workflow.Input.Schema, "/"); err != nil {
192+
if err = utils.ValidateSchema(input, wr.Workflow.Input.Schema, "/"); err != nil {
179193
return nil, err
180194
}
181195
}
182196

183197
if wr.Workflow.Input.From != nil {
184-
output, err = traverseAndEvaluate(wr.Workflow.Input.From, input, "/", wr.Context)
198+
output, err = expr.TraverseAndEvaluateObj(wr.Workflow.Input.From, input, "/", wr.Context)
185199
if err != nil {
186200
return nil, err
187201
}
@@ -196,13 +210,13 @@ func (wr *workflowRunnerImpl) processOutput(output interface{}) (interface{}, er
196210
if wr.Workflow.Output != nil {
197211
if wr.Workflow.Output.As != nil {
198212
var err error
199-
output, err = traverseAndEvaluate(wr.Workflow.Output.As, output, "/", wr.Context)
213+
output, err = expr.TraverseAndEvaluateObj(wr.Workflow.Output.As, output, "/", wr.Context)
200214
if err != nil {
201215
return nil, err
202216
}
203217
}
204218
if wr.Workflow.Output.Schema != nil {
205-
if err := validateSchema(output, wr.Workflow.Output.Schema, "/"); err != nil {
219+
if err := utils.ValidateSchema(output, wr.Workflow.Output.Schema, "/"); err != nil {
206220
return nil, err
207221
}
208222
}

impl/runner_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,3 +456,14 @@ func TestSwitchTaskRunner_DefaultCase(t *testing.T) {
456456
runWorkflowTest(t, workflowPath, input, expectedOutput)
457457
})
458458
}
459+
460+
func TestForkSimple_NoCompete(t *testing.T) {
461+
t.Run("Create a color array", func(t *testing.T) {
462+
workflowPath := "./testdata/fork_simple.yaml"
463+
input := map[string]interface{}{}
464+
expectedOutput := map[string]interface{}{
465+
"colors": []interface{}{"red", "blue"},
466+
}
467+
runWorkflowTest(t, workflowPath, input, expectedOutput)
468+
})
469+
}

impl/task_runner.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,7 @@ type TaskSupport interface {
5353
AddLocalExprVars(vars map[string]interface{})
5454
// RemoveLocalExprVars removes local variables added in AddLocalExprVars or SetLocalExprVars
5555
RemoveLocalExprVars(keys ...string)
56+
// CloneWithContext returns a full clone of this TaskSupport, but using
57+
// the provided context.Context (so deadlines/cancellations propagate).
58+
CloneWithContext(ctx context.Context) TaskSupport
5659
}

impl/task_runner_do.go

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ package impl
1616

1717
import (
1818
"fmt"
19+
"github.com/serverlessworkflow/sdk-go/v3/impl/expr"
20+
"github.com/serverlessworkflow/sdk-go/v3/impl/utils"
1921
"time"
2022

2123
"github.com/serverlessworkflow/sdk-go/v3/impl/ctx"
@@ -35,6 +37,8 @@ func NewTaskRunner(taskName string, task model.Task, workflowDef *model.Workflow
3537
return NewForTaskRunner(taskName, t)
3638
case *model.CallHTTP:
3739
return NewCallHttpRunner(taskName, t)
40+
case *model.ForkTask:
41+
return NewForkTaskRunner(taskName, t, workflowDef)
3842
default:
3943
return nil, fmt.Errorf("unsupported task type '%T' for task '%s'", t, taskName)
4044
}
@@ -117,7 +121,7 @@ func (d *DoTaskRunner) runTasks(input interface{}, taskSupport TaskSupport) (out
117121
}
118122

119123
taskSupport.SetTaskStatus(currentTask.Key, ctx.CompletedStatus)
120-
input = deepCloneValue(output)
124+
input = utils.DeepCloneValue(output)
121125
idx, currentTask = d.TaskList.Next(idx)
122126
}
123127

@@ -126,7 +130,7 @@ func (d *DoTaskRunner) runTasks(input interface{}, taskSupport TaskSupport) (out
126130

127131
func (d *DoTaskRunner) shouldRunTask(input interface{}, taskSupport TaskSupport, task *model.TaskItem) (bool, error) {
128132
if task.GetBase().If != nil {
129-
output, err := traverseAndEvaluateBool(task.GetBase().If.String(), input, taskSupport.GetContext())
133+
output, err := expr.TraverseAndEvaluateBool(task.GetBase().If.String(), input, taskSupport.GetContext())
130134
if err != nil {
131135
return false, model.NewErrExpression(err, task.Key)
132136
}
@@ -143,7 +147,7 @@ func (d *DoTaskRunner) evaluateSwitchTask(input interface{}, taskSupport TaskSup
143147
defaultThen = switchCase.Then
144148
continue
145149
}
146-
result, err := traverseAndEvaluateBool(model.NormalizeExpr(switchCase.When.String()), input, taskSupport.GetContext())
150+
result, err := expr.TraverseAndEvaluateBool(model.NormalizeExpr(switchCase.When.String()), input, taskSupport.GetContext())
147151
if err != nil {
148152
return nil, model.NewErrExpression(err, taskKey)
149153
}
@@ -199,11 +203,11 @@ func (d *DoTaskRunner) processTaskInput(task *model.TaskBase, taskInput interfac
199203
return taskInput, nil
200204
}
201205

202-
if err = validateSchema(taskInput, task.Input.Schema, taskName); err != nil {
206+
if err = utils.ValidateSchema(taskInput, task.Input.Schema, taskName); err != nil {
203207
return nil, err
204208
}
205209

206-
if output, err = traverseAndEvaluate(task.Input.From, taskInput, taskName, taskSupport.GetContext()); err != nil {
210+
if output, err = expr.TraverseAndEvaluateObj(task.Input.From, taskInput, taskName, taskSupport.GetContext()); err != nil {
207211
return nil, err
208212
}
209213

@@ -216,11 +220,11 @@ func (d *DoTaskRunner) processTaskOutput(task *model.TaskBase, taskOutput interf
216220
return taskOutput, nil
217221
}
218222

219-
if output, err = traverseAndEvaluate(task.Output.As, taskOutput, taskName, taskSupport.GetContext()); err != nil {
223+
if output, err = expr.TraverseAndEvaluateObj(task.Output.As, taskOutput, taskName, taskSupport.GetContext()); err != nil {
220224
return nil, err
221225
}
222226

223-
if err = validateSchema(output, task.Output.Schema, taskName); err != nil {
227+
if err = utils.ValidateSchema(output, task.Output.Schema, taskName); err != nil {
224228
return nil, err
225229
}
226230

@@ -232,12 +236,12 @@ func (d *DoTaskRunner) processTaskExport(task *model.TaskBase, taskOutput interf
232236
return nil
233237
}
234238

235-
output, err := traverseAndEvaluate(task.Export.As, taskOutput, taskName, taskSupport.GetContext())
239+
output, err := expr.TraverseAndEvaluateObj(task.Export.As, taskOutput, taskName, taskSupport.GetContext())
236240
if err != nil {
237241
return err
238242
}
239243

240-
if err = validateSchema(output, task.Export.Schema, taskName); err != nil {
244+
if err = utils.ValidateSchema(output, task.Export.Schema, taskName); err != nil {
241245
return nil
242246
}
243247

impl/task_runner_for.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func (f *ForTaskRunner) Run(input interface{}, taskSupport TaskSupport) (interfa
7373
return nil, err
7474
}
7575
if f.Task.While != "" {
76-
whileIsTrue, err := traverseAndEvaluateBool(f.Task.While, forOutput, taskSupport.GetContext())
76+
whileIsTrue, err := expr.TraverseAndEvaluateBool(f.Task.While, forOutput, taskSupport.GetContext())
7777
if err != nil {
7878
return nil, err
7979
}

impl/task_runner_fork.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package impl
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"github.com/serverlessworkflow/sdk-go/v3/model"
7+
"sync"
8+
)
9+
10+
func NewForkTaskRunner(taskName string, task *model.ForkTask, workflowDef *model.Workflow) (*ForkTaskRunner, error) {
11+
if task == nil || task.Fork.Branches == nil {
12+
return nil, model.NewErrValidation(fmt.Errorf("invalid Fork task %s", taskName), taskName)
13+
}
14+
15+
var runners []TaskRunner
16+
for _, branchItem := range *task.Fork.Branches {
17+
r, err := NewTaskRunner(branchItem.Key, branchItem.Task, workflowDef)
18+
if err != nil {
19+
return nil, err
20+
}
21+
runners = append(runners, r)
22+
}
23+
24+
return &ForkTaskRunner{
25+
Task: task,
26+
TaskName: taskName,
27+
BranchRunners: runners,
28+
}, nil
29+
}
30+
31+
type ForkTaskRunner struct {
32+
Task *model.ForkTask
33+
TaskName string
34+
BranchRunners []TaskRunner
35+
}
36+
37+
func (f ForkTaskRunner) GetTaskName() string {
38+
return f.TaskName
39+
}
40+
41+
func (f ForkTaskRunner) Run(input interface{}, parentSupport TaskSupport) (interface{}, error) {
42+
cancelCtx, cancel := context.WithCancel(parentSupport.GetContext())
43+
defer cancel()
44+
45+
n := len(f.BranchRunners)
46+
results := make([]interface{}, n)
47+
errs := make(chan error, n)
48+
done := make(chan struct{})
49+
resultCh := make(chan interface{}, 1)
50+
51+
var (
52+
wg sync.WaitGroup
53+
once sync.Once // <-- declare a Once
54+
)
55+
56+
for i, runner := range f.BranchRunners {
57+
wg.Add(1)
58+
go func(i int, runner TaskRunner) {
59+
defer wg.Done()
60+
// **Isolate context** for each branch!
61+
branchSupport := parentSupport.CloneWithContext(cancelCtx)
62+
63+
select {
64+
case <-cancelCtx.Done():
65+
return
66+
default:
67+
}
68+
69+
out, err := runner.Run(input, branchSupport)
70+
if err != nil {
71+
errs <- err
72+
return
73+
}
74+
results[i] = out
75+
76+
if f.Task.Fork.Compete {
77+
select {
78+
case resultCh <- out:
79+
once.Do(func() {
80+
cancel() // **signal cancellation** to all other branches
81+
close(done) // signal we have a winner
82+
})
83+
default:
84+
}
85+
}
86+
}(i, runner)
87+
}
88+
89+
if f.Task.Fork.Compete {
90+
select {
91+
case <-done:
92+
return <-resultCh, nil
93+
case err := <-errs:
94+
return nil, err
95+
}
96+
}
97+
98+
wg.Wait()
99+
select {
100+
case err := <-errs:
101+
return nil, err
102+
default:
103+
}
104+
return results, nil
105+
}

0 commit comments

Comments
 (0)