Skip to content

Commit 57789b0

Browse files
authored
Merge pull request #1069 from ellemouton/sql36
[sql-36] actions: prepare ActionsDB interface
2 parents e31a093 + 53fc1d2 commit 57789b0

File tree

8 files changed

+128
-86
lines changed

8 files changed

+128
-86
lines changed

firewall/request_logger.go

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ type RequestLogger struct {
5353
// be used to find the corresponding action. This is used so that
5454
// requests and responses can be easily linked. The mu mutex must be
5555
// used when accessing this map.
56-
reqIDToAction map[uint64]*firewalldb.ActionLocator
56+
reqIDToAction map[uint64]firewalldb.ActionLocator
5757
mu sync.Mutex
5858
}
5959

@@ -105,7 +105,7 @@ func NewRequestLogger(cfg *RequestLoggerConfig,
105105
return &RequestLogger{
106106
shouldLogAction: shouldLogAction,
107107
actionsDB: actionsDB,
108-
reqIDToAction: make(map[uint64]*firewalldb.ActionLocator),
108+
reqIDToAction: make(map[uint64]firewalldb.ActionLocator),
109109
}, nil
110110
}
111111

@@ -128,7 +128,7 @@ func (r *RequestLogger) CustomCaveatName() string {
128128

129129
// Intercept processes an RPC middleware interception request and returns the
130130
// interception result which either accepts or rejects the intercepted message.
131-
func (r *RequestLogger) Intercept(_ context.Context,
131+
func (r *RequestLogger) Intercept(ctx context.Context,
132132
req *lnrpc.RPCMiddlewareRequest) (*lnrpc.RPCMiddlewareResponse, error) {
133133

134134
ri, err := NewInfoFromRequest(req)
@@ -156,7 +156,7 @@ func (r *RequestLogger) Intercept(_ context.Context,
156156

157157
// Parse incoming requests and act on them.
158158
case MWRequestTypeRequest:
159-
return mid.RPCErr(req, r.addNewAction(ri, withPayloadData))
159+
return mid.RPCErr(req, r.addNewAction(ctx, ri, withPayloadData))
160160

161161
// Parse and possibly manipulate outgoing responses.
162162
case MWRequestTypeResponse:
@@ -170,7 +170,7 @@ func (r *RequestLogger) Intercept(_ context.Context,
170170
}
171171

172172
return mid.RPCErr(
173-
req, r.MarkAction(ri.RequestID, state, errReason),
173+
req, r.MarkAction(ctx, ri.RequestID, state, errReason),
174174
)
175175

176176
default:
@@ -179,7 +179,7 @@ func (r *RequestLogger) Intercept(_ context.Context,
179179
}
180180

181181
// addNewAction persists the new action to the db.
182-
func (r *RequestLogger) addNewAction(ri *RequestInfo,
182+
func (r *RequestLogger) addNewAction(ctx context.Context, ri *RequestInfo,
183183
withPayloadData bool) error {
184184

185185
// If no macaroon is provided, then an empty 4-byte array is used as the
@@ -223,24 +223,21 @@ func (r *RequestLogger) addNewAction(ri *RequestInfo,
223223
}
224224
}
225225

226-
id, err := r.actionsDB.AddAction(action)
226+
locator, err := r.actionsDB.AddAction(ctx, action)
227227
if err != nil {
228228
return err
229229
}
230230

231231
r.mu.Lock()
232-
r.reqIDToAction[ri.RequestID] = &firewalldb.ActionLocator{
233-
SessionID: sessionID,
234-
ActionID: id,
235-
}
232+
r.reqIDToAction[ri.RequestID] = locator
236233
r.mu.Unlock()
237234

238235
return nil
239236
}
240237

241238
// MarkAction can be used to set the state of an action identified by the given
242239
// requestID.
243-
func (r *RequestLogger) MarkAction(reqID uint64,
240+
func (r *RequestLogger) MarkAction(ctx context.Context, reqID uint64,
244241
state firewalldb.ActionState, errReason string) error {
245242

246243
r.mu.Lock()
@@ -252,5 +249,5 @@ func (r *RequestLogger) MarkAction(reqID uint64,
252249
}
253250
delete(r.reqIDToAction, reqID)
254251

255-
return r.actionsDB.SetActionState(actionLocator, state, errReason)
252+
return r.actionsDB.SetActionState(ctx, actionLocator, state, errReason)
256253
}

firewall/rule_enforcer.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ type RuleEnforcer struct {
3232
ruleDB firewalldb.RulesDB
3333
actionsDB firewalldb.ActionReadDBGetter
3434
sessionDB firewalldb.SessionDB
35-
markActionErrored func(reqID uint64, reason string) error
36-
privMapDB firewalldb.PrivacyMapper
35+
markActionErrored func(ctx context.Context, reqID uint64,
36+
reason string) error
37+
privMapDB firewalldb.PrivacyMapper
3738

3839
permsMgr *perms.Manager
3940
getFeaturePerms featurePerms
@@ -63,7 +64,8 @@ func NewRuleEnforcer(ruleDB firewalldb.RulesDB,
6364
routerClient lndclient.RouterClient,
6465
lndClient lndclient.LightningClient, lndConnID string,
6566
ruleMgrs rules.ManagerSet,
66-
markActionErrored func(reqID uint64, reason string) error,
67+
markActionErrored func(ctx context.Context, reqID uint64,
68+
reason string) error,
6769
privMap firewalldb.PrivacyMapper) *RuleEnforcer {
6870

6971
return &RuleEnforcer{
@@ -164,7 +166,9 @@ func (r *RuleEnforcer) Intercept(ctx context.Context,
164166

165167
replacement, err := r.handleRequest(ctx, ri)
166168
if err != nil {
167-
dbErr := r.markActionErrored(ri.RequestID, err.Error())
169+
dbErr := r.markActionErrored(
170+
ctx, ri.RequestID, err.Error(),
171+
)
168172
if dbErr != nil {
169173
log.Error("could not mark action for "+
170174
"request ID %d as Errored: %v",

firewalldb/actions.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,9 @@ func WithActionState(state ActionState) ListActionOption {
181181
// ActionsWriteDB is an abstraction over the Actions DB that will allow a
182182
// caller to add new actions as well as change the values of an existing action.
183183
type ActionsWriteDB interface {
184-
AddAction(action *Action) (uint64, error)
185-
SetActionState(al *ActionLocator, state ActionState,
186-
errReason string) error
184+
AddAction(ctx context.Context, action *Action) (ActionLocator, error)
185+
SetActionState(ctx context.Context, al ActionLocator,
186+
state ActionState, errReason string) error
187187
}
188188

189189
// RuleAction represents a method call that was performed at a certain time at
@@ -230,7 +230,7 @@ func (db *BoltDB) GetActionsReadDB(groupID session.ID,
230230

231231
// allActionsReadDb is an implementation of the ActionsReadDB.
232232
type allActionsReadDB struct {
233-
db *BoltDB
233+
db ActionDB
234234
groupID session.ID
235235
featureName string
236236
}
@@ -318,7 +318,6 @@ func actionToRulesAction(a *Action) *RuleAction {
318318
}
319319

320320
// ActionLocator helps us find an action in the database.
321-
type ActionLocator struct {
322-
SessionID session.ID
323-
ActionID uint64
321+
type ActionLocator interface {
322+
isActionLocator()
324323
}

firewalldb/actions_kvdb.go

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,15 @@ var (
5353
)
5454

5555
// AddAction serialises and adds an Action to the DB under the given sessionID.
56-
func (db *BoltDB) AddAction(action *Action) (uint64, error) {
56+
func (db *BoltDB) AddAction(_ context.Context, action *Action) (ActionLocator,
57+
error) {
58+
5759
var buf bytes.Buffer
5860
if err := SerializeAction(&buf, action); err != nil {
59-
return 0, err
61+
return nil, err
6062
}
6163

62-
var id uint64
64+
var locator kvdbActionLocator
6365
err := db.DB.Update(func(tx *bbolt.Tx) error {
6466
mainActionsBucket, err := getBucket(tx, actionsBucketKey)
6567
if err != nil {
@@ -82,7 +84,6 @@ func (db *BoltDB) AddAction(action *Action) (uint64, error) {
8284
if err != nil {
8385
return err
8486
}
85-
id = nextActionIndex
8687

8788
var actionIndex [8]byte
8889
byteOrder.PutUint64(actionIndex[:], nextActionIndex)
@@ -101,9 +102,9 @@ func (db *BoltDB) AddAction(action *Action) (uint64, error) {
101102
return err
102103
}
103104

104-
locator := ActionLocator{
105-
SessionID: action.SessionID,
106-
ActionID: nextActionIndex,
105+
locator = kvdbActionLocator{
106+
sessionID: action.SessionID,
107+
actionID: nextActionIndex,
107108
}
108109

109110
var buf bytes.Buffer
@@ -117,13 +118,25 @@ func (db *BoltDB) AddAction(action *Action) (uint64, error) {
117118
return actionsIndexBucket.Put(seqNoBytes[:], buf.Bytes())
118119
})
119120
if err != nil {
120-
return 0, err
121+
return nil, err
121122
}
122123

123-
return id, nil
124+
return &locator, nil
125+
}
126+
127+
// kvdbActionLocator helps us find an action in a KVDB database.
128+
type kvdbActionLocator struct {
129+
sessionID session.ID
130+
actionID uint64
124131
}
125132

126-
func putAction(tx *bbolt.Tx, al *ActionLocator, a *Action) error {
133+
// A compile-time check to ensure kvdbActionLocator implements the ActionLocator
134+
// interface.
135+
var _ ActionLocator = (*kvdbActionLocator)(nil)
136+
137+
func (al *kvdbActionLocator) isActionLocator() {}
138+
139+
func putAction(tx *bbolt.Tx, al *kvdbActionLocator, a *Action) error {
127140
var buf bytes.Buffer
128141
if err := SerializeAction(&buf, a); err != nil {
129142
return err
@@ -139,42 +152,49 @@ func putAction(tx *bbolt.Tx, al *ActionLocator, a *Action) error {
139152
return ErrNoSuchKeyFound
140153
}
141154

142-
sessBucket := actionsBucket.Bucket(al.SessionID[:])
155+
sessBucket := actionsBucket.Bucket(al.sessionID[:])
143156
if sessBucket == nil {
144157
return fmt.Errorf("session bucket for session ID %x does not "+
145-
"exist", al.SessionID)
158+
"exist", al.sessionID)
146159
}
147160

148161
var id [8]byte
149-
binary.BigEndian.PutUint64(id[:], al.ActionID)
162+
binary.BigEndian.PutUint64(id[:], al.actionID)
150163

151164
return sessBucket.Put(id[:], buf.Bytes())
152165
}
153166

154-
func getAction(actionsBkt *bbolt.Bucket, al *ActionLocator) (*Action, error) {
155-
sessBucket := actionsBkt.Bucket(al.SessionID[:])
167+
func getAction(actionsBkt *bbolt.Bucket, al *kvdbActionLocator) (*Action,
168+
error) {
169+
170+
sessBucket := actionsBkt.Bucket(al.sessionID[:])
156171
if sessBucket == nil {
157172
return nil, fmt.Errorf("session bucket for session ID "+
158-
"%x does not exist", al.SessionID)
173+
"%x does not exist", al.sessionID)
159174
}
160175

161176
var id [8]byte
162-
binary.BigEndian.PutUint64(id[:], al.ActionID)
177+
binary.BigEndian.PutUint64(id[:], al.actionID)
163178

164179
actionBytes := sessBucket.Get(id[:])
165-
return DeserializeAction(bytes.NewReader(actionBytes), al.SessionID)
180+
return DeserializeAction(bytes.NewReader(actionBytes), al.sessionID)
166181
}
167182

168183
// SetActionState finds the action specified by the ActionLocator and sets its
169184
// state to the given state.
170-
func (db *BoltDB) SetActionState(al *ActionLocator, state ActionState,
171-
errorReason string) error {
185+
func (db *BoltDB) SetActionState(_ context.Context, al ActionLocator,
186+
state ActionState, errorReason string) error {
172187

173188
if errorReason != "" && state != ActionStateError {
174189
return fmt.Errorf("error reason should only be set for " +
175190
"ActionStateError")
176191
}
177192

193+
locator, ok := al.(*kvdbActionLocator)
194+
if !ok {
195+
return fmt.Errorf("expected kvdbActionLocator, got %T", al)
196+
}
197+
178198
return db.DB.Update(func(tx *bbolt.Tx) error {
179199
mainActionsBucket, err := getBucket(tx, actionsBucketKey)
180200
if err != nil {
@@ -186,15 +206,15 @@ func (db *BoltDB) SetActionState(al *ActionLocator, state ActionState,
186206
return ErrNoSuchKeyFound
187207
}
188208

189-
action, err := getAction(actionsBucket, al)
209+
action, err := getAction(actionsBucket, locator)
190210
if err != nil {
191211
return err
192212
}
193213

194214
action.State = state
195215
action.ErrorReason = errorReason
196216

197-
return putAction(tx, al, action)
217+
return putAction(tx, locator, action)
198218
})
199219
}
200220

@@ -540,14 +560,14 @@ func DeserializeAction(r io.Reader, sessionID session.ID) (*Action, error) {
540560

541561
// serializeActionLocator binary serializes the given ActionLocator to the
542562
// writer using the tlv format.
543-
func serializeActionLocator(w io.Writer, al *ActionLocator) error {
563+
func serializeActionLocator(w io.Writer, al *kvdbActionLocator) error {
544564
if al == nil {
545565
return fmt.Errorf("action locator cannot be nil")
546566
}
547567

548568
var (
549-
sessionID = al.SessionID[:]
550-
actionID = al.ActionID
569+
sessionID = al.sessionID[:]
570+
actionID = al.actionID
551571
)
552572

553573
tlvRecords := []tlv.Record{
@@ -565,7 +585,7 @@ func serializeActionLocator(w io.Writer, al *ActionLocator) error {
565585

566586
// deserializeActionLocator deserializes an ActionLocator from the given reader,
567587
// expecting the data to be encoded in the tlv format.
568-
func deserializeActionLocator(r io.Reader) (*ActionLocator, error) {
588+
func deserializeActionLocator(r io.Reader) (*kvdbActionLocator, error) {
569589
var (
570590
sessionID []byte
571591
actionID uint64
@@ -588,8 +608,8 @@ func deserializeActionLocator(r io.Reader) (*ActionLocator, error) {
588608
return nil, err
589609
}
590610

591-
return &ActionLocator{
592-
SessionID: id,
593-
ActionID: actionID,
611+
return &kvdbActionLocator{
612+
sessionID: id,
613+
actionID: actionID,
594614
}, nil
595615
}

0 commit comments

Comments
 (0)