From 3c4277857cd51c100c9790a5f30d93ad30faa3be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Mon, 9 Jun 2025 12:13:09 +0200 Subject: [PATCH 01/11] accounts: ensure that test SQL store is closed We add a helper function to the functions that creates the test SQL stores, in order to ensure that the store is properly closed when the test is cleaned up. --- accounts/sql_migration_test.go | 3 --- accounts/test_postgres.go | 4 ++-- accounts/test_sql.go | 22 ++++++++++++++++++++++ accounts/test_sqlite.go | 6 +++--- 4 files changed, 27 insertions(+), 8 deletions(-) create mode 100644 accounts/test_sql.go diff --git a/accounts/sql_migration_test.go b/accounts/sql_migration_test.go index d1e331e42..621c947a5 100644 --- a/accounts/sql_migration_test.go +++ b/accounts/sql_migration_test.go @@ -39,9 +39,6 @@ func TestAccountStoreMigration(t *testing.T) { *db.TransactionExecutor[SQLQueries]) { testDBStore := NewTestDB(t, clock) - t.Cleanup(func() { - require.NoError(t, testDBStore.Close()) - }) store, ok := testDBStore.(*SQLStore) require.True(t, ok) diff --git a/accounts/test_postgres.go b/accounts/test_postgres.go index 609eeb608..16665030d 100644 --- a/accounts/test_postgres.go +++ b/accounts/test_postgres.go @@ -16,7 +16,7 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return NewSQLStore(db.NewTestPostgresDB(t).BaseDB, clock) + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,5 +24,5 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return NewSQLStore(db.NewTestPostgresDB(t).BaseDB, clock) + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } diff --git a/accounts/test_sql.go b/accounts/test_sql.go new file mode 100644 index 000000000..3c1ee7f16 --- /dev/null +++ b/accounts/test_sql.go @@ -0,0 +1,22 @@ +//go:build test_db_postgres || test_db_sqlite + +package accounts + +import ( + "testing" + + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightningnetwork/lnd/clock" + "github.com/stretchr/testify/require" +) + +// createStore is a helper function that creates a new SQLStore and ensure that +// it is closed when during the test cleanup. +func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { + store := NewSQLStore(sqlDB, clock) + t.Cleanup(func() { + require.NoError(t, store.Close()) + }) + + return store +} diff --git a/accounts/test_sqlite.go b/accounts/test_sqlite.go index 0dd042a28..9d899b3e2 100644 --- a/accounts/test_sqlite.go +++ b/accounts/test_sqlite.go @@ -16,7 +16,7 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return NewSQLStore(db.NewTestSqliteDB(t).BaseDB, clock) + return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,7 +24,7 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return NewSQLStore( - db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, + return createStore( + t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, ) } From a7ecf0a58e537dc1f1c3cd3b47990eacaeb6ebb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Wed, 30 Apr 2025 00:12:39 +0700 Subject: [PATCH 02/11] session: update NewTestDB funcs to return Store In preparation for upcoming migration tests from a kvdb to an SQL store, this commit updates the NewTestDB function to return the Store interface rather than a concrete store implementation. This change ensures that migration tests can call NewTestDB under any build tag while receiving a consistent return type. --- session/test_kvdb.go | 8 ++++---- session/test_postgres.go | 4 ++-- session/test_sql.go | 2 +- session/test_sqlite.go | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/session/test_kvdb.go b/session/test_kvdb.go index 241448410..cc939159d 100644 --- a/session/test_kvdb.go +++ b/session/test_kvdb.go @@ -11,14 +11,14 @@ import ( ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *BoltStore { +func NewTestDB(t *testing.T, clock clock.Clock) Store { return NewTestDBFromPath(t, t.TempDir(), clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. func NewTestDBFromPath(t *testing.T, dbPath string, - clock clock.Clock) *BoltStore { + clock clock.Clock) Store { acctStore := accounts.NewTestDB(t, clock) @@ -28,13 +28,13 @@ func NewTestDBFromPath(t *testing.T, dbPath string, // NewTestDBWithAccounts creates a new test session Store with access to an // existing accounts DB. func NewTestDBWithAccounts(t *testing.T, clock clock.Clock, - acctStore accounts.Store) *BoltStore { + acctStore accounts.Store) Store { return newDBFromPathWithAccounts(t, clock, t.TempDir(), acctStore) } func newDBFromPathWithAccounts(t *testing.T, clock clock.Clock, dbPath string, - acctStore accounts.Store) *BoltStore { + acctStore accounts.Store) Store { store, err := NewDB(dbPath, DBFilename, clock, acctStore) require.NoError(t, err) diff --git a/session/test_postgres.go b/session/test_postgres.go index db392fe7f..decf3bcc2 100644 --- a/session/test_postgres.go +++ b/session/test_postgres.go @@ -15,14 +15,14 @@ import ( var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLStore { +func NewTestDB(t *testing.T, clock clock.Clock) Store { return NewSQLStore(db.NewTestPostgresDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a // connection to an existing postgres database for testing. func NewTestDBFromPath(t *testing.T, dbPath string, - clock clock.Clock) *SQLStore { + clock clock.Clock) Store { return NewSQLStore(db.NewTestPostgresDB(t).BaseDB, clock) } diff --git a/session/test_sql.go b/session/test_sql.go index ab4b32a6c..ceb02194c 100644 --- a/session/test_sql.go +++ b/session/test_sql.go @@ -11,7 +11,7 @@ import ( ) func NewTestDBWithAccounts(t *testing.T, clock clock.Clock, - acctStore accounts.Store) *SQLStore { + acctStore accounts.Store) Store { accounts, ok := acctStore.(*accounts.SQLStore) require.True(t, ok) diff --git a/session/test_sqlite.go b/session/test_sqlite.go index 87519f4f1..dccbefe85 100644 --- a/session/test_sqlite.go +++ b/session/test_sqlite.go @@ -15,14 +15,14 @@ import ( var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLStore { +func NewTestDB(t *testing.T, clock clock.Clock) Store { return NewSQLStore(db.NewTestSqliteDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a // connection to an existing sqlite database for testing. func NewTestDBFromPath(t *testing.T, dbPath string, - clock clock.Clock) *SQLStore { + clock clock.Clock) Store { return NewSQLStore( db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, From 0b70477dfa92c5daf164f2edc5b53d79c3c871c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Sun, 8 Jun 2025 13:31:03 +0200 Subject: [PATCH 03/11] session: ensure that test SQL store is closed We add a helper function to the functions that creates the test SQL stores, in order to ensure that the store is properly closed when the test is cleaned up. --- session/test_postgres.go | 4 ++-- session/test_sql.go | 14 +++++++++++++- session/test_sqlite.go | 6 +++--- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/session/test_postgres.go b/session/test_postgres.go index decf3bcc2..cb5aa061d 100644 --- a/session/test_postgres.go +++ b/session/test_postgres.go @@ -16,7 +16,7 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return NewSQLStore(db.NewTestPostgresDB(t).BaseDB, clock) + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,5 +24,5 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return NewSQLStore(db.NewTestPostgresDB(t).BaseDB, clock) + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } diff --git a/session/test_sql.go b/session/test_sql.go index ceb02194c..a83186069 100644 --- a/session/test_sql.go +++ b/session/test_sql.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" "github.com/stretchr/testify/require" ) @@ -16,5 +17,16 @@ func NewTestDBWithAccounts(t *testing.T, clock clock.Clock, accounts, ok := acctStore.(*accounts.SQLStore) require.True(t, ok) - return NewSQLStore(accounts.BaseDB, clock) + return createStore(t, accounts.BaseDB, clock) +} + +// createStore is a helper function that creates a new SQLStore and ensure that +// it is closed when during the test cleanup. +func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { + store := NewSQLStore(sqlDB, clock) + t.Cleanup(func() { + require.NoError(t, store.Close()) + }) + + return store } diff --git a/session/test_sqlite.go b/session/test_sqlite.go index dccbefe85..0ceb0e046 100644 --- a/session/test_sqlite.go +++ b/session/test_sqlite.go @@ -16,7 +16,7 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return NewSQLStore(db.NewTestSqliteDB(t).BaseDB, clock) + return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,7 +24,7 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return NewSQLStore( - db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, + return createStore( + t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, ) } From dd4acca4bb4e92a2269bc162a4815270059b4d41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Mon, 9 Jun 2025 11:31:19 +0200 Subject: [PATCH 04/11] firewalldb: ensure that test SQL store is closed We add a helper function to the functions that creates the test SQL stores, in order to ensure that the store is properly closed when the test is cleaned up. --- firewalldb/actions_test.go | 9 --------- firewalldb/test_postgres.go | 4 ++-- firewalldb/test_sql.go | 16 ++++++++++++++-- firewalldb/test_sqlite.go | 6 +++--- 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/firewalldb/actions_test.go b/firewalldb/actions_test.go index c27e53e96..69990c1da 100644 --- a/firewalldb/actions_test.go +++ b/firewalldb/actions_test.go @@ -28,9 +28,6 @@ func TestActionStorage(t *testing.T) { sessDB := session.NewTestDBWithAccounts(t, clock, accountsDB) db := NewTestDBWithSessionsAndAccounts(t, sessDB, accountsDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // Assert that attempting to add an action for a session that does not // exist returns an error. @@ -198,9 +195,6 @@ func TestListActions(t *testing.T) { sessDB := session.NewTestDB(t, clock) db := NewTestDBWithSessions(t, sessDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // Add 2 sessions that we can reference. sess1, err := sessDB.NewSession( @@ -466,9 +460,6 @@ func TestListGroupActions(t *testing.T) { } db := NewTestDBWithSessions(t, sessDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // There should not be any actions in group 1 yet. al, _, _, err := db.ListActions(ctx, nil, WithActionGroupID(group1)) diff --git a/firewalldb/test_postgres.go b/firewalldb/test_postgres.go index f5777e4cb..324aea2c4 100644 --- a/firewalldb/test_postgres.go +++ b/firewalldb/test_postgres.go @@ -11,11 +11,11 @@ import ( // NewTestDB is a helper function that creates an BBolt database for testing. func NewTestDB(t *testing.T, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestPostgresDB(t).BaseDB, clock) + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. func NewTestDBFromPath(t *testing.T, _ string, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestPostgresDB(t).BaseDB, clock) + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } diff --git a/firewalldb/test_sql.go b/firewalldb/test_sql.go index 03dcfbebf..2f6c6e62e 100644 --- a/firewalldb/test_sql.go +++ b/firewalldb/test_sql.go @@ -7,6 +7,7 @@ import ( "time" "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" "github.com/stretchr/testify/require" @@ -20,7 +21,7 @@ func NewTestDBWithSessions(t *testing.T, sessionStore session.Store, sessions, ok := sessionStore.(*session.SQLStore) require.True(t, ok) - return NewSQLDB(sessions.BaseDB, clock) + return createStore(t, sessions.BaseDB, clock) } // NewTestDBWithSessionsAndAccounts creates a new test SQLDB Store with access @@ -36,7 +37,7 @@ func NewTestDBWithSessionsAndAccounts(t *testing.T, sessionStore SessionDB, require.Equal(t, accounts.BaseDB, sessions.BaseDB) - return NewSQLDB(sessions.BaseDB, clock) + return createStore(t, sessions.BaseDB, clock) } func assertEqualActions(t *testing.T, expected, got *Action) { @@ -52,3 +53,14 @@ func assertEqualActions(t *testing.T, expected, got *Action) { expected.AttemptedAt = expectedAttemptedAt got.AttemptedAt = actualAttemptedAt } + +// createStore is a helper function that creates a new SQLDB and ensure that +// it is closed when during the test cleanup. +func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLDB { + store := NewSQLDB(sqlDB, clock) + t.Cleanup(func() { + require.NoError(t, store.Close()) + }) + + return store +} diff --git a/firewalldb/test_sqlite.go b/firewalldb/test_sqlite.go index 5496cb205..506b49bcd 100644 --- a/firewalldb/test_sqlite.go +++ b/firewalldb/test_sqlite.go @@ -11,13 +11,13 @@ import ( // NewTestDB is a helper function that creates an BBolt database for testing. func NewTestDB(t *testing.T, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestSqliteDB(t).BaseDB, clock) + return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) *SQLDB { - return NewSQLDB( - db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, + return createStore( + t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, ) } From 8a3ae240f7af2548b2e23b42036c352d131b824f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Tue, 27 May 2025 16:52:50 +0200 Subject: [PATCH 05/11] firewalldb: export FirewallDBs interface In the upcoming migration of the firewall database to SQL, the helper functions that creates the test databases of different types, need to return a unified interface in order to not have to control the migration tests file by build tags. Therefore, we export the unified interface FirewallDBs, so that it can be returned public test DB creation functions --- firewalldb/db.go | 14 +++----------- firewalldb/interface.go | 8 ++++++++ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/firewalldb/db.go b/firewalldb/db.go index b8d9ed06f..a8349a538 100644 --- a/firewalldb/db.go +++ b/firewalldb/db.go @@ -14,29 +14,21 @@ var ( ErrNoSuchKeyFound = fmt.Errorf("no such key found") ) -// firewallDBs is an interface that groups the RulesDB and PrivacyMapper -// interfaces. -type firewallDBs interface { - RulesDB - PrivacyMapper - ActionDB -} - // DB manages the firewall rules database. type DB struct { started sync.Once stopped sync.Once - firewallDBs + FirewallDBs cancel fn.Option[context.CancelFunc] } // NewDB creates a new firewall database. For now, it only contains the // underlying rules' and privacy mapper databases. -func NewDB(dbs firewallDBs) *DB { +func NewDB(dbs FirewallDBs) *DB { return &DB{ - firewallDBs: dbs, + FirewallDBs: dbs, } } diff --git a/firewalldb/interface.go b/firewalldb/interface.go index 5ee729e91..c2955bdc6 100644 --- a/firewalldb/interface.go +++ b/firewalldb/interface.go @@ -134,3 +134,11 @@ type ActionDB interface { // and feature name. GetActionsReadDB(groupID session.ID, featureName string) ActionsReadDB } + +// FirewallDBs is an interface that groups the RulesDB, PrivacyMapper and +// ActionDB interfaces. +type FirewallDBs interface { + RulesDB + PrivacyMapper + ActionDB +} From f9d48dfe7e66f78a7a6b7918a99da5df63b65f07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Mon, 19 May 2025 13:58:38 +0200 Subject: [PATCH 06/11] firewalldb: update NewTestDB funcs to return FirewallDBs In the upcoming migration of the firewall database to SQL, the helper functions that creates the test databases of different types, need to return a unified interface in order to not have to control the migration tests file by build tags. Therefore, we update the `NewTestDB` functions to return the `FirewallDBs` interface instead of the specific store implementation type. --- firewalldb/test_kvdb.go | 18 +++++++++++------- firewalldb/test_postgres.go | 4 ++-- firewalldb/test_sql.go | 5 ++--- firewalldb/test_sqlite.go | 4 ++-- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/firewalldb/test_kvdb.go b/firewalldb/test_kvdb.go index 6f7a49aa3..c3cd4533a 100644 --- a/firewalldb/test_kvdb.go +++ b/firewalldb/test_kvdb.go @@ -6,34 +6,37 @@ import ( "testing" "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/stretchr/testify/require" ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *BoltDB { +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { return NewTestDBFromPath(t, t.TempDir(), clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) *BoltDB { +func NewTestDBFromPath(t *testing.T, dbPath string, + clock clock.Clock) FirewallDBs { + return newDBFromPathWithSessions(t, dbPath, nil, nil, clock) } // NewTestDBWithSessions creates a new test BoltDB Store with access to an // existing sessions DB. -func NewTestDBWithSessions(t *testing.T, sessStore SessionDB, - clock clock.Clock) *BoltDB { +func NewTestDBWithSessions(t *testing.T, sessStore session.Store, + clock clock.Clock) FirewallDBs { return newDBFromPathWithSessions(t, t.TempDir(), sessStore, nil, clock) } // NewTestDBWithSessionsAndAccounts creates a new test BoltDB Store with access // to an existing sessions DB and accounts DB. -func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore SessionDB, - acctStore AccountsDB, clock clock.Clock) *BoltDB { +func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore session.Store, + acctStore AccountsDB, clock clock.Clock) FirewallDBs { return newDBFromPathWithSessions( t, t.TempDir(), sessStore, acctStore, clock, @@ -41,7 +44,8 @@ func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore SessionDB, } func newDBFromPathWithSessions(t *testing.T, dbPath string, - sessStore SessionDB, acctStore AccountsDB, clock clock.Clock) *BoltDB { + sessStore session.Store, acctStore AccountsDB, + clock clock.Clock) FirewallDBs { store, err := NewBoltDB(dbPath, DBFilename, sessStore, acctStore, clock) require.NoError(t, err) diff --git a/firewalldb/test_postgres.go b/firewalldb/test_postgres.go index 324aea2c4..732b19b4a 100644 --- a/firewalldb/test_postgres.go +++ b/firewalldb/test_postgres.go @@ -10,12 +10,12 @@ import ( ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLDB { +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, _ string, clock clock.Clock) *SQLDB { +func NewTestDBFromPath(t *testing.T, _ string, clock clock.Clock) FirewallDBs { return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } diff --git a/firewalldb/test_sql.go b/firewalldb/test_sql.go index 2f6c6e62e..a412441f8 100644 --- a/firewalldb/test_sql.go +++ b/firewalldb/test_sql.go @@ -16,8 +16,7 @@ import ( // NewTestDBWithSessions creates a new test SQLDB Store with access to an // existing sessions DB. func NewTestDBWithSessions(t *testing.T, sessionStore session.Store, - clock clock.Clock) *SQLDB { - + clock clock.Clock) FirewallDBs { sessions, ok := sessionStore.(*session.SQLStore) require.True(t, ok) @@ -27,7 +26,7 @@ func NewTestDBWithSessions(t *testing.T, sessionStore session.Store, // NewTestDBWithSessionsAndAccounts creates a new test SQLDB Store with access // to an existing sessions DB and accounts DB. func NewTestDBWithSessionsAndAccounts(t *testing.T, sessionStore SessionDB, - acctStore AccountsDB, clock clock.Clock) *SQLDB { + acctStore AccountsDB, clock clock.Clock) FirewallDBs { sessions, ok := sessionStore.(*session.SQLStore) require.True(t, ok) diff --git a/firewalldb/test_sqlite.go b/firewalldb/test_sqlite.go index 506b49bcd..49b956d7d 100644 --- a/firewalldb/test_sqlite.go +++ b/firewalldb/test_sqlite.go @@ -10,13 +10,13 @@ import ( ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLDB { +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) *SQLDB { +func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) FirewallDBs { return createStore( t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, ) From f9f6e1b1f7d13446f242c48b0432a33836d591a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Tue, 20 May 2025 10:39:39 +0200 Subject: [PATCH 07/11] db: add List All Kv Records query During the upcoming upcoming migration of the firewall database to SQL, we need to be able to check all kvstores records in the SQL database, to validate that the migration is successful in tests. This commits adds a query to list all kvstores records, which enables that functionality. --- db/sqlc/kvstores.sql.go | 36 ++++++++++++++++++++++++++++++++++++ db/sqlc/querier.go | 1 + db/sqlc/queries/kvstores.sql | 4 ++++ 3 files changed, 41 insertions(+) diff --git a/db/sqlc/kvstores.sql.go b/db/sqlc/kvstores.sql.go index b2e6632f4..c0949d173 100644 --- a/db/sqlc/kvstores.sql.go +++ b/db/sqlc/kvstores.sql.go @@ -257,6 +257,42 @@ func (q *Queries) InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreReco return err } +const listAllKVStoresRecords = `-- name: ListAllKVStoresRecords :many +SELECT id, perm, rule_id, session_id, feature_id, entry_key, value +FROM kvstores +` + +func (q *Queries) ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) { + rows, err := q.db.QueryContext(ctx, listAllKVStoresRecords) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Kvstore + for rows.Next() { + var i Kvstore + if err := rows.Scan( + &i.ID, + &i.Perm, + &i.RuleID, + &i.SessionID, + &i.FeatureID, + &i.EntryKey, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const updateFeatureKVStoreRecord = `-- name: UpdateFeatureKVStoreRecord :exec UPDATE kvstores SET value = $1 diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index df89d0898..117a1fbc5 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -57,6 +57,7 @@ type Querier interface { ListAccountInvoices(ctx context.Context, accountID int64) ([]AccountInvoice, error) ListAccountPayments(ctx context.Context, accountID int64) ([]AccountPayment, error) ListAllAccounts(ctx context.Context) ([]Account, error) + ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) ListSessions(ctx context.Context) ([]Session, error) ListSessionsByState(ctx context.Context, state int16) ([]Session, error) ListSessionsByType(ctx context.Context, type_ int16) ([]Session, error) diff --git a/db/sqlc/queries/kvstores.sql b/db/sqlc/queries/kvstores.sql index 7963e46a4..1ebfe3b0d 100644 --- a/db/sqlc/queries/kvstores.sql +++ b/db/sqlc/queries/kvstores.sql @@ -28,6 +28,10 @@ VALUES ($1, $2, $3, $4, $5, $6); DELETE FROM kvstores WHERE perm = false; +-- name: ListAllKVStoresRecords :many +SELECT * +FROM kvstores; + -- name: GetGlobalKVStoreRecord :one SELECT value FROM kvstores From 584c90876cbf057a18a135dbe0f82469a01c1ca5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Tue, 6 May 2025 19:42:08 +0200 Subject: [PATCH 08/11] firewalldb: clarify bbolt kvstores illustration During the migration of the kvstores to SQL, we'll iterate over the buckets in the bbolt database, which holds all kvstores records. In order to understand why the migration iterates over the buckets in the specific order, we need to clarify the bbolt kvstores illustration docs, so that it correctly reflects how the records are actually stored in the bbolt database. --- firewalldb/kvstores_kvdb.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/firewalldb/kvstores_kvdb.go b/firewalldb/kvstores_kvdb.go index 51721d475..d1e8e35a6 100644 --- a/firewalldb/kvstores_kvdb.go +++ b/firewalldb/kvstores_kvdb.go @@ -16,13 +16,13 @@ the temporary store changes instead of just keeping an in-memory store is that we can then guarantee atomicity if changes are made to both the permanent and temporary stores. -rules -> perm -> rule-name -> global -> {k:v} - -> sessions -> group ID -> session-kv-store -> {k:v} - -> feature-kv-stores -> feature-name -> {k:v} +"rules" -> "perm" -> rule-name -> "global" -> {k:v} + "session-kv-store" -> group ID -> {k:v} + -> "feature-kv-stores" -> feature-name -> {k:v} - -> temp -> rule-name -> global -> {k:v} - -> sessions -> group ID -> session-kv-store -> {k:v} - -> feature-kv-stores -> feature-name -> {k:v} + -> "temp" -> rule-name -> "global" -> {k:v} + "session-kv-store" -> group ID -> {k:v} + -> "feature-kv-stores" -> feature-name -> {k:v} */ var ( From 10550d4ade2f64bd91e9181adfedb65e5cab328f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Tue, 6 May 2025 19:44:31 +0200 Subject: [PATCH 09/11] firewalldb: add kvstores kvdb to SQL migration This commit introduces the migration logic for transitioning the kvstores store from kvdb to SQL. Note that as of this commit, the migration is not yet triggered by any production code, i.e. only tests execute the migration logic. --- firewalldb/sql_migration.go | 486 ++++++++++++++++++++++++++++ firewalldb/sql_migration_test.go | 537 +++++++++++++++++++++++++++++++ 2 files changed, 1023 insertions(+) create mode 100644 firewalldb/sql_migration.go create mode 100644 firewalldb/sql_migration_test.go diff --git a/firewalldb/sql_migration.go b/firewalldb/sql_migration.go new file mode 100644 index 000000000..092b61c8e --- /dev/null +++ b/firewalldb/sql_migration.go @@ -0,0 +1,486 @@ +package firewalldb + +import ( + "bytes" + "context" + "database/sql" + "errors" + "fmt" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/sqldb" + "go.etcd.io/bbolt" +) + +// kvParams is a type alias for the InsertKVStoreRecordParams, to shorten the +// line length in the migration code. +type kvParams = sqlc.InsertKVStoreRecordParams + +// MigrateFirewallDBToSQL runs the migration of the firwalldb stores from the +// bbolt database to a SQL database. The migration is done in a single +// transaction to ensure that all rows in the stores are migrated or none at +// all. +// +// Note that this migration currently only migrates the kvstores, but will be +// extended in the future to also migrate the privacy mapper and action stores. +// +// NOTE: As sessions may contain linked sessions and accounts, the sessions and +// accounts sql migration MUST be run prior to this migration. +func MigrateFirewallDBToSQL(ctx context.Context, kvStore *bbolt.DB, + tx SQLQueries) error { + + log.Infof("Starting migration of the rules DB to SQL") + + err := migrateKVStoresDBToSQL(ctx, kvStore, tx) + if err != nil { + return err + } + + log.Infof("The rules DB has been migrated from KV to SQL.") + + // TODO(viktor): Add migration for the privacy mapper and the action + // stores. + + return nil +} + +// migrateKVStoresDBToSQL runs the migration of all KV stores from the KV +// database to the SQL database. The function also asserts that the +// migrated values match the original values in the KV store. +// See the illustration in the firwalldb/kvstores_kvdb.go file to understand +// the structure of the KV stores, and why we process the buckets in the +// order we do. +// Note that this function and the subsequent functions are intentionally +// designed to loop over all buckets and values that exist in the KV store, +// so that we are sure that we actually find all stores and values that +// exist in the KV store, and can be sure that the kv store actually follows +// the expected structure. +func migrateKVStoresDBToSQL(ctx context.Context, kvStore *bbolt.DB, + sqlTx SQLQueries) error { + + log.Infof("Starting migration of the KV stores to SQL") + + // allParams will hold all the kvParams that are inserted into the + // SQL database during the migration. + var allParams []kvParams + + err := kvStore.View(func(kvTx *bbolt.Tx) error { + for _, perm := range []bool{true, false} { + mainBucket, err := getMainBucket(kvTx, false, perm) + if err != nil { + return err + } + + if mainBucket == nil { + // If the mainBucket doesn't exist, there are no + // records to migrate under that bucket, + // therefore we don't error, and just proceed + // to not migrate any records under that bucket. + continue + } + + err = mainBucket.ForEach(func(k, v []byte) error { + if v != nil { + return errors.New("expected only " + + "buckets under main bucket") + } + + ruleName := k + ruleNameBucket := mainBucket.Bucket(k) + if ruleNameBucket == nil { + return fmt.Errorf("rule bucket %s "+ + "not found", string(k)) + } + + ruleId, err := sqlTx.GetOrInsertRuleID( + ctx, string(ruleName), + ) + if err != nil { + return err + } + + params, err := processRuleBucket( + ctx, sqlTx, perm, ruleId, + ruleNameBucket, + ) + if err != nil { + return err + } + + allParams = append(allParams, params...) + + return nil + }) + if err != nil { + return err + } + } + + return nil + }) + if err != nil { + return err + } + + // After the migration is done, we validate that all inserted kvParams + // can match the original values in the KV store. Note that this is done + // after all values have been inserted, in order to ensure that the + // migration doesn't overwrite any values after they were inserted. + for _, param := range allParams { + switch { + case param.FeatureID.Valid && param.SessionID.Valid: + migratedValue, err := sqlTx.GetFeatureKVStoreRecord( + ctx, + sqlc.GetFeatureKVStoreRecordParams{ + Key: param.EntryKey, + Perm: param.Perm, + RuleID: param.RuleID, + SessionID: param.SessionID, + FeatureID: param.FeatureID, + }, + ) + if err != nil { + return fmt.Errorf("retreiving of migrated "+ + "feature specific kv store record "+ + "failed %w", err) + } + + if !bytes.Equal(migratedValue, param.Value) { + return fmt.Errorf("migrated feature specific "+ + "kv record value %x does not match "+ + "original value %x", migratedValue, + param.Value) + } + + case param.SessionID.Valid: + migratedValue, err := sqlTx.GetSessionKVStoreRecord( + ctx, + sqlc.GetSessionKVStoreRecordParams{ + Key: param.EntryKey, + Perm: param.Perm, + RuleID: param.RuleID, + SessionID: param.SessionID, + }, + ) + if err != nil { + return fmt.Errorf("retreiving of migrated "+ + "session wide kv store record "+ + "failed %w", err) + } + + if !bytes.Equal(migratedValue, param.Value) { + return fmt.Errorf("migrated session wide kv "+ + "record value %x does not match "+ + "original value %x", migratedValue, + param.Value) + } + + case !param.FeatureID.Valid && !param.SessionID.Valid: + migratedValue, err := sqlTx.GetGlobalKVStoreRecord( + ctx, + sqlc.GetGlobalKVStoreRecordParams{ + Key: param.EntryKey, + Perm: param.Perm, + RuleID: param.RuleID, + }, + ) + if err != nil { + return fmt.Errorf("retreiving of migrated "+ + "global kv store record failed %w", err) + } + + if !bytes.Equal(migratedValue, param.Value) { + return fmt.Errorf("migrated global kv record "+ + "value %x does not match original "+ + "value %x", migratedValue, param.Value) + } + + default: + return fmt.Errorf("unexpected combination of "+ + "FeatureID and SessionID for: %v", param) + } + } + + log.Infof("Migration of the KV stores to SQL completed. Total number "+ + "of rows migrated: %d", len(allParams)) + + return nil +} + +// processRuleBucket processes a single rule bucket, which contains the +// global and session-kv-store key buckets. +func processRuleBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, ruleBucket *bbolt.Bucket) ([]kvParams, error) { + + var params []kvParams + + return params, ruleBucket.ForEach(func(k, v []byte) error { + switch { + case v != nil: + return errors.New("expected only buckets under " + + "rule-name bucket") + case bytes.Equal(k, globalKVStoreBucketKey): + globalBucket := ruleBucket.Bucket( + globalKVStoreBucketKey, + ) + if globalBucket == nil { + return fmt.Errorf("global bucket %s for rule "+ + "id %d not found", string(k), ruleSqlId) + } + + p, err := processGlobalRuleBucket( + ctx, sqlTx, perm, ruleSqlId, globalBucket, + ) + if err != nil { + return err + } + + params = append(params, p...) + + return nil + case bytes.Equal(k, sessKVStoreBucketKey): + sessionBucket := ruleBucket.Bucket( + sessKVStoreBucketKey, + ) + if sessionBucket == nil { + return fmt.Errorf("session bucket %s for rule "+ + "id %d not found", string(k), ruleSqlId) + } + + p, err := processSessionBucket( + ctx, sqlTx, perm, ruleSqlId, sessionBucket, + ) + if err != nil { + return err + } + + params = append(params, p...) + + return nil + default: + return fmt.Errorf("unexpected bucket %s under "+ + "rule-name bucket", string(k)) + } + }) +} + +// processGlobalRuleBucket processes the global bucket under a rule bucket, +// which contains the global key-value store records for the rule. +// It inserts the records into the SQL database and asserts that +// the migrated values match the original values in the KV store. +func processGlobalRuleBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, globalBucket *bbolt.Bucket) ([]kvParams, error) { + + var params []kvParams + + return params, globalBucket.ForEach(func(k, v []byte) error { + if v == nil { + return errors.New("expected only key-values under " + + "global rule-name bucket") + } + + globalInsertParams := kvParams{ + EntryKey: string(k), + Value: v, + Perm: perm, + RuleID: ruleSqlId, + } + + err := sqlTx.InsertKVStoreRecord(ctx, globalInsertParams) + if err != nil { + return fmt.Errorf("inserting global kv store "+ + "record failed %w", err) + } + + params = append(params, globalInsertParams) + + return nil + }) +} + +// processSessionBucket processes the session-kv-store bucket under a rule +// bucket, which contains the group-id buckets for that rule. +func processSessionBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, mainSessionBucket *bbolt.Bucket) ([]kvParams, error) { + + var params []kvParams + + return params, mainSessionBucket.ForEach(func(groupId, v []byte) error { + if v != nil { + return fmt.Errorf("expected only buckets under "+ + "%s bucket", string(sessKVStoreBucketKey)) + } + + groupBucket := mainSessionBucket.Bucket(groupId) + if groupBucket == nil { + return fmt.Errorf("group bucket for group id %s"+ + "not found", string(groupId)) + } + + p, err := processGroupBucket( + ctx, sqlTx, perm, ruleSqlId, groupId, groupBucket, + ) + if err != nil { + return err + } + + params = append(params, p...) + + return nil + }) +} + +// processGroupBucket processes a single group bucket, which contains the +// session-wide kv records and as well as the feature-kv-stores key bucket for +// that group. For the session-wide kv records, it inserts the records into the +// SQL database and asserts that the migrated values match the original values. +func processGroupBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, groupAlias []byte, + groupBucket *bbolt.Bucket) ([]kvParams, error) { + + groupSqlId, err := sqlTx.GetSessionIDByAlias( + ctx, groupAlias, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("session with group id %x "+ + "not found in sql db", groupAlias) + } else if err != nil { + return nil, err + } + + var params []kvParams + + return params, groupBucket.ForEach(func(k, v []byte) error { + switch { + case v != nil: + // This is a non-feature specific k:v store for the + // session, i.e. the session-wide store. + sessWideParams := kvParams{ + EntryKey: string(k), + Value: v, + Perm: perm, + RuleID: ruleSqlId, + SessionID: sqldb.SQLInt64(groupSqlId), + } + + err := sqlTx.InsertKVStoreRecord(ctx, sessWideParams) + if err != nil { + return fmt.Errorf("inserting session wide kv "+ + "store record failed %w", err) + } + + params = append(params, sessWideParams) + + return nil + case bytes.Equal(k, featureKVStoreBucketKey): + // This is a feature specific k:v store for the + // session, which will be stored under the feature-name + // under this bucket. + + featureStoreBucket := groupBucket.Bucket( + featureKVStoreBucketKey, + ) + if featureStoreBucket == nil { + return fmt.Errorf("feature store bucket %s "+ + "for group id %s not found", + string(featureKVStoreBucketKey), + string(groupAlias)) + } + + p, err := processFeatureStoreBucket( + ctx, sqlTx, perm, ruleSqlId, groupSqlId, + featureStoreBucket, + ) + if err != nil { + return err + } + + params = append(params, p...) + + return nil + default: + return fmt.Errorf("unexpected bucket %s found under "+ + "the %s bucket", string(k), + string(sessKVStoreBucketKey)) + } + }) +} + +// processFeatureStoreBucket processes the feature-kv-store bucket under a +// group bucket, which contains the feature specific buckets for that group. +func processFeatureStoreBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, groupSqlId int64, + featureStoreBucket *bbolt.Bucket) ([]kvParams, error) { + + var params []kvParams + + return params, featureStoreBucket.ForEach(func(k, v []byte) error { + if v != nil { + return fmt.Errorf("expected only buckets under " + + "feature stores bucket") + } + + featureName := k + featureNameBucket := featureStoreBucket.Bucket(featureName) + if featureNameBucket == nil { + return fmt.Errorf("feature bucket %s not found", + string(featureName)) + } + + featureSqlId, err := sqlTx.GetOrInsertFeatureID( + ctx, string(featureName), + ) + if err != nil { + return err + } + + p, err := processFeatureNameBucket( + ctx, sqlTx, perm, ruleSqlId, groupSqlId, featureSqlId, + featureNameBucket, + ) + if err != nil { + return err + } + + params = append(params, p...) + + return nil + }) +} + +// processFeatureNameBucket processes a single feature name bucket, which +// contains the feature specific key-value store records for that group. +// It inserts the records into the SQL database and asserts that +// the migrated values match the original values in the KV store. +func processFeatureNameBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, groupSqlId int64, featureSqlId int64, + featureNameBucket *bbolt.Bucket) ([]kvParams, error) { + + var params []kvParams + + return params, featureNameBucket.ForEach(func(k, v []byte) error { + if v == nil { + return fmt.Errorf("expected only key-values under "+ + "feature name bucket, but found bucket %s", + string(k)) + } + + featureParams := kvParams{ + EntryKey: string(k), + Value: v, + Perm: perm, + RuleID: ruleSqlId, + SessionID: sqldb.SQLInt64(groupSqlId), + FeatureID: sqldb.SQLInt64(featureSqlId), + } + + err := sqlTx.InsertKVStoreRecord(ctx, featureParams) + if err != nil { + return fmt.Errorf("inserting feature specific kv "+ + "store record failed %w", err) + } + + params = append(params, featureParams) + + return nil + }) +} diff --git a/firewalldb/sql_migration_test.go b/firewalldb/sql_migration_test.go new file mode 100644 index 000000000..c068671cc --- /dev/null +++ b/firewalldb/sql_migration_test.go @@ -0,0 +1,537 @@ +package firewalldb + +import ( + "context" + "database/sql" + "fmt" + "github.com/lightningnetwork/lnd/fn" + "testing" + "time" + + "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/session" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb" + "github.com/stretchr/testify/require" + "golang.org/x/exp/rand" +) + +// kvStoreRecord represents a single KV entry inserted into the BoltDB. +type kvStoreRecord struct { + Perm bool + RuleName string + EntryKey string + Global bool + GroupID *session.ID + FeatureName fn.Option[string] // Set if the record is feature specific + Value []byte +} + +// TestFirewallDBMigration tests the migration of firewalldb from a bolt +// backed to a SQL database. Note that this test does not attempt to be a +// complete migration test. +// This test only tests the migration of the KV stores currently, but will +// be extended in the future to also test the migration of the privacy mapper +// and the actions store in the future. +func TestFirewallDBMigration(t *testing.T) { + t.Parallel() + + ctx := context.Background() + clock := clock.NewTestClock(time.Now()) + + // When using build tags that creates a kvdb store for NewTestDB, we + // skip this test as it is only applicable for postgres and sqlite tags. + store := NewTestDB(t, clock) + if _, ok := store.(*BoltDB); ok { + t.Skipf("Skipping Firewall DB migration test for kvdb build") + } + + makeSQLDB := func(t *testing.T, sessionsStore session.Store) (*SQLDB, + *db.TransactionExecutor[SQLQueries]) { + + testDBStore := NewTestDBWithSessions(t, sessionsStore, clock) + + store, ok := testDBStore.(*SQLDB) + require.True(t, ok) + + baseDB := store.BaseDB + + genericExecutor := db.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return baseDB.WithTx(tx) + }, + ) + + return store, genericExecutor + } + + // The assertMigrationResults function will currently assert that + // the migrated kv stores records in the SQLDB match the original kv + // stores records in the BoltDB. + assertMigrationResults := func(t *testing.T, sqlStore *SQLDB, + kvRecords []kvStoreRecord) { + + var ( + ruleIDs = make(map[string]int64) + groupIDs = make(map[string]int64) + featureIDs = make(map[string]int64) + err error + ) + + getRuleID := func(ruleName string) int64 { + ruleID, ok := ruleIDs[ruleName] + if !ok { + ruleID, err = sqlStore.GetRuleID( + ctx, ruleName, + ) + require.NoError(t, err) + + ruleIDs[ruleName] = ruleID + } + + return ruleID + } + + getGroupID := func(groupAlias []byte) int64 { + groupID, ok := groupIDs[string(groupAlias)] + if !ok { + groupID, err = sqlStore.GetSessionIDByAlias( + ctx, groupAlias, + ) + require.NoError(t, err) + + groupIDs[string(groupAlias)] = groupID + } + + return groupID + } + + getFeatureID := func(featureName string) int64 { + featureID, ok := featureIDs[featureName] + if !ok { + featureID, err = sqlStore.GetFeatureID( + ctx, featureName, + ) + require.NoError(t, err) + + featureIDs[featureName] = featureID + } + + return featureID + } + + // First we extract all migrated kv records from the SQLDB, + // in order to be able to compare them to the original kv + // records, to ensure that the migration was successful. + sqlKvRecords, err := sqlStore.ListAllKVStoresRecords(ctx) + require.NoError(t, err) + require.Equal(t, len(kvRecords), len(sqlKvRecords)) + + for _, kvRecord := range kvRecords { + ruleID := getRuleID(kvRecord.RuleName) + + if kvRecord.Global { + sqlVal, err := sqlStore.GetGlobalKVStoreRecord( + ctx, + sqlc.GetGlobalKVStoreRecordParams{ + Key: kvRecord.EntryKey, + Perm: kvRecord.Perm, + RuleID: ruleID, + }, + ) + require.NoError(t, err) + require.Equal(t, kvRecord.Value, sqlVal) + } else if kvRecord.FeatureName.IsNone() { + groupID := getGroupID(kvRecord.GroupID[:]) + + sqlVal, err := sqlStore.GetSessionKVStoreRecord( + ctx, + sqlc.GetSessionKVStoreRecordParams{ + Key: kvRecord.EntryKey, + Perm: kvRecord.Perm, + RuleID: ruleID, + SessionID: sql.NullInt64{ + Int64: groupID, + Valid: true, + }, + }, + ) + require.NoError(t, err) + require.Equal(t, kvRecord.Value, sqlVal) + } else { + groupID := getGroupID(kvRecord.GroupID[:]) + featureID := getFeatureID( + kvRecord.FeatureName.UnwrapOrFail(t), + ) + + sqlVal, err := sqlStore.GetFeatureKVStoreRecord( + ctx, + sqlc.GetFeatureKVStoreRecordParams{ + Key: kvRecord.EntryKey, + Perm: kvRecord.Perm, + RuleID: ruleID, + SessionID: sql.NullInt64{ + Int64: groupID, + Valid: true, + }, + FeatureID: sql.NullInt64{ + Int64: featureID, + Valid: true, + }, + }, + ) + require.NoError(t, err) + require.Equal(t, kvRecord.Value, sqlVal) + } + } + } + + // The tests slice contains all the tests that we will run for the + // migration of the firewalldb from a BoltDB to a SQLDB. + // Note that the tests currently only test the migration of the KV + // stores, but will be extended in the future to also test the migration + // of the privacy mapper and the actions store. + tests := []struct { + name string + populateDB func(t *testing.T, ctx context.Context, + boltDB *BoltDB, + sessionStore session.Store) []kvStoreRecord + }{ + { + name: "empty", + populateDB: func(t *testing.T, ctx context.Context, + boltDB *BoltDB, + sessionStore session.Store) []kvStoreRecord { + + // Don't populate the DB. + return make([]kvStoreRecord, 0) + }, + }, + { + name: "global records", + populateDB: globalRecords, + }, + { + name: "session specific records", + populateDB: sessionSpecificRecords, + }, + { + name: "feature specific records", + populateDB: featureSpecificRecords, + }, + { + name: "records at all levels", + populateDB: recordsAtAllLevels, + }, + { + name: "random records", + populateDB: randomKVRecords, + }, + } + + for _, test := range tests { + tc := test + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // First let's create a sessions store to link to in + // the kvstores DB. In order to create the sessions + // store though, we also need to create an accounts + // store, that we link to the sessions store. + // Note that both of these stores will be sql stores due + // to the build tags enabled when running this test, + // which means we can also pass the sessions store to + // the sql version of the kv stores that we'll create + // in test, without also needing to migrate it. + accountStore := accounts.NewTestDB(t, clock) + sessionsStore := session.NewTestDBWithAccounts( + t, clock, accountStore, + ) + + // Create a new firewall store to populate with test + // data. + firewallStore, err := NewBoltDB( + t.TempDir(), DBFilename, sessionsStore, + accountStore, clock, + ) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, firewallStore.Close()) + }) + + // Populate the kv store. + records := test.populateDB( + t, ctx, firewallStore, sessionsStore, + ) + + // Create the SQL store that we will migrate the data + // to. + sqlStore, txEx := makeSQLDB(t, sessionsStore) + + // Perform the migration. + var opts sqldb.MigrationTxOptions + err = txEx.ExecTx(ctx, &opts, + func(tx SQLQueries) error { + return MigrateFirewallDBToSQL( + ctx, firewallStore.DB, tx, + ) + }, + ) + require.NoError(t, err) + + // Assert migration results. + assertMigrationResults(t, sqlStore, records) + }) + } +} + +// globalRecords populates the kv store with one global record for the temp +// store, and one for the perm store. +func globalRecords(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []kvStoreRecord { + + return insertTestKVRecords( + t, ctx, boltDB, sessionStore, true, fn.None[string](), + ) +} + +// sessionSpecificRecords populates the kv store with one session specific +// record for the local temp store, and one session specific record for the perm +// local store. +func sessionSpecificRecords(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []kvStoreRecord { + + return insertTestKVRecords( + t, ctx, boltDB, sessionStore, false, fn.None[string](), + ) +} + +// featureSpecificRecords populates the kv store with one feature specific +// record for the local temp store, and one feature specific record for the perm +// local store. +func featureSpecificRecords(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []kvStoreRecord { + + return insertTestKVRecords( + t, ctx, boltDB, sessionStore, false, fn.Some("test-feature"), + ) +} + +// recordsAtAllLevels uses adds a record at all possible levels of the kvstores, +// by utilizing all the other helper functions that populates the kvstores at +// different levels. +func recordsAtAllLevels(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []kvStoreRecord { + + gRecords := globalRecords(t, ctx, boltDB, sessionStore) + sRecords := sessionSpecificRecords(t, ctx, boltDB, sessionStore) + fRecords := featureSpecificRecords(t, ctx, boltDB, sessionStore) + + return append(gRecords, append(sRecords, fRecords...)...) +} + +// insertTestKVRecords populates the kv store with one record for the local temp +// store, and one record for the local store. The records will be feature +// specific if the featureNameOpt is set, otherwise they will be session +// specific. Both of the records will be inserted with the same +// session.GroupID, which is created in this function, as well as the same +// ruleName, entryKey and entryVal. +func insertTestKVRecords(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store, global bool, + featureNameOpt fn.Option[string]) []kvStoreRecord { + + var ( + ruleName = "test-rule" + entryKey = "test1" + entryVal = []byte{1, 2, 3} + ) + + // Create a session that we can reference. + sess, err := sessionStore.NewSession( + ctx, "test", session.TypeAutopilot, + time.Unix(1000, 0), "something", + ) + require.NoError(t, err) + + tempKvRecord := kvStoreRecord{ + RuleName: ruleName, + GroupID: &sess.GroupID, + FeatureName: featureNameOpt, + EntryKey: entryKey, + Value: entryVal, + Perm: false, + Global: global, + } + + insertKvRecord(t, ctx, boltDB, tempKvRecord) + + permKvRecord := kvStoreRecord{ + RuleName: ruleName, + GroupID: &sess.GroupID, + FeatureName: featureNameOpt, + EntryKey: entryKey, + Value: entryVal, + Perm: true, + Global: global, + } + + insertKvRecord(t, ctx, boltDB, permKvRecord) + + return []kvStoreRecord{tempKvRecord, permKvRecord} +} + +// insertTestKVRecords populates the kv store with passed record, and asserts +// that the record is inserted correctly. +func insertKvRecord(t *testing.T, ctx context.Context, + boltDB *BoltDB, record kvStoreRecord) { + + if record.Global && record.FeatureName.IsSome() { + t.Fatalf("cannot set both global and feature specific at the " + + "same time") + } + + kvStores := boltDB.GetKVStores( + record.RuleName, *record.GroupID, + record.FeatureName.UnwrapOr(""), + ) + + err := kvStores.Update(ctx, func(ctx context.Context, + tx KVStoreTx) error { + + switch { + case record.Global && !record.Perm: + return tx.GlobalTemp().Set( + ctx, record.EntryKey, record.Value, + ) + case record.Global && record.Perm: + return tx.Global().Set( + ctx, record.EntryKey, record.Value, + ) + case !record.Global && !record.Perm: + return tx.LocalTemp().Set( + ctx, record.EntryKey, record.Value, + ) + case !record.Global && record.Perm: + return tx.Local().Set( + ctx, record.EntryKey, record.Value, + ) + default: + return fmt.Errorf("unexpected global/perm "+ + "combination: global=%v, perm=%v", + record.Global, record.Perm) + } + }) + require.NoError(t, err) +} + +// randomKVRecords populates the kv store with random kv records that span +// across all possible combinations of different levels of records in the kv +// store. All values and different bucket names are randomly generated. +func randomKVRecords(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []kvStoreRecord { + + var ( + // We set the number of records to insert to 1000, as that + // should be enough to cover as many different + // combinations of records as possible, while still being + // fast enough to run in a reasonable time. + numberOfRecords = 1000 + insertedRecords = make([]kvStoreRecord, 0) + ruleName = "initial-rule" + groupId *session.ID + featureName = "initial-feature" + ) + + // Create a random session that we can reference for the initial group + // ID. + sess, err := sessionStore.NewSession( + ctx, "initial-session", session.Type(uint8(rand.Intn(5))), + time.Unix(1000, 0), randomString(rand.Intn(10)+1), + ) + require.NoError(t, err) + + groupId = &sess.GroupID + + // Generate random records. Note that many records will use the same + // rule name, group ID and feature name, to simulate the real world + // usage of the kv stores as much as possible. + for i := 0; i < numberOfRecords; i++ { + // On average, we will generate a new rule which will be used + // for the kv store record 10% of the time. + if rand.Intn(10) == 0 { + ruleName = fmt.Sprintf( + "rule-%s-%d", randomString(rand.Intn(30)+1), i, + ) + } + + // On average, we use the global store 25% of the time. + global := rand.Intn(4) == 0 + + // We'll use the perm store 50% of the time. + perm := rand.Intn(2) == 0 + + // For the non-global records, we will generate a new group ID + // 25% of the time. + if !global && rand.Intn(4) == 0 { + newSess, err := sessionStore.NewSession( + ctx, fmt.Sprintf("session-%d", i), + session.Type(uint8(rand.Intn(5))), + time.Unix(1000, 0), + randomString(rand.Intn(10)+1), + ) + require.NoError(t, err) + + groupId = &newSess.GroupID + } + + featureNameOpt := fn.None[string]() + + // For 50% of the non-global records, we insert a feature + // specific record. The other 50% will be session specific + // records. + if !global && rand.Intn(2) == 0 { + // 25% of the time, we will generate a new feature name. + if rand.Intn(4) == 0 { + featureName = fmt.Sprintf( + "feature-%s-%d", + randomString(rand.Intn(30)+1), i, + ) + } + + featureNameOpt = fn.Some(featureName) + } + + kvEntry := kvStoreRecord{ + RuleName: ruleName, + GroupID: groupId, + FeatureName: featureNameOpt, + EntryKey: fmt.Sprintf("key-%d", i), + Perm: perm, + Global: global, + // We'll generate a random value for all records, + Value: []byte(randomString(rand.Intn(100) + 1)), + } + + // Insert the record into the kv store. + insertKvRecord(t, ctx, boltDB, kvEntry) + + // Add the record to the list of inserted records. + insertedRecords = append(insertedRecords, kvEntry) + } + + return insertedRecords +} + +// randomString generates a random string of the passed length n. +func randomString(n int) string { + letterBytes := "abcdefghijklmnopqrstuvwxyz" + + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return string(b) +} From 05d2dd3268f6d113e94e17570b8dac8dfa7621f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Wed, 4 Jun 2025 19:17:45 +0200 Subject: [PATCH 10/11] mod: go get sqldb/v2 --- go.mod | 8 ++++++-- go.sum | 8 ++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 5a9aa5c8d..e1131a527 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/lightningnetwork/lnd/fn/v2 v2.0.8 github.com/lightningnetwork/lnd/kvdb v1.4.16 github.com/lightningnetwork/lnd/sqldb v1.0.9 + github.com/lightningnetwork/lnd/sqldb/v2 v2.0.0-00010101000000-000000000000 github.com/lightningnetwork/lnd/tlv v1.3.1 github.com/lightningnetwork/lnd/tor v1.1.6 github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f @@ -60,7 +61,7 @@ require ( ) require ( - dario.cat/mergo v1.0.1 // indirect + dario.cat/mergo v1.0.2 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/NebulousLabs/fastrand v0.0.0-20181203155948-6fb6489aac4e // indirect @@ -245,6 +246,9 @@ replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-d // automatically, so we need to add it manually. replace github.com/golang-migrate/migrate/v4 => github.com/lightninglabs/migrate/v4 v4.18.2-9023d66a-fork-pr-2 -replace github.com/lightningnetwork/lnd => github.com/lightningnetwork/lnd v0.19.0-beta +replace github.com/lightningnetwork/lnd => github.com/ViktorTigerstrom/lnd v0.0.0-20250604171448-07036473e46c + +// TODO: replace this with your own local fork +replace github.com/lightningnetwork/lnd/sqldb/v2 => ../../lnd_forked/lnd/sqldb go 1.23.9 diff --git a/go.sum b/go.sum index 41fd8b809..b02a8d628 100644 --- a/go.sum +++ b/go.sum @@ -30,8 +30,8 @@ cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0Zeo cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= -dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= -dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= +dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= +dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= @@ -50,6 +50,8 @@ github.com/NebulousLabs/go-upnp v0.0.0-20180202185039-29b680b06c82 h1:MG93+PZYs9 github.com/NebulousLabs/go-upnp v0.0.0-20180202185039-29b680b06c82/go.mod h1:GbuBk21JqF+driLX3XtJYNZjGa45YDoa9IqCTzNSfEc= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= +github.com/ViktorTigerstrom/lnd v0.0.0-20250604171448-07036473e46c h1:RjJus6oMMn3SyStT7DRpLEovYHQaAdlxbGiq91jTpdg= +github.com/ViktorTigerstrom/lnd v0.0.0-20250604171448-07036473e46c/go.mod h1:AeAtPyAAV51d9EQxGXB4rrU2J9REwMxqf8N4bEpLfqU= github.com/Yawning/aez v0.0.0-20211027044916-e49e68abd344 h1:cDVUiFo+npB0ZASqnw4q90ylaVAbnYyx0JYqK4YcGok= github.com/Yawning/aez v0.0.0-20211027044916-e49e68abd344/go.mod h1:9pIqrY6SXNL8vjRQE5Hd/OL5GyK/9MrGUWs87z/eFfk= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= @@ -512,8 +514,6 @@ github.com/lightninglabs/taproot-assets/taprpc v1.0.4 h1:D1Zcjvaz5viyNXwecgj2yhQ github.com/lightninglabs/taproot-assets/taprpc v1.0.4/go.mod h1:Ccq0t2GsXzOtC8qF0U1ux/yTF5HcBbVrhCb0tb/jObM= github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb h1:yfM05S8DXKhuCBp5qSMZdtSwvJ+GFzl94KbXMNB1JDY= github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb/go.mod h1:c0kvRShutpj3l6B9WtTsNTBUtjSmjZXbJd9ZBRQOSKI= -github.com/lightningnetwork/lnd v0.19.0-beta h1:/8i2UdARiEpI2iAmPoSDcwZSSEuWqXyfsMxz/mLGbdw= -github.com/lightningnetwork/lnd v0.19.0-beta/go.mod h1:hu6zo1zcznx7nViiFlJY8qGDwwGw5LNLdGJ7ICz5Ysc= github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf0d0Uy4qBjI= github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U= github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0= From 45d36a78e80cd216177a5f057373da6f5335d9cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Tue, 10 Jun 2025 17:48:09 +0200 Subject: [PATCH 11/11] multi: use sqldb v2 in litd This commit updates litd to use the new sqldb v2 package. Note that this with just this commit, litd will not utilize the capabilities of sqldb v2 to run specific post migrations steps (such as migrating the kvdb to SQL). That functionality will be added in later commits. Instead, this commit just focuses on adding support for the new sqldb v2 package, and the functionality of the SQL stores are expected to remain the same as prior to this commit. --- accounts/sql_migration_test.go | 17 +++---- accounts/store_sql.go | 66 ++++++++++++++++++--------- accounts/test_sql.go | 11 +++-- accounts/test_sqlite.go | 12 +++-- config_dev.go | 67 ++++++++++++++++++++++++---- db/interfaces.go | 7 +-- db/post_migration_checks.go | 81 ++++++++++++++++++++++++++++++++++ db/postgres.go | 21 +++------ db/sql_migrations.go | 31 +++++++++++++ db/sqlc/db_custom.go | 34 ++++---------- db/sqlite.go | 14 ++---- firewalldb/actions_sql.go | 8 ++-- firewalldb/kvstores_sql.go | 3 +- firewalldb/sql_store.go | 47 +++++++++++++++----- firewalldb/test_sql.go | 10 +++-- firewalldb/test_sqlite.go | 16 ++++--- session/sql_store.go | 64 ++++++++++++++++++--------- session/test_sql.go | 11 +++-- session/test_sqlite.go | 12 +++-- 19 files changed, 375 insertions(+), 157 deletions(-) create mode 100644 db/post_migration_checks.go create mode 100644 db/sql_migrations.go diff --git a/accounts/sql_migration_test.go b/accounts/sql_migration_test.go index 621c947a5..6832e42aa 100644 --- a/accounts/sql_migration_test.go +++ b/accounts/sql_migration_test.go @@ -2,18 +2,17 @@ package accounts import ( "context" - "database/sql" "fmt" "testing" "time" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" "golang.org/x/exp/rand" "pgregory.net/rapid" @@ -36,7 +35,7 @@ func TestAccountStoreMigration(t *testing.T) { } makeSQLDB := func(t *testing.T) (*SQLStore, - *db.TransactionExecutor[SQLQueries]) { + *SQLQueriesExecutor[SQLQueries]) { testDBStore := NewTestDB(t, clock) @@ -45,13 +44,9 @@ func TestAccountStoreMigration(t *testing.T) { baseDB := store.BaseDB - genericExecutor := db.NewTransactionExecutor( - baseDB, func(tx *sql.Tx) SQLQueries { - return baseDB.WithTx(tx) - }, - ) + queries := sqlc.NewForType(baseDB, baseDB.BackendType) - return store, genericExecutor + return store, NewSQLQueriesExecutor(baseDB, queries) } assertMigrationResults := func(t *testing.T, sqlStore *SQLStore, @@ -343,7 +338,7 @@ func TestAccountStoreMigration(t *testing.T) { return MigrateAccountStoreToSQL( ctx, kvStore, tx, ) - }, + }, sqldb.NoOpReset, ) require.NoError(t, err) diff --git a/accounts/store_sql.go b/accounts/store_sql.go index 830f16587..c7e8ab070 100644 --- a/accounts/store_sql.go +++ b/accounts/store_sql.go @@ -16,6 +16,7 @@ import ( "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/sqldb/v2" ) const ( @@ -33,6 +34,8 @@ const ( // //nolint:lll type SQLQueries interface { + sqldb.BaseQuerier + AddAccountInvoice(ctx context.Context, arg sqlc.AddAccountInvoiceParams) error DeleteAccount(ctx context.Context, id int64) error DeleteAccountPayment(ctx context.Context, arg sqlc.DeleteAccountPaymentParams) error @@ -53,12 +56,13 @@ type SQLQueries interface { GetAccountInvoice(ctx context.Context, arg sqlc.GetAccountInvoiceParams) (sqlc.AccountInvoice, error) } -// BatchedSQLQueries is a version of the SQLQueries that's capable -// of batched database operations. +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLStore represents a storage backend. @@ -68,19 +72,37 @@ type SQLStore struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } -// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries -// storage backend. -func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) }, ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + +// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries +// storage backend. +func NewSQLStore(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLStore { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLStore{ db: executor, @@ -157,7 +179,7 @@ func (s *SQLStore) NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -299,7 +321,7 @@ func (s *SQLStore) AddAccountInvoice(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, acctID) - }) + }, sqldb.NoOpReset) } func getAccountIDByAlias(ctx context.Context, db SQLQueries, alias AccountID) ( @@ -377,7 +399,7 @@ func (s *SQLStore) UpdateAccountBalanceAndExpiry(ctx context.Context, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // CreditAccount increases the balance of the account with the given alias by @@ -412,7 +434,7 @@ func (s *SQLStore) CreditAccount(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // DebitAccount decreases the balance of the account with the given alias by the @@ -453,7 +475,7 @@ func (s *SQLStore) DebitAccount(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // Account retrieves an account from the SQL store and un-marshals it. If the @@ -475,7 +497,7 @@ func (s *SQLStore) Account(ctx context.Context, alias AccountID) ( account, err = getAndMarshalAccount(ctx, db, id) return err - }) + }, sqldb.NoOpReset) return account, err } @@ -507,7 +529,7 @@ func (s *SQLStore) Accounts(ctx context.Context) ([]*OffChainBalanceAccount, } return nil - }) + }, sqldb.NoOpReset) return accounts, err } @@ -524,7 +546,7 @@ func (s *SQLStore) RemoveAccount(ctx context.Context, alias AccountID) error { } return db.DeleteAccount(ctx, id) - }) + }, sqldb.NoOpReset) } // UpsertAccountPayment updates or inserts a payment entry for the given @@ -634,7 +656,7 @@ func (s *SQLStore) UpsertAccountPayment(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // DeleteAccountPayment removes a payment entry from the account with the given @@ -677,7 +699,7 @@ func (s *SQLStore) DeleteAccountPayment(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // LastIndexes returns the last invoice add and settle index or @@ -704,7 +726,7 @@ func (s *SQLStore) LastIndexes(ctx context.Context) (uint64, uint64, error) { } return err - }) + }, sqldb.NoOpReset) return uint64(addIndex), uint64(settleIndex), err } @@ -729,7 +751,7 @@ func (s *SQLStore) StoreLastIndexes(ctx context.Context, addIndex, Name: settleIndexName, Value: int64(settleIndex), }) - }) + }, sqldb.NoOpReset) } // Close closes the underlying store. diff --git a/accounts/test_sql.go b/accounts/test_sql.go index 3c1ee7f16..ca2f43d6f 100644 --- a/accounts/test_sql.go +++ b/accounts/test_sql.go @@ -5,15 +5,20 @@ package accounts import ( "testing" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) // createStore is a helper function that creates a new SQLStore and ensure that // it is closed when during the test cleanup. -func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - store := NewSQLStore(sqlDB, clock) +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, + clock clock.Clock) *SQLStore { + + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLStore(sqlDB, queries, clock) t.Cleanup(func() { require.NoError(t, store.Close()) }) diff --git a/accounts/test_sqlite.go b/accounts/test_sqlite.go index 9d899b3e2..a31f990a6 100644 --- a/accounts/test_sqlite.go +++ b/accounts/test_sqlite.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // ErrDBClosed is an error that is returned when a database operation is @@ -16,7 +17,10 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) + return createStore( + t, sqldb.NewTestSqliteDB(t, db.LitdMigrationStreams).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,7 +28,7 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return createStore( - t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, - ) + tDb := sqldb.NewTestSqliteDBFromPath(t, dbPath, db.LitdMigrationStreams) + + return createStore(t, tDb.BaseDB, clock) } diff --git a/config_dev.go b/config_dev.go index 90b8b290f..5b6185aa0 100644 --- a/config_dev.go +++ b/config_dev.go @@ -4,6 +4,7 @@ package terminal import ( "fmt" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "path/filepath" "github.com/lightninglabs/lightning-terminal/accounts" @@ -11,6 +12,7 @@ import ( "github.com/lightninglabs/lightning-terminal/firewalldb" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) const ( @@ -101,14 +103,36 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { return stores, err } - sqlStore, err := db.NewSqliteStore(cfg.Sqlite) + sqlStore, err := sqldb.NewSqliteStore(&sqldb.SqliteConfig{ + SkipMigrations: cfg.Sqlite.SkipMigrations, + SkipMigrationDbBackup: cfg.Sqlite.SkipMigrationDbBackup, + }, cfg.Sqlite.DatabaseFileName) if err != nil { return stores, err } - acctStore := accounts.NewSQLStore(sqlStore.BaseDB, clock) - sessStore := session.NewSQLStore(sqlStore.BaseDB, clock) - firewallStore := firewalldb.NewSQLDB(sqlStore.BaseDB, clock) + if !cfg.Sqlite.SkipMigrations { + err = sqldb.ApplyAllMigrations( + sqlStore, db.LitdMigrationStreams, + ) + if err != nil { + return stores, fmt.Errorf("error applying "+ + "migrations to SQLlite store: %w", err, + ) + } + } + + queries := sqlc.NewForType(sqlStore, sqlStore.BackendType) + + acctStore := accounts.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + sessStore := session.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + firewallStore := firewalldb.NewSQLDB( + sqlStore.BaseDB, queries, clock, + ) stores.accounts = acctStore stores.sessions = sessStore @@ -116,14 +140,41 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { stores.closeFns["sqlite"] = sqlStore.BaseDB.Close case DatabaseBackendPostgres: - sqlStore, err := db.NewPostgresStore(cfg.Postgres) + sqlStore, err := sqldb.NewPostgresStore(&sqldb.PostgresConfig{ + Dsn: cfg.Postgres.DSN(false), + MaxOpenConnections: cfg.Postgres.MaxOpenConnections, + MaxIdleConnections: cfg.Postgres.MaxIdleConnections, + ConnMaxLifetime: cfg.Postgres.ConnMaxLifetime, + ConnMaxIdleTime: cfg.Postgres.ConnMaxIdleTime, + RequireSSL: cfg.Postgres.RequireSSL, + SkipMigrations: cfg.Postgres.SkipMigrations, + }) if err != nil { return stores, err } - acctStore := accounts.NewSQLStore(sqlStore.BaseDB, clock) - sessStore := session.NewSQLStore(sqlStore.BaseDB, clock) - firewallStore := firewalldb.NewSQLDB(sqlStore.BaseDB, clock) + if !cfg.Sqlite.SkipMigrations { + err = sqldb.ApplyAllMigrations( + sqlStore, db.LitdMigrationStreams, + ) + if err != nil { + return stores, fmt.Errorf("error applying "+ + "migrations to Postgres store: %w", err, + ) + } + } + + queries := sqlc.NewForType(sqlStore, sqlStore.BackendType) + + acctStore := accounts.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + sessStore := session.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + firewallStore := firewalldb.NewSQLDB( + sqlStore.BaseDB, queries, clock, + ) stores.accounts = acctStore stores.sessions = sessStore diff --git a/db/interfaces.go b/db/interfaces.go index ba64520b4..bb39df9ea 100644 --- a/db/interfaces.go +++ b/db/interfaces.go @@ -8,6 +8,7 @@ import ( "time" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/sqldb/v2" ) var ( @@ -56,7 +57,7 @@ type BatchedTx[Q any] interface { txBody func(Q) error) error // Backend returns the type of the database backend used. - Backend() sqlc.BackendType + Backend() sqldb.BackendType } // Tx represents a database transaction that can be committed or rolled back. @@ -277,7 +278,7 @@ func (t *TransactionExecutor[Q]) ExecTx(ctx context.Context, } // Backend returns the type of the database backend used. -func (t *TransactionExecutor[Q]) Backend() sqlc.BackendType { +func (t *TransactionExecutor[Q]) Backend() sqldb.BackendType { return t.BatchedQuerier.Backend() } @@ -301,7 +302,7 @@ func (s *BaseDB) BeginTx(ctx context.Context, opts TxOptions) (*sql.Tx, error) { } // Backend returns the type of the database backend used. -func (s *BaseDB) Backend() sqlc.BackendType { +func (s *BaseDB) Backend() sqldb.BackendType { return s.Queries.Backend() } diff --git a/db/post_migration_checks.go b/db/post_migration_checks.go new file mode 100644 index 000000000..106ac8ec7 --- /dev/null +++ b/db/post_migration_checks.go @@ -0,0 +1,81 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/sqldb/v2" +) + +// postMigrationCheck is a function type for a function that performs a +// post-migration check on the database. +type postMigrationCheck func(context.Context, *sqlc.Queries) error + +var ( + // postMigrationChecks is a map of functions that are run after the + // database migration with the version specified in the key has been + // applied. These functions are used to perform additional checks on the + // database state that are not fully expressible in SQL. + postMigrationChecks = map[uint]postMigrationCheck{} +) + +// makePostStepCallbacks turns the post migration checks into a map of post +// step callbacks that can be used with the migrate package. The keys of the map +// are the migration versions, and the values are the callbacks that will be +// executed after the migration with the corresponding version is applied. +func makePostStepCallbacks(db *sqldb.BaseDB, + c map[uint]postMigrationCheck) map[uint]migrate.PostStepCallback { + + queries := sqlc.NewForType(db, db.BackendType) + executor := sqldb.NewTransactionExecutor( + db, func(tx *sql.Tx) *sqlc.Queries { + return queries.WithTx(tx) + }, + ) + + var ( + ctx = context.Background() + postStepCallbacks = make(map[uint]migrate.PostStepCallback) + ) + for version, check := range c { + runCheck := func(m *migrate.Migration, q *sqlc.Queries) error { + log.Infof("Running post-migration check for version %d", + version) + start := time.Now() + + err := check(ctx, q) + if err != nil { + return fmt.Errorf("post-migration "+ + "check failed for version %d: "+ + "%w", version, err) + } + + log.Infof("Post-migration check for version %d "+ + "completed in %v", version, time.Since(start)) + + return nil + } + + // We ignore the actual driver that's being returned here, since + // we use migrate.NewWithInstance() to create the migration + // instance from our already instantiated database backend that + // is also passed into this function. + postStepCallbacks[version] = func(m *migrate.Migration, + _ database.Driver) error { + + return executor.ExecTx( + ctx, sqldb.NewWriteTx(), + func(q *sqlc.Queries) error { + return runCheck(m, q) + }, sqldb.NoOpReset, + ) + } + } + + return postStepCallbacks +} diff --git a/db/postgres.go b/db/postgres.go index 16e41dc09..962629be6 100644 --- a/db/postgres.go +++ b/db/postgres.go @@ -9,6 +9,7 @@ import ( postgres_migrate "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) @@ -119,7 +120,7 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) { rawDb.SetConnMaxLifetime(connMaxLifetime) rawDb.SetConnMaxIdleTime(connMaxIdleTime) - queries := sqlc.NewPostgres(rawDb) + queries := sqlc.NewForType(rawDb, sqldb.BackendTypePostgres) s := &PostgresStore{ cfg: cfg, BaseDB: &BaseDB{ @@ -128,15 +129,6 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) { }, } - // Now that the database is open, populate the database with our set of - // schemas based on our embedded in-memory file system. - if !cfg.SkipMigrations { - if err := s.ExecuteMigrations(TargetLatest); err != nil { - return nil, fmt.Errorf("error executing migrations: "+ - "%w", err) - } - } - return s, nil } @@ -166,20 +158,17 @@ func (s *PostgresStore) ExecuteMigrations(target MigrationTarget, // NewTestPostgresDB is a helper function that creates a Postgres database for // testing. -func NewTestPostgresDB(t *testing.T) *PostgresStore { +func NewTestPostgresDB(t *testing.T) *sqldb.PostgresStore { t.Helper() t.Logf("Creating new Postgres DB for testing") - sqlFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime, true) - store, err := NewPostgresStore(sqlFixture.GetConfig()) - require.NoError(t, err) - + sqlFixture := sqldb.NewTestPgFixture(t, DefaultPostgresFixtureLifetime) t.Cleanup(func() { sqlFixture.TearDown(t) }) - return store + return sqldb.NewTestPostgresDB(t, sqlFixture, LitdMigrationStreams) } // NewTestPostgresDBWithVersion is a helper function that creates a Postgres diff --git a/db/sql_migrations.go b/db/sql_migrations.go new file mode 100644 index 000000000..4b492ca5a --- /dev/null +++ b/db/sql_migrations.go @@ -0,0 +1,31 @@ +package db + +import ( + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/pgx/v5" + "github.com/lightningnetwork/lnd/sqldb/v2" +) + +var ( + LitdMigrationStream = sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable, + SQLFileDirectory: "sqlc/migrations", + Schemas: sqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // database. This is used to implement downgrade protection for + // the daemon. + // + // NOTE: This MUST be updated when a new migration is added. + LatestMigrationVersion: LatestMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + return makePostStepCallbacks(db, postMigrationChecks), + nil + }, + } + LitdMigrationStreams = []sqldb.MigrationStream{LitdMigrationStream} +) diff --git a/db/sqlc/db_custom.go b/db/sqlc/db_custom.go index f4bf7f611..af556eae7 100644 --- a/db/sqlc/db_custom.go +++ b/db/sqlc/db_custom.go @@ -2,21 +2,8 @@ package sqlc import ( "context" -) - -// BackendType is an enum that represents the type of database backend we're -// using. -type BackendType uint8 - -const ( - // BackendTypeUnknown indicates we're using an unknown backend. - BackendTypeUnknown BackendType = iota - // BackendTypeSqlite indicates we're using a SQLite backend. - BackendTypeSqlite - - // BackendTypePostgres indicates we're using a Postgres backend. - BackendTypePostgres + "github.com/lightningnetwork/lnd/sqldb/v2" ) // wrappedTX is a wrapper around a DBTX that also stores the database backend @@ -24,29 +11,24 @@ const ( type wrappedTX struct { DBTX - backendType BackendType + backendType sqldb.BackendType } // Backend returns the type of database backend we're using. -func (q *Queries) Backend() BackendType { +func (q *Queries) Backend() sqldb.BackendType { wtx, ok := q.db.(*wrappedTX) if !ok { // Shouldn't happen unless a new database backend type is added // but not initialized correctly. - return BackendTypeUnknown + return sqldb.BackendTypeUnknown } return wtx.backendType } -// NewSqlite creates a new Queries instance for a SQLite database. -func NewSqlite(db DBTX) *Queries { - return &Queries{db: &wrappedTX{db, BackendTypeSqlite}} -} - -// NewPostgres creates a new Queries instance for a Postgres database. -func NewPostgres(db DBTX) *Queries { - return &Queries{db: &wrappedTX{db, BackendTypePostgres}} +// NewForType creates a new Queries instance for the given database type. +func NewForType(db DBTX, typ sqldb.BackendType) *Queries { + return &Queries{db: &wrappedTX{db, typ}} } // CustomQueries defines a set of custom queries that we define in addition @@ -62,5 +44,5 @@ type CustomQueries interface { arg ListActionsParams) ([]Action, error) // Backend returns the type of the database backend used. - Backend() BackendType + Backend() sqldb.BackendType } diff --git a/db/sqlite.go b/db/sqlite.go index 803362fa8..6f69a7e5b 100644 --- a/db/sqlite.go +++ b/db/sqlite.go @@ -11,6 +11,7 @@ import ( "github.com/golang-migrate/migrate/v4" sqlite_migrate "github.com/golang-migrate/migrate/v4/database/sqlite" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" _ "modernc.org/sqlite" // Register relevant drivers. ) @@ -132,7 +133,7 @@ func NewSqliteStore(cfg *SqliteConfig) (*SqliteStore, error) { db.SetMaxIdleConns(defaultMaxConns) db.SetConnMaxLifetime(defaultConnMaxLifetime) - queries := sqlc.NewSqlite(db) + queries := sqlc.NewForType(db, sqldb.BackendTypeSqlite) s := &SqliteStore{ cfg: cfg, BaseDB: &BaseDB{ @@ -140,16 +141,7 @@ func NewSqliteStore(cfg *SqliteConfig) (*SqliteStore, error) { Queries: queries, }, } - - // Now that the database is open, populate the database with our set of - // schemas based on our embedded in-memory file system. - if !cfg.SkipMigrations { - if err := s.ExecuteMigrations(s.backupAndMigrate); err != nil { - return nil, fmt.Errorf("error executing migrations: "+ - "%w", err) - } - } - + return s, nil } diff --git a/firewalldb/actions_sql.go b/firewalldb/actions_sql.go index 75c9d0a6d..4d5448313 100644 --- a/firewalldb/actions_sql.go +++ b/firewalldb/actions_sql.go @@ -12,7 +12,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/fn" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLAccountQueries is a subset of the sqlc.Queries interface that can be used @@ -167,7 +167,7 @@ func (s *SQLDB) AddAction(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -202,7 +202,7 @@ func (s *SQLDB) SetActionState(ctx context.Context, al ActionLocator, Valid: errReason != "", }, }) - }) + }, sqldb.NoOpReset) } // ListActions returns a list of Actions. The query IndexOffset and MaxNum @@ -350,7 +350,7 @@ func (s *SQLDB) ListActions(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) return actions, lastIndex, uint64(totalCount), err } diff --git a/firewalldb/kvstores_sql.go b/firewalldb/kvstores_sql.go index 0c3df2ddb..6173c3b09 100644 --- a/firewalldb/kvstores_sql.go +++ b/firewalldb/kvstores_sql.go @@ -11,6 +11,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLKVStoreQueries is a subset of the sqlc.Queries interface that can be @@ -45,7 +46,7 @@ func (s *SQLDB) DeleteTempKVStores(ctx context.Context) error { return s.db.ExecTx(ctx, &writeTxOpts, func(tx SQLQueries) error { return tx.DeleteAllTempKVStores(ctx) - }) + }, sqldb.NoOpReset) } // GetKVStores constructs a new rules.KVStores in a namespace defined by the diff --git a/firewalldb/sql_store.go b/firewalldb/sql_store.go index f17010f2c..1be887ace 100644 --- a/firewalldb/sql_store.go +++ b/firewalldb/sql_store.go @@ -5,7 +5,9 @@ import ( "database/sql" "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLSessionQueries is a subset of the sqlc.Queries interface that can be used @@ -18,17 +20,20 @@ type SQLSessionQueries interface { // SQLQueries is a subset of the sqlc.Queries interface that can be used to // interact with various firewalldb tables. type SQLQueries interface { + sqldb.BaseQuerier + SQLKVStoreQueries SQLPrivacyPairQueries SQLActionQueries } -// BatchedSQLQueries is a version of the SQLQueries that's capable of batched -// database operations. +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLDB represents a storage backend. @@ -38,11 +43,31 @@ type SQLDB struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) + }, + ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + // A compile-time assertion to ensure that SQLDB implements the RulesDB // interface. var _ RulesDB = (*SQLDB)(nil) @@ -53,12 +78,10 @@ var _ ActionDB = (*SQLDB)(nil) // NewSQLDB creates a new SQLStore instance given an open SQLQueries // storage backend. -func NewSQLDB(sqlDB *db.BaseDB, clock clock.Clock) *SQLDB { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) - }, - ) +func NewSQLDB(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLDB { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLDB{ db: executor, @@ -88,7 +111,7 @@ func (e *sqlExecutor[T]) Update(ctx context.Context, var txOpts db.QueriesTxOptions return e.db.ExecTx(ctx, &txOpts, func(queries SQLQueries) error { return fn(ctx, e.wrapTx(queries)) - }) + }, sqldb.NoOpReset) } // View opens a database read transaction and executes the function f with the @@ -104,5 +127,5 @@ func (e *sqlExecutor[T]) View(ctx context.Context, return e.db.ExecTx(ctx, &txOpts, func(queries SQLQueries) error { return fn(ctx, e.wrapTx(queries)) - }) + }, sqldb.NoOpReset) } diff --git a/firewalldb/test_sql.go b/firewalldb/test_sql.go index a412441f8..b7e3d9052 100644 --- a/firewalldb/test_sql.go +++ b/firewalldb/test_sql.go @@ -6,10 +6,12 @@ import ( "testing" "time" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) @@ -55,8 +57,10 @@ func assertEqualActions(t *testing.T, expected, got *Action) { // createStore is a helper function that creates a new SQLDB and ensure that // it is closed when during the test cleanup. -func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLDB { - store := NewSQLDB(sqlDB, clock) +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, clock clock.Clock) *SQLDB { + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLDB(sqlDB, queries, clock) t.Cleanup(func() { require.NoError(t, store.Close()) }) diff --git a/firewalldb/test_sqlite.go b/firewalldb/test_sqlite.go index 49b956d7d..ab184b5a6 100644 --- a/firewalldb/test_sqlite.go +++ b/firewalldb/test_sqlite.go @@ -7,17 +7,23 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // NewTestDB is a helper function that creates an BBolt database for testing. func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { - return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) + return createStore( + t, sqldb.NewTestSqliteDB(t, db.LitdMigrationStreams).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) FirewallDBs { - return createStore( - t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, - ) +func NewTestDBFromPath(t *testing.T, dbPath string, + clock clock.Clock) FirewallDBs { + + tDb := sqldb.NewTestSqliteDBFromPath(t, dbPath, db.LitdMigrationStreams) + + return createStore(t, tDb.BaseDB, clock) } diff --git a/session/sql_store.go b/session/sql_store.go index b1d366fe7..26662a574 100644 --- a/session/sql_store.go +++ b/session/sql_store.go @@ -14,6 +14,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb/v2" "gopkg.in/macaroon-bakery.v2/bakery" "gopkg.in/macaroon.v2" ) @@ -21,6 +22,8 @@ import ( // SQLQueries is a subset of the sqlc.Queries interface that can be used to // interact with session related tables. type SQLQueries interface { + sqldb.BaseQuerier + GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) GetSessionByID(ctx context.Context, id int64) (sqlc.Session, error) GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]sqlc.Session, error) @@ -51,12 +54,13 @@ type SQLQueries interface { var _ Store = (*SQLStore)(nil) -// BatchedSQLQueries is a version of the SQLQueries that's capable of batched -// database operations. +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLStore represents a storage backend. @@ -66,19 +70,37 @@ type SQLStore struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } -// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries -// storage backend. -func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) }, ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + +// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries +// storage backend. +func NewSQLStore(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLStore { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLStore{ db: executor, @@ -281,7 +303,7 @@ func (s *SQLStore) NewSession(ctx context.Context, label string, typ Type, } return nil - }) + }, sqldb.NoOpReset) if err != nil { mappedSQLErr := db.MapSQLError(err) var uniqueConstraintErr *db.ErrSqlUniqueConstraintViolation @@ -325,7 +347,7 @@ func (s *SQLStore) ListSessionsByType(ctx context.Context, t Type) ([]*Session, } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -358,7 +380,7 @@ func (s *SQLStore) ListSessionsByState(ctx context.Context, state State) ( } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -417,7 +439,7 @@ func (s *SQLStore) ShiftState(ctx context.Context, alias ID, dest State) error { State: int16(dest), }, ) - }) + }, sqldb.NoOpReset) } // DeleteReservedSessions deletes all sessions that are in the StateReserved @@ -428,7 +450,7 @@ func (s *SQLStore) DeleteReservedSessions(ctx context.Context) error { var writeTxOpts db.QueriesTxOptions return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { return db.DeleteSessionsWithState(ctx, int16(StateReserved)) - }) + }, sqldb.NoOpReset) } // GetSessionByLocalPub fetches the session with the given local pub key. @@ -458,7 +480,7 @@ func (s *SQLStore) GetSessionByLocalPub(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -491,7 +513,7 @@ func (s *SQLStore) ListAllSessions(ctx context.Context) ([]*Session, error) { } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -521,7 +543,7 @@ func (s *SQLStore) UpdateSessionRemotePubKey(ctx context.Context, alias ID, RemotePublicKey: remoteKey, }, ) - }) + }, sqldb.NoOpReset) } // getSqlUnusedAliasAndKeyPair can be used to generate a new, unused, local @@ -576,7 +598,7 @@ func (s *SQLStore) GetSession(ctx context.Context, alias ID) (*Session, error) { } return nil - }) + }, sqldb.NoOpReset) return sess, err } @@ -617,7 +639,7 @@ func (s *SQLStore) GetGroupID(ctx context.Context, sessionID ID) (ID, error) { legacyGroupID, err = IDFromBytes(legacyGroupIDB) return err - }) + }, sqldb.NoOpReset) if err != nil { return ID{}, err } @@ -666,7 +688,7 @@ func (s *SQLStore) GetSessionIDs(ctx context.Context, legacyGroupID ID) ([]ID, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } diff --git a/session/test_sql.go b/session/test_sql.go index a83186069..5623c8207 100644 --- a/session/test_sql.go +++ b/session/test_sql.go @@ -6,8 +6,9 @@ import ( "testing" "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) @@ -22,8 +23,12 @@ func NewTestDBWithAccounts(t *testing.T, clock clock.Clock, // createStore is a helper function that creates a new SQLStore and ensure that // it is closed when during the test cleanup. -func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - store := NewSQLStore(sqlDB, clock) +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, + clock clock.Clock) *SQLStore { + + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLStore(sqlDB, queries, clock) t.Cleanup(func() { require.NoError(t, store.Close()) }) diff --git a/session/test_sqlite.go b/session/test_sqlite.go index 0ceb0e046..84d946ce2 100644 --- a/session/test_sqlite.go +++ b/session/test_sqlite.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // ErrDBClosed is an error that is returned when a database operation is @@ -16,7 +17,10 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) + return createStore( + t, sqldb.NewTestSqliteDB(t, db.LitdMigrationStreams).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,7 +28,7 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return createStore( - t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, - ) + tDb := sqldb.NewTestSqliteDBFromPath(t, dbPath, db.LitdMigrationStreams) + + return createStore(t, tDb.BaseDB, clock) }