Skip to content

Commit 6c36b01

Browse files
committed
session: add ListSessionsByType method
And use it to replace one call to ListSessions which uses a filter function which would be inefficient in SQL land.
1 parent 0023002 commit 6c36b01

File tree

4 files changed

+37
-4
lines changed

4 files changed

+37
-4
lines changed

session/interface.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ type Store interface {
164164
// ListSessions returns all sessions currently known to the store.
165165
ListSessions(filterFn func(s *Session) bool) ([]*Session, error)
166166

167+
// ListSessionsByType returns all sessions of the given type.
168+
ListSessionsByType(t Type) ([]*Session, error)
169+
167170
// RevokeSession updates the state of the session with the given local
168171
// public key to be revoked.
169172
RevokeSession(*btcec.PublicKey) error

session/kvdb_store.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,16 @@ func (db *BoltStore) ListSessions(filterFn func(s *Session) bool) ([]*Session, e
371371
return db.listSessions(filterFn)
372372
}
373373

374+
// ListSessionsByType returns all sessions currently known to the store that
375+
// have the given type.
376+
//
377+
// NOTE: this is part of the Store interface.
378+
func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) {
379+
return db.listSessions(func(s *Session) bool {
380+
return s.Type == t
381+
})
382+
}
383+
374384
// listSessions returns all sessions currently known to the store that pass the
375385
// given filter function.
376386
func (db *BoltStore) listSessions(filterFn func(s *Session) bool) ([]*Session,

session/store_test.go

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func TestBasicSessionStore(t *testing.T) {
3131
clock.SetTime(testTime.Add(time.Second))
3232
s2 := newSession(t, db, clock, "session 2")
3333
clock.SetTime(testTime.Add(2 * time.Second))
34-
s3 := newSession(t, db, clock, "session 3")
34+
s3 := newSession(t, db, clock, "session 3", withType(TypeAutopilot))
3535
clock.SetTime(testTime.Add(3 * time.Second))
3636
s4 := newSession(t, db, clock, "session 4")
3737

@@ -64,6 +64,22 @@ func TestBasicSessionStore(t *testing.T) {
6464
assertEqualSessions(t, s2, sessions[1])
6565
assertEqualSessions(t, s3, sessions[2])
6666

67+
// Test the ListSessionsByType method.
68+
sessions, err = db.ListSessionsByType(TypeMacaroonAdmin)
69+
require.NoError(t, err)
70+
require.Equal(t, 2, len(sessions))
71+
assertEqualSessions(t, s1, sessions[0])
72+
assertEqualSessions(t, s2, sessions[1])
73+
74+
sessions, err = db.ListSessionsByType(TypeAutopilot)
75+
require.NoError(t, err)
76+
require.Equal(t, 1, len(sessions))
77+
assertEqualSessions(t, s3, sessions[0])
78+
79+
sessions, err = db.ListSessionsByType(TypeMacaroonReadonly)
80+
require.NoError(t, err)
81+
require.Empty(t, sessions)
82+
6783
// Ensure that we can retrieve each session by both its local pub key
6884
// and by its ID.
6985
for _, s := range []*Session{s1, s2, s3} {
@@ -310,6 +326,12 @@ func withLinkedGroupID(groupID *ID) testSessionModifier {
310326
}
311327
}
312328

329+
func withType(t Type) testSessionModifier {
330+
return func(s *Session) {
331+
s.Type = t
332+
}
333+
}
334+
313335
func newSession(t *testing.T, db Store, clock clock.Clock, label string,
314336
mods ...testSessionModifier) *Session {
315337

session_rpcserver.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,9 +1259,7 @@ func (s *sessionRpcServer) ListAutopilotSessions(_ context.Context,
12591259
_ *litrpc.ListAutopilotSessionsRequest) (
12601260
*litrpc.ListAutopilotSessionsResponse, error) {
12611261

1262-
sessions, err := s.cfg.db.ListSessions(func(s *session.Session) bool {
1263-
return s.Type == session.TypeAutopilot
1264-
})
1262+
sessions, err := s.cfg.db.ListSessionsByType(session.TypeAutopilot)
12651263
if err != nil {
12661264
return nil, fmt.Errorf("error fetching sessions: %v", err)
12671265
}

0 commit comments

Comments
 (0)