Skip to content

Commit 73110b6

Browse files
committed
multi: give firewallDB access to session ID index
1 parent 7c7e467 commit 73110b6

File tree

6 files changed

+45
-36
lines changed

6 files changed

+45
-36
lines changed

firewalldb/actions_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
func TestActionStorage(t *testing.T) {
1313
tmpDir := t.TempDir()
1414

15-
db, err := NewDB(tmpDir, "test.db")
15+
db, err := NewDB(tmpDir, "test.db", nil)
1616
require.NoError(t, err)
1717
t.Cleanup(func() {
1818
_ = db.Close()
@@ -147,7 +147,7 @@ func TestActionStorage(t *testing.T) {
147147
func TestListActions(t *testing.T) {
148148
tmpDir := t.TempDir()
149149

150-
db, err := NewDB(tmpDir, "test.db")
150+
db, err := NewDB(tmpDir, "test.db", nil)
151151
require.NoError(t, err)
152152
t.Cleanup(func() {
153153
_ = db.Close()

firewalldb/db.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"path/filepath"
99
"time"
1010

11+
"github.com/lightninglabs/lightning-terminal/session"
1112
"go.etcd.io/bbolt"
1213
)
1314

@@ -40,10 +41,14 @@ var (
4041
// DB is a bolt-backed persistent store.
4142
type DB struct {
4243
*bbolt.DB
44+
45+
sessionIDIndex session.IDToGroupIndex
4346
}
4447

4548
// NewDB creates a new bolt database that can be found at the given directory.
46-
func NewDB(dir, fileName string) (*DB, error) {
49+
func NewDB(dir, fileName string, sessionIDIndex session.IDToGroupIndex) (*DB,
50+
error) {
51+
4752
firstInit := false
4853
path := filepath.Join(dir, fileName)
4954

@@ -66,7 +71,10 @@ func NewDB(dir, fileName string) (*DB, error) {
6671
return nil, err
6772
}
6873

69-
return &DB{DB: db}, nil
74+
return &DB{
75+
DB: db,
76+
sessionIDIndex: sessionIDIndex,
77+
}, nil
7078
}
7179

7280
// fileExists reports whether the named file or directory exists.

firewalldb/kvstores_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func TestKVStoreTxs(t *testing.T) {
1818
ctx := context.Background()
1919
tmpDir := t.TempDir()
2020

21-
db, err := NewDB(tmpDir, "test.db")
21+
db, err := NewDB(tmpDir, "test.db", nil)
2222
require.NoError(t, err)
2323
t.Cleanup(func() {
2424
_ = db.Close()
@@ -65,7 +65,7 @@ func TestTempAndPermStores(t *testing.T) {
6565
ctx := context.Background()
6666
tmpDir := t.TempDir()
6767

68-
db, err := NewDB(tmpDir, "test.db")
68+
db, err := NewDB(tmpDir, "test.db", nil)
6969
require.NoError(t, err)
7070
t.Cleanup(func() {
7171
_ = db.Close()
@@ -113,7 +113,7 @@ func TestTempAndPermStores(t *testing.T) {
113113
require.NoError(t, db.Close())
114114

115115
// Restart it.
116-
db, err = NewDB(tmpDir, "test.db")
116+
db, err = NewDB(tmpDir, "test.db", nil)
117117
require.NoError(t, err)
118118
t.Cleanup(func() {
119119
_ = db.Close()
@@ -147,7 +147,7 @@ func TestKVStoreNameSpaces(t *testing.T) {
147147
ctx := context.Background()
148148
tmpDir := t.TempDir()
149149

150-
db, err := NewDB(tmpDir, "test.db")
150+
db, err := NewDB(tmpDir, "test.db", nil)
151151
require.NoError(t, err)
152152
t.Cleanup(func() {
153153
_ = db.Close()

firewalldb/privacy_mapper_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
// TestPrivacyMapStorage tests the privacy mapper CRUD logic.
1111
func TestPrivacyMapStorage(t *testing.T) {
1212
tmpDir := t.TempDir()
13-
db, err := NewDB(tmpDir, "test.db")
13+
db, err := NewDB(tmpDir, "test.db", nil)
1414
require.NoError(t, err)
1515
t.Cleanup(func() {
1616
_ = db.Close()
@@ -68,7 +68,7 @@ func TestPrivacyMapStorage(t *testing.T) {
6868
// `Update` function, then all the changes prior should be rolled back.
6969
func TestPrivacyMapTxs(t *testing.T) {
7070
tmpDir := t.TempDir()
71-
db, err := NewDB(tmpDir, "test.db")
71+
db, err := NewDB(tmpDir, "test.db", nil)
7272
require.NoError(t, err)
7373
t.Cleanup(func() {
7474
_ = db.Close()

session_rpcserver.go

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ type sessionRpcServer struct {
4141
litrpc.UnimplementedAutopilotServer
4242

4343
cfg *sessionRpcServerConfig
44-
db *session.DB
4544
sessionServer *session.Server
4645

4746
// sessRegMu is a mutex that should be held between acquiring an unused
@@ -57,8 +56,8 @@ type sessionRpcServer struct {
5756
// sessionRpcServerConfig holds the values used to configure the
5857
// sessionRpcServer.
5958
type sessionRpcServerConfig struct {
59+
db *session.DB
6060
basicAuth string
61-
dbDir string
6261
grpcOptions []grpc.ServerOption
6362
registerGrpcServers func(server *grpc.Server)
6463
superMacBaker session.MacaroonBaker
@@ -74,12 +73,6 @@ type sessionRpcServerConfig struct {
7473
func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer,
7574
error) {
7675

77-
// Create an instance of the local Terminal Connect session store DB.
78-
db, err := session.NewDB(cfg.dbDir, session.DBFilename)
79-
if err != nil {
80-
return nil, fmt.Errorf("error creating session DB: %v", err)
81-
}
82-
8376
// Create the gRPC server that handles adding/removing sessions and the
8477
// actual mailbox server that spins up the Terminal Connect server
8578
// interface.
@@ -96,7 +89,6 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer,
9689

9790
return &sessionRpcServer{
9891
cfg: cfg,
99-
db: db,
10092
sessionServer: server,
10193
quit: make(chan struct{}),
10294
}, nil
@@ -106,7 +98,7 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer,
10698
// requests. This includes resuming all non-revoked sessions.
10799
func (s *sessionRpcServer) start() error {
108100
// Start up all previously created sessions.
109-
sessions, err := s.db.ListSessions(nil)
101+
sessions, err := s.cfg.db.ListSessions(nil)
110102
if err != nil {
111103
return fmt.Errorf("error listing sessions: %v", err)
112104
}
@@ -157,7 +149,7 @@ func (s *sessionRpcServer) start() error {
157149
err)
158150

159151
if perm {
160-
err := s.db.RevokeSession(
152+
err := s.cfg.db.RevokeSession(
161153
sess.LocalPublicKey,
162154
)
163155
if err != nil {
@@ -182,7 +174,7 @@ func (s *sessionRpcServer) start() error {
182174
func (s *sessionRpcServer) stop() error {
183175
var returnErr error
184176
s.stopOnce.Do(func() {
185-
if err := s.db.Close(); err != nil {
177+
if err := s.cfg.db.Close(); err != nil {
186178
log.Errorf("Error closing session DB: %v", err)
187179
returnErr = err
188180
}
@@ -323,7 +315,7 @@ func (s *sessionRpcServer) AddSession(_ context.Context,
323315
s.sessRegMu.Lock()
324316
defer s.sessRegMu.Unlock()
325317

326-
id, localPrivKey, err := s.db.GetUnusedIDAndKeyPair()
318+
id, localPrivKey, err := s.cfg.db.GetUnusedIDAndKeyPair()
327319
if err != nil {
328320
return nil, err
329321
}
@@ -336,7 +328,7 @@ func (s *sessionRpcServer) AddSession(_ context.Context,
336328
return nil, fmt.Errorf("error creating new session: %v", err)
337329
}
338330

339-
if err := s.db.CreateSession(sess); err != nil {
331+
if err := s.cfg.db.CreateSession(sess); err != nil {
340332
return nil, fmt.Errorf("error storing session: %v", err)
341333
}
342334

@@ -375,7 +367,7 @@ func (s *sessionRpcServer) resumeSession(sess *session.Session) error {
375367
log.Debugf("Not resuming session %x with expiry %s",
376368
pubKeyBytes, sess.Expiry)
377369

378-
if err := s.db.RevokeSession(pubKey); err != nil {
370+
if err := s.cfg.db.RevokeSession(pubKey); err != nil {
379371
return fmt.Errorf("error revoking session: %v", err)
380372
}
381373

@@ -455,7 +447,7 @@ func (s *sessionRpcServer) resumeSession(sess *session.Session) error {
455447
log.Debugf("Deadline for session %x has already "+
456448
"passed. Revoking session", pubKeyBytes)
457449

458-
return s.db.RevokeSession(pubKey)
450+
return s.cfg.db.RevokeSession(pubKey)
459451
}
460452

461453
// Start the deadline timer.
@@ -490,7 +482,7 @@ func (s *sessionRpcServer) resumeSession(sess *session.Session) error {
490482

491483
authData := []byte(fmt.Sprintf("%s: %s", HeaderMacaroon, mac))
492484
sessionClosedSub, err := s.sessionServer.StartSession(
493-
sess, authData, s.db.UpdateSessionRemotePubKey, onNewStatus,
485+
sess, authData, s.cfg.db.UpdateSessionRemotePubKey, onNewStatus,
494486
)
495487
if err != nil {
496488
return err
@@ -535,7 +527,7 @@ func (s *sessionRpcServer) resumeSession(sess *session.Session) error {
535527
log.Debugf("Error stopping session: %v", err)
536528
}
537529

538-
err = s.db.RevokeSession(pubKey)
530+
err = s.cfg.db.RevokeSession(pubKey)
539531
if err != nil {
540532
log.Debugf("error revoking session: %v", err)
541533
}
@@ -548,7 +540,7 @@ func (s *sessionRpcServer) resumeSession(sess *session.Session) error {
548540
func (s *sessionRpcServer) ListSessions(_ context.Context,
549541
_ *litrpc.ListSessionsRequest) (*litrpc.ListSessionsResponse, error) {
550542

551-
sessions, err := s.db.ListSessions(nil)
543+
sessions, err := s.cfg.db.ListSessions(nil)
552544
if err != nil {
553545
return nil, fmt.Errorf("error fetching sessions: %v", err)
554546
}
@@ -577,7 +569,7 @@ func (s *sessionRpcServer) RevokeSession(ctx context.Context,
577569
return nil, fmt.Errorf("error parsing public key: %v", err)
578570
}
579571

580-
if err := s.db.RevokeSession(pubKey); err != nil {
572+
if err := s.cfg.db.RevokeSession(pubKey); err != nil {
581573
return nil, fmt.Errorf("error revoking session: %v", err)
582574
}
583575

@@ -995,7 +987,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
995987
s.sessRegMu.Lock()
996988
defer s.sessRegMu.Unlock()
997989

998-
id, localPrivKey, err := s.db.GetUnusedIDAndKeyPair()
990+
id, localPrivKey, err := s.cfg.db.GetUnusedIDAndKeyPair()
999991
if err != nil {
1000992
return nil, err
1001993
}
@@ -1037,7 +1029,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
10371029
// We only persist this session if we successfully retrieved the
10381030
// autopilot's static key.
10391031
sess.RemotePublicKey = remoteKey
1040-
if err := s.db.CreateSession(sess); err != nil {
1032+
if err := s.cfg.db.CreateSession(sess); err != nil {
10411033
return nil, fmt.Errorf("error storing session: %v", err)
10421034
}
10431035

@@ -1061,7 +1053,7 @@ func (s *sessionRpcServer) ListAutopilotSessions(_ context.Context,
10611053
_ *litrpc.ListAutopilotSessionsRequest) (
10621054
*litrpc.ListAutopilotSessionsResponse, error) {
10631055

1064-
sessions, err := s.db.ListSessions(func(s *session.Session) bool {
1056+
sessions, err := s.cfg.db.ListSessions(func(s *session.Session) bool {
10651057
return s.Type == session.TypeAutopilot
10661058
})
10671059
if err != nil {
@@ -1092,7 +1084,7 @@ func (s *sessionRpcServer) RevokeAutopilotSession(ctx context.Context,
10921084
return nil, fmt.Errorf("error parsing public key: %v", err)
10931085
}
10941086

1095-
sess, err := s.db.GetSession(pubKey)
1087+
sess, err := s.cfg.db.GetSession(pubKey)
10961088
if err != nil {
10971089
return nil, err
10981090
}

terminal.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ type LightningTerminal struct {
184184
accountRpcServer *accounts.RPCServer
185185

186186
firewallDB *firewalldb.DB
187+
sessionDB *session.DB
187188

188189
restHandler http.Handler
189190
restCancel func()
@@ -317,12 +318,20 @@ func (g *LightningTerminal) start() error {
317318

318319
g.ruleMgrs = rules.NewRuleManagerSet()
319320

321+
// Create an instance of the local Terminal Connect session store DB.
320322
networkDir := filepath.Join(g.cfg.LitDir, g.cfg.Network)
321-
g.firewallDB, err = firewalldb.NewDB(networkDir, firewalldb.DBFilename)
323+
g.sessionDB, err = session.NewDB(networkDir, session.DBFilename)
322324
if err != nil {
323325
return fmt.Errorf("error creating session DB: %v", err)
324326
}
325327

328+
g.firewallDB, err = firewalldb.NewDB(
329+
networkDir, firewalldb.DBFilename, g.sessionDB,
330+
)
331+
if err != nil {
332+
return fmt.Errorf("error creating firewall DB: %v", err)
333+
}
334+
326335
if !g.cfg.Autopilot.Disable {
327336
if g.cfg.Autopilot.Address == "" &&
328337
len(g.cfg.Autopilot.DialOpts) == 0 {
@@ -353,8 +362,8 @@ func (g *LightningTerminal) start() error {
353362
}
354363

355364
g.sessionRpcServer, err = newSessionRPCServer(&sessionRpcServerConfig{
365+
db: g.sessionDB,
356366
basicAuth: g.rpcProxy.basicAuth,
357-
dbDir: filepath.Join(g.cfg.LitDir, g.cfg.Network),
358367
grpcOptions: []grpc.ServerOption{
359368
grpc.CustomCodec(grpcProxy.Codec()), // nolint: staticcheck,
360369
grpc.ChainStreamInterceptor(

0 commit comments

Comments
 (0)