@@ -53,13 +53,15 @@ var (
53
53
)
54
54
55
55
// AddAction serialises and adds an Action to the DB under the given sessionID.
56
- func (db * BoltDB ) AddAction (action * Action ) (uint64 , error ) {
56
+ func (db * BoltDB ) AddAction (_ context.Context , action * Action ) (ActionLocator ,
57
+ error ) {
58
+
57
59
var buf bytes.Buffer
58
60
if err := SerializeAction (& buf , action ); err != nil {
59
- return 0 , err
61
+ return nil , err
60
62
}
61
63
62
- var id uint64
64
+ var locator kvdbActionLocator
63
65
err := db .DB .Update (func (tx * bbolt.Tx ) error {
64
66
mainActionsBucket , err := getBucket (tx , actionsBucketKey )
65
67
if err != nil {
@@ -82,7 +84,6 @@ func (db *BoltDB) AddAction(action *Action) (uint64, error) {
82
84
if err != nil {
83
85
return err
84
86
}
85
- id = nextActionIndex
86
87
87
88
var actionIndex [8 ]byte
88
89
byteOrder .PutUint64 (actionIndex [:], nextActionIndex )
@@ -101,9 +102,9 @@ func (db *BoltDB) AddAction(action *Action) (uint64, error) {
101
102
return err
102
103
}
103
104
104
- locator := ActionLocator {
105
- SessionID : action .SessionID ,
106
- ActionID : nextActionIndex ,
105
+ locator = kvdbActionLocator {
106
+ sessionID : action .SessionID ,
107
+ actionID : nextActionIndex ,
107
108
}
108
109
109
110
var buf bytes.Buffer
@@ -117,13 +118,25 @@ func (db *BoltDB) AddAction(action *Action) (uint64, error) {
117
118
return actionsIndexBucket .Put (seqNoBytes [:], buf .Bytes ())
118
119
})
119
120
if err != nil {
120
- return 0 , err
121
+ return nil , err
121
122
}
122
123
123
- return id , nil
124
+ return & locator , nil
125
+ }
126
+
127
+ // kvdbActionLocator helps us find an action in a KVDB database.
128
+ type kvdbActionLocator struct {
129
+ sessionID session.ID
130
+ actionID uint64
124
131
}
125
132
126
- func putAction (tx * bbolt.Tx , al * ActionLocator , a * Action ) error {
133
+ // A compile-time check to ensure kvdbActionLocator implements the ActionLocator
134
+ // interface.
135
+ var _ ActionLocator = (* kvdbActionLocator )(nil )
136
+
137
+ func (al * kvdbActionLocator ) isActionLocator () {}
138
+
139
+ func putAction (tx * bbolt.Tx , al * kvdbActionLocator , a * Action ) error {
127
140
var buf bytes.Buffer
128
141
if err := SerializeAction (& buf , a ); err != nil {
129
142
return err
@@ -139,42 +152,49 @@ func putAction(tx *bbolt.Tx, al *ActionLocator, a *Action) error {
139
152
return ErrNoSuchKeyFound
140
153
}
141
154
142
- sessBucket := actionsBucket .Bucket (al .SessionID [:])
155
+ sessBucket := actionsBucket .Bucket (al .sessionID [:])
143
156
if sessBucket == nil {
144
157
return fmt .Errorf ("session bucket for session ID %x does not " +
145
- "exist" , al .SessionID )
158
+ "exist" , al .sessionID )
146
159
}
147
160
148
161
var id [8 ]byte
149
- binary .BigEndian .PutUint64 (id [:], al .ActionID )
162
+ binary .BigEndian .PutUint64 (id [:], al .actionID )
150
163
151
164
return sessBucket .Put (id [:], buf .Bytes ())
152
165
}
153
166
154
- func getAction (actionsBkt * bbolt.Bucket , al * ActionLocator ) (* Action , error ) {
155
- sessBucket := actionsBkt .Bucket (al .SessionID [:])
167
+ func getAction (actionsBkt * bbolt.Bucket , al * kvdbActionLocator ) (* Action ,
168
+ error ) {
169
+
170
+ sessBucket := actionsBkt .Bucket (al .sessionID [:])
156
171
if sessBucket == nil {
157
172
return nil , fmt .Errorf ("session bucket for session ID " +
158
- "%x does not exist" , al .SessionID )
173
+ "%x does not exist" , al .sessionID )
159
174
}
160
175
161
176
var id [8 ]byte
162
- binary .BigEndian .PutUint64 (id [:], al .ActionID )
177
+ binary .BigEndian .PutUint64 (id [:], al .actionID )
163
178
164
179
actionBytes := sessBucket .Get (id [:])
165
- return DeserializeAction (bytes .NewReader (actionBytes ), al .SessionID )
180
+ return DeserializeAction (bytes .NewReader (actionBytes ), al .sessionID )
166
181
}
167
182
168
183
// SetActionState finds the action specified by the ActionLocator and sets its
169
184
// state to the given state.
170
- func (db * BoltDB ) SetActionState (al * ActionLocator , state ActionState ,
171
- errorReason string ) error {
185
+ func (db * BoltDB ) SetActionState (_ context. Context , al ActionLocator ,
186
+ state ActionState , errorReason string ) error {
172
187
173
188
if errorReason != "" && state != ActionStateError {
174
189
return fmt .Errorf ("error reason should only be set for " +
175
190
"ActionStateError" )
176
191
}
177
192
193
+ locator , ok := al .(* kvdbActionLocator )
194
+ if ! ok {
195
+ return fmt .Errorf ("expected kvdbActionLocator, got %T" , al )
196
+ }
197
+
178
198
return db .DB .Update (func (tx * bbolt.Tx ) error {
179
199
mainActionsBucket , err := getBucket (tx , actionsBucketKey )
180
200
if err != nil {
@@ -186,15 +206,15 @@ func (db *BoltDB) SetActionState(al *ActionLocator, state ActionState,
186
206
return ErrNoSuchKeyFound
187
207
}
188
208
189
- action , err := getAction (actionsBucket , al )
209
+ action , err := getAction (actionsBucket , locator )
190
210
if err != nil {
191
211
return err
192
212
}
193
213
194
214
action .State = state
195
215
action .ErrorReason = errorReason
196
216
197
- return putAction (tx , al , action )
217
+ return putAction (tx , locator , action )
198
218
})
199
219
}
200
220
@@ -540,14 +560,14 @@ func DeserializeAction(r io.Reader, sessionID session.ID) (*Action, error) {
540
560
541
561
// serializeActionLocator binary serializes the given ActionLocator to the
542
562
// writer using the tlv format.
543
- func serializeActionLocator (w io.Writer , al * ActionLocator ) error {
563
+ func serializeActionLocator (w io.Writer , al * kvdbActionLocator ) error {
544
564
if al == nil {
545
565
return fmt .Errorf ("action locator cannot be nil" )
546
566
}
547
567
548
568
var (
549
- sessionID = al .SessionID [:]
550
- actionID = al .ActionID
569
+ sessionID = al .sessionID [:]
570
+ actionID = al .actionID
551
571
)
552
572
553
573
tlvRecords := []tlv.Record {
@@ -565,7 +585,7 @@ func serializeActionLocator(w io.Writer, al *ActionLocator) error {
565
585
566
586
// deserializeActionLocator deserializes an ActionLocator from the given reader,
567
587
// expecting the data to be encoded in the tlv format.
568
- func deserializeActionLocator (r io.Reader ) (* ActionLocator , error ) {
588
+ func deserializeActionLocator (r io.Reader ) (* kvdbActionLocator , error ) {
569
589
var (
570
590
sessionID []byte
571
591
actionID uint64
@@ -588,8 +608,8 @@ func deserializeActionLocator(r io.Reader) (*ActionLocator, error) {
588
608
return nil , err
589
609
}
590
610
591
- return & ActionLocator {
592
- SessionID : id ,
593
- ActionID : actionID ,
611
+ return & kvdbActionLocator {
612
+ sessionID : id ,
613
+ actionID : actionID ,
594
614
}, nil
595
615
}
0 commit comments