Skip to content

Commit 71cac86

Browse files
committed
session: add a CheckSessionGroupPredicate method
This method can be used to check that each session in a group passes for a given predicate.
1 parent a14d7ae commit 71cac86

File tree

3 files changed

+191
-20
lines changed

3 files changed

+191
-20
lines changed

session/interface.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,5 +171,11 @@ type Store interface {
171171
// GetSessionByID fetches the session with the given ID.
172172
GetSessionByID(id ID) (*Session, error)
173173

174+
// CheckSessionGroupPredicate iterates over all the sessions in a group
175+
// and checks if each one passes the given predicate function. True is
176+
// returned if each session passes.
177+
CheckSessionGroupPredicate(groupID ID,
178+
fn func(s *Session) bool) (bool, error)
179+
174180
IDToGroupIndex
175181
}

session/store.go

Lines changed: 96 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -385,39 +385,115 @@ func (db *DB) GetGroupID(sessionID ID) (ID, error) {
385385
//
386386
// NOTE: this is part of the IDToGroupIndex interface.
387387
func (db *DB) GetSessionIDs(groupID ID) ([]ID, error) {
388-
var sessionIDs []ID
388+
var (
389+
sessionIDs []ID
390+
err error
391+
)
392+
err = db.View(func(tx *bbolt.Tx) error {
393+
sessionIDs, err = getSessionIDs(tx, groupID)
394+
395+
return err
396+
})
397+
if err != nil {
398+
return nil, err
399+
}
400+
401+
return sessionIDs, nil
402+
}
403+
404+
// CheckSessionGroupPredicate iterates over all the sessions in a group and
405+
// checks if each one passes the given predicate function. True is returned if
406+
// each session passes.
407+
//
408+
// NOTE: this is part of the Store interface.
409+
func (db *DB) CheckSessionGroupPredicate(groupID ID,
410+
fn func(s *Session) bool) (bool, error) {
411+
412+
var (
413+
pass bool
414+
errFailedPred = errors.New("session failed predicate")
415+
)
389416
err := db.View(func(tx *bbolt.Tx) error {
390417
sessionBkt, err := getBucket(tx, sessionBucketKey)
391418
if err != nil {
392419
return err
393420
}
394421

395-
groupIndexBkt := sessionBkt.Bucket(groupIDIndexKey)
396-
if groupIndexBkt == nil {
397-
return ErrDBInitErr
422+
sessionIDs, err := getSessionIDs(tx, groupID)
423+
if err != nil {
424+
return err
398425
}
399426

400-
groupIDBkt := groupIndexBkt.Bucket(groupID[:])
401-
if groupIDBkt == nil {
402-
return fmt.Errorf("no sessions for group ID %v",
403-
groupID)
404-
}
427+
// Iterate over all the sessions.
428+
for _, id := range sessionIDs {
429+
key, err := getKeyForID(sessionBkt, id)
430+
if err != nil {
431+
return err
432+
}
433+
434+
v := sessionBkt.Get(key)
435+
if len(v) == 0 {
436+
return ErrSessionNotFound
437+
}
438+
439+
session, err := DeserializeSession(bytes.NewReader(v))
440+
if err != nil {
441+
return err
442+
}
405443

406-
sessionIDsBkt := groupIDBkt.Bucket(sessionIDKey)
407-
if sessionIDsBkt == nil {
408-
return fmt.Errorf("no sessions for group ID %v",
409-
groupID)
444+
if !fn(session) {
445+
return errFailedPred
446+
}
410447
}
411448

412-
return sessionIDsBkt.ForEach(func(_,
413-
sessionIDBytes []byte) error {
449+
pass = true
414450

415-
var sessionID ID
416-
copy(sessionID[:], sessionIDBytes)
417-
sessionIDs = append(sessionIDs, sessionID)
451+
return nil
452+
})
453+
if errors.Is(err, errFailedPred) {
454+
return pass, nil
455+
}
456+
if err != nil {
457+
return pass, err
458+
}
418459

419-
return nil
420-
})
460+
return pass, nil
461+
}
462+
463+
// getSessionIDs returns all the session IDs associated with the given group ID.
464+
func getSessionIDs(tx *bbolt.Tx, groupID ID) ([]ID, error) {
465+
var sessionIDs []ID
466+
467+
sessionBkt, err := getBucket(tx, sessionBucketKey)
468+
if err != nil {
469+
return nil, err
470+
}
471+
472+
groupIndexBkt := sessionBkt.Bucket(groupIDIndexKey)
473+
if groupIndexBkt == nil {
474+
return nil, ErrDBInitErr
475+
}
476+
477+
groupIDBkt := groupIndexBkt.Bucket(groupID[:])
478+
if groupIDBkt == nil {
479+
return nil, fmt.Errorf("no sessions for group ID %v",
480+
groupID)
481+
}
482+
483+
sessionIDsBkt := groupIDBkt.Bucket(sessionIDKey)
484+
if sessionIDsBkt == nil {
485+
return nil, fmt.Errorf("no sessions for group ID %v",
486+
groupID)
487+
}
488+
489+
err = sessionIDsBkt.ForEach(func(_,
490+
sessionIDBytes []byte) error {
491+
492+
var sessionID ID
493+
copy(sessionID[:], sessionIDBytes)
494+
sessionIDs = append(sessionIDs, sessionID)
495+
496+
return nil
421497
})
422498
if err != nil {
423499
return nil, err

session/store_test.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package session
22

33
import (
4+
"strings"
45
"testing"
56
"time"
67

@@ -172,6 +173,94 @@ func TestLinkedSessions(t *testing.T) {
172173
require.EqualValues(t, []ID{s4.ID, s5.ID}, sIDs)
173174
}
174175

176+
// TestCheckSessionGroupPredicate asserts that the CheckSessionGroupPredicate
177+
// method correctly checks if each session in a group passes a predicate.
178+
func TestCheckSessionGroupPredicate(t *testing.T) {
179+
// Set up a new DB.
180+
db, err := NewDB(t.TempDir(), "test.db")
181+
require.NoError(t, err)
182+
t.Cleanup(func() {
183+
_ = db.Close()
184+
})
185+
186+
// We will use the Label of the Session to test that the predicate
187+
// function is checked correctly.
188+
189+
// Add a new session to the DB.
190+
s1 := newSession(t, db, "label 1", nil)
191+
require.NoError(t, db.CreateSession(s1))
192+
193+
// Check that the group passes against an appropriate predicate.
194+
ok, err := db.CheckSessionGroupPredicate(
195+
s1.GroupID, func(s *Session) bool {
196+
return strings.Contains(s.Label, "label 1")
197+
},
198+
)
199+
require.NoError(t, err)
200+
require.True(t, ok)
201+
202+
// Check that the group fails against an appropriate predicate.
203+
ok, err = db.CheckSessionGroupPredicate(
204+
s1.GroupID, func(s *Session) bool {
205+
return strings.Contains(s.Label, "label 2")
206+
},
207+
)
208+
require.NoError(t, err)
209+
require.False(t, ok)
210+
211+
// Add a new session to the same group as the first one.
212+
s2 := newSession(t, db, "label 2", &s1.GroupID)
213+
require.NoError(t, db.CreateSession(s2))
214+
215+
// Check that the group passes against an appropriate predicate.
216+
ok, err = db.CheckSessionGroupPredicate(
217+
s1.GroupID, func(s *Session) bool {
218+
return strings.Contains(s.Label, "label")
219+
},
220+
)
221+
require.NoError(t, err)
222+
require.True(t, ok)
223+
224+
// Check that the group fails against an appropriate predicate.
225+
ok, err = db.CheckSessionGroupPredicate(
226+
s1.GroupID, func(s *Session) bool {
227+
return strings.Contains(s.Label, "label 1")
228+
},
229+
)
230+
require.NoError(t, err)
231+
require.False(t, ok)
232+
233+
// Add a new session that is not linked to the first one.
234+
s3 := newSession(t, db, "completely different", nil)
235+
require.NoError(t, db.CreateSession(s3))
236+
237+
// Ensure that the first group is unaffected.
238+
ok, err = db.CheckSessionGroupPredicate(
239+
s1.GroupID, func(s *Session) bool {
240+
return strings.Contains(s.Label, "label")
241+
},
242+
)
243+
require.NoError(t, err)
244+
require.True(t, ok)
245+
246+
// And that the new session is evaluated separately.
247+
ok, err = db.CheckSessionGroupPredicate(
248+
s3.GroupID, func(s *Session) bool {
249+
return strings.Contains(s.Label, "label")
250+
},
251+
)
252+
require.NoError(t, err)
253+
require.False(t, ok)
254+
255+
ok, err = db.CheckSessionGroupPredicate(
256+
s3.GroupID, func(s *Session) bool {
257+
return strings.Contains(s.Label, "different")
258+
},
259+
)
260+
require.NoError(t, err)
261+
require.True(t, ok)
262+
}
263+
175264
func newSession(t *testing.T, db Store, label string,
176265
linkedGroupID *ID) *Session {
177266

0 commit comments

Comments
 (0)