Skip to content

Commit 0023002

Browse files
committed
session: ensure listed sessions are sorted
Sorted by creation time. Also add a test to cover this.
1 parent c71df78 commit 0023002

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

session/kvdb_store.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"os"
99
"path/filepath"
10+
"sort"
1011
"time"
1112

1213
"github.com/btcsuite/btcd/btcec/v2"
@@ -367,6 +368,14 @@ func (db *BoltStore) GetSession(key *btcec.PublicKey) (*Session, error) {
367368
//
368369
// NOTE: this is part of the Store interface.
369370
func (db *BoltStore) ListSessions(filterFn func(s *Session) bool) ([]*Session, error) {
371+
return db.listSessions(filterFn)
372+
}
373+
374+
// listSessions returns all sessions currently known to the store that pass the
375+
// given filter function.
376+
func (db *BoltStore) listSessions(filterFn func(s *Session) bool) ([]*Session,
377+
error) {
378+
370379
var sessions []*Session
371380
err := db.View(func(tx *bbolt.Tx) error {
372381
sessionBucket, err := getBucket(tx, sessionBucketKey)
@@ -399,6 +408,11 @@ func (db *BoltStore) ListSessions(filterFn func(s *Session) bool) ([]*Session, e
399408
return nil, err
400409
}
401410

411+
// Make sure to sort the sessions by creation time.
412+
sort.Slice(sessions, func(i, j int) bool {
413+
return sessions[i].CreatedAt.Before(sessions[j].CreatedAt)
414+
})
415+
402416
return sessions, nil
403417
}
404418

session/store_test.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,16 @@ func TestBasicSessionStore(t *testing.T) {
2323
_ = db.Close()
2424
})
2525

26-
// Create a few sessions.
26+
// Create a few sessions. We increment the time by one second between
27+
// each session to ensure that the created at time is unique and hence
28+
// that the ListSessions method returns the sessions in a deterministic
29+
// order.
2730
s1 := newSession(t, db, clock, "session 1")
31+
clock.SetTime(testTime.Add(time.Second))
2832
s2 := newSession(t, db, clock, "session 2")
33+
clock.SetTime(testTime.Add(2 * time.Second))
2934
s3 := newSession(t, db, clock, "session 3")
35+
clock.SetTime(testTime.Add(3 * time.Second))
3036
s4 := newSession(t, db, clock, "session 4")
3137

3238
// Persist session 1. This should now succeed.
@@ -50,6 +56,14 @@ func TestBasicSessionStore(t *testing.T) {
5056
require.NoError(t, db.CreateSession(s2))
5157
require.NoError(t, db.CreateSession(s3))
5258

59+
// Check that all sessions are returned in ListSessions.
60+
sessions, err := db.ListSessions(nil)
61+
require.NoError(t, err)
62+
require.Equal(t, 3, len(sessions))
63+
assertEqualSessions(t, s1, sessions[0])
64+
assertEqualSessions(t, s2, sessions[1])
65+
assertEqualSessions(t, s3, sessions[2])
66+
5367
// Ensure that we can retrieve each session by both its local pub key
5468
// and by its ID.
5569
for _, s := range []*Session{s1, s2, s3} {

0 commit comments

Comments
 (0)