Skip to content

Commit d3a2626

Browse files
committed
session: separate methods for creating vs updating a session
1 parent 92575c8 commit d3a2626

File tree

5 files changed

+156
-24
lines changed

5 files changed

+156
-24
lines changed

session/interface.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,9 @@ func NewSession(label string, typ Type, expiry time.Time, serverAddr string,
123123
// Store is the interface a persistent storage must implement for storing and
124124
// retrieving Terminal Connect sessions.
125125
type Store interface {
126-
// StoreSession stores a session in the store. If a session with the
127-
// same local public key already exists, the existing record is updated/
128-
// overwritten instead.
129-
StoreSession(*Session) error
126+
// CreateSession adds a new session to the store. If a session with the same
127+
// local public key already exists an error is returned.
128+
CreateSession(*Session) error
130129

131130
// GetSession fetches the session with the given key.
132131
GetSession(key *btcec.PublicKey) (*Session, error)
@@ -137,4 +136,9 @@ type Store interface {
137136
// RevokeSession updates the state of the session with the given local
138137
// public key to be revoked.
139138
RevokeSession(*btcec.PublicKey) error
139+
140+
// UpdateSessionRemotePubKey can be used to add the given remote pub key
141+
// to the session with the given local pub key.
142+
UpdateSessionRemotePubKey(localPubKey,
143+
remotePubKey *btcec.PublicKey) error
140144
}

session/server.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func newMailboxSession() *mailboxSession {
3333

3434
func (m *mailboxSession) start(session *Session,
3535
serverCreator GRPCServerCreator, authData []byte,
36-
onUpdate func(sess *Session) error,
36+
onUpdate func(local, remote *btcec.PublicKey) error,
3737
onNewStatus func(s mailbox.ServerStatus)) error {
3838

3939
tlsConfig := &tls.Config{}
@@ -46,8 +46,7 @@ func (m *mailboxSession) start(session *Session,
4646
keys := mailbox.NewConnData(
4747
ecdh, session.RemotePublicKey, session.PairingSecret[:],
4848
authData, func(key *btcec.PublicKey) error {
49-
session.RemotePublicKey = key
50-
return onUpdate(session)
49+
return onUpdate(session.LocalPublicKey, key)
5150
}, nil,
5251
)
5352

@@ -105,7 +104,7 @@ func NewServer(serverCreator GRPCServerCreator) *Server {
105104
}
106105

107106
func (s *Server) StartSession(session *Session, authData []byte,
108-
onUpdate func(sess *Session) error,
107+
onUpdate func(local, remote *btcec.PublicKey) error,
109108
onNewStatus func(s mailbox.ServerStatus)) (chan struct{}, error) {
110109

111110
s.activeSessionsMtx.Lock()

session/store.go

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package session
33
import (
44
"bytes"
55
"errors"
6+
"fmt"
67
"time"
78

89
"github.com/btcsuite/btcd/btcec/v2"
@@ -25,12 +26,11 @@ func getSessionKey(session *Session) []byte {
2526
return session.LocalPublicKey.SerializeCompressed()
2627
}
2728

28-
// StoreSession stores a session in the store. If a session with the
29-
// same local public key already exists, the existing record is updated/
30-
// overwritten instead.
29+
// CreateSession adds a new session to the store. If a session with the same
30+
// local public key already exists an error is returned.
3131
//
3232
// NOTE: this is part of the Store interface.
33-
func (db *DB) StoreSession(session *Session) error {
33+
func (db *DB) CreateSession(session *Session) error {
3434
var buf bytes.Buffer
3535
if err := SerializeSession(&buf, session); err != nil {
3636
return err
@@ -43,10 +43,55 @@ func (db *DB) StoreSession(session *Session) error {
4343
return err
4444
}
4545

46+
if len(sessionBucket.Get(sessionKey)) != 0 {
47+
return fmt.Errorf("session with local public "+
48+
"key(%x) already exists",
49+
session.LocalPublicKey.SerializeCompressed())
50+
}
51+
4652
return sessionBucket.Put(sessionKey, buf.Bytes())
4753
})
4854
}
4955

56+
// UpdateSessionRemotePubKey can be used to add the given remote pub key
57+
// to the session with the given local pub key.
58+
//
59+
// NOTE: this is part of the Store interface.
60+
func (db *DB) UpdateSessionRemotePubKey(localPubKey,
61+
remotePubKey *btcec.PublicKey) error {
62+
63+
key := localPubKey.SerializeCompressed()
64+
65+
return db.Update(func(tx *bbolt.Tx) error {
66+
sessionBucket, err := getBucket(tx, sessionBucketKey)
67+
if err != nil {
68+
return err
69+
}
70+
71+
serialisedSession := sessionBucket.Get(key)
72+
73+
if len(serialisedSession) == 0 {
74+
return ErrSessionNotFound
75+
}
76+
77+
session, err := DeserializeSession(
78+
bytes.NewReader(serialisedSession),
79+
)
80+
if err != nil {
81+
return err
82+
}
83+
84+
session.RemotePublicKey = remotePubKey
85+
86+
var buf bytes.Buffer
87+
if err := SerializeSession(&buf, session); err != nil {
88+
return err
89+
}
90+
91+
return sessionBucket.Put(key, buf.Bytes())
92+
})
93+
}
94+
5095
// GetSession fetches the session with the given key.
5196
//
5297
// NOTE: this is part of the Store interface.
@@ -122,7 +167,7 @@ func (db *DB) ListSessions(filterFn func(s *Session) bool) ([]*Session, error) {
122167
// NOTE: this is part of the Store interface.
123168
func (db *DB) RevokeSession(key *btcec.PublicKey) error {
124169
var session *Session
125-
err := db.View(func(tx *bbolt.Tx) error {
170+
return db.Update(func(tx *bbolt.Tx) error {
126171
sessionBucket, err := getBucket(tx, sessionBucketKey)
127172
if err != nil {
128173
return err
@@ -134,14 +179,18 @@ func (db *DB) RevokeSession(key *btcec.PublicKey) error {
134179
}
135180

136181
session, err = DeserializeSession(bytes.NewReader(sessionBytes))
137-
return err
138-
})
139-
if err != nil {
140-
return err
141-
}
182+
if err != nil {
183+
return err
184+
}
142185

143-
session.State = StateRevoked
144-
session.RevokedAt = time.Now()
186+
session.State = StateRevoked
187+
session.RevokedAt = time.Now()
145188

146-
return db.StoreSession(session)
189+
var buf bytes.Buffer
190+
if err := SerializeSession(&buf, session); err != nil {
191+
return err
192+
}
193+
194+
return sessionBucket.Put(key.SerializeCompressed(), buf.Bytes())
195+
})
147196
}

session/store_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package session
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"github.com/btcsuite/btcd/btcec/v2"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
// TestBasicSessionStore tests the basic getters and setters of the session
12+
// store.
13+
func TestBasicSessionStore(t *testing.T) {
14+
// Set up a new DB.
15+
db, err := NewDB(t.TempDir(), "test.db")
16+
require.NoError(t, err)
17+
t.Cleanup(func() {
18+
_ = db.Close()
19+
})
20+
21+
// Create a few sessions.
22+
s1 := newSession(t, "session 1")
23+
s2 := newSession(t, "session 2")
24+
s3 := newSession(t, "session 3")
25+
26+
// Persist session 1.
27+
require.NoError(t, db.CreateSession(s1))
28+
29+
// Trying to persist session 1 again should fail.
30+
require.ErrorContains(t, db.CreateSession(s1), "already exists")
31+
32+
// Persist a few more sessions.
33+
require.NoError(t, db.CreateSession(s2))
34+
require.NoError(t, db.CreateSession(s3))
35+
36+
// Ensure that we can retrieve each session.
37+
for _, s := range []*Session{s1, s2, s3} {
38+
session, err := db.GetSession(s.LocalPublicKey)
39+
require.NoError(t, err)
40+
require.Equal(t, s.Label, session.Label)
41+
}
42+
43+
// Fetch session 1 and assert that it currently has no remote pub key.
44+
session1, err := db.GetSession(s1.LocalPublicKey)
45+
require.NoError(t, err)
46+
require.Nil(t, session1.RemotePublicKey)
47+
48+
// Use the update method to add a remote key.
49+
remotePriv, err := btcec.NewPrivateKey()
50+
require.NoError(t, err)
51+
remotePub := remotePriv.PubKey()
52+
53+
err = db.UpdateSessionRemotePubKey(session1.LocalPublicKey, remotePub)
54+
require.NoError(t, err)
55+
56+
// Assert that the session now does have the remote pub key.
57+
session1, err = db.GetSession(s1.LocalPublicKey)
58+
require.NoError(t, err)
59+
require.True(t, remotePub.IsEqual(session1.RemotePublicKey))
60+
61+
// Check that the session's state is currently StateCreated.
62+
require.Equal(t, session1.State, StateCreated)
63+
64+
// Now revoke the session and assert that the state is revoked.
65+
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
66+
session1, err = db.GetSession(s1.LocalPublicKey)
67+
require.NoError(t, err)
68+
require.Equal(t, session1.State, StateRevoked)
69+
}
70+
71+
func newSession(t *testing.T, label string) *Session {
72+
session, err := NewSession(
73+
label, TypeMacaroonAdmin,
74+
time.Date(99999, 1, 1, 0, 0, 0, 0, time.UTC),
75+
"foo.bar.baz:1234", true, nil, nil, nil, true,
76+
)
77+
require.NoError(t, err)
78+
79+
return session
80+
}

session_rpcserver.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ func (s *sessionRpcServer) AddSession(_ context.Context,
323323
return nil, fmt.Errorf("error creating new session: %v", err)
324324
}
325325

326-
if err := s.db.StoreSession(sess); err != nil {
326+
if err := s.db.CreateSession(sess); err != nil {
327327
return nil, fmt.Errorf("error storing session: %v", err)
328328
}
329329

@@ -477,7 +477,7 @@ func (s *sessionRpcServer) resumeSession(sess *session.Session) error {
477477

478478
authData := []byte(fmt.Sprintf("%s: %s", HeaderMacaroon, mac))
479479
sessionClosedSub, err := s.sessionServer.StartSession(
480-
sess, authData, s.db.StoreSession, onNewStatus,
480+
sess, authData, s.db.UpdateSessionRemotePubKey, onNewStatus,
481481
)
482482
if err != nil {
483483
return err
@@ -1015,7 +1015,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
10151015
// We only persist this session if we successfully retrieved the
10161016
// autopilot's static key.
10171017
sess.RemotePublicKey = remoteKey
1018-
if err := s.db.StoreSession(sess); err != nil {
1018+
if err := s.db.CreateSession(sess); err != nil {
10191019
return nil, fmt.Errorf("error storing session: %v", err)
10201020
}
10211021

0 commit comments

Comments
 (0)