Skip to content

Commit 78fafba

Browse files
authored
Merge pull request #839 from sputn1ck/fsm_ctx
FSM: add ctx to SendEvent and Actions
2 parents 1e8ae31 + f2fb722 commit 78fafba

File tree

13 files changed

+188
-154
lines changed

13 files changed

+188
-154
lines changed

fsm/example_fsm.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package fsm
22

33
import (
4+
"context"
45
"fmt"
56
)
67

@@ -90,7 +91,9 @@ type InitStuffRequest struct {
9091
}
9192

9293
// initFSM is the action for the InitFSM state.
93-
func (e *ExampleFSM) initFSM(eventCtx EventContext) EventType {
94+
func (e *ExampleFSM) initFSM(_ context.Context, eventCtx EventContext,
95+
) EventType {
96+
9497
req, ok := eventCtx.(*InitStuffRequest)
9598
if !ok {
9699
return e.HandleError(
@@ -109,15 +112,17 @@ func (e *ExampleFSM) initFSM(eventCtx EventContext) EventType {
109112
}
110113

111114
// waitForStuff is an action that waits for stuff to happen.
112-
func (e *ExampleFSM) waitForStuff(eventCtx EventContext) EventType {
115+
func (e *ExampleFSM) waitForStuff(ctx context.Context, eventCtx EventContext,
116+
) EventType {
117+
113118
waitChan, err := e.service.WaitForStuffHappening()
114119
if err != nil {
115120
return e.HandleError(err)
116121
}
117122

118123
go func() {
119124
<-waitChan
120-
err := e.SendEvent(OnStuffSuccess, nil)
125+
err := e.SendEvent(ctx, OnStuffSuccess, nil)
121126
if err != nil {
122127
log.Errorf("unable to send event: %v", err)
123128
}

fsm/example_fsm_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ func TestExampleFSM(t *testing.T) {
7979
tc := tc
8080

8181
t.Run(tc.name, func(t *testing.T) {
82+
ctxb := context.Background()
8283
respondChan := make(chan string, 1)
8384
if req, ok := tc.eventCtx.(*InitStuffRequest); ok {
8485
req.respondChan = respondChan
@@ -102,7 +103,7 @@ func TestExampleFSM(t *testing.T) {
102103
exampleContext.RegisterObserver(cachedObserver)
103104

104105
err := exampleContext.SendEvent(
105-
tc.sendEvent, tc.eventCtx,
106+
ctxb, tc.sendEvent, tc.eventCtx,
106107
)
107108
require.Equal(t, tc.sendEventErr, err)
108109

@@ -195,6 +196,7 @@ func TestExampleFSMFlow(t *testing.T) {
195196

196197
t.Run(tc.name, func(t *testing.T) {
197198
exampleContext, cachedObserver := getTestContext()
199+
ctxb := context.Background()
198200

199201
if tc.storeError != nil {
200202
exampleContext.store.(*mockStore).
@@ -208,8 +210,7 @@ func TestExampleFSMFlow(t *testing.T) {
208210

209211
go func() {
210212
err := exampleContext.SendEvent(
211-
OnRequestStuff,
212-
newInitStuffRequest(),
213+
ctxb, OnRequestStuff, newInitStuffRequest(),
213214
)
214215

215216
require.NoError(t, err)
@@ -273,6 +274,7 @@ func TestObserverAsyncWait(t *testing.T) {
273274
service := &mockService{
274275
respondChan: make(chan bool),
275276
}
277+
ctxb := context.Background()
276278

277279
store := &mockStore{}
278280

@@ -282,7 +284,7 @@ func TestObserverAsyncWait(t *testing.T) {
282284

283285
t0 := time.Now()
284286
timeoutCtx, cancel := context.WithTimeout(
285-
context.Background(), tc.waitTime,
287+
ctxb, tc.waitTime,
286288
)
287289
defer cancel()
288290

@@ -293,8 +295,7 @@ func TestObserverAsyncWait(t *testing.T) {
293295

294296
go func() {
295297
err := exampleContext.SendEvent(
296-
OnRequestStuff,
297-
newInitStuffRequest(),
298+
ctxb, OnRequestStuff, newInitStuffRequest(),
298299
)
299300

300301
require.NoError(t, err)

fsm/fsm.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package fsm
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"sync"
@@ -45,7 +46,7 @@ type EventType string
4546
type EventContext interface{}
4647

4748
// Action represents the action to be executed in a given state.
48-
type Action func(eventCtx EventContext) EventType
49+
type Action func(ctx context.Context, eventCtx EventContext) EventType
4950

5051
// Transitions represents a mapping of events and states.
5152
type Transitions map[EventType]StateType
@@ -95,11 +96,11 @@ type StateMachine struct {
9596

9697
// ActionEntryFunc is a function that is called before an action is
9798
// executed.
98-
ActionEntryFunc func(Notification)
99+
ActionEntryFunc func(context.Context, Notification)
99100

100101
// ActionExitFunc is a function that is called after an action is
101102
// executed, it is called with the EventType returned by the action.
102-
ActionExitFunc func(NextEvent EventType)
103+
ActionExitFunc func(ctx context.Context, NextEvent EventType)
103104

104105
// LastActionError is an error set by the last action executed.
105106
LastActionError error
@@ -200,7 +201,9 @@ func (s *StateMachine) getNextState(event EventType) (State, error) {
200201
// SendEvent sends an event to the state machine. It returns an error if the
201202
// event cannot be processed in the current state. Otherwise, it only returns
202203
// nil if the event for the last action is a no-op.
203-
func (s *StateMachine) SendEvent(event EventType, eventCtx EventContext) error {
204+
func (s *StateMachine) SendEvent(ctx context.Context, event EventType,
205+
eventCtx EventContext) error {
206+
204207
s.mutex.Lock()
205208
defer s.mutex.Unlock()
206209

@@ -235,7 +238,7 @@ func (s *StateMachine) SendEvent(event EventType, eventCtx EventContext) error {
235238

236239
// Execute the state machines ActionEntryFunc.
237240
if s.ActionEntryFunc != nil {
238-
s.ActionEntryFunc(notification)
241+
s.ActionEntryFunc(ctx, notification)
239242
}
240243

241244
// Execute the current state's entry function
@@ -245,7 +248,7 @@ func (s *StateMachine) SendEvent(event EventType, eventCtx EventContext) error {
245248

246249
// Execute the next state's action and loop over again if the
247250
// event returned is not a no-op.
248-
nextEvent := state.Action(eventCtx)
251+
nextEvent := state.Action(ctx, eventCtx)
249252

250253
// Execute the current state's exit function
251254
if state.ExitFunc != nil {
@@ -254,7 +257,7 @@ func (s *StateMachine) SendEvent(event EventType, eventCtx EventContext) error {
254257

255258
// Execute the state machines ActionExitFunc.
256259
if s.ActionExitFunc != nil {
257-
s.ActionExitFunc(nextEvent)
260+
s.ActionExitFunc(ctx, nextEvent)
258261
}
259262

260263
// If the next event is a no-op, we're done.
@@ -304,7 +307,7 @@ func (s *StateMachine) HandleError(err error) EventType {
304307

305308
// NoOpAction is a no-op action that can be used by states that don't need to
306309
// execute any action.
307-
func NoOpAction(_ EventContext) EventType {
310+
func NoOpAction(_ context.Context, _ EventContext) EventType {
308311
return NoOp
309312
}
310313

fsm/fsm_test.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package fsm
22

33
import (
4+
"context"
45
"errors"
56
"testing"
67

@@ -22,15 +23,15 @@ type TestStateMachineContext struct {
2223
func (c *TestStateMachineContext) GetStates() States {
2324
return States{
2425
"State1": State{
25-
Action: func(ctx EventContext) EventType {
26+
Action: func(_ context.Context, ctx EventContext) EventType {
2627
return "Event1"
2728
},
2829
Transitions: Transitions{
2930
"Event1": "State2",
3031
},
3132
},
3233
"State2": State{
33-
Action: func(ctx EventContext) EventType {
34+
Action: func(_ context.Context, ctx EventContext) EventType {
3435
return "NoOp"
3536
},
3637
Transitions: Transitions{},
@@ -39,7 +40,9 @@ func (c *TestStateMachineContext) GetStates() States {
3940
}
4041

4142
// errorAction returns an error.
42-
func (c *TestStateMachineContext) errorAction(eventCtx EventContext) EventType {
43+
func (c *TestStateMachineContext) errorAction(ctx context.Context,
44+
eventCtx EventContext) EventType {
45+
4346
return c.StateMachine.HandleError(errAction)
4447
}
4548

@@ -58,9 +61,9 @@ func setupTestStateMachineContext() *TestStateMachineContext {
5861
// TestStateMachine_Success tests the state machine with a successful event.
5962
func TestStateMachine_Success(t *testing.T) {
6063
ctx := setupTestStateMachineContext()
61-
64+
ctxb := context.Background()
6265
// Send an event to the state machine.
63-
err := ctx.SendEvent("Event1", nil)
66+
err := ctx.SendEvent(ctxb, "Event1", nil)
6467
require.NoError(t, err)
6568

6669
// Check that the state machine has transitioned to the next state.
@@ -72,8 +75,9 @@ func TestStateMachine_Success(t *testing.T) {
7275
func TestStateMachine_ConfigurationError(t *testing.T) {
7376
ctx := setupTestStateMachineContext()
7477
ctx.StateMachine.States = nil
78+
ctxb := context.Background()
7579

76-
err := ctx.SendEvent("Event1", nil)
80+
err := ctx.SendEvent(ctxb, "Event1", nil)
7781
require.EqualError(
7882
t, err,
7983
NewErrConfigError("state machine config is nil").Error(),
@@ -83,6 +87,7 @@ func TestStateMachine_ConfigurationError(t *testing.T) {
8387
// TestStateMachine_ActionError tests the state machine with an action error.
8488
func TestStateMachine_ActionError(t *testing.T) {
8589
ctx := setupTestStateMachineContext()
90+
ctxb := context.Background()
8691

8792
states := ctx.StateMachine.States
8893

@@ -99,13 +104,13 @@ func TestStateMachine_ActionError(t *testing.T) {
99104
}
100105

101106
states["ErrorState"] = State{
102-
Action: func(ctx EventContext) EventType {
107+
Action: func(_ context.Context, ctx EventContext) EventType {
103108
return "NoOp"
104109
},
105110
Transitions: Transitions{},
106111
}
107112

108-
err := ctx.SendEvent("Event1", nil)
113+
err := ctx.SendEvent(ctxb, "Event1", nil)
109114

110115
// Sending an event to the state machine should not return an error.
111116
require.NoError(t, err)

0 commit comments

Comments
 (0)