Skip to content

Commit 3b5007c

Browse files
committed
session: ensure past linked sessions are not active
This commit adds logic to the CreateSession method that checks that a all the past sessions in a linked set are no longer active.
1 parent c27f4eb commit 3b5007c

File tree

2 files changed

+69
-12
lines changed

2 files changed

+69
-12
lines changed

session/store.go

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,54 @@ func (db *DB) CreateSession(session *Session) error {
8585

8686
// If this is a linked session (meaning the group ID is
8787
// different from the ID) the make sure that the Group ID of
88-
// this session is an ID known by the store. We can do this by
89-
// checking that an entry for this ID exists in the id-to-key
90-
// index.
88+
// this session is an ID known by the store. We also need to
89+
// check that all older sessions in this group have been
90+
// revoked.
9191
if session.ID != session.GroupID {
9292
_, err = getKeyForID(sessionBucket, session.GroupID)
9393
if err != nil {
9494
return fmt.Errorf("unknown linked session "+
9595
"%x: %w", session.GroupID, err)
9696
}
97+
98+
// Fetch all the session IDs for this group. This will
99+
// through an error if this group does not exist.
100+
sessionIDs, err := getSessionIDs(
101+
sessionBucket, session.GroupID,
102+
)
103+
if err != nil {
104+
return err
105+
}
106+
107+
for _, id := range sessionIDs {
108+
keyBytes, err := getKeyForID(
109+
sessionBucket, id,
110+
)
111+
if err != nil {
112+
return err
113+
}
114+
115+
v := sessionBucket.Get(keyBytes)
116+
if len(v) == 0 {
117+
return ErrSessionNotFound
118+
}
119+
120+
sess, err := DeserializeSession(
121+
bytes.NewReader(v),
122+
)
123+
if err != nil {
124+
return err
125+
}
126+
127+
// Ensure that the session is no longer active.
128+
if sess.State == StateCreated ||
129+
sess.State == StateInUse {
130+
131+
return fmt.Errorf("session (id=%x) "+
132+
"in group %x is still active",
133+
sess.ID, sess.GroupID)
134+
}
135+
}
97136
}
98137

99138
// Add the mapping from session ID to session key to the ID
@@ -390,7 +429,12 @@ func (db *DB) GetSessionIDs(groupID ID) ([]ID, error) {
390429
err error
391430
)
392431
err = db.View(func(tx *bbolt.Tx) error {
393-
sessionIDs, err = getSessionIDs(tx, groupID)
432+
sessionBkt, err := getBucket(tx, sessionBucketKey)
433+
if err != nil {
434+
return err
435+
}
436+
437+
sessionIDs, err = getSessionIDs(sessionBkt, groupID)
394438

395439
return err
396440
})
@@ -419,7 +463,7 @@ func (db *DB) CheckSessionGroupPredicate(groupID ID,
419463
return err
420464
}
421465

422-
sessionIDs, err := getSessionIDs(tx, groupID)
466+
sessionIDs, err := getSessionIDs(sessionBkt, groupID)
423467
if err != nil {
424468
return err
425469
}
@@ -461,14 +505,9 @@ func (db *DB) CheckSessionGroupPredicate(groupID ID,
461505
}
462506

463507
// getSessionIDs returns all the session IDs associated with the given group ID.
464-
func getSessionIDs(tx *bbolt.Tx, groupID ID) ([]ID, error) {
508+
func getSessionIDs(sessionBkt *bbolt.Bucket, groupID ID) ([]ID, error) {
465509
var sessionIDs []ID
466510

467-
sessionBkt, err := getBucket(tx, sessionBucketKey)
468-
if err != nil {
469-
return nil, err
470-
}
471-
472511
groupIndexBkt := sessionBkt.Bucket(groupIDIndexKey)
473512
if groupIndexBkt == nil {
474513
return nil, ErrDBInitErr
@@ -486,7 +525,7 @@ func getSessionIDs(tx *bbolt.Tx, groupID ID) ([]ID, error) {
486525
groupID)
487526
}
488527

489-
err = sessionIDsBkt.ForEach(func(_,
528+
err := sessionIDsBkt.ForEach(func(_,
490529
sessionIDBytes []byte) error {
491530

492531
var sessionID ID

session/store_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,15 @@ func TestLinkingSessions(t *testing.T) {
108108
// Now persist the first session and retry persisting the second one
109109
// and assert that this now works.
110110
require.NoError(t, db.CreateSession(s1))
111+
112+
// Persisting the second session immediately should fail due to the
113+
// first session still being active.
114+
require.ErrorContains(t, db.CreateSession(s2), "is still active")
115+
116+
// Revoke the first session.
117+
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
118+
119+
// Persisting the second linked session should now work.
111120
require.NoError(t, db.CreateSession(s2))
112121
}
113122

@@ -132,7 +141,11 @@ func TestLinkedSessions(t *testing.T) {
132141

133142
// Persist the sessions.
134143
require.NoError(t, db.CreateSession(s1))
144+
145+
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
135146
require.NoError(t, db.CreateSession(s2))
147+
148+
require.NoError(t, db.RevokeSession(s2.LocalPublicKey))
136149
require.NoError(t, db.CreateSession(s3))
137150

138151
// Assert that the session ID to group ID index works as expected.
@@ -157,6 +170,8 @@ func TestLinkedSessions(t *testing.T) {
157170

158171
// Persist the sessions.
159172
require.NoError(t, db.CreateSession(s4))
173+
require.NoError(t, db.RevokeSession(s4.LocalPublicKey))
174+
160175
require.NoError(t, db.CreateSession(s5))
161176

162177
// Assert that the session ID to group ID index works as expected.
@@ -208,6 +223,9 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
208223
require.NoError(t, err)
209224
require.False(t, ok)
210225

226+
// Revoke the first session.
227+
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
228+
211229
// Add a new session to the same group as the first one.
212230
s2 := newSession(t, db, "label 2", &s1.GroupID)
213231
require.NoError(t, db.CreateSession(s2))

0 commit comments

Comments
 (0)