@@ -85,15 +85,54 @@ func (db *DB) CreateSession(session *Session) error {
85
85
86
86
// If this is a linked session (meaning the group ID is
87
87
// different from the ID) the make sure that the Group ID of
88
- // this session is an ID known by the store. We can do this by
89
- // checking that an entry for this ID exists in the id-to-key
90
- // index .
88
+ // this session is an ID known by the store. We also need to
89
+ // check that all older sessions in this group have been
90
+ // revoked .
91
91
if session .ID != session .GroupID {
92
92
_ , err = getKeyForID (sessionBucket , session .GroupID )
93
93
if err != nil {
94
94
return fmt .Errorf ("unknown linked session " +
95
95
"%x: %w" , session .GroupID , err )
96
96
}
97
+
98
+ // Fetch all the session IDs for this group. This will
99
+ // through an error if this group does not exist.
100
+ sessionIDs , err := getSessionIDs (
101
+ sessionBucket , session .GroupID ,
102
+ )
103
+ if err != nil {
104
+ return err
105
+ }
106
+
107
+ for _ , id := range sessionIDs {
108
+ keyBytes , err := getKeyForID (
109
+ sessionBucket , id ,
110
+ )
111
+ if err != nil {
112
+ return err
113
+ }
114
+
115
+ v := sessionBucket .Get (keyBytes )
116
+ if len (v ) == 0 {
117
+ return ErrSessionNotFound
118
+ }
119
+
120
+ sess , err := DeserializeSession (
121
+ bytes .NewReader (v ),
122
+ )
123
+ if err != nil {
124
+ return err
125
+ }
126
+
127
+ // Ensure that the session is no longer active.
128
+ if sess .State == StateCreated ||
129
+ sess .State == StateInUse {
130
+
131
+ return fmt .Errorf ("session (id=%x) " +
132
+ "in group %x is still active" ,
133
+ sess .ID , sess .GroupID )
134
+ }
135
+ }
97
136
}
98
137
99
138
// Add the mapping from session ID to session key to the ID
@@ -390,7 +429,12 @@ func (db *DB) GetSessionIDs(groupID ID) ([]ID, error) {
390
429
err error
391
430
)
392
431
err = db .View (func (tx * bbolt.Tx ) error {
393
- sessionIDs , err = getSessionIDs (tx , groupID )
432
+ sessionBkt , err := getBucket (tx , sessionBucketKey )
433
+ if err != nil {
434
+ return err
435
+ }
436
+
437
+ sessionIDs , err = getSessionIDs (sessionBkt , groupID )
394
438
395
439
return err
396
440
})
@@ -419,7 +463,7 @@ func (db *DB) CheckSessionGroupPredicate(groupID ID,
419
463
return err
420
464
}
421
465
422
- sessionIDs , err := getSessionIDs (tx , groupID )
466
+ sessionIDs , err := getSessionIDs (sessionBkt , groupID )
423
467
if err != nil {
424
468
return err
425
469
}
@@ -461,14 +505,9 @@ func (db *DB) CheckSessionGroupPredicate(groupID ID,
461
505
}
462
506
463
507
// getSessionIDs returns all the session IDs associated with the given group ID.
464
- func getSessionIDs (tx * bbolt.Tx , groupID ID ) ([]ID , error ) {
508
+ func getSessionIDs (sessionBkt * bbolt.Bucket , groupID ID ) ([]ID , error ) {
465
509
var sessionIDs []ID
466
510
467
- sessionBkt , err := getBucket (tx , sessionBucketKey )
468
- if err != nil {
469
- return nil , err
470
- }
471
-
472
511
groupIndexBkt := sessionBkt .Bucket (groupIDIndexKey )
473
512
if groupIndexBkt == nil {
474
513
return nil , ErrDBInitErr
@@ -486,7 +525,7 @@ func getSessionIDs(tx *bbolt.Tx, groupID ID) ([]ID, error) {
486
525
groupID )
487
526
}
488
527
489
- err = sessionIDsBkt .ForEach (func (_ ,
528
+ err : = sessionIDsBkt .ForEach (func (_ ,
490
529
sessionIDBytes []byte ) error {
491
530
492
531
var sessionID ID
0 commit comments