Skip to content

Commit c12e643

Browse files
committed
firewalldb: add ListGroupActions method
This method can be used to list actions in a session Group. Note that actions are still stored under session IDs.
1 parent 73110b6 commit c12e643

File tree

2 files changed

+179
-13
lines changed

2 files changed

+179
-13
lines changed

firewalldb/actions.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"context"
66
"encoding/binary"
7+
"errors"
78
"fmt"
89
"io"
910
"time"
@@ -386,6 +387,79 @@ func (db *DB) ListSessionActions(sessionID session.ID,
386387
return actions, lastIndex, totalCount, nil
387388
}
388389

390+
// ListGroupActions returns a list of the given session group's Actions that
391+
// pass the filterFn requirements.
392+
//
393+
// TODO: update to allow for pagination.
394+
func (db *DB) ListGroupActions(groupID session.ID,
395+
filterFn ListActionsFilterFn) ([]*Action, error) {
396+
397+
if filterFn == nil {
398+
filterFn = func(a *Action, reversed bool) (bool, bool) {
399+
return true, true
400+
}
401+
}
402+
403+
sessionIDs, err := db.sessionIDIndex.GetSessionIDs(groupID)
404+
if err != nil {
405+
return nil, err
406+
}
407+
408+
var (
409+
actions []*Action
410+
errDone = errors.New("done iterating")
411+
)
412+
err = db.View(func(tx *bbolt.Tx) error {
413+
mainActionsBucket, err := getBucket(tx, actionsBucketKey)
414+
if err != nil {
415+
return err
416+
}
417+
418+
actionsBucket := mainActionsBucket.Bucket(actionsKey)
419+
if actionsBucket == nil {
420+
return ErrNoSuchKeyFound
421+
}
422+
423+
// Iterate over each session ID in this group.
424+
for _, sessionID := range sessionIDs {
425+
sessionsBucket := actionsBucket.Bucket(sessionID[:])
426+
if sessionsBucket == nil {
427+
return nil
428+
}
429+
430+
err = sessionsBucket.ForEach(func(_, v []byte) error {
431+
action, err := DeserializeAction(
432+
bytes.NewReader(v), sessionID,
433+
)
434+
if err != nil {
435+
return err
436+
}
437+
438+
include, cont := filterFn(action, false)
439+
if include {
440+
actions = append(actions, action)
441+
}
442+
443+
if !cont {
444+
return errDone
445+
}
446+
447+
return nil
448+
})
449+
if err != nil {
450+
return err
451+
}
452+
}
453+
454+
return nil
455+
})
456+
if err != nil && !errors.Is(err, errDone) {
457+
return nil, err
458+
}
459+
460+
return actions, nil
461+
}
462+
389463
// SerializeAction binary serializes the given action to the writer using the
390464
// tlv format.
391465
func SerializeAction(w io.Writer, action *Action) error {

firewalldb/actions_test.go

Lines changed: 105 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,15 @@ import (
55
"testing"
66
"time"
77

8+
"github.com/lightninglabs/lightning-terminal/session"
89
"github.com/stretchr/testify/require"
910
)
1011

11-
// TestActionStorage tests that the ActionsDB CRUD logic.
12-
func TestActionStorage(t *testing.T) {
13-
tmpDir := t.TempDir()
12+
var (
13+
sessionID1 = intToSessionID(1)
14+
sessionID2 = intToSessionID(2)
1415

15-
db, err := NewDB(tmpDir, "test.db", nil)
16-
require.NoError(t, err)
17-
t.Cleanup(func() {
18-
_ = db.Close()
19-
})
20-
21-
sessionID1 := [4]byte{1, 1, 1, 1}
22-
action1 := &Action{
16+
action1 = &Action{
2317
SessionID: sessionID1,
2418
ActorName: "Autopilot",
2519
FeatureName: "auto-fees",
@@ -32,8 +26,7 @@ func TestActionStorage(t *testing.T) {
3226
State: ActionStateDone,
3327
}
3428

35-
sessionID2 := [4]byte{2, 2, 2, 2}
36-
action2 := &Action{
29+
action2 = &Action{
3730
SessionID: sessionID2,
3831
ActorName: "Autopilot",
3932
FeatureName: "rebalancer",
@@ -44,6 +37,17 @@ func TestActionStorage(t *testing.T) {
4437
AttemptedAt: time.Unix(12300, 0),
4538
State: ActionStateInit,
4639
}
40+
)
41+
42+
// TestActionStorage tests that the ActionsDB CRUD logic.
43+
func TestActionStorage(t *testing.T) {
44+
tmpDir := t.TempDir()
45+
46+
db, err := NewDB(tmpDir, "test.db", nil)
47+
require.NoError(t, err)
48+
t.Cleanup(func() {
49+
_ = db.Close()
50+
})
4751

4852
actionsStateFilterFn := func(state ActionState) ListActionsFilterFn {
4953
return func(a *Action, _ bool) (bool, bool) {
@@ -335,3 +339,91 @@ func TestListActions(t *testing.T) {
335339
{sessionID2, "6"},
336340
})
337341
}
342+
343+
// TestListGroupActions tests that the ListGroupActions correctly returns all
344+
// actions in a particular session group.
345+
func TestListGroupActions(t *testing.T) {
346+
group1 := intToSessionID(0)
347+
348+
// Link session 1 and session 2 to group 1.
349+
index := newMockSessionIDIndex()
350+
index.addPair(sessionID1, group1)
351+
index.addPair(sessionID2, group1)
352+
353+
db, err := NewDB(t.TempDir(), "test.db", index)
354+
require.NoError(t, err)
355+
t.Cleanup(func() {
356+
_ = db.Close()
357+
})
358+
359+
// There should not be any actions in group 1 yet.
360+
al, err := db.ListGroupActions(group1, nil)
361+
require.NoError(t, err)
362+
require.Empty(t, al)
363+
364+
// Add an action under session 1.
365+
_, err = db.AddAction(sessionID1, action1)
366+
require.NoError(t, err)
367+
368+
// There should now be one action in the group.
369+
al, err = db.ListGroupActions(group1, nil)
370+
require.NoError(t, err)
371+
require.Len(t, al, 1)
372+
require.Equal(t, sessionID1, al[0].SessionID)
373+
374+
// Add an action under session 2.
375+
_, err = db.AddAction(sessionID2, action2)
376+
require.NoError(t, err)
377+
378+
// There should now be actions in the group.
379+
al, err = db.ListGroupActions(group1, nil)
380+
require.NoError(t, err)
381+
require.Len(t, al, 2)
382+
require.Equal(t, sessionID1, al[0].SessionID)
383+
require.Equal(t, sessionID2, al[1].SessionID)
384+
}
385+
386+
type mockSessionIDIndex struct {
387+
sessionToGroupID map[session.ID]session.ID
388+
groupToSessionIDs map[session.ID][]session.ID
389+
}
390+
391+
var _ session.IDToGroupIndex = (*mockSessionIDIndex)(nil)
392+
393+
func newMockSessionIDIndex() *mockSessionIDIndex {
394+
return &mockSessionIDIndex{
395+
sessionToGroupID: make(map[session.ID]session.ID),
396+
groupToSessionIDs: make(map[session.ID][]session.ID),
397+
}
398+
}
399+
400+
func (m *mockSessionIDIndex) addPair(sessionID, groupID session.ID) {
401+
m.sessionToGroupID[sessionID] = groupID
402+
403+
m.groupToSessionIDs[groupID] = append(
404+
m.groupToSessionIDs[groupID], sessionID,
405+
)
406+
}
407+
408+
func (m *mockSessionIDIndex) GetGroupID(sessionID session.ID) (session.ID,
409+
error) {
410+
411+
id, ok := m.sessionToGroupID[sessionID]
412+
if !ok {
413+
return session.ID{}, fmt.Errorf("no group ID found for " +
414+
"session ID")
415+
}
416+
417+
return id, nil
418+
}
419+
420+
func (m *mockSessionIDIndex) GetSessionIDs(groupID session.ID) ([]session.ID,
421+
error) {
422+
423+
ids, ok := m.groupToSessionIDs[groupID]
424+
if !ok {
425+
return nil, fmt.Errorf("no session IDs found for group ID")
426+
}
427+
428+
return ids, nil
429+
}

0 commit comments

Comments
 (0)