Skip to content

Commit f9ffc0c

Browse files
committed
session: add GetSession method
1 parent 3d669d6 commit f9ffc0c

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

session/store.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,36 @@ func (db *DB) StoreSession(session *Session) error {
4545
})
4646
}
4747

48+
// GetSession fetches the session with the given key.
49+
func (db *DB) GetSession(key *btcec.PublicKey) (*Session, error) {
50+
var session *Session
51+
err := db.View(func(tx *bbolt.Tx) error {
52+
sessionBucket, err := getBucket(tx, sessionBucketKey)
53+
if err != nil {
54+
return err
55+
}
56+
57+
v := sessionBucket.Get(key.SerializeCompressed())
58+
if len(v) == 0 {
59+
return ErrSessionNotFound
60+
}
61+
62+
session, err = DeserializeSession(bytes.NewReader(v))
63+
if err != nil {
64+
return err
65+
}
66+
67+
return nil
68+
})
69+
if err != nil {
70+
return nil, err
71+
}
72+
73+
return session, nil
74+
}
75+
4876
// ListSessions returns all sessions currently known to the store.
49-
func (db *DB) ListSessions() ([]*Session, error) {
77+
func (db *DB) ListSessions(filterFn func(s *Session) bool) ([]*Session, error) {
5078
var sessions []*Session
5179
err := db.View(func(tx *bbolt.Tx) error {
5280
sessionBucket, err := getBucket(tx, sessionBucketKey)
@@ -65,12 +93,16 @@ func (db *DB) ListSessions() ([]*Session, error) {
6593
if err != nil {
6694
return err
6795
}
96+
97+
if filterFn != nil && !filterFn(session) {
98+
return nil
99+
}
100+
68101
sessions = append(sessions, session)
69102

70103
return nil
71104
})
72105
})
73-
74106
if err != nil {
75107
return nil, err
76108
}

session_rpcserver.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer,
9292
// requests. This includes resuming all non-revoked sessions.
9393
func (s *sessionRpcServer) start() error {
9494
// Start up all previously created sessions.
95-
sessions, err := s.db.ListSessions()
95+
sessions, err := s.db.ListSessions(nil)
9696
if err != nil {
9797
return fmt.Errorf("error listing sessions: %v", err)
9898
}
@@ -456,7 +456,7 @@ func (s *sessionRpcServer) resumeSession(sess *session.Session) error {
456456
func (s *sessionRpcServer) ListSessions(_ context.Context,
457457
_ *litrpc.ListSessionsRequest) (*litrpc.ListSessionsResponse, error) {
458458

459-
sessions, err := s.db.ListSessions()
459+
sessions, err := s.db.ListSessions(nil)
460460
if err != nil {
461461
return nil, fmt.Errorf("error fetching sessions: %v", err)
462462
}

0 commit comments

Comments
 (0)