diff --git a/accounts/sql_migration_test.go b/accounts/sql_migration_test.go index d1e331e42..6832e42aa 100644 --- a/accounts/sql_migration_test.go +++ b/accounts/sql_migration_test.go @@ -2,18 +2,17 @@ package accounts import ( "context" - "database/sql" "fmt" "testing" "time" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" "golang.org/x/exp/rand" "pgregory.net/rapid" @@ -36,25 +35,18 @@ func TestAccountStoreMigration(t *testing.T) { } makeSQLDB := func(t *testing.T) (*SQLStore, - *db.TransactionExecutor[SQLQueries]) { + *SQLQueriesExecutor[SQLQueries]) { testDBStore := NewTestDB(t, clock) - t.Cleanup(func() { - require.NoError(t, testDBStore.Close()) - }) store, ok := testDBStore.(*SQLStore) require.True(t, ok) baseDB := store.BaseDB - genericExecutor := db.NewTransactionExecutor( - baseDB, func(tx *sql.Tx) SQLQueries { - return baseDB.WithTx(tx) - }, - ) + queries := sqlc.NewForType(baseDB, baseDB.BackendType) - return store, genericExecutor + return store, NewSQLQueriesExecutor(baseDB, queries) } assertMigrationResults := func(t *testing.T, sqlStore *SQLStore, @@ -346,7 +338,7 @@ func TestAccountStoreMigration(t *testing.T) { return MigrateAccountStoreToSQL( ctx, kvStore, tx, ) - }, + }, sqldb.NoOpReset, ) require.NoError(t, err) diff --git a/accounts/store_sql.go b/accounts/store_sql.go index 830f16587..c7e8ab070 100644 --- a/accounts/store_sql.go +++ b/accounts/store_sql.go @@ -16,6 +16,7 @@ import ( "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/sqldb/v2" ) const ( @@ -33,6 +34,8 @@ const ( // //nolint:lll type SQLQueries interface { + sqldb.BaseQuerier + AddAccountInvoice(ctx context.Context, arg sqlc.AddAccountInvoiceParams) error DeleteAccount(ctx context.Context, id int64) error DeleteAccountPayment(ctx context.Context, arg sqlc.DeleteAccountPaymentParams) error @@ -53,12 +56,13 @@ type SQLQueries interface { GetAccountInvoice(ctx context.Context, arg sqlc.GetAccountInvoiceParams) (sqlc.AccountInvoice, error) } -// BatchedSQLQueries is a version of the SQLQueries that's capable -// of batched database operations. +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLStore represents a storage backend. @@ -68,19 +72,37 @@ type SQLStore struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } -// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries -// storage backend. -func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) }, ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + +// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries +// storage backend. +func NewSQLStore(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLStore { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLStore{ db: executor, @@ -157,7 +179,7 @@ func (s *SQLStore) NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -299,7 +321,7 @@ func (s *SQLStore) AddAccountInvoice(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, acctID) - }) + }, sqldb.NoOpReset) } func getAccountIDByAlias(ctx context.Context, db SQLQueries, alias AccountID) ( @@ -377,7 +399,7 @@ func (s *SQLStore) UpdateAccountBalanceAndExpiry(ctx context.Context, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // CreditAccount increases the balance of the account with the given alias by @@ -412,7 +434,7 @@ func (s *SQLStore) CreditAccount(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // DebitAccount decreases the balance of the account with the given alias by the @@ -453,7 +475,7 @@ func (s *SQLStore) DebitAccount(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // Account retrieves an account from the SQL store and un-marshals it. If the @@ -475,7 +497,7 @@ func (s *SQLStore) Account(ctx context.Context, alias AccountID) ( account, err = getAndMarshalAccount(ctx, db, id) return err - }) + }, sqldb.NoOpReset) return account, err } @@ -507,7 +529,7 @@ func (s *SQLStore) Accounts(ctx context.Context) ([]*OffChainBalanceAccount, } return nil - }) + }, sqldb.NoOpReset) return accounts, err } @@ -524,7 +546,7 @@ func (s *SQLStore) RemoveAccount(ctx context.Context, alias AccountID) error { } return db.DeleteAccount(ctx, id) - }) + }, sqldb.NoOpReset) } // UpsertAccountPayment updates or inserts a payment entry for the given @@ -634,7 +656,7 @@ func (s *SQLStore) UpsertAccountPayment(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // DeleteAccountPayment removes a payment entry from the account with the given @@ -677,7 +699,7 @@ func (s *SQLStore) DeleteAccountPayment(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // LastIndexes returns the last invoice add and settle index or @@ -704,7 +726,7 @@ func (s *SQLStore) LastIndexes(ctx context.Context) (uint64, uint64, error) { } return err - }) + }, sqldb.NoOpReset) return uint64(addIndex), uint64(settleIndex), err } @@ -729,7 +751,7 @@ func (s *SQLStore) StoreLastIndexes(ctx context.Context, addIndex, Name: settleIndexName, Value: int64(settleIndex), }) - }) + }, sqldb.NoOpReset) } // Close closes the underlying store. diff --git a/accounts/test_postgres.go b/accounts/test_postgres.go index 609eeb608..16665030d 100644 --- a/accounts/test_postgres.go +++ b/accounts/test_postgres.go @@ -16,7 +16,7 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return NewSQLStore(db.NewTestPostgresDB(t).BaseDB, clock) + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,5 +24,5 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return NewSQLStore(db.NewTestPostgresDB(t).BaseDB, clock) + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } diff --git a/accounts/test_sql.go b/accounts/test_sql.go new file mode 100644 index 000000000..ca2f43d6f --- /dev/null +++ b/accounts/test_sql.go @@ -0,0 +1,27 @@ +//go:build test_db_postgres || test_db_sqlite + +package accounts + +import ( + "testing" + + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" + "github.com/stretchr/testify/require" +) + +// createStore is a helper function that creates a new SQLStore and ensure that +// it is closed when during the test cleanup. +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, + clock clock.Clock) *SQLStore { + + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLStore(sqlDB, queries, clock) + t.Cleanup(func() { + require.NoError(t, store.Close()) + }) + + return store +} diff --git a/accounts/test_sqlite.go b/accounts/test_sqlite.go index 0dd042a28..a31f990a6 100644 --- a/accounts/test_sqlite.go +++ b/accounts/test_sqlite.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // ErrDBClosed is an error that is returned when a database operation is @@ -16,7 +17,10 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return NewSQLStore(db.NewTestSqliteDB(t).BaseDB, clock) + return createStore( + t, sqldb.NewTestSqliteDB(t, db.LitdMigrationStreams).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,7 +28,7 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return NewSQLStore( - db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, - ) + tDb := sqldb.NewTestSqliteDBFromPath(t, dbPath, db.LitdMigrationStreams) + + return createStore(t, tDb.BaseDB, clock) } diff --git a/config_dev.go b/config_dev.go index 90b8b290f..5b6185aa0 100644 --- a/config_dev.go +++ b/config_dev.go @@ -4,6 +4,7 @@ package terminal import ( "fmt" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "path/filepath" "github.com/lightninglabs/lightning-terminal/accounts" @@ -11,6 +12,7 @@ import ( "github.com/lightninglabs/lightning-terminal/firewalldb" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) const ( @@ -101,14 +103,36 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { return stores, err } - sqlStore, err := db.NewSqliteStore(cfg.Sqlite) + sqlStore, err := sqldb.NewSqliteStore(&sqldb.SqliteConfig{ + SkipMigrations: cfg.Sqlite.SkipMigrations, + SkipMigrationDbBackup: cfg.Sqlite.SkipMigrationDbBackup, + }, cfg.Sqlite.DatabaseFileName) if err != nil { return stores, err } - acctStore := accounts.NewSQLStore(sqlStore.BaseDB, clock) - sessStore := session.NewSQLStore(sqlStore.BaseDB, clock) - firewallStore := firewalldb.NewSQLDB(sqlStore.BaseDB, clock) + if !cfg.Sqlite.SkipMigrations { + err = sqldb.ApplyAllMigrations( + sqlStore, db.LitdMigrationStreams, + ) + if err != nil { + return stores, fmt.Errorf("error applying "+ + "migrations to SQLlite store: %w", err, + ) + } + } + + queries := sqlc.NewForType(sqlStore, sqlStore.BackendType) + + acctStore := accounts.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + sessStore := session.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + firewallStore := firewalldb.NewSQLDB( + sqlStore.BaseDB, queries, clock, + ) stores.accounts = acctStore stores.sessions = sessStore @@ -116,14 +140,41 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { stores.closeFns["sqlite"] = sqlStore.BaseDB.Close case DatabaseBackendPostgres: - sqlStore, err := db.NewPostgresStore(cfg.Postgres) + sqlStore, err := sqldb.NewPostgresStore(&sqldb.PostgresConfig{ + Dsn: cfg.Postgres.DSN(false), + MaxOpenConnections: cfg.Postgres.MaxOpenConnections, + MaxIdleConnections: cfg.Postgres.MaxIdleConnections, + ConnMaxLifetime: cfg.Postgres.ConnMaxLifetime, + ConnMaxIdleTime: cfg.Postgres.ConnMaxIdleTime, + RequireSSL: cfg.Postgres.RequireSSL, + SkipMigrations: cfg.Postgres.SkipMigrations, + }) if err != nil { return stores, err } - acctStore := accounts.NewSQLStore(sqlStore.BaseDB, clock) - sessStore := session.NewSQLStore(sqlStore.BaseDB, clock) - firewallStore := firewalldb.NewSQLDB(sqlStore.BaseDB, clock) + if !cfg.Sqlite.SkipMigrations { + err = sqldb.ApplyAllMigrations( + sqlStore, db.LitdMigrationStreams, + ) + if err != nil { + return stores, fmt.Errorf("error applying "+ + "migrations to Postgres store: %w", err, + ) + } + } + + queries := sqlc.NewForType(sqlStore, sqlStore.BackendType) + + acctStore := accounts.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + sessStore := session.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + firewallStore := firewalldb.NewSQLDB( + sqlStore.BaseDB, queries, clock, + ) stores.accounts = acctStore stores.sessions = sessStore diff --git a/db/interfaces.go b/db/interfaces.go index ba64520b4..bb39df9ea 100644 --- a/db/interfaces.go +++ b/db/interfaces.go @@ -8,6 +8,7 @@ import ( "time" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/sqldb/v2" ) var ( @@ -56,7 +57,7 @@ type BatchedTx[Q any] interface { txBody func(Q) error) error // Backend returns the type of the database backend used. - Backend() sqlc.BackendType + Backend() sqldb.BackendType } // Tx represents a database transaction that can be committed or rolled back. @@ -277,7 +278,7 @@ func (t *TransactionExecutor[Q]) ExecTx(ctx context.Context, } // Backend returns the type of the database backend used. -func (t *TransactionExecutor[Q]) Backend() sqlc.BackendType { +func (t *TransactionExecutor[Q]) Backend() sqldb.BackendType { return t.BatchedQuerier.Backend() } @@ -301,7 +302,7 @@ func (s *BaseDB) BeginTx(ctx context.Context, opts TxOptions) (*sql.Tx, error) { } // Backend returns the type of the database backend used. -func (s *BaseDB) Backend() sqlc.BackendType { +func (s *BaseDB) Backend() sqldb.BackendType { return s.Queries.Backend() } diff --git a/db/post_migration_checks.go b/db/post_migration_checks.go new file mode 100644 index 000000000..106ac8ec7 --- /dev/null +++ b/db/post_migration_checks.go @@ -0,0 +1,81 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/sqldb/v2" +) + +// postMigrationCheck is a function type for a function that performs a +// post-migration check on the database. +type postMigrationCheck func(context.Context, *sqlc.Queries) error + +var ( + // postMigrationChecks is a map of functions that are run after the + // database migration with the version specified in the key has been + // applied. These functions are used to perform additional checks on the + // database state that are not fully expressible in SQL. + postMigrationChecks = map[uint]postMigrationCheck{} +) + +// makePostStepCallbacks turns the post migration checks into a map of post +// step callbacks that can be used with the migrate package. The keys of the map +// are the migration versions, and the values are the callbacks that will be +// executed after the migration with the corresponding version is applied. +func makePostStepCallbacks(db *sqldb.BaseDB, + c map[uint]postMigrationCheck) map[uint]migrate.PostStepCallback { + + queries := sqlc.NewForType(db, db.BackendType) + executor := sqldb.NewTransactionExecutor( + db, func(tx *sql.Tx) *sqlc.Queries { + return queries.WithTx(tx) + }, + ) + + var ( + ctx = context.Background() + postStepCallbacks = make(map[uint]migrate.PostStepCallback) + ) + for version, check := range c { + runCheck := func(m *migrate.Migration, q *sqlc.Queries) error { + log.Infof("Running post-migration check for version %d", + version) + start := time.Now() + + err := check(ctx, q) + if err != nil { + return fmt.Errorf("post-migration "+ + "check failed for version %d: "+ + "%w", version, err) + } + + log.Infof("Post-migration check for version %d "+ + "completed in %v", version, time.Since(start)) + + return nil + } + + // We ignore the actual driver that's being returned here, since + // we use migrate.NewWithInstance() to create the migration + // instance from our already instantiated database backend that + // is also passed into this function. + postStepCallbacks[version] = func(m *migrate.Migration, + _ database.Driver) error { + + return executor.ExecTx( + ctx, sqldb.NewWriteTx(), + func(q *sqlc.Queries) error { + return runCheck(m, q) + }, sqldb.NoOpReset, + ) + } + } + + return postStepCallbacks +} diff --git a/db/postgres.go b/db/postgres.go index 16e41dc09..962629be6 100644 --- a/db/postgres.go +++ b/db/postgres.go @@ -9,6 +9,7 @@ import ( postgres_migrate "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) @@ -119,7 +120,7 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) { rawDb.SetConnMaxLifetime(connMaxLifetime) rawDb.SetConnMaxIdleTime(connMaxIdleTime) - queries := sqlc.NewPostgres(rawDb) + queries := sqlc.NewForType(rawDb, sqldb.BackendTypePostgres) s := &PostgresStore{ cfg: cfg, BaseDB: &BaseDB{ @@ -128,15 +129,6 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) { }, } - // Now that the database is open, populate the database with our set of - // schemas based on our embedded in-memory file system. - if !cfg.SkipMigrations { - if err := s.ExecuteMigrations(TargetLatest); err != nil { - return nil, fmt.Errorf("error executing migrations: "+ - "%w", err) - } - } - return s, nil } @@ -166,20 +158,17 @@ func (s *PostgresStore) ExecuteMigrations(target MigrationTarget, // NewTestPostgresDB is a helper function that creates a Postgres database for // testing. -func NewTestPostgresDB(t *testing.T) *PostgresStore { +func NewTestPostgresDB(t *testing.T) *sqldb.PostgresStore { t.Helper() t.Logf("Creating new Postgres DB for testing") - sqlFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime, true) - store, err := NewPostgresStore(sqlFixture.GetConfig()) - require.NoError(t, err) - + sqlFixture := sqldb.NewTestPgFixture(t, DefaultPostgresFixtureLifetime) t.Cleanup(func() { sqlFixture.TearDown(t) }) - return store + return sqldb.NewTestPostgresDB(t, sqlFixture, LitdMigrationStreams) } // NewTestPostgresDBWithVersion is a helper function that creates a Postgres diff --git a/db/sql_migrations.go b/db/sql_migrations.go new file mode 100644 index 000000000..4b492ca5a --- /dev/null +++ b/db/sql_migrations.go @@ -0,0 +1,31 @@ +package db + +import ( + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/pgx/v5" + "github.com/lightningnetwork/lnd/sqldb/v2" +) + +var ( + LitdMigrationStream = sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable, + SQLFileDirectory: "sqlc/migrations", + Schemas: sqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // database. This is used to implement downgrade protection for + // the daemon. + // + // NOTE: This MUST be updated when a new migration is added. + LatestMigrationVersion: LatestMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + return makePostStepCallbacks(db, postMigrationChecks), + nil + }, + } + LitdMigrationStreams = []sqldb.MigrationStream{LitdMigrationStream} +) diff --git a/db/sqlc/db_custom.go b/db/sqlc/db_custom.go index f4bf7f611..af556eae7 100644 --- a/db/sqlc/db_custom.go +++ b/db/sqlc/db_custom.go @@ -2,21 +2,8 @@ package sqlc import ( "context" -) - -// BackendType is an enum that represents the type of database backend we're -// using. -type BackendType uint8 - -const ( - // BackendTypeUnknown indicates we're using an unknown backend. - BackendTypeUnknown BackendType = iota - // BackendTypeSqlite indicates we're using a SQLite backend. - BackendTypeSqlite - - // BackendTypePostgres indicates we're using a Postgres backend. - BackendTypePostgres + "github.com/lightningnetwork/lnd/sqldb/v2" ) // wrappedTX is a wrapper around a DBTX that also stores the database backend @@ -24,29 +11,24 @@ const ( type wrappedTX struct { DBTX - backendType BackendType + backendType sqldb.BackendType } // Backend returns the type of database backend we're using. -func (q *Queries) Backend() BackendType { +func (q *Queries) Backend() sqldb.BackendType { wtx, ok := q.db.(*wrappedTX) if !ok { // Shouldn't happen unless a new database backend type is added // but not initialized correctly. - return BackendTypeUnknown + return sqldb.BackendTypeUnknown } return wtx.backendType } -// NewSqlite creates a new Queries instance for a SQLite database. -func NewSqlite(db DBTX) *Queries { - return &Queries{db: &wrappedTX{db, BackendTypeSqlite}} -} - -// NewPostgres creates a new Queries instance for a Postgres database. -func NewPostgres(db DBTX) *Queries { - return &Queries{db: &wrappedTX{db, BackendTypePostgres}} +// NewForType creates a new Queries instance for the given database type. +func NewForType(db DBTX, typ sqldb.BackendType) *Queries { + return &Queries{db: &wrappedTX{db, typ}} } // CustomQueries defines a set of custom queries that we define in addition @@ -62,5 +44,5 @@ type CustomQueries interface { arg ListActionsParams) ([]Action, error) // Backend returns the type of the database backend used. - Backend() BackendType + Backend() sqldb.BackendType } diff --git a/db/sqlc/kvstores.sql.go b/db/sqlc/kvstores.sql.go index b2e6632f4..c0949d173 100644 --- a/db/sqlc/kvstores.sql.go +++ b/db/sqlc/kvstores.sql.go @@ -257,6 +257,42 @@ func (q *Queries) InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreReco return err } +const listAllKVStoresRecords = `-- name: ListAllKVStoresRecords :many +SELECT id, perm, rule_id, session_id, feature_id, entry_key, value +FROM kvstores +` + +func (q *Queries) ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) { + rows, err := q.db.QueryContext(ctx, listAllKVStoresRecords) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Kvstore + for rows.Next() { + var i Kvstore + if err := rows.Scan( + &i.ID, + &i.Perm, + &i.RuleID, + &i.SessionID, + &i.FeatureID, + &i.EntryKey, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const updateFeatureKVStoreRecord = `-- name: UpdateFeatureKVStoreRecord :exec UPDATE kvstores SET value = $1 diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index df89d0898..117a1fbc5 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -57,6 +57,7 @@ type Querier interface { ListAccountInvoices(ctx context.Context, accountID int64) ([]AccountInvoice, error) ListAccountPayments(ctx context.Context, accountID int64) ([]AccountPayment, error) ListAllAccounts(ctx context.Context) ([]Account, error) + ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) ListSessions(ctx context.Context) ([]Session, error) ListSessionsByState(ctx context.Context, state int16) ([]Session, error) ListSessionsByType(ctx context.Context, type_ int16) ([]Session, error) diff --git a/db/sqlc/queries/kvstores.sql b/db/sqlc/queries/kvstores.sql index 7963e46a4..1ebfe3b0d 100644 --- a/db/sqlc/queries/kvstores.sql +++ b/db/sqlc/queries/kvstores.sql @@ -28,6 +28,10 @@ VALUES ($1, $2, $3, $4, $5, $6); DELETE FROM kvstores WHERE perm = false; +-- name: ListAllKVStoresRecords :many +SELECT * +FROM kvstores; + -- name: GetGlobalKVStoreRecord :one SELECT value FROM kvstores diff --git a/db/sqlite.go b/db/sqlite.go index 803362fa8..6f69a7e5b 100644 --- a/db/sqlite.go +++ b/db/sqlite.go @@ -11,6 +11,7 @@ import ( "github.com/golang-migrate/migrate/v4" sqlite_migrate "github.com/golang-migrate/migrate/v4/database/sqlite" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" _ "modernc.org/sqlite" // Register relevant drivers. ) @@ -132,7 +133,7 @@ func NewSqliteStore(cfg *SqliteConfig) (*SqliteStore, error) { db.SetMaxIdleConns(defaultMaxConns) db.SetConnMaxLifetime(defaultConnMaxLifetime) - queries := sqlc.NewSqlite(db) + queries := sqlc.NewForType(db, sqldb.BackendTypeSqlite) s := &SqliteStore{ cfg: cfg, BaseDB: &BaseDB{ @@ -140,16 +141,7 @@ func NewSqliteStore(cfg *SqliteConfig) (*SqliteStore, error) { Queries: queries, }, } - - // Now that the database is open, populate the database with our set of - // schemas based on our embedded in-memory file system. - if !cfg.SkipMigrations { - if err := s.ExecuteMigrations(s.backupAndMigrate); err != nil { - return nil, fmt.Errorf("error executing migrations: "+ - "%w", err) - } - } - + return s, nil } diff --git a/firewalldb/actions_sql.go b/firewalldb/actions_sql.go index 75c9d0a6d..4d5448313 100644 --- a/firewalldb/actions_sql.go +++ b/firewalldb/actions_sql.go @@ -12,7 +12,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/fn" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLAccountQueries is a subset of the sqlc.Queries interface that can be used @@ -167,7 +167,7 @@ func (s *SQLDB) AddAction(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -202,7 +202,7 @@ func (s *SQLDB) SetActionState(ctx context.Context, al ActionLocator, Valid: errReason != "", }, }) - }) + }, sqldb.NoOpReset) } // ListActions returns a list of Actions. The query IndexOffset and MaxNum @@ -350,7 +350,7 @@ func (s *SQLDB) ListActions(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) return actions, lastIndex, uint64(totalCount), err } diff --git a/firewalldb/actions_test.go b/firewalldb/actions_test.go index c27e53e96..69990c1da 100644 --- a/firewalldb/actions_test.go +++ b/firewalldb/actions_test.go @@ -28,9 +28,6 @@ func TestActionStorage(t *testing.T) { sessDB := session.NewTestDBWithAccounts(t, clock, accountsDB) db := NewTestDBWithSessionsAndAccounts(t, sessDB, accountsDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // Assert that attempting to add an action for a session that does not // exist returns an error. @@ -198,9 +195,6 @@ func TestListActions(t *testing.T) { sessDB := session.NewTestDB(t, clock) db := NewTestDBWithSessions(t, sessDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // Add 2 sessions that we can reference. sess1, err := sessDB.NewSession( @@ -466,9 +460,6 @@ func TestListGroupActions(t *testing.T) { } db := NewTestDBWithSessions(t, sessDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // There should not be any actions in group 1 yet. al, _, _, err := db.ListActions(ctx, nil, WithActionGroupID(group1)) diff --git a/firewalldb/db.go b/firewalldb/db.go index b8d9ed06f..a8349a538 100644 --- a/firewalldb/db.go +++ b/firewalldb/db.go @@ -14,29 +14,21 @@ var ( ErrNoSuchKeyFound = fmt.Errorf("no such key found") ) -// firewallDBs is an interface that groups the RulesDB and PrivacyMapper -// interfaces. -type firewallDBs interface { - RulesDB - PrivacyMapper - ActionDB -} - // DB manages the firewall rules database. type DB struct { started sync.Once stopped sync.Once - firewallDBs + FirewallDBs cancel fn.Option[context.CancelFunc] } // NewDB creates a new firewall database. For now, it only contains the // underlying rules' and privacy mapper databases. -func NewDB(dbs firewallDBs) *DB { +func NewDB(dbs FirewallDBs) *DB { return &DB{ - firewallDBs: dbs, + FirewallDBs: dbs, } } diff --git a/firewalldb/interface.go b/firewalldb/interface.go index 5ee729e91..c2955bdc6 100644 --- a/firewalldb/interface.go +++ b/firewalldb/interface.go @@ -134,3 +134,11 @@ type ActionDB interface { // and feature name. GetActionsReadDB(groupID session.ID, featureName string) ActionsReadDB } + +// FirewallDBs is an interface that groups the RulesDB, PrivacyMapper and +// ActionDB interfaces. +type FirewallDBs interface { + RulesDB + PrivacyMapper + ActionDB +} diff --git a/firewalldb/kvstores_kvdb.go b/firewalldb/kvstores_kvdb.go index 51721d475..d1e8e35a6 100644 --- a/firewalldb/kvstores_kvdb.go +++ b/firewalldb/kvstores_kvdb.go @@ -16,13 +16,13 @@ the temporary store changes instead of just keeping an in-memory store is that we can then guarantee atomicity if changes are made to both the permanent and temporary stores. -rules -> perm -> rule-name -> global -> {k:v} - -> sessions -> group ID -> session-kv-store -> {k:v} - -> feature-kv-stores -> feature-name -> {k:v} +"rules" -> "perm" -> rule-name -> "global" -> {k:v} + "session-kv-store" -> group ID -> {k:v} + -> "feature-kv-stores" -> feature-name -> {k:v} - -> temp -> rule-name -> global -> {k:v} - -> sessions -> group ID -> session-kv-store -> {k:v} - -> feature-kv-stores -> feature-name -> {k:v} + -> "temp" -> rule-name -> "global" -> {k:v} + "session-kv-store" -> group ID -> {k:v} + -> "feature-kv-stores" -> feature-name -> {k:v} */ var ( diff --git a/firewalldb/kvstores_sql.go b/firewalldb/kvstores_sql.go index 0c3df2ddb..6173c3b09 100644 --- a/firewalldb/kvstores_sql.go +++ b/firewalldb/kvstores_sql.go @@ -11,6 +11,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLKVStoreQueries is a subset of the sqlc.Queries interface that can be @@ -45,7 +46,7 @@ func (s *SQLDB) DeleteTempKVStores(ctx context.Context) error { return s.db.ExecTx(ctx, &writeTxOpts, func(tx SQLQueries) error { return tx.DeleteAllTempKVStores(ctx) - }) + }, sqldb.NoOpReset) } // GetKVStores constructs a new rules.KVStores in a namespace defined by the diff --git a/firewalldb/sql_migration.go b/firewalldb/sql_migration.go new file mode 100644 index 000000000..092b61c8e --- /dev/null +++ b/firewalldb/sql_migration.go @@ -0,0 +1,486 @@ +package firewalldb + +import ( + "bytes" + "context" + "database/sql" + "errors" + "fmt" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/sqldb" + "go.etcd.io/bbolt" +) + +// kvParams is a type alias for the InsertKVStoreRecordParams, to shorten the +// line length in the migration code. +type kvParams = sqlc.InsertKVStoreRecordParams + +// MigrateFirewallDBToSQL runs the migration of the firwalldb stores from the +// bbolt database to a SQL database. The migration is done in a single +// transaction to ensure that all rows in the stores are migrated or none at +// all. +// +// Note that this migration currently only migrates the kvstores, but will be +// extended in the future to also migrate the privacy mapper and action stores. +// +// NOTE: As sessions may contain linked sessions and accounts, the sessions and +// accounts sql migration MUST be run prior to this migration. +func MigrateFirewallDBToSQL(ctx context.Context, kvStore *bbolt.DB, + tx SQLQueries) error { + + log.Infof("Starting migration of the rules DB to SQL") + + err := migrateKVStoresDBToSQL(ctx, kvStore, tx) + if err != nil { + return err + } + + log.Infof("The rules DB has been migrated from KV to SQL.") + + // TODO(viktor): Add migration for the privacy mapper and the action + // stores. + + return nil +} + +// migrateKVStoresDBToSQL runs the migration of all KV stores from the KV +// database to the SQL database. The function also asserts that the +// migrated values match the original values in the KV store. +// See the illustration in the firwalldb/kvstores_kvdb.go file to understand +// the structure of the KV stores, and why we process the buckets in the +// order we do. +// Note that this function and the subsequent functions are intentionally +// designed to loop over all buckets and values that exist in the KV store, +// so that we are sure that we actually find all stores and values that +// exist in the KV store, and can be sure that the kv store actually follows +// the expected structure. +func migrateKVStoresDBToSQL(ctx context.Context, kvStore *bbolt.DB, + sqlTx SQLQueries) error { + + log.Infof("Starting migration of the KV stores to SQL") + + // allParams will hold all the kvParams that are inserted into the + // SQL database during the migration. + var allParams []kvParams + + err := kvStore.View(func(kvTx *bbolt.Tx) error { + for _, perm := range []bool{true, false} { + mainBucket, err := getMainBucket(kvTx, false, perm) + if err != nil { + return err + } + + if mainBucket == nil { + // If the mainBucket doesn't exist, there are no + // records to migrate under that bucket, + // therefore we don't error, and just proceed + // to not migrate any records under that bucket. + continue + } + + err = mainBucket.ForEach(func(k, v []byte) error { + if v != nil { + return errors.New("expected only " + + "buckets under main bucket") + } + + ruleName := k + ruleNameBucket := mainBucket.Bucket(k) + if ruleNameBucket == nil { + return fmt.Errorf("rule bucket %s "+ + "not found", string(k)) + } + + ruleId, err := sqlTx.GetOrInsertRuleID( + ctx, string(ruleName), + ) + if err != nil { + return err + } + + params, err := processRuleBucket( + ctx, sqlTx, perm, ruleId, + ruleNameBucket, + ) + if err != nil { + return err + } + + allParams = append(allParams, params...) + + return nil + }) + if err != nil { + return err + } + } + + return nil + }) + if err != nil { + return err + } + + // After the migration is done, we validate that all inserted kvParams + // can match the original values in the KV store. Note that this is done + // after all values have been inserted, in order to ensure that the + // migration doesn't overwrite any values after they were inserted. + for _, param := range allParams { + switch { + case param.FeatureID.Valid && param.SessionID.Valid: + migratedValue, err := sqlTx.GetFeatureKVStoreRecord( + ctx, + sqlc.GetFeatureKVStoreRecordParams{ + Key: param.EntryKey, + Perm: param.Perm, + RuleID: param.RuleID, + SessionID: param.SessionID, + FeatureID: param.FeatureID, + }, + ) + if err != nil { + return fmt.Errorf("retreiving of migrated "+ + "feature specific kv store record "+ + "failed %w", err) + } + + if !bytes.Equal(migratedValue, param.Value) { + return fmt.Errorf("migrated feature specific "+ + "kv record value %x does not match "+ + "original value %x", migratedValue, + param.Value) + } + + case param.SessionID.Valid: + migratedValue, err := sqlTx.GetSessionKVStoreRecord( + ctx, + sqlc.GetSessionKVStoreRecordParams{ + Key: param.EntryKey, + Perm: param.Perm, + RuleID: param.RuleID, + SessionID: param.SessionID, + }, + ) + if err != nil { + return fmt.Errorf("retreiving of migrated "+ + "session wide kv store record "+ + "failed %w", err) + } + + if !bytes.Equal(migratedValue, param.Value) { + return fmt.Errorf("migrated session wide kv "+ + "record value %x does not match "+ + "original value %x", migratedValue, + param.Value) + } + + case !param.FeatureID.Valid && !param.SessionID.Valid: + migratedValue, err := sqlTx.GetGlobalKVStoreRecord( + ctx, + sqlc.GetGlobalKVStoreRecordParams{ + Key: param.EntryKey, + Perm: param.Perm, + RuleID: param.RuleID, + }, + ) + if err != nil { + return fmt.Errorf("retreiving of migrated "+ + "global kv store record failed %w", err) + } + + if !bytes.Equal(migratedValue, param.Value) { + return fmt.Errorf("migrated global kv record "+ + "value %x does not match original "+ + "value %x", migratedValue, param.Value) + } + + default: + return fmt.Errorf("unexpected combination of "+ + "FeatureID and SessionID for: %v", param) + } + } + + log.Infof("Migration of the KV stores to SQL completed. Total number "+ + "of rows migrated: %d", len(allParams)) + + return nil +} + +// processRuleBucket processes a single rule bucket, which contains the +// global and session-kv-store key buckets. +func processRuleBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, ruleBucket *bbolt.Bucket) ([]kvParams, error) { + + var params []kvParams + + return params, ruleBucket.ForEach(func(k, v []byte) error { + switch { + case v != nil: + return errors.New("expected only buckets under " + + "rule-name bucket") + case bytes.Equal(k, globalKVStoreBucketKey): + globalBucket := ruleBucket.Bucket( + globalKVStoreBucketKey, + ) + if globalBucket == nil { + return fmt.Errorf("global bucket %s for rule "+ + "id %d not found", string(k), ruleSqlId) + } + + p, err := processGlobalRuleBucket( + ctx, sqlTx, perm, ruleSqlId, globalBucket, + ) + if err != nil { + return err + } + + params = append(params, p...) + + return nil + case bytes.Equal(k, sessKVStoreBucketKey): + sessionBucket := ruleBucket.Bucket( + sessKVStoreBucketKey, + ) + if sessionBucket == nil { + return fmt.Errorf("session bucket %s for rule "+ + "id %d not found", string(k), ruleSqlId) + } + + p, err := processSessionBucket( + ctx, sqlTx, perm, ruleSqlId, sessionBucket, + ) + if err != nil { + return err + } + + params = append(params, p...) + + return nil + default: + return fmt.Errorf("unexpected bucket %s under "+ + "rule-name bucket", string(k)) + } + }) +} + +// processGlobalRuleBucket processes the global bucket under a rule bucket, +// which contains the global key-value store records for the rule. +// It inserts the records into the SQL database and asserts that +// the migrated values match the original values in the KV store. +func processGlobalRuleBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, globalBucket *bbolt.Bucket) ([]kvParams, error) { + + var params []kvParams + + return params, globalBucket.ForEach(func(k, v []byte) error { + if v == nil { + return errors.New("expected only key-values under " + + "global rule-name bucket") + } + + globalInsertParams := kvParams{ + EntryKey: string(k), + Value: v, + Perm: perm, + RuleID: ruleSqlId, + } + + err := sqlTx.InsertKVStoreRecord(ctx, globalInsertParams) + if err != nil { + return fmt.Errorf("inserting global kv store "+ + "record failed %w", err) + } + + params = append(params, globalInsertParams) + + return nil + }) +} + +// processSessionBucket processes the session-kv-store bucket under a rule +// bucket, which contains the group-id buckets for that rule. +func processSessionBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, mainSessionBucket *bbolt.Bucket) ([]kvParams, error) { + + var params []kvParams + + return params, mainSessionBucket.ForEach(func(groupId, v []byte) error { + if v != nil { + return fmt.Errorf("expected only buckets under "+ + "%s bucket", string(sessKVStoreBucketKey)) + } + + groupBucket := mainSessionBucket.Bucket(groupId) + if groupBucket == nil { + return fmt.Errorf("group bucket for group id %s"+ + "not found", string(groupId)) + } + + p, err := processGroupBucket( + ctx, sqlTx, perm, ruleSqlId, groupId, groupBucket, + ) + if err != nil { + return err + } + + params = append(params, p...) + + return nil + }) +} + +// processGroupBucket processes a single group bucket, which contains the +// session-wide kv records and as well as the feature-kv-stores key bucket for +// that group. For the session-wide kv records, it inserts the records into the +// SQL database and asserts that the migrated values match the original values. +func processGroupBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, groupAlias []byte, + groupBucket *bbolt.Bucket) ([]kvParams, error) { + + groupSqlId, err := sqlTx.GetSessionIDByAlias( + ctx, groupAlias, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("session with group id %x "+ + "not found in sql db", groupAlias) + } else if err != nil { + return nil, err + } + + var params []kvParams + + return params, groupBucket.ForEach(func(k, v []byte) error { + switch { + case v != nil: + // This is a non-feature specific k:v store for the + // session, i.e. the session-wide store. + sessWideParams := kvParams{ + EntryKey: string(k), + Value: v, + Perm: perm, + RuleID: ruleSqlId, + SessionID: sqldb.SQLInt64(groupSqlId), + } + + err := sqlTx.InsertKVStoreRecord(ctx, sessWideParams) + if err != nil { + return fmt.Errorf("inserting session wide kv "+ + "store record failed %w", err) + } + + params = append(params, sessWideParams) + + return nil + case bytes.Equal(k, featureKVStoreBucketKey): + // This is a feature specific k:v store for the + // session, which will be stored under the feature-name + // under this bucket. + + featureStoreBucket := groupBucket.Bucket( + featureKVStoreBucketKey, + ) + if featureStoreBucket == nil { + return fmt.Errorf("feature store bucket %s "+ + "for group id %s not found", + string(featureKVStoreBucketKey), + string(groupAlias)) + } + + p, err := processFeatureStoreBucket( + ctx, sqlTx, perm, ruleSqlId, groupSqlId, + featureStoreBucket, + ) + if err != nil { + return err + } + + params = append(params, p...) + + return nil + default: + return fmt.Errorf("unexpected bucket %s found under "+ + "the %s bucket", string(k), + string(sessKVStoreBucketKey)) + } + }) +} + +// processFeatureStoreBucket processes the feature-kv-store bucket under a +// group bucket, which contains the feature specific buckets for that group. +func processFeatureStoreBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, groupSqlId int64, + featureStoreBucket *bbolt.Bucket) ([]kvParams, error) { + + var params []kvParams + + return params, featureStoreBucket.ForEach(func(k, v []byte) error { + if v != nil { + return fmt.Errorf("expected only buckets under " + + "feature stores bucket") + } + + featureName := k + featureNameBucket := featureStoreBucket.Bucket(featureName) + if featureNameBucket == nil { + return fmt.Errorf("feature bucket %s not found", + string(featureName)) + } + + featureSqlId, err := sqlTx.GetOrInsertFeatureID( + ctx, string(featureName), + ) + if err != nil { + return err + } + + p, err := processFeatureNameBucket( + ctx, sqlTx, perm, ruleSqlId, groupSqlId, featureSqlId, + featureNameBucket, + ) + if err != nil { + return err + } + + params = append(params, p...) + + return nil + }) +} + +// processFeatureNameBucket processes a single feature name bucket, which +// contains the feature specific key-value store records for that group. +// It inserts the records into the SQL database and asserts that +// the migrated values match the original values in the KV store. +func processFeatureNameBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, groupSqlId int64, featureSqlId int64, + featureNameBucket *bbolt.Bucket) ([]kvParams, error) { + + var params []kvParams + + return params, featureNameBucket.ForEach(func(k, v []byte) error { + if v == nil { + return fmt.Errorf("expected only key-values under "+ + "feature name bucket, but found bucket %s", + string(k)) + } + + featureParams := kvParams{ + EntryKey: string(k), + Value: v, + Perm: perm, + RuleID: ruleSqlId, + SessionID: sqldb.SQLInt64(groupSqlId), + FeatureID: sqldb.SQLInt64(featureSqlId), + } + + err := sqlTx.InsertKVStoreRecord(ctx, featureParams) + if err != nil { + return fmt.Errorf("inserting feature specific kv "+ + "store record failed %w", err) + } + + params = append(params, featureParams) + + return nil + }) +} diff --git a/firewalldb/sql_migration_test.go b/firewalldb/sql_migration_test.go new file mode 100644 index 000000000..c068671cc --- /dev/null +++ b/firewalldb/sql_migration_test.go @@ -0,0 +1,537 @@ +package firewalldb + +import ( + "context" + "database/sql" + "fmt" + "github.com/lightningnetwork/lnd/fn" + "testing" + "time" + + "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/session" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb" + "github.com/stretchr/testify/require" + "golang.org/x/exp/rand" +) + +// kvStoreRecord represents a single KV entry inserted into the BoltDB. +type kvStoreRecord struct { + Perm bool + RuleName string + EntryKey string + Global bool + GroupID *session.ID + FeatureName fn.Option[string] // Set if the record is feature specific + Value []byte +} + +// TestFirewallDBMigration tests the migration of firewalldb from a bolt +// backed to a SQL database. Note that this test does not attempt to be a +// complete migration test. +// This test only tests the migration of the KV stores currently, but will +// be extended in the future to also test the migration of the privacy mapper +// and the actions store in the future. +func TestFirewallDBMigration(t *testing.T) { + t.Parallel() + + ctx := context.Background() + clock := clock.NewTestClock(time.Now()) + + // When using build tags that creates a kvdb store for NewTestDB, we + // skip this test as it is only applicable for postgres and sqlite tags. + store := NewTestDB(t, clock) + if _, ok := store.(*BoltDB); ok { + t.Skipf("Skipping Firewall DB migration test for kvdb build") + } + + makeSQLDB := func(t *testing.T, sessionsStore session.Store) (*SQLDB, + *db.TransactionExecutor[SQLQueries]) { + + testDBStore := NewTestDBWithSessions(t, sessionsStore, clock) + + store, ok := testDBStore.(*SQLDB) + require.True(t, ok) + + baseDB := store.BaseDB + + genericExecutor := db.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return baseDB.WithTx(tx) + }, + ) + + return store, genericExecutor + } + + // The assertMigrationResults function will currently assert that + // the migrated kv stores records in the SQLDB match the original kv + // stores records in the BoltDB. + assertMigrationResults := func(t *testing.T, sqlStore *SQLDB, + kvRecords []kvStoreRecord) { + + var ( + ruleIDs = make(map[string]int64) + groupIDs = make(map[string]int64) + featureIDs = make(map[string]int64) + err error + ) + + getRuleID := func(ruleName string) int64 { + ruleID, ok := ruleIDs[ruleName] + if !ok { + ruleID, err = sqlStore.GetRuleID( + ctx, ruleName, + ) + require.NoError(t, err) + + ruleIDs[ruleName] = ruleID + } + + return ruleID + } + + getGroupID := func(groupAlias []byte) int64 { + groupID, ok := groupIDs[string(groupAlias)] + if !ok { + groupID, err = sqlStore.GetSessionIDByAlias( + ctx, groupAlias, + ) + require.NoError(t, err) + + groupIDs[string(groupAlias)] = groupID + } + + return groupID + } + + getFeatureID := func(featureName string) int64 { + featureID, ok := featureIDs[featureName] + if !ok { + featureID, err = sqlStore.GetFeatureID( + ctx, featureName, + ) + require.NoError(t, err) + + featureIDs[featureName] = featureID + } + + return featureID + } + + // First we extract all migrated kv records from the SQLDB, + // in order to be able to compare them to the original kv + // records, to ensure that the migration was successful. + sqlKvRecords, err := sqlStore.ListAllKVStoresRecords(ctx) + require.NoError(t, err) + require.Equal(t, len(kvRecords), len(sqlKvRecords)) + + for _, kvRecord := range kvRecords { + ruleID := getRuleID(kvRecord.RuleName) + + if kvRecord.Global { + sqlVal, err := sqlStore.GetGlobalKVStoreRecord( + ctx, + sqlc.GetGlobalKVStoreRecordParams{ + Key: kvRecord.EntryKey, + Perm: kvRecord.Perm, + RuleID: ruleID, + }, + ) + require.NoError(t, err) + require.Equal(t, kvRecord.Value, sqlVal) + } else if kvRecord.FeatureName.IsNone() { + groupID := getGroupID(kvRecord.GroupID[:]) + + sqlVal, err := sqlStore.GetSessionKVStoreRecord( + ctx, + sqlc.GetSessionKVStoreRecordParams{ + Key: kvRecord.EntryKey, + Perm: kvRecord.Perm, + RuleID: ruleID, + SessionID: sql.NullInt64{ + Int64: groupID, + Valid: true, + }, + }, + ) + require.NoError(t, err) + require.Equal(t, kvRecord.Value, sqlVal) + } else { + groupID := getGroupID(kvRecord.GroupID[:]) + featureID := getFeatureID( + kvRecord.FeatureName.UnwrapOrFail(t), + ) + + sqlVal, err := sqlStore.GetFeatureKVStoreRecord( + ctx, + sqlc.GetFeatureKVStoreRecordParams{ + Key: kvRecord.EntryKey, + Perm: kvRecord.Perm, + RuleID: ruleID, + SessionID: sql.NullInt64{ + Int64: groupID, + Valid: true, + }, + FeatureID: sql.NullInt64{ + Int64: featureID, + Valid: true, + }, + }, + ) + require.NoError(t, err) + require.Equal(t, kvRecord.Value, sqlVal) + } + } + } + + // The tests slice contains all the tests that we will run for the + // migration of the firewalldb from a BoltDB to a SQLDB. + // Note that the tests currently only test the migration of the KV + // stores, but will be extended in the future to also test the migration + // of the privacy mapper and the actions store. + tests := []struct { + name string + populateDB func(t *testing.T, ctx context.Context, + boltDB *BoltDB, + sessionStore session.Store) []kvStoreRecord + }{ + { + name: "empty", + populateDB: func(t *testing.T, ctx context.Context, + boltDB *BoltDB, + sessionStore session.Store) []kvStoreRecord { + + // Don't populate the DB. + return make([]kvStoreRecord, 0) + }, + }, + { + name: "global records", + populateDB: globalRecords, + }, + { + name: "session specific records", + populateDB: sessionSpecificRecords, + }, + { + name: "feature specific records", + populateDB: featureSpecificRecords, + }, + { + name: "records at all levels", + populateDB: recordsAtAllLevels, + }, + { + name: "random records", + populateDB: randomKVRecords, + }, + } + + for _, test := range tests { + tc := test + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // First let's create a sessions store to link to in + // the kvstores DB. In order to create the sessions + // store though, we also need to create an accounts + // store, that we link to the sessions store. + // Note that both of these stores will be sql stores due + // to the build tags enabled when running this test, + // which means we can also pass the sessions store to + // the sql version of the kv stores that we'll create + // in test, without also needing to migrate it. + accountStore := accounts.NewTestDB(t, clock) + sessionsStore := session.NewTestDBWithAccounts( + t, clock, accountStore, + ) + + // Create a new firewall store to populate with test + // data. + firewallStore, err := NewBoltDB( + t.TempDir(), DBFilename, sessionsStore, + accountStore, clock, + ) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, firewallStore.Close()) + }) + + // Populate the kv store. + records := test.populateDB( + t, ctx, firewallStore, sessionsStore, + ) + + // Create the SQL store that we will migrate the data + // to. + sqlStore, txEx := makeSQLDB(t, sessionsStore) + + // Perform the migration. + var opts sqldb.MigrationTxOptions + err = txEx.ExecTx(ctx, &opts, + func(tx SQLQueries) error { + return MigrateFirewallDBToSQL( + ctx, firewallStore.DB, tx, + ) + }, + ) + require.NoError(t, err) + + // Assert migration results. + assertMigrationResults(t, sqlStore, records) + }) + } +} + +// globalRecords populates the kv store with one global record for the temp +// store, and one for the perm store. +func globalRecords(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []kvStoreRecord { + + return insertTestKVRecords( + t, ctx, boltDB, sessionStore, true, fn.None[string](), + ) +} + +// sessionSpecificRecords populates the kv store with one session specific +// record for the local temp store, and one session specific record for the perm +// local store. +func sessionSpecificRecords(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []kvStoreRecord { + + return insertTestKVRecords( + t, ctx, boltDB, sessionStore, false, fn.None[string](), + ) +} + +// featureSpecificRecords populates the kv store with one feature specific +// record for the local temp store, and one feature specific record for the perm +// local store. +func featureSpecificRecords(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []kvStoreRecord { + + return insertTestKVRecords( + t, ctx, boltDB, sessionStore, false, fn.Some("test-feature"), + ) +} + +// recordsAtAllLevels uses adds a record at all possible levels of the kvstores, +// by utilizing all the other helper functions that populates the kvstores at +// different levels. +func recordsAtAllLevels(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []kvStoreRecord { + + gRecords := globalRecords(t, ctx, boltDB, sessionStore) + sRecords := sessionSpecificRecords(t, ctx, boltDB, sessionStore) + fRecords := featureSpecificRecords(t, ctx, boltDB, sessionStore) + + return append(gRecords, append(sRecords, fRecords...)...) +} + +// insertTestKVRecords populates the kv store with one record for the local temp +// store, and one record for the local store. The records will be feature +// specific if the featureNameOpt is set, otherwise they will be session +// specific. Both of the records will be inserted with the same +// session.GroupID, which is created in this function, as well as the same +// ruleName, entryKey and entryVal. +func insertTestKVRecords(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store, global bool, + featureNameOpt fn.Option[string]) []kvStoreRecord { + + var ( + ruleName = "test-rule" + entryKey = "test1" + entryVal = []byte{1, 2, 3} + ) + + // Create a session that we can reference. + sess, err := sessionStore.NewSession( + ctx, "test", session.TypeAutopilot, + time.Unix(1000, 0), "something", + ) + require.NoError(t, err) + + tempKvRecord := kvStoreRecord{ + RuleName: ruleName, + GroupID: &sess.GroupID, + FeatureName: featureNameOpt, + EntryKey: entryKey, + Value: entryVal, + Perm: false, + Global: global, + } + + insertKvRecord(t, ctx, boltDB, tempKvRecord) + + permKvRecord := kvStoreRecord{ + RuleName: ruleName, + GroupID: &sess.GroupID, + FeatureName: featureNameOpt, + EntryKey: entryKey, + Value: entryVal, + Perm: true, + Global: global, + } + + insertKvRecord(t, ctx, boltDB, permKvRecord) + + return []kvStoreRecord{tempKvRecord, permKvRecord} +} + +// insertTestKVRecords populates the kv store with passed record, and asserts +// that the record is inserted correctly. +func insertKvRecord(t *testing.T, ctx context.Context, + boltDB *BoltDB, record kvStoreRecord) { + + if record.Global && record.FeatureName.IsSome() { + t.Fatalf("cannot set both global and feature specific at the " + + "same time") + } + + kvStores := boltDB.GetKVStores( + record.RuleName, *record.GroupID, + record.FeatureName.UnwrapOr(""), + ) + + err := kvStores.Update(ctx, func(ctx context.Context, + tx KVStoreTx) error { + + switch { + case record.Global && !record.Perm: + return tx.GlobalTemp().Set( + ctx, record.EntryKey, record.Value, + ) + case record.Global && record.Perm: + return tx.Global().Set( + ctx, record.EntryKey, record.Value, + ) + case !record.Global && !record.Perm: + return tx.LocalTemp().Set( + ctx, record.EntryKey, record.Value, + ) + case !record.Global && record.Perm: + return tx.Local().Set( + ctx, record.EntryKey, record.Value, + ) + default: + return fmt.Errorf("unexpected global/perm "+ + "combination: global=%v, perm=%v", + record.Global, record.Perm) + } + }) + require.NoError(t, err) +} + +// randomKVRecords populates the kv store with random kv records that span +// across all possible combinations of different levels of records in the kv +// store. All values and different bucket names are randomly generated. +func randomKVRecords(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []kvStoreRecord { + + var ( + // We set the number of records to insert to 1000, as that + // should be enough to cover as many different + // combinations of records as possible, while still being + // fast enough to run in a reasonable time. + numberOfRecords = 1000 + insertedRecords = make([]kvStoreRecord, 0) + ruleName = "initial-rule" + groupId *session.ID + featureName = "initial-feature" + ) + + // Create a random session that we can reference for the initial group + // ID. + sess, err := sessionStore.NewSession( + ctx, "initial-session", session.Type(uint8(rand.Intn(5))), + time.Unix(1000, 0), randomString(rand.Intn(10)+1), + ) + require.NoError(t, err) + + groupId = &sess.GroupID + + // Generate random records. Note that many records will use the same + // rule name, group ID and feature name, to simulate the real world + // usage of the kv stores as much as possible. + for i := 0; i < numberOfRecords; i++ { + // On average, we will generate a new rule which will be used + // for the kv store record 10% of the time. + if rand.Intn(10) == 0 { + ruleName = fmt.Sprintf( + "rule-%s-%d", randomString(rand.Intn(30)+1), i, + ) + } + + // On average, we use the global store 25% of the time. + global := rand.Intn(4) == 0 + + // We'll use the perm store 50% of the time. + perm := rand.Intn(2) == 0 + + // For the non-global records, we will generate a new group ID + // 25% of the time. + if !global && rand.Intn(4) == 0 { + newSess, err := sessionStore.NewSession( + ctx, fmt.Sprintf("session-%d", i), + session.Type(uint8(rand.Intn(5))), + time.Unix(1000, 0), + randomString(rand.Intn(10)+1), + ) + require.NoError(t, err) + + groupId = &newSess.GroupID + } + + featureNameOpt := fn.None[string]() + + // For 50% of the non-global records, we insert a feature + // specific record. The other 50% will be session specific + // records. + if !global && rand.Intn(2) == 0 { + // 25% of the time, we will generate a new feature name. + if rand.Intn(4) == 0 { + featureName = fmt.Sprintf( + "feature-%s-%d", + randomString(rand.Intn(30)+1), i, + ) + } + + featureNameOpt = fn.Some(featureName) + } + + kvEntry := kvStoreRecord{ + RuleName: ruleName, + GroupID: groupId, + FeatureName: featureNameOpt, + EntryKey: fmt.Sprintf("key-%d", i), + Perm: perm, + Global: global, + // We'll generate a random value for all records, + Value: []byte(randomString(rand.Intn(100) + 1)), + } + + // Insert the record into the kv store. + insertKvRecord(t, ctx, boltDB, kvEntry) + + // Add the record to the list of inserted records. + insertedRecords = append(insertedRecords, kvEntry) + } + + return insertedRecords +} + +// randomString generates a random string of the passed length n. +func randomString(n int) string { + letterBytes := "abcdefghijklmnopqrstuvwxyz" + + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return string(b) +} diff --git a/firewalldb/sql_store.go b/firewalldb/sql_store.go index f17010f2c..1be887ace 100644 --- a/firewalldb/sql_store.go +++ b/firewalldb/sql_store.go @@ -5,7 +5,9 @@ import ( "database/sql" "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLSessionQueries is a subset of the sqlc.Queries interface that can be used @@ -18,17 +20,20 @@ type SQLSessionQueries interface { // SQLQueries is a subset of the sqlc.Queries interface that can be used to // interact with various firewalldb tables. type SQLQueries interface { + sqldb.BaseQuerier + SQLKVStoreQueries SQLPrivacyPairQueries SQLActionQueries } -// BatchedSQLQueries is a version of the SQLQueries that's capable of batched -// database operations. +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLDB represents a storage backend. @@ -38,11 +43,31 @@ type SQLDB struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) + }, + ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + // A compile-time assertion to ensure that SQLDB implements the RulesDB // interface. var _ RulesDB = (*SQLDB)(nil) @@ -53,12 +78,10 @@ var _ ActionDB = (*SQLDB)(nil) // NewSQLDB creates a new SQLStore instance given an open SQLQueries // storage backend. -func NewSQLDB(sqlDB *db.BaseDB, clock clock.Clock) *SQLDB { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) - }, - ) +func NewSQLDB(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLDB { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLDB{ db: executor, @@ -88,7 +111,7 @@ func (e *sqlExecutor[T]) Update(ctx context.Context, var txOpts db.QueriesTxOptions return e.db.ExecTx(ctx, &txOpts, func(queries SQLQueries) error { return fn(ctx, e.wrapTx(queries)) - }) + }, sqldb.NoOpReset) } // View opens a database read transaction and executes the function f with the @@ -104,5 +127,5 @@ func (e *sqlExecutor[T]) View(ctx context.Context, return e.db.ExecTx(ctx, &txOpts, func(queries SQLQueries) error { return fn(ctx, e.wrapTx(queries)) - }) + }, sqldb.NoOpReset) } diff --git a/firewalldb/test_kvdb.go b/firewalldb/test_kvdb.go index 6f7a49aa3..c3cd4533a 100644 --- a/firewalldb/test_kvdb.go +++ b/firewalldb/test_kvdb.go @@ -6,34 +6,37 @@ import ( "testing" "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/stretchr/testify/require" ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *BoltDB { +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { return NewTestDBFromPath(t, t.TempDir(), clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) *BoltDB { +func NewTestDBFromPath(t *testing.T, dbPath string, + clock clock.Clock) FirewallDBs { + return newDBFromPathWithSessions(t, dbPath, nil, nil, clock) } // NewTestDBWithSessions creates a new test BoltDB Store with access to an // existing sessions DB. -func NewTestDBWithSessions(t *testing.T, sessStore SessionDB, - clock clock.Clock) *BoltDB { +func NewTestDBWithSessions(t *testing.T, sessStore session.Store, + clock clock.Clock) FirewallDBs { return newDBFromPathWithSessions(t, t.TempDir(), sessStore, nil, clock) } // NewTestDBWithSessionsAndAccounts creates a new test BoltDB Store with access // to an existing sessions DB and accounts DB. -func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore SessionDB, - acctStore AccountsDB, clock clock.Clock) *BoltDB { +func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore session.Store, + acctStore AccountsDB, clock clock.Clock) FirewallDBs { return newDBFromPathWithSessions( t, t.TempDir(), sessStore, acctStore, clock, @@ -41,7 +44,8 @@ func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore SessionDB, } func newDBFromPathWithSessions(t *testing.T, dbPath string, - sessStore SessionDB, acctStore AccountsDB, clock clock.Clock) *BoltDB { + sessStore session.Store, acctStore AccountsDB, + clock clock.Clock) FirewallDBs { store, err := NewBoltDB(dbPath, DBFilename, sessStore, acctStore, clock) require.NoError(t, err) diff --git a/firewalldb/test_postgres.go b/firewalldb/test_postgres.go index f5777e4cb..732b19b4a 100644 --- a/firewalldb/test_postgres.go +++ b/firewalldb/test_postgres.go @@ -10,12 +10,12 @@ import ( ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestPostgresDB(t).BaseDB, clock) +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, _ string, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestPostgresDB(t).BaseDB, clock) +func NewTestDBFromPath(t *testing.T, _ string, clock clock.Clock) FirewallDBs { + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } diff --git a/firewalldb/test_sql.go b/firewalldb/test_sql.go index 03dcfbebf..b7e3d9052 100644 --- a/firewalldb/test_sql.go +++ b/firewalldb/test_sql.go @@ -6,27 +6,29 @@ import ( "testing" "time" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) // NewTestDBWithSessions creates a new test SQLDB Store with access to an // existing sessions DB. func NewTestDBWithSessions(t *testing.T, sessionStore session.Store, - clock clock.Clock) *SQLDB { - + clock clock.Clock) FirewallDBs { sessions, ok := sessionStore.(*session.SQLStore) require.True(t, ok) - return NewSQLDB(sessions.BaseDB, clock) + return createStore(t, sessions.BaseDB, clock) } // NewTestDBWithSessionsAndAccounts creates a new test SQLDB Store with access // to an existing sessions DB and accounts DB. func NewTestDBWithSessionsAndAccounts(t *testing.T, sessionStore SessionDB, - acctStore AccountsDB, clock clock.Clock) *SQLDB { + acctStore AccountsDB, clock clock.Clock) FirewallDBs { sessions, ok := sessionStore.(*session.SQLStore) require.True(t, ok) @@ -36,7 +38,7 @@ func NewTestDBWithSessionsAndAccounts(t *testing.T, sessionStore SessionDB, require.Equal(t, accounts.BaseDB, sessions.BaseDB) - return NewSQLDB(sessions.BaseDB, clock) + return createStore(t, sessions.BaseDB, clock) } func assertEqualActions(t *testing.T, expected, got *Action) { @@ -52,3 +54,16 @@ func assertEqualActions(t *testing.T, expected, got *Action) { expected.AttemptedAt = expectedAttemptedAt got.AttemptedAt = actualAttemptedAt } + +// createStore is a helper function that creates a new SQLDB and ensure that +// it is closed when during the test cleanup. +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, clock clock.Clock) *SQLDB { + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLDB(sqlDB, queries, clock) + t.Cleanup(func() { + require.NoError(t, store.Close()) + }) + + return store +} diff --git a/firewalldb/test_sqlite.go b/firewalldb/test_sqlite.go index 5496cb205..ab184b5a6 100644 --- a/firewalldb/test_sqlite.go +++ b/firewalldb/test_sqlite.go @@ -7,17 +7,23 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestSqliteDB(t).BaseDB, clock) +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { + return createStore( + t, sqldb.NewTestSqliteDB(t, db.LitdMigrationStreams).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) *SQLDB { - return NewSQLDB( - db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, - ) +func NewTestDBFromPath(t *testing.T, dbPath string, + clock clock.Clock) FirewallDBs { + + tDb := sqldb.NewTestSqliteDBFromPath(t, dbPath, db.LitdMigrationStreams) + + return createStore(t, tDb.BaseDB, clock) } diff --git a/go.mod b/go.mod index 5a9aa5c8d..e1131a527 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/lightningnetwork/lnd/fn/v2 v2.0.8 github.com/lightningnetwork/lnd/kvdb v1.4.16 github.com/lightningnetwork/lnd/sqldb v1.0.9 + github.com/lightningnetwork/lnd/sqldb/v2 v2.0.0-00010101000000-000000000000 github.com/lightningnetwork/lnd/tlv v1.3.1 github.com/lightningnetwork/lnd/tor v1.1.6 github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f @@ -60,7 +61,7 @@ require ( ) require ( - dario.cat/mergo v1.0.1 // indirect + dario.cat/mergo v1.0.2 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/NebulousLabs/fastrand v0.0.0-20181203155948-6fb6489aac4e // indirect @@ -245,6 +246,9 @@ replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-d // automatically, so we need to add it manually. replace github.com/golang-migrate/migrate/v4 => github.com/lightninglabs/migrate/v4 v4.18.2-9023d66a-fork-pr-2 -replace github.com/lightningnetwork/lnd => github.com/lightningnetwork/lnd v0.19.0-beta +replace github.com/lightningnetwork/lnd => github.com/ViktorTigerstrom/lnd v0.0.0-20250604171448-07036473e46c + +// TODO: replace this with your own local fork +replace github.com/lightningnetwork/lnd/sqldb/v2 => ../../lnd_forked/lnd/sqldb go 1.23.9 diff --git a/go.sum b/go.sum index 41fd8b809..b02a8d628 100644 --- a/go.sum +++ b/go.sum @@ -30,8 +30,8 @@ cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0Zeo cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= -dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= -dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= +dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= +dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= @@ -50,6 +50,8 @@ github.com/NebulousLabs/go-upnp v0.0.0-20180202185039-29b680b06c82 h1:MG93+PZYs9 github.com/NebulousLabs/go-upnp v0.0.0-20180202185039-29b680b06c82/go.mod h1:GbuBk21JqF+driLX3XtJYNZjGa45YDoa9IqCTzNSfEc= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= +github.com/ViktorTigerstrom/lnd v0.0.0-20250604171448-07036473e46c h1:RjJus6oMMn3SyStT7DRpLEovYHQaAdlxbGiq91jTpdg= +github.com/ViktorTigerstrom/lnd v0.0.0-20250604171448-07036473e46c/go.mod h1:AeAtPyAAV51d9EQxGXB4rrU2J9REwMxqf8N4bEpLfqU= github.com/Yawning/aez v0.0.0-20211027044916-e49e68abd344 h1:cDVUiFo+npB0ZASqnw4q90ylaVAbnYyx0JYqK4YcGok= github.com/Yawning/aez v0.0.0-20211027044916-e49e68abd344/go.mod h1:9pIqrY6SXNL8vjRQE5Hd/OL5GyK/9MrGUWs87z/eFfk= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= @@ -512,8 +514,6 @@ github.com/lightninglabs/taproot-assets/taprpc v1.0.4 h1:D1Zcjvaz5viyNXwecgj2yhQ github.com/lightninglabs/taproot-assets/taprpc v1.0.4/go.mod h1:Ccq0t2GsXzOtC8qF0U1ux/yTF5HcBbVrhCb0tb/jObM= github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb h1:yfM05S8DXKhuCBp5qSMZdtSwvJ+GFzl94KbXMNB1JDY= github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb/go.mod h1:c0kvRShutpj3l6B9WtTsNTBUtjSmjZXbJd9ZBRQOSKI= -github.com/lightningnetwork/lnd v0.19.0-beta h1:/8i2UdARiEpI2iAmPoSDcwZSSEuWqXyfsMxz/mLGbdw= -github.com/lightningnetwork/lnd v0.19.0-beta/go.mod h1:hu6zo1zcznx7nViiFlJY8qGDwwGw5LNLdGJ7ICz5Ysc= github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf0d0Uy4qBjI= github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U= github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0= diff --git a/session/sql_store.go b/session/sql_store.go index b1d366fe7..26662a574 100644 --- a/session/sql_store.go +++ b/session/sql_store.go @@ -14,6 +14,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb/v2" "gopkg.in/macaroon-bakery.v2/bakery" "gopkg.in/macaroon.v2" ) @@ -21,6 +22,8 @@ import ( // SQLQueries is a subset of the sqlc.Queries interface that can be used to // interact with session related tables. type SQLQueries interface { + sqldb.BaseQuerier + GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) GetSessionByID(ctx context.Context, id int64) (sqlc.Session, error) GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]sqlc.Session, error) @@ -51,12 +54,13 @@ type SQLQueries interface { var _ Store = (*SQLStore)(nil) -// BatchedSQLQueries is a version of the SQLQueries that's capable of batched -// database operations. +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLStore represents a storage backend. @@ -66,19 +70,37 @@ type SQLStore struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } -// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries -// storage backend. -func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) }, ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + +// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries +// storage backend. +func NewSQLStore(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLStore { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLStore{ db: executor, @@ -281,7 +303,7 @@ func (s *SQLStore) NewSession(ctx context.Context, label string, typ Type, } return nil - }) + }, sqldb.NoOpReset) if err != nil { mappedSQLErr := db.MapSQLError(err) var uniqueConstraintErr *db.ErrSqlUniqueConstraintViolation @@ -325,7 +347,7 @@ func (s *SQLStore) ListSessionsByType(ctx context.Context, t Type) ([]*Session, } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -358,7 +380,7 @@ func (s *SQLStore) ListSessionsByState(ctx context.Context, state State) ( } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -417,7 +439,7 @@ func (s *SQLStore) ShiftState(ctx context.Context, alias ID, dest State) error { State: int16(dest), }, ) - }) + }, sqldb.NoOpReset) } // DeleteReservedSessions deletes all sessions that are in the StateReserved @@ -428,7 +450,7 @@ func (s *SQLStore) DeleteReservedSessions(ctx context.Context) error { var writeTxOpts db.QueriesTxOptions return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { return db.DeleteSessionsWithState(ctx, int16(StateReserved)) - }) + }, sqldb.NoOpReset) } // GetSessionByLocalPub fetches the session with the given local pub key. @@ -458,7 +480,7 @@ func (s *SQLStore) GetSessionByLocalPub(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -491,7 +513,7 @@ func (s *SQLStore) ListAllSessions(ctx context.Context) ([]*Session, error) { } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -521,7 +543,7 @@ func (s *SQLStore) UpdateSessionRemotePubKey(ctx context.Context, alias ID, RemotePublicKey: remoteKey, }, ) - }) + }, sqldb.NoOpReset) } // getSqlUnusedAliasAndKeyPair can be used to generate a new, unused, local @@ -576,7 +598,7 @@ func (s *SQLStore) GetSession(ctx context.Context, alias ID) (*Session, error) { } return nil - }) + }, sqldb.NoOpReset) return sess, err } @@ -617,7 +639,7 @@ func (s *SQLStore) GetGroupID(ctx context.Context, sessionID ID) (ID, error) { legacyGroupID, err = IDFromBytes(legacyGroupIDB) return err - }) + }, sqldb.NoOpReset) if err != nil { return ID{}, err } @@ -666,7 +688,7 @@ func (s *SQLStore) GetSessionIDs(ctx context.Context, legacyGroupID ID) ([]ID, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } diff --git a/session/test_kvdb.go b/session/test_kvdb.go index 241448410..cc939159d 100644 --- a/session/test_kvdb.go +++ b/session/test_kvdb.go @@ -11,14 +11,14 @@ import ( ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *BoltStore { +func NewTestDB(t *testing.T, clock clock.Clock) Store { return NewTestDBFromPath(t, t.TempDir(), clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. func NewTestDBFromPath(t *testing.T, dbPath string, - clock clock.Clock) *BoltStore { + clock clock.Clock) Store { acctStore := accounts.NewTestDB(t, clock) @@ -28,13 +28,13 @@ func NewTestDBFromPath(t *testing.T, dbPath string, // NewTestDBWithAccounts creates a new test session Store with access to an // existing accounts DB. func NewTestDBWithAccounts(t *testing.T, clock clock.Clock, - acctStore accounts.Store) *BoltStore { + acctStore accounts.Store) Store { return newDBFromPathWithAccounts(t, clock, t.TempDir(), acctStore) } func newDBFromPathWithAccounts(t *testing.T, clock clock.Clock, dbPath string, - acctStore accounts.Store) *BoltStore { + acctStore accounts.Store) Store { store, err := NewDB(dbPath, DBFilename, clock, acctStore) require.NoError(t, err) diff --git a/session/test_postgres.go b/session/test_postgres.go index db392fe7f..cb5aa061d 100644 --- a/session/test_postgres.go +++ b/session/test_postgres.go @@ -15,14 +15,14 @@ import ( var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLStore { - return NewSQLStore(db.NewTestPostgresDB(t).BaseDB, clock) +func NewTestDB(t *testing.T, clock clock.Clock) Store { + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a // connection to an existing postgres database for testing. func NewTestDBFromPath(t *testing.T, dbPath string, - clock clock.Clock) *SQLStore { + clock clock.Clock) Store { - return NewSQLStore(db.NewTestPostgresDB(t).BaseDB, clock) + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } diff --git a/session/test_sql.go b/session/test_sql.go index ab4b32a6c..5623c8207 100644 --- a/session/test_sql.go +++ b/session/test_sql.go @@ -6,15 +6,32 @@ import ( "testing" "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) func NewTestDBWithAccounts(t *testing.T, clock clock.Clock, - acctStore accounts.Store) *SQLStore { + acctStore accounts.Store) Store { accounts, ok := acctStore.(*accounts.SQLStore) require.True(t, ok) - return NewSQLStore(accounts.BaseDB, clock) + return createStore(t, accounts.BaseDB, clock) +} + +// createStore is a helper function that creates a new SQLStore and ensure that +// it is closed when during the test cleanup. +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, + clock clock.Clock) *SQLStore { + + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLStore(sqlDB, queries, clock) + t.Cleanup(func() { + require.NoError(t, store.Close()) + }) + + return store } diff --git a/session/test_sqlite.go b/session/test_sqlite.go index 87519f4f1..84d946ce2 100644 --- a/session/test_sqlite.go +++ b/session/test_sqlite.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // ErrDBClosed is an error that is returned when a database operation is @@ -15,16 +16,19 @@ import ( var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLStore { - return NewSQLStore(db.NewTestSqliteDB(t).BaseDB, clock) +func NewTestDB(t *testing.T, clock clock.Clock) Store { + return createStore( + t, sqldb.NewTestSqliteDB(t, db.LitdMigrationStreams).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a // connection to an existing sqlite database for testing. func NewTestDBFromPath(t *testing.T, dbPath string, - clock clock.Clock) *SQLStore { + clock clock.Clock) Store { - return NewSQLStore( - db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, - ) + tDb := sqldb.NewTestSqliteDBFromPath(t, dbPath, db.LitdMigrationStreams) + + return createStore(t, tDb.BaseDB, clock) }