Skip to content

Commit 32a34d1

Browse files
committed
session: remove Session Group Predicate method
This was used to check that all linked sessions are no longer active before attempting to register an autopilot session. But this is no longer needed since this is done within NewSession.
1 parent 013e7c0 commit 32a34d1

File tree

4 files changed

+1
-177
lines changed

4 files changed

+1
-177
lines changed

session/interface.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -216,12 +216,6 @@ type Store interface {
216216
// GetSessionByID fetches the session with the given ID.
217217
GetSessionByID(id ID) (*Session, error)
218218

219-
// CheckSessionGroupPredicate iterates over all the sessions in a group
220-
// and checks if each one passes the given predicate function. True is
221-
// returned if each session passes.
222-
CheckSessionGroupPredicate(groupID ID,
223-
fn func(s *Session) bool) (bool, error)
224-
225219
// DeleteReservedSessions deletes all sessions that are in the
226220
// StateReserved state.
227221
DeleteReservedSessions() error

session/kvdb_store.go

Lines changed: 1 addition & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,7 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
249249
}
250250

251251
// Ensure that the session is no longer active.
252-
if sess.State == StateCreated ||
253-
sess.State == StateInUse {
254-
252+
if !sess.State.Terminal() {
255253
return fmt.Errorf("session (id=%x) "+
256254
"in group %x is still active",
257255
sess.ID, sess.GroupID)
@@ -679,65 +677,6 @@ func (db *BoltStore) GetSessionIDs(groupID ID) ([]ID, error) {
679677
return sessionIDs, nil
680678
}
681679

682-
// CheckSessionGroupPredicate iterates over all the sessions in a group and
683-
// checks if each one passes the given predicate function. True is returned if
684-
// each session passes.
685-
//
686-
// NOTE: this is part of the Store interface.
687-
func (db *BoltStore) CheckSessionGroupPredicate(groupID ID,
688-
fn func(s *Session) bool) (bool, error) {
689-
690-
var (
691-
pass bool
692-
errFailedPred = errors.New("session failed predicate")
693-
)
694-
err := db.View(func(tx *bbolt.Tx) error {
695-
sessionBkt, err := getBucket(tx, sessionBucketKey)
696-
if err != nil {
697-
return err
698-
}
699-
700-
sessionIDs, err := getSessionIDs(sessionBkt, groupID)
701-
if err != nil {
702-
return err
703-
}
704-
705-
// Iterate over all the sessions.
706-
for _, id := range sessionIDs {
707-
key, err := getKeyForID(sessionBkt, id)
708-
if err != nil {
709-
return err
710-
}
711-
712-
v := sessionBkt.Get(key)
713-
if len(v) == 0 {
714-
return ErrSessionNotFound
715-
}
716-
717-
session, err := DeserializeSession(bytes.NewReader(v))
718-
if err != nil {
719-
return err
720-
}
721-
722-
if !fn(session) {
723-
return errFailedPred
724-
}
725-
}
726-
727-
pass = true
728-
729-
return nil
730-
})
731-
if errors.Is(err, errFailedPred) {
732-
return pass, nil
733-
}
734-
if err != nil {
735-
return pass, err
736-
}
737-
738-
return pass, nil
739-
}
740-
741680
// getSessionIDs returns all the session IDs associated with the given group ID.
742681
func getSessionIDs(sessionBkt *bbolt.Bucket, groupID ID) ([]ID, error) {
743682
var sessionIDs []ID

session/store_test.go

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package session
22

33
import (
4-
"strings"
54
"testing"
65
"time"
76

@@ -291,97 +290,6 @@ func TestLinkedSessions(t *testing.T) {
291290
require.EqualValues(t, []ID{s4.ID, s5.ID}, sIDs)
292291
}
293292

294-
// TestCheckSessionGroupPredicate asserts that the CheckSessionGroupPredicate
295-
// method correctly checks if each session in a group passes a predicate.
296-
func TestCheckSessionGroupPredicate(t *testing.T) {
297-
t.Parallel()
298-
299-
// Set up a new DB.
300-
clock := clock.NewTestClock(testTime)
301-
db, err := NewDB(t.TempDir(), "test.db", clock)
302-
require.NoError(t, err)
303-
t.Cleanup(func() {
304-
_ = db.Close()
305-
})
306-
307-
// We will use the Label of the Session to test that the predicate
308-
// function is checked correctly.
309-
310-
// Add a new session to the DB.
311-
s1 := createSession(t, db, "label 1")
312-
313-
// Check that the group passes against an appropriate predicate.
314-
ok, err := db.CheckSessionGroupPredicate(
315-
s1.GroupID, func(s *Session) bool {
316-
return strings.Contains(s.Label, "label 1")
317-
},
318-
)
319-
require.NoError(t, err)
320-
require.True(t, ok)
321-
322-
// Check that the group fails against an appropriate predicate.
323-
ok, err = db.CheckSessionGroupPredicate(
324-
s1.GroupID, func(s *Session) bool {
325-
return strings.Contains(s.Label, "label 2")
326-
},
327-
)
328-
require.NoError(t, err)
329-
require.False(t, ok)
330-
331-
// Revoke the first session.
332-
require.NoError(t, db.ShiftState(s1.ID, StateRevoked))
333-
334-
// Add a new session to the same group as the first one.
335-
_ = createSession(t, db, "label 2", withLinkedGroupID(&s1.GroupID))
336-
337-
// Check that the group passes against an appropriate predicate.
338-
ok, err = db.CheckSessionGroupPredicate(
339-
s1.GroupID, func(s *Session) bool {
340-
return strings.Contains(s.Label, "label")
341-
},
342-
)
343-
require.NoError(t, err)
344-
require.True(t, ok)
345-
346-
// Check that the group fails against an appropriate predicate.
347-
ok, err = db.CheckSessionGroupPredicate(
348-
s1.GroupID, func(s *Session) bool {
349-
return strings.Contains(s.Label, "label 1")
350-
},
351-
)
352-
require.NoError(t, err)
353-
require.False(t, ok)
354-
355-
// Add a new session that is not linked to the first one.
356-
s3 := createSession(t, db, "completely different")
357-
358-
// Ensure that the first group is unaffected.
359-
ok, err = db.CheckSessionGroupPredicate(
360-
s1.GroupID, func(s *Session) bool {
361-
return strings.Contains(s.Label, "label")
362-
},
363-
)
364-
require.NoError(t, err)
365-
require.True(t, ok)
366-
367-
// And that the new session is evaluated separately.
368-
ok, err = db.CheckSessionGroupPredicate(
369-
s3.GroupID, func(s *Session) bool {
370-
return strings.Contains(s.Label, "label")
371-
},
372-
)
373-
require.NoError(t, err)
374-
require.False(t, ok)
375-
376-
ok, err = db.CheckSessionGroupPredicate(
377-
s3.GroupID, func(s *Session) bool {
378-
return strings.Contains(s.Label, "different")
379-
},
380-
)
381-
require.NoError(t, err)
382-
require.True(t, ok)
383-
}
384-
385293
// TestStateShift tests that the ShiftState method works as expected.
386294
func TestStateShift(t *testing.T) {
387295
// Set up a new DB.

session_rpcserver.go

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -874,23 +874,6 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
874874
"group %x", groupSess.ID, groupSess.GroupID)
875875
}
876876

877-
// Now we need to check that all the sessions in the group are
878-
// no longer active.
879-
ok, err := s.cfg.db.CheckSessionGroupPredicate(
880-
groupID, func(s *session.Session) bool {
881-
return s.State == session.StateRevoked ||
882-
s.State == session.StateExpired
883-
},
884-
)
885-
if err != nil {
886-
return nil, err
887-
}
888-
889-
if !ok {
890-
return nil, fmt.Errorf("a linked session in group "+
891-
"%x is still active", groupID)
892-
}
893-
894877
linkedGroupID = &groupID
895878
linkedGroupSession = groupSess
896879

0 commit comments

Comments
 (0)