Skip to content

Commit 1ed4907

Browse files
committed
session: replace RevokeSession with ShiftState
1 parent 8f22fc9 commit 1ed4907

File tree

4 files changed

+27
-47
lines changed

4 files changed

+27
-47
lines changed

session/interface.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,6 @@ type Store interface {
215215
// that are in the given states.
216216
ListSessionsByState(...State) ([]*Session, error)
217217

218-
// RevokeSession updates the state of the session with the given local
219-
// public key to be revoked.
220-
RevokeSession(*btcec.PublicKey) error
221-
222218
// UpdateSessionRemotePubKey can be used to add the given remote pub key
223219
// to the session with the given local pub key.
224220
UpdateSessionRemotePubKey(localPubKey,

session/kvdb_store.go

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -555,35 +555,6 @@ func (db *BoltStore) ShiftState(id ID, dest State) error {
555555
})
556556
}
557557

558-
// RevokeSession updates the state of the session with the given local
559-
// public key to be revoked.
560-
//
561-
// NOTE: this is part of the Store interface.
562-
func (db *BoltStore) RevokeSession(key *btcec.PublicKey) error {
563-
var session *Session
564-
return db.Update(func(tx *bbolt.Tx) error {
565-
sessionBucket, err := getBucket(tx, sessionBucketKey)
566-
if err != nil {
567-
return err
568-
}
569-
570-
sessionBytes := sessionBucket.Get(key.SerializeCompressed())
571-
if len(sessionBytes) == 0 {
572-
return ErrSessionNotFound
573-
}
574-
575-
session, err = DeserializeSession(bytes.NewReader(sessionBytes))
576-
if err != nil {
577-
return err
578-
}
579-
580-
session.State = StateRevoked
581-
session.RevokedAt = db.clock.Now().UTC()
582-
583-
return putSession(sessionBucket, session)
584-
})
585-
}
586-
587558
// GetSessionByID fetches the session with the given ID.
588559
//
589560
// NOTE: this is part of the Store interface.

session/store_test.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func TestBasicSessionStore(t *testing.T) {
106106
require.Equal(t, session1.State, StateCreated)
107107

108108
// Now revoke the session and assert that the state is revoked.
109-
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
109+
require.NoError(t, db.ShiftState(s1.ID, StateRevoked))
110110
s1, err = db.GetSession(s1.LocalPublicKey)
111111
require.NoError(t, err)
112112
require.Equal(t, s1.State, StateRevoked)
@@ -225,7 +225,7 @@ func TestLinkingSessions(t *testing.T) {
225225
require.ErrorContains(t, db.CreateSession(s2), "is still active")
226226

227227
// Revoke the first session.
228-
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
228+
require.NoError(t, db.ShiftState(s1.ID, StateRevoked))
229229

230230
// Persisting the second linked session should now work.
231231
require.NoError(t, db.CreateSession(s2))
@@ -248,16 +248,20 @@ func TestLinkedSessions(t *testing.T) {
248248
// the same group. The group ID is equivalent to the session ID of the
249249
// first session.
250250
s1 := newSession(t, db, clock, "session 1")
251-
s2 := newSession(t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID))
252-
s3 := newSession(t, db, clock, "session 3", withLinkedGroupID(&s2.GroupID))
251+
s2 := newSession(
252+
t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID),
253+
)
254+
s3 := newSession(
255+
t, db, clock, "session 3", withLinkedGroupID(&s2.GroupID),
256+
)
253257

254258
// Persist the sessions.
255259
require.NoError(t, db.CreateSession(s1))
256260

257-
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
261+
require.NoError(t, db.ShiftState(s1.ID, StateRevoked))
258262
require.NoError(t, db.CreateSession(s2))
259263

260-
require.NoError(t, db.RevokeSession(s2.LocalPublicKey))
264+
require.NoError(t, db.ShiftState(s2.ID, StateRevoked))
261265
require.NoError(t, db.CreateSession(s3))
262266

263267
// Assert that the session ID to group ID index works as expected.
@@ -282,7 +286,7 @@ func TestLinkedSessions(t *testing.T) {
282286

283287
// Persist the sessions.
284288
require.NoError(t, db.CreateSession(s4))
285-
require.NoError(t, db.RevokeSession(s4.LocalPublicKey))
289+
require.NoError(t, db.ShiftState(s4.ID, StateRevoked))
286290

287291
require.NoError(t, db.CreateSession(s5))
288292

@@ -337,7 +341,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
337341
require.False(t, ok)
338342

339343
// Revoke the first session.
340-
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
344+
require.NoError(t, db.ShiftState(s1.ID, StateRevoked))
341345

342346
// Add a new session to the same group as the first one.
343347
s2 := newSession(t, db, clock, "label 2", withLinkedGroupID(&s1.GroupID))

session_rpcserver.go

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ func (s *sessionRpcServer) start(ctx context.Context) error {
154154
err)
155155

156156
if perm {
157-
err := s.cfg.db.RevokeSession(
158-
sess.LocalPublicKey,
157+
err := s.cfg.db.ShiftState(
158+
sess.ID, session.StateRevoked,
159159
)
160160
if err != nil {
161161
log.Errorf("error revoking "+
@@ -360,7 +360,8 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
360360
log.Debugf("Not resuming session %x with expiry %s",
361361
pubKeyBytes, sess.Expiry)
362362

363-
if err := s.cfg.db.RevokeSession(pubKey); err != nil {
363+
err := s.cfg.db.ShiftState(sess.ID, session.StateRevoked)
364+
if err != nil {
364365
return fmt.Errorf("error revoking session: %v", err)
365366
}
366367

@@ -436,7 +437,9 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
436437
log.Debugf("Deadline for session %x has already "+
437438
"passed. Revoking session", pubKeyBytes)
438439

439-
return s.cfg.db.RevokeSession(pubKey)
440+
return s.cfg.db.ShiftState(
441+
sess.ID, session.StateRevoked,
442+
)
440443
}
441444

442445
// Start the deadline timer.
@@ -515,7 +518,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
515518
log.Debugf("Error stopping session: %v", err)
516519
}
517520

518-
err = s.cfg.db.RevokeSession(pubKey)
521+
err = s.cfg.db.ShiftState(sess.ID, session.StateRevoked)
519522
if err != nil {
520523
log.Debugf("error revoking session: %v", err)
521524
}
@@ -557,7 +560,13 @@ func (s *sessionRpcServer) RevokeSession(ctx context.Context,
557560
return nil, fmt.Errorf("error parsing public key: %v", err)
558561
}
559562

560-
if err := s.cfg.db.RevokeSession(pubKey); err != nil {
563+
sess, err := s.cfg.db.GetSession(pubKey)
564+
if err != nil {
565+
return nil, fmt.Errorf("error fetching session: %v", err)
566+
}
567+
568+
err = s.cfg.db.ShiftState(sess.ID, session.StateRevoked)
569+
if err != nil {
561570
return nil, fmt.Errorf("error revoking session: %v", err)
562571
}
563572

0 commit comments

Comments
 (0)