diff --git a/db/sqlc/kvstores.sql.go b/db/sqlc/kvstores.sql.go index b2e6632f4..b46719eec 100644 --- a/db/sqlc/kvstores.sql.go +++ b/db/sqlc/kvstores.sql.go @@ -25,7 +25,7 @@ DELETE FROM kvstores WHERE entry_key = $1 AND rule_id = $2 AND perm = $3 - AND session_id = $4 + AND group_id = $4 AND feature_id = $5 ` @@ -33,7 +33,7 @@ type DeleteFeatureKVStoreRecordParams struct { Key string RuleID int64 Perm bool - SessionID sql.NullInt64 + GroupID sql.NullInt64 FeatureID sql.NullInt64 } @@ -42,7 +42,7 @@ func (q *Queries) DeleteFeatureKVStoreRecord(ctx context.Context, arg DeleteFeat arg.Key, arg.RuleID, arg.Perm, - arg.SessionID, + arg.GroupID, arg.FeatureID, ) return err @@ -53,7 +53,7 @@ DELETE FROM kvstores WHERE entry_key = $1 AND rule_id = $2 AND perm = $3 - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL ` @@ -68,28 +68,28 @@ func (q *Queries) DeleteGlobalKVStoreRecord(ctx context.Context, arg DeleteGloba return err } -const deleteSessionKVStoreRecord = `-- name: DeleteSessionKVStoreRecord :exec +const deleteGroupKVStoreRecord = `-- name: DeleteGroupKVStoreRecord :exec DELETE FROM kvstores WHERE entry_key = $1 AND rule_id = $2 AND perm = $3 - AND session_id = $4 + AND group_id = $4 AND feature_id IS NULL ` -type DeleteSessionKVStoreRecordParams struct { - Key string - RuleID int64 - Perm bool - SessionID sql.NullInt64 +type DeleteGroupKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 } -func (q *Queries) DeleteSessionKVStoreRecord(ctx context.Context, arg DeleteSessionKVStoreRecordParams) error { - _, err := q.db.ExecContext(ctx, deleteSessionKVStoreRecord, +func (q *Queries) DeleteGroupKVStoreRecord(ctx context.Context, arg DeleteGroupKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, deleteGroupKVStoreRecord, arg.Key, arg.RuleID, arg.Perm, - arg.SessionID, + arg.GroupID, ) return err } @@ -113,7 +113,7 @@ FROM kvstores WHERE entry_key = $1 AND rule_id = $2 AND perm = $3 - AND session_id = $4 + AND group_id = $4 AND feature_id = $5 ` @@ -121,7 +121,7 @@ type GetFeatureKVStoreRecordParams struct { Key string RuleID int64 Perm bool - SessionID sql.NullInt64 + GroupID sql.NullInt64 FeatureID sql.NullInt64 } @@ -130,7 +130,7 @@ func (q *Queries) GetFeatureKVStoreRecord(ctx context.Context, arg GetFeatureKVS arg.Key, arg.RuleID, arg.Perm, - arg.SessionID, + arg.GroupID, arg.FeatureID, ) var value []byte @@ -144,7 +144,7 @@ FROM kvstores WHERE entry_key = $1 AND rule_id = $2 AND perm = $3 - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL ` @@ -161,6 +161,35 @@ func (q *Queries) GetGlobalKVStoreRecord(ctx context.Context, arg GetGlobalKVSto return value, err } +const getGroupKVStoreRecord = `-- name: GetGroupKVStoreRecord :one +SELECT value +FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id = $4 + AND feature_id IS NULL +` + +type GetGroupKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 +} + +func (q *Queries) GetGroupKVStoreRecord(ctx context.Context, arg GetGroupKVStoreRecordParams) ([]byte, error) { + row := q.db.QueryRowContext(ctx, getGroupKVStoreRecord, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + ) + var value []byte + err := row.Scan(&value) + return value, err +} + const getOrInsertFeatureID = `-- name: GetOrInsertFeatureID :one INSERT INTO features (name) VALUES ($1) @@ -202,44 +231,15 @@ func (q *Queries) GetRuleID(ctx context.Context, name string) (int64, error) { return id, err } -const getSessionKVStoreRecord = `-- name: GetSessionKVStoreRecord :one -SELECT value -FROM kvstores -WHERE entry_key = $1 - AND rule_id = $2 - AND perm = $3 - AND session_id = $4 - AND feature_id IS NULL -` - -type GetSessionKVStoreRecordParams struct { - Key string - RuleID int64 - Perm bool - SessionID sql.NullInt64 -} - -func (q *Queries) GetSessionKVStoreRecord(ctx context.Context, arg GetSessionKVStoreRecordParams) ([]byte, error) { - row := q.db.QueryRowContext(ctx, getSessionKVStoreRecord, - arg.Key, - arg.RuleID, - arg.Perm, - arg.SessionID, - ) - var value []byte - err := row.Scan(&value) - return value, err -} - const insertKVStoreRecord = `-- name: InsertKVStoreRecord :exec -INSERT INTO kvstores (perm, rule_id, session_id, feature_id, entry_key, value) +INSERT INTO kvstores (perm, rule_id, group_id, feature_id, entry_key, value) VALUES ($1, $2, $3, $4, $5, $6) ` type InsertKVStoreRecordParams struct { Perm bool RuleID int64 - SessionID sql.NullInt64 + GroupID sql.NullInt64 FeatureID sql.NullInt64 EntryKey string Value []byte @@ -249,7 +249,7 @@ func (q *Queries) InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreReco _, err := q.db.ExecContext(ctx, insertKVStoreRecord, arg.Perm, arg.RuleID, - arg.SessionID, + arg.GroupID, arg.FeatureID, arg.EntryKey, arg.Value, @@ -257,13 +257,49 @@ func (q *Queries) InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreReco return err } +const listAllKVStoresRecords = `-- name: ListAllKVStoresRecords :many +SELECT id, perm, rule_id, group_id, feature_id, entry_key, value +FROM kvstores +` + +func (q *Queries) ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) { + rows, err := q.db.QueryContext(ctx, listAllKVStoresRecords) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Kvstore + for rows.Next() { + var i Kvstore + if err := rows.Scan( + &i.ID, + &i.Perm, + &i.RuleID, + &i.GroupID, + &i.FeatureID, + &i.EntryKey, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const updateFeatureKVStoreRecord = `-- name: UpdateFeatureKVStoreRecord :exec UPDATE kvstores SET value = $1 WHERE entry_key = $2 AND rule_id = $3 AND perm = $4 - AND session_id = $5 + AND group_id = $5 AND feature_id = $6 ` @@ -272,7 +308,7 @@ type UpdateFeatureKVStoreRecordParams struct { Key string RuleID int64 Perm bool - SessionID sql.NullInt64 + GroupID sql.NullInt64 FeatureID sql.NullInt64 } @@ -282,7 +318,7 @@ func (q *Queries) UpdateFeatureKVStoreRecord(ctx context.Context, arg UpdateFeat arg.Key, arg.RuleID, arg.Perm, - arg.SessionID, + arg.GroupID, arg.FeatureID, ) return err @@ -294,7 +330,7 @@ SET value = $1 WHERE entry_key = $2 AND rule_id = $3 AND perm = $4 - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL ` @@ -315,31 +351,31 @@ func (q *Queries) UpdateGlobalKVStoreRecord(ctx context.Context, arg UpdateGloba return err } -const updateSessionKVStoreRecord = `-- name: UpdateSessionKVStoreRecord :exec +const updateGroupKVStoreRecord = `-- name: UpdateGroupKVStoreRecord :exec UPDATE kvstores SET value = $1 WHERE entry_key = $2 AND rule_id = $3 AND perm = $4 - AND session_id = $5 + AND group_id = $5 AND feature_id IS NULL ` -type UpdateSessionKVStoreRecordParams struct { - Value []byte - Key string - RuleID int64 - Perm bool - SessionID sql.NullInt64 +type UpdateGroupKVStoreRecordParams struct { + Value []byte + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 } -func (q *Queries) UpdateSessionKVStoreRecord(ctx context.Context, arg UpdateSessionKVStoreRecordParams) error { - _, err := q.db.ExecContext(ctx, updateSessionKVStoreRecord, +func (q *Queries) UpdateGroupKVStoreRecord(ctx context.Context, arg UpdateGroupKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, updateGroupKVStoreRecord, arg.Value, arg.Key, arg.RuleID, arg.Perm, - arg.SessionID, + arg.GroupID, ) return err } diff --git a/db/sqlc/migrations/000003_kvstores.up.sql b/db/sqlc/migrations/000003_kvstores.up.sql index d2f0653a5..e49ed9622 100644 --- a/db/sqlc/migrations/000003_kvstores.up.sql +++ b/db/sqlc/migrations/000003_kvstores.up.sql @@ -21,7 +21,7 @@ CREATE TABLE IF NOT EXISTS features ( CREATE UNIQUE INDEX IF NOT EXISTS features_name_idx ON features (name); -- kvstores houses key-value pairs under various namespaces determined --- by the rule name, session ID, and feature name. +-- by the rule name, group ID, and feature name. CREATE TABLE IF NOT EXISTS kvstores ( -- The auto incrementing primary key. id INTEGER PRIMARY KEY, @@ -35,15 +35,15 @@ CREATE TABLE IF NOT EXISTS kvstores ( -- kv_store. rule_id BIGINT REFERENCES rules(id) NOT NULL, - -- The session ID that this kv_store belongs to. - -- If this is set, then this kv_store is a session-specific + -- The group ID that this kv_store belongs to. + -- If this is set, then this kv_store is a session-group specific -- kv_store for the given rule. - session_id BIGINT REFERENCES sessions(id) ON DELETE CASCADE, + group_id BIGINT REFERENCES sessions(id) ON DELETE CASCADE, -- The feature name that this kv_store belongs to. -- If this is set, then this kv_store is a feature-specific - -- kvstore under the given session ID and rule name. - -- If this is set, then session_id must also be set. + -- kvstore under the given group ID and rule name. + -- If this is set, then group_id must also be set. feature_id BIGINT REFERENCES features(id), -- The key of the entry. @@ -54,4 +54,4 @@ CREATE TABLE IF NOT EXISTS kvstores ( ); CREATE UNIQUE INDEX IF NOT EXISTS kvstores_lookup_idx - ON kvstores (entry_key, rule_id, perm, session_id, feature_id); + ON kvstores (entry_key, rule_id, perm, group_id, feature_id); diff --git a/db/sqlc/models.go b/db/sqlc/models.go index 357360c9e..d19e66e10 100644 --- a/db/sqlc/models.go +++ b/db/sqlc/models.go @@ -63,7 +63,7 @@ type Kvstore struct { ID int64 Perm bool RuleID int64 - SessionID sql.NullInt64 + GroupID sql.NullInt64 FeatureID sql.NullInt64 EntryKey string Value []byte diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index df89d0898..d76d5e6e3 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -16,7 +16,7 @@ type Querier interface { DeleteAllTempKVStores(ctx context.Context) error DeleteFeatureKVStoreRecord(ctx context.Context, arg DeleteFeatureKVStoreRecordParams) error DeleteGlobalKVStoreRecord(ctx context.Context, arg DeleteGlobalKVStoreRecordParams) error - DeleteSessionKVStoreRecord(ctx context.Context, arg DeleteSessionKVStoreRecordParams) error + DeleteGroupKVStoreRecord(ctx context.Context, arg DeleteGroupKVStoreRecordParams) error DeleteSessionsWithState(ctx context.Context, state int16) error GetAccount(ctx context.Context, id int64) (Account, error) GetAccountByLabel(ctx context.Context, label sql.NullString) (Account, error) @@ -29,6 +29,7 @@ type Querier interface { GetFeatureID(ctx context.Context, name string) (int64, error) GetFeatureKVStoreRecord(ctx context.Context, arg GetFeatureKVStoreRecordParams) ([]byte, error) GetGlobalKVStoreRecord(ctx context.Context, arg GetGlobalKVStoreRecordParams) ([]byte, error) + GetGroupKVStoreRecord(ctx context.Context, arg GetGroupKVStoreRecordParams) ([]byte, error) GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) GetOrInsertRuleID(ctx context.Context, name string) (int64, error) GetPseudoForReal(ctx context.Context, arg GetPseudoForRealParams) (string, error) @@ -40,7 +41,6 @@ type Querier interface { GetSessionByLocalPublicKey(ctx context.Context, localPublicKey []byte) (Session, error) GetSessionFeatureConfigs(ctx context.Context, sessionID int64) ([]SessionFeatureConfig, error) GetSessionIDByAlias(ctx context.Context, alias []byte) (int64, error) - GetSessionKVStoreRecord(ctx context.Context, arg GetSessionKVStoreRecordParams) ([]byte, error) GetSessionMacaroonCaveats(ctx context.Context, sessionID int64) ([]SessionMacaroonCaveat, error) GetSessionMacaroonPermissions(ctx context.Context, sessionID int64) ([]SessionMacaroonPermission, error) GetSessionPrivacyFlags(ctx context.Context, sessionID int64) ([]SessionPrivacyFlag, error) @@ -57,6 +57,7 @@ type Querier interface { ListAccountInvoices(ctx context.Context, accountID int64) ([]AccountInvoice, error) ListAccountPayments(ctx context.Context, accountID int64) ([]AccountPayment, error) ListAllAccounts(ctx context.Context) ([]Account, error) + ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) ListSessions(ctx context.Context) ([]Session, error) ListSessionsByState(ctx context.Context, state int16) ([]Session, error) ListSessionsByType(ctx context.Context, type_ int16) ([]Session, error) @@ -70,7 +71,7 @@ type Querier interface { UpdateAccountLastUpdate(ctx context.Context, arg UpdateAccountLastUpdateParams) (int64, error) UpdateFeatureKVStoreRecord(ctx context.Context, arg UpdateFeatureKVStoreRecordParams) error UpdateGlobalKVStoreRecord(ctx context.Context, arg UpdateGlobalKVStoreRecordParams) error - UpdateSessionKVStoreRecord(ctx context.Context, arg UpdateSessionKVStoreRecordParams) error + UpdateGroupKVStoreRecord(ctx context.Context, arg UpdateGroupKVStoreRecordParams) error UpdateSessionState(ctx context.Context, arg UpdateSessionStateParams) error UpsertAccountPayment(ctx context.Context, arg UpsertAccountPaymentParams) error } diff --git a/db/sqlc/queries/kvstores.sql b/db/sqlc/queries/kvstores.sql index 7963e46a4..6acc27468 100644 --- a/db/sqlc/queries/kvstores.sql +++ b/db/sqlc/queries/kvstores.sql @@ -21,29 +21,33 @@ FROM features WHERE name = sqlc.arg('name'); -- name: InsertKVStoreRecord :exec -INSERT INTO kvstores (perm, rule_id, session_id, feature_id, entry_key, value) +INSERT INTO kvstores (perm, rule_id, group_id, feature_id, entry_key, value) VALUES ($1, $2, $3, $4, $5, $6); -- name: DeleteAllTempKVStores :exec DELETE FROM kvstores WHERE perm = false; +-- name: ListAllKVStoresRecords :many +SELECT * +FROM kvstores; + -- name: GetGlobalKVStoreRecord :one SELECT value FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL; --- name: GetSessionKVStoreRecord :one +-- name: GetGroupKVStoreRecord :one SELECT value FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id IS NULL; -- name: GetFeatureKVStoreRecord :one @@ -52,7 +56,7 @@ FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id = sqlc.arg('feature_id'); -- name: DeleteGlobalKVStoreRecord :exec @@ -60,15 +64,15 @@ DELETE FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL; --- name: DeleteSessionKVStoreRecord :exec +-- name: DeleteGroupKVStoreRecord :exec DELETE FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id IS NULL; -- name: DeleteFeatureKVStoreRecord :exec @@ -76,7 +80,7 @@ DELETE FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id = sqlc.arg('feature_id'); -- name: UpdateGlobalKVStoreRecord :exec @@ -85,16 +89,16 @@ SET value = $1 WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL; --- name: UpdateSessionKVStoreRecord :exec +-- name: UpdateGroupKVStoreRecord :exec UPDATE kvstores SET value = $1 WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id IS NULL; -- name: UpdateFeatureKVStoreRecord :exec @@ -103,5 +107,5 @@ SET value = $1 WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id = sqlc.arg('feature_id'); diff --git a/firewalldb/actions_test.go b/firewalldb/actions_test.go index c27e53e96..69990c1da 100644 --- a/firewalldb/actions_test.go +++ b/firewalldb/actions_test.go @@ -28,9 +28,6 @@ func TestActionStorage(t *testing.T) { sessDB := session.NewTestDBWithAccounts(t, clock, accountsDB) db := NewTestDBWithSessionsAndAccounts(t, sessDB, accountsDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // Assert that attempting to add an action for a session that does not // exist returns an error. @@ -198,9 +195,6 @@ func TestListActions(t *testing.T) { sessDB := session.NewTestDB(t, clock) db := NewTestDBWithSessions(t, sessDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // Add 2 sessions that we can reference. sess1, err := sessDB.NewSession( @@ -466,9 +460,6 @@ func TestListGroupActions(t *testing.T) { } db := NewTestDBWithSessions(t, sessDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // There should not be any actions in group 1 yet. al, _, _, err := db.ListActions(ctx, nil, WithActionGroupID(group1)) diff --git a/firewalldb/db.go b/firewalldb/db.go index b8d9ed06f..a8349a538 100644 --- a/firewalldb/db.go +++ b/firewalldb/db.go @@ -14,29 +14,21 @@ var ( ErrNoSuchKeyFound = fmt.Errorf("no such key found") ) -// firewallDBs is an interface that groups the RulesDB and PrivacyMapper -// interfaces. -type firewallDBs interface { - RulesDB - PrivacyMapper - ActionDB -} - // DB manages the firewall rules database. type DB struct { started sync.Once stopped sync.Once - firewallDBs + FirewallDBs cancel fn.Option[context.CancelFunc] } // NewDB creates a new firewall database. For now, it only contains the // underlying rules' and privacy mapper databases. -func NewDB(dbs firewallDBs) *DB { +func NewDB(dbs FirewallDBs) *DB { return &DB{ - firewallDBs: dbs, + FirewallDBs: dbs, } } diff --git a/firewalldb/interface.go b/firewalldb/interface.go index 5ee729e91..c2955bdc6 100644 --- a/firewalldb/interface.go +++ b/firewalldb/interface.go @@ -134,3 +134,11 @@ type ActionDB interface { // and feature name. GetActionsReadDB(groupID session.ID, featureName string) ActionsReadDB } + +// FirewallDBs is an interface that groups the RulesDB, PrivacyMapper and +// ActionDB interfaces. +type FirewallDBs interface { + RulesDB + PrivacyMapper + ActionDB +} diff --git a/firewalldb/kvstores_kvdb.go b/firewalldb/kvstores_kvdb.go index 51721d475..78676e3ed 100644 --- a/firewalldb/kvstores_kvdb.go +++ b/firewalldb/kvstores_kvdb.go @@ -16,13 +16,13 @@ the temporary store changes instead of just keeping an in-memory store is that we can then guarantee atomicity if changes are made to both the permanent and temporary stores. -rules -> perm -> rule-name -> global -> {k:v} - -> sessions -> group ID -> session-kv-store -> {k:v} - -> feature-kv-stores -> feature-name -> {k:v} +"rules" -> "perm" -> -> "global" -> {k:v} + -> "session-kv-store" -> -> {k:v} + -> "feature-kv-stores" -> -> {k:v} - -> temp -> rule-name -> global -> {k:v} - -> sessions -> group ID -> session-kv-store -> {k:v} - -> feature-kv-stores -> feature-name -> {k:v} + -> "temp" -> -> "global" -> {k:v} + -> "session-kv-store" -> -> {k:v} + -> "feature-kv-stores" -> -> {k:v} */ var ( diff --git a/firewalldb/kvstores_sql.go b/firewalldb/kvstores_sql.go index 0c3df2ddb..248892130 100644 --- a/firewalldb/kvstores_sql.go +++ b/firewalldb/kvstores_sql.go @@ -22,13 +22,13 @@ type SQLKVStoreQueries interface { DeleteFeatureKVStoreRecord(ctx context.Context, arg sqlc.DeleteFeatureKVStoreRecordParams) error DeleteGlobalKVStoreRecord(ctx context.Context, arg sqlc.DeleteGlobalKVStoreRecordParams) error - DeleteSessionKVStoreRecord(ctx context.Context, arg sqlc.DeleteSessionKVStoreRecordParams) error + DeleteGroupKVStoreRecord(ctx context.Context, arg sqlc.DeleteGroupKVStoreRecordParams) error GetFeatureKVStoreRecord(ctx context.Context, arg sqlc.GetFeatureKVStoreRecordParams) ([]byte, error) GetGlobalKVStoreRecord(ctx context.Context, arg sqlc.GetGlobalKVStoreRecordParams) ([]byte, error) - GetSessionKVStoreRecord(ctx context.Context, arg sqlc.GetSessionKVStoreRecordParams) ([]byte, error) + GetGroupKVStoreRecord(ctx context.Context, arg sqlc.GetGroupKVStoreRecordParams) ([]byte, error) UpdateFeatureKVStoreRecord(ctx context.Context, arg sqlc.UpdateFeatureKVStoreRecordParams) error UpdateGlobalKVStoreRecord(ctx context.Context, arg sqlc.UpdateGlobalKVStoreRecordParams) error - UpdateSessionKVStoreRecord(ctx context.Context, arg sqlc.UpdateSessionKVStoreRecordParams) error + UpdateGroupKVStoreRecord(ctx context.Context, arg sqlc.UpdateGroupKVStoreRecordParams) error InsertKVStoreRecord(ctx context.Context, arg sqlc.InsertKVStoreRecordParams) error DeleteAllTempKVStores(ctx context.Context) error GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) @@ -198,7 +198,7 @@ func (s *sqlKVStore) Get(ctx context.Context, key string) ([]byte, error) { // // NOTE: part of the KVStore interface. func (s *sqlKVStore) Set(ctx context.Context, key string, value []byte) error { - ruleID, sessionID, featureID, err := s.genNamespaceFields(ctx, false) + ruleID, groupID, featureID, err := s.genNamespaceFields(ctx, false) if err != nil { return err } @@ -219,7 +219,7 @@ func (s *sqlKVStore) Set(ctx context.Context, key string, value []byte) error { Value: value, Perm: s.params.perm, RuleID: ruleID, - SessionID: sessionID, + GroupID: groupID, FeatureID: featureID, }, ) @@ -233,26 +233,26 @@ func (s *sqlKVStore) Set(ctx context.Context, key string, value []byte) error { // Otherwise, the key exists but the value needs to be updated. switch { - case sessionID.Valid && featureID.Valid: + case groupID.Valid && featureID.Valid: return s.queries.UpdateFeatureKVStoreRecord( ctx, sqlc.UpdateFeatureKVStoreRecordParams{ Key: key, Value: value, Perm: s.params.perm, - SessionID: sessionID, + GroupID: groupID, RuleID: ruleID, FeatureID: featureID, }, ) - case sessionID.Valid: - return s.queries.UpdateSessionKVStoreRecord( - ctx, sqlc.UpdateSessionKVStoreRecordParams{ - Key: key, - Value: value, - Perm: s.params.perm, - SessionID: sessionID, - RuleID: ruleID, + case groupID.Valid: + return s.queries.UpdateGroupKVStoreRecord( + ctx, sqlc.UpdateGroupKVStoreRecordParams{ + Key: key, + Value: value, + Perm: s.params.perm, + GroupID: groupID, + RuleID: ruleID, }, ) @@ -278,7 +278,7 @@ func (s *sqlKVStore) Del(ctx context.Context, key string) error { // Note: we pass in true here for "read-only" since because this is a // Delete, if the record does not exist, we don't need to create one. // But no need to error out if it doesn't exist. - ruleID, sessionID, featureID, err := s.genNamespaceFields(ctx, true) + ruleID, groupID, featureID, err := s.genNamespaceFields(ctx, true) if errors.Is(err, sql.ErrNoRows) || errors.Is(err, session.ErrUnknownGroup) { @@ -288,24 +288,24 @@ func (s *sqlKVStore) Del(ctx context.Context, key string) error { } switch { - case sessionID.Valid && featureID.Valid: + case groupID.Valid && featureID.Valid: return s.queries.DeleteFeatureKVStoreRecord( ctx, sqlc.DeleteFeatureKVStoreRecordParams{ Key: key, Perm: s.params.perm, - SessionID: sessionID, + GroupID: groupID, RuleID: ruleID, FeatureID: featureID, }, ) - case sessionID.Valid: - return s.queries.DeleteSessionKVStoreRecord( - ctx, sqlc.DeleteSessionKVStoreRecordParams{ - Key: key, - Perm: s.params.perm, - SessionID: sessionID, - RuleID: ruleID, + case groupID.Valid: + return s.queries.DeleteGroupKVStoreRecord( + ctx, sqlc.DeleteGroupKVStoreRecordParams{ + Key: key, + Perm: s.params.perm, + GroupID: groupID, + RuleID: ruleID, }, ) @@ -326,30 +326,30 @@ func (s *sqlKVStore) Del(ctx context.Context, key string) error { // get fetches the value under the given key from the underlying kv store given // the namespace fields. func (s *sqlKVStore) get(ctx context.Context, key string) ([]byte, error) { - ruleID, sessionID, featureID, err := s.genNamespaceFields(ctx, true) + ruleID, groupID, featureID, err := s.genNamespaceFields(ctx, true) if err != nil { return nil, err } switch { - case sessionID.Valid && featureID.Valid: + case groupID.Valid && featureID.Valid: return s.queries.GetFeatureKVStoreRecord( ctx, sqlc.GetFeatureKVStoreRecordParams{ Key: key, Perm: s.params.perm, - SessionID: sessionID, + GroupID: groupID, RuleID: ruleID, FeatureID: featureID, }, ) - case sessionID.Valid: - return s.queries.GetSessionKVStoreRecord( - ctx, sqlc.GetSessionKVStoreRecordParams{ - Key: key, - Perm: s.params.perm, - SessionID: sessionID, - RuleID: ruleID, + case groupID.Valid: + return s.queries.GetGroupKVStoreRecord( + ctx, sqlc.GetGroupKVStoreRecordParams{ + Key: key, + Perm: s.params.perm, + GroupID: groupID, + RuleID: ruleID, }, ) @@ -373,7 +373,7 @@ func (s *sqlKVStore) genNamespaceFields(ctx context.Context, readOnly bool) (int64, sql.NullInt64, sql.NullInt64, error) { var ( - sessionID sql.NullInt64 + groupID sql.NullInt64 featureID sql.NullInt64 ruleID int64 err error @@ -382,8 +382,8 @@ func (s *sqlKVStore) genNamespaceFields(ctx context.Context, // If a group ID is specified, then we first check that this group ID // is a known session alias. s.params.groupID.WhenSome(func(id session.ID) { - var groupID int64 - groupID, err = s.queries.GetSessionIDByAlias(ctx, id[:]) + var dbGroupID int64 + dbGroupID, err = s.queries.GetSessionIDByAlias(ctx, id[:]) if errors.Is(err, sql.ErrNoRows) { err = session.ErrUnknownGroup @@ -392,20 +392,20 @@ func (s *sqlKVStore) genNamespaceFields(ctx context.Context, return } - sessionID = sql.NullInt64{ - Int64: groupID, + groupID = sql.NullInt64{ + Int64: dbGroupID, Valid: true, } }) if err != nil { - return ruleID, sessionID, featureID, err + return ruleID, groupID, featureID, err } // We only insert a new rule name into the DB if this is a write call. if readOnly { ruleID, err = s.queries.GetRuleID(ctx, s.params.ruleName) if err != nil { - return 0, sessionID, featureID, + return 0, groupID, featureID, fmt.Errorf("unable to get rule ID: %w", err) } } else { @@ -413,7 +413,7 @@ func (s *sqlKVStore) genNamespaceFields(ctx context.Context, ctx, s.params.ruleName, ) if err != nil { - return 0, sessionID, featureID, + return 0, groupID, featureID, fmt.Errorf("unable to get or insert rule "+ "ID: %w", err) } @@ -441,5 +441,5 @@ func (s *sqlKVStore) genNamespaceFields(ctx context.Context, } }) - return ruleID, sessionID, featureID, err + return ruleID, groupID, featureID, err } diff --git a/firewalldb/sql_migration.go b/firewalldb/sql_migration.go new file mode 100644 index 000000000..1e114c12c --- /dev/null +++ b/firewalldb/sql_migration.go @@ -0,0 +1,492 @@ +package firewalldb + +import ( + "bytes" + "context" + "database/sql" + "errors" + "fmt" + + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb" + "go.etcd.io/bbolt" +) + +// kvEntry represents a single KV entry inserted into the BoltDB. +type kvEntry struct { + perm bool + ruleName string + key string + value []byte + + // groupAlias is the legacy session group alias that the entry is + // associated with. For global entries, this will be fn.None[[]byte]. + groupAlias fn.Option[[]byte] + + // featureName is the name of the feature that the entry is associated + // with. If the entry is not feature specific, this will be + // fn.None[string]. + featureName fn.Option[string] +} + +// sqlKvEntry represents a single KV entry inserted into the SQL DB, containing +// the same fields as the kvEntry, but with additional fields that represent the +// SQL IDs of the rule, session group, and feature. +type sqlKvEntry struct { + *kvEntry + + ruleID int64 + + // groupID is the sql session group ID that the entry is associated + // with. For global entries, this will be Valid=false. + groupID sql.NullInt64 + + // featureID is the sql feature ID that the entry is associated with. + // This is only set if the entry is feature specific, and will be + // Valid=false for other types entries. If this is set, then groupID + // will also be set. + featureID sql.NullInt64 +} + +// namespacedKey returns a string representation of the kvEntry purely used for +// logging purposes. +func (e *kvEntry) namespacedKey() string { + ns := fmt.Sprintf("perm: %t, rule: %s", e.perm, e.ruleName) + + e.groupAlias.WhenSome(func(alias []byte) { + ns += fmt.Sprintf(", group: %s", alias) + }) + + e.featureName.WhenSome(func(feature string) { + ns += fmt.Sprintf(", feature: %s", feature) + }) + + ns += fmt.Sprintf(", key: %s", e.key) + + return ns +} + +// MigrateFirewallDBToSQL runs the migration of the firwalldb stores from the +// bbolt database to a SQL database. The migration is done in a single +// transaction to ensure that all rows in the stores are migrated or none at +// all. +// +// Note that this migration currently only migrates the kvstores, but will be +// extended in the future to also migrate the privacy mapper and action stores. +// +// NOTE: As sessions may contain linked sessions and accounts, the sessions and +// accounts sql migration MUST be run prior to this migration. +func MigrateFirewallDBToSQL(ctx context.Context, kvStore *bbolt.DB, + sqlTx SQLQueries) error { + + log.Infof("Starting migration of the rules DB to SQL") + + err := migrateKVStoresDBToSQL(ctx, kvStore, sqlTx) + if err != nil { + return err + } + + log.Infof("The rules DB has been migrated from KV to SQL.") + + // TODO(viktor): Add migration for the privacy mapper and the action + // stores. + + return nil +} + +// migrateKVStoresDBToSQL runs the migration of all KV stores from the KV +// database to the SQL database. The function also asserts that the +// migrated values match the original values in the KV store. +func migrateKVStoresDBToSQL(ctx context.Context, kvStore *bbolt.DB, + sqlTx SQLQueries) error { + + log.Infof("Starting migration of the KV stores to SQL") + + var pairs []*kvEntry + + // 1) Collect all key-value pairs from the KV store. + err := kvStore.View(func(tx *bbolt.Tx) error { + var err error + pairs, err = collectAllPairs(tx) + return err + }) + if err != nil { + return fmt.Errorf("collecting all kv pairs failed: %w", err) + } + + var insertedPairs []*sqlKvEntry + + // 2) Insert all collected key-value pairs into the SQL database. + for _, entry := range pairs { + insertedPair, err := insertPair(ctx, sqlTx, entry) + if err != nil { + return fmt.Errorf("inserting kv pair %v failed: %w", + entry.key, err) + } + + insertedPairs = append(insertedPairs, insertedPair) + } + + // 3) Validate the migrated values against the original values. + for _, insertedPair := range insertedPairs { + // Fetch the appropriate SQL entry's value. + migratedValue, err := getSQLValue(ctx, sqlTx, insertedPair) + if err != nil { + return fmt.Errorf("getting SQL value for key %s "+ + "failed: %w", insertedPair.namespacedKey(), err) + } + + // Compare the value of the migrated entry with the original + // value from the KV store. + // NOTE: if the insert a []byte{} value into the sqldb as the + // entry value, and then retrieve it, the value will be + // returned as nil. The bytes.Equal will pass in that case, + // and therefore such cases won't error out. The kvdb instance + // can store []byte{} values. + if !bytes.Equal(migratedValue, insertedPair.value) { + return fmt.Errorf("migrated value for key %s "+ + "does not match original value: "+ + "migrated %x, original %x", + insertedPair.namespacedKey(), migratedValue, + insertedPair.value) + } + } + + log.Infof("Migration of the KV stores to SQL completed. Total number "+ + "of rows migrated: %d", len(pairs)) + + return nil +} + +// collectAllPairs collects all key-value pairs from the KV store, and returns +// them as a slice of kvEntry structs. The function expects the KV store to be +// stuctured as described in the comment in the firewalldb/kvstores_kvdb.go +// file. Any other structure will result in an error. +// Note that this function and the subsequent functions are intentionally +// designed to iterate over all buckets and values that exist in the KV store. +// That ensures that we find all stores and values that exist in the KV store, +// and can be sure that the kv store actually follows the expected structure. +func collectAllPairs(tx *bbolt.Tx) ([]*kvEntry, error) { + var entries []*kvEntry + for _, perm := range []bool{true, false} { + mainBucket, err := getMainBucket(tx, false, perm) + if err != nil { + return nil, err + } + + if mainBucket == nil { + // If the mainBucket doesn't exist, there are no entries + // to migrate under that bucket, therefore we don't + // error, and just proceed to not migrate any entries + // under that bucket. + continue + } + + // Loop over each rule-name bucket. + err = mainBucket.ForEach(func(rule, v []byte) error { + if v != nil { + return errors.New("expected only " + + "buckets under main bucket") + } + + ruleBucket := mainBucket.Bucket(rule) + if ruleBucket == nil { + return fmt.Errorf("rule bucket %s not found", + rule) + } + + pairs, err := collectRulePairs( + ruleBucket, perm, string(rule), + ) + if err != nil { + return err + } + + entries = append(entries, pairs...) + + return nil + }) + if err != nil { + return nil, err + } + } + + return entries, nil +} + +// collectRulePairs processes a single rule bucket, which should contain the +// global and session-kv-store key buckets. +func collectRulePairs(bkt *bbolt.Bucket, perm bool, rule string) ([]*kvEntry, + error) { + + var params []*kvEntry + + err := verifyBktKeys( + bkt, true, globalKVStoreBucketKey, sessKVStoreBucketKey, + ) + if err != nil { + return params, fmt.Errorf("verifying rule bucket %s keys "+ + "failed: %w", rule, err) + } + + if globalBkt := bkt.Bucket(globalKVStoreBucketKey); globalBkt != nil { + p, err := collectKVPairs( + globalBkt, true, perm, rule, + fn.None[[]byte](), fn.None[string](), + ) + if err != nil { + return nil, fmt.Errorf("collecting global kv pairs "+ + "failed: %w", err) + } + + params = append(params, p...) + } + + if sessBkt := bkt.Bucket(sessKVStoreBucketKey); sessBkt != nil { + err := sessBkt.ForEach(func(groupAlias, v []byte) error { + if v != nil { + return fmt.Errorf("expected only buckets "+ + "under %s bucket", sessKVStoreBucketKey) + } + + groupBucket := sessBkt.Bucket(groupAlias) + if groupBucket == nil { + return fmt.Errorf("group bucket for group "+ + "alias %s not found", groupAlias) + } + + kvPairs, err := collectKVPairs( + groupBucket, false, perm, rule, + fn.Some(groupAlias), fn.None[string](), + ) + if err != nil { + return fmt.Errorf("collecting group kv "+ + "pairs failed: %w", err) + } + + params = append(params, kvPairs...) + + err = verifyBktKeys( + groupBucket, false, featureKVStoreBucketKey, + ) + if err != nil { + return fmt.Errorf("verification of group "+ + "bucket %s keys failed: %w", groupAlias, + err) + } + + ftBkt := groupBucket.Bucket(featureKVStoreBucketKey) + if ftBkt == nil { + return nil + } + + return ftBkt.ForEach(func(ftName, v []byte) error { + if v != nil { + return fmt.Errorf("expected only "+ + "buckets under %s bucket", + featureKVStoreBucketKey) + } + + // The feature name should exist, as per the + // verification above. + featureBucket := ftBkt.Bucket(ftName) + if featureBucket == nil { + return fmt.Errorf("feature bucket "+ + "%s not found", ftName) + } + + featurePairs, err := collectKVPairs( + featureBucket, true, perm, rule, + fn.Some(groupAlias), + fn.Some(string(ftName)), + ) + if err != nil { + return fmt.Errorf("collecting "+ + "feature kv pairs failed: %w", + err) + } + + params = append(params, featurePairs...) + + return nil + }) + }) + if err != nil { + return nil, fmt.Errorf("collecting session kv pairs "+ + "failed: %w", err) + } + } + + return params, nil +} + +// collectKVPairs collects all key-value pairs from the given bucket, and +// returns them as a slice of kvEntry structs. If the errorOnBuckets parameter +// is set to true, then the function will return an error if the bucket +// contains any sub-buckets. Note that when the errorOnBuckets parameter is +// set to false, the function will not collect any key-value pairs from the +// sub-buckets, and will just ignore them. +func collectKVPairs(bkt *bbolt.Bucket, errorOnBuckets, perm bool, + ruleName string, groupAlias fn.Option[[]byte], + featureName fn.Option[string]) ([]*kvEntry, error) { + + var params []*kvEntry + + return params, bkt.ForEach(func(key, value []byte) error { + // If the value is nil, then this is a bucket, which we + // don't want to process here, as we only want to collect + // the key-value pairs, not the buckets. If we should + // error on buckets, then we return an error here. + if value == nil { + if errorOnBuckets { + return fmt.Errorf("unexpected bucket %s found "+ + "in when collecting kv pairs", key) + } + + return nil + } + + params = append(params, &kvEntry{ + perm: perm, + ruleName: ruleName, + key: string(key), + featureName: featureName, + groupAlias: groupAlias, + value: value, + }) + + return nil + }) +} + +// insertPair inserts a single key-value pair into the SQL database. +func insertPair(ctx context.Context, tx SQLQueries, + entry *kvEntry) (*sqlKvEntry, error) { + + ruleID, err := tx.GetOrInsertRuleID(ctx, entry.ruleName) + if err != nil { + return nil, err + } + + p := sqlc.InsertKVStoreRecordParams{ + Perm: entry.perm, + RuleID: ruleID, + EntryKey: entry.key, + Value: entry.value, + } + + entry.groupAlias.WhenSome(func(alias []byte) { + var groupID int64 + groupID, err = tx.GetSessionIDByAlias(ctx, alias) + if err != nil { + err = fmt.Errorf("getting group id by alias %x "+ + "failed: %w", alias, err) + return + } + + p.GroupID = sqldb.SQLInt64(groupID) + }) + if err != nil { + return nil, err + } + + entry.featureName.WhenSome(func(feature string) { + var featureID int64 + featureID, err = tx.GetOrInsertFeatureID(ctx, feature) + if err != nil { + err = fmt.Errorf("getting/inserting feature id for %s "+ + "failed: %w", feature, err) + return + } + + p.FeatureID = sqldb.SQLInt64(featureID) + }) + if err != nil { + return nil, err + } + + err = tx.InsertKVStoreRecord(ctx, p) + if err != nil { + return nil, err + } + + return &sqlKvEntry{ + kvEntry: entry, + ruleID: p.RuleID, + groupID: p.GroupID, + featureID: p.FeatureID, + }, nil +} + +// getSQLValue retrieves the key value for the given kvEntry from the SQL +// database. +func getSQLValue(ctx context.Context, tx SQLQueries, + entry *sqlKvEntry) ([]byte, error) { + + switch { + case entry.featureID.Valid && entry.groupID.Valid: + return tx.GetFeatureKVStoreRecord( + ctx, sqlc.GetFeatureKVStoreRecordParams{ + Perm: entry.perm, + RuleID: entry.ruleID, + GroupID: entry.groupID, + FeatureID: entry.featureID, + Key: entry.key, + }, + ) + case entry.groupID.Valid: + return tx.GetGroupKVStoreRecord( + ctx, sqlc.GetGroupKVStoreRecordParams{ + Perm: entry.perm, + RuleID: entry.ruleID, + GroupID: entry.groupID, + Key: entry.key, + }, + ) + case !entry.featureID.Valid && !entry.groupID.Valid: + return tx.GetGlobalKVStoreRecord( + ctx, sqlc.GetGlobalKVStoreRecordParams{ + Perm: entry.perm, + RuleID: entry.ruleID, + Key: entry.key, + }, + ) + default: + return nil, fmt.Errorf("invalid combination of feature and "+ + "session ID: featureID valid: %v, groupID valid: %v", + entry.featureID.Valid, entry.groupID.Valid) + } +} + +// verifyBktKeys checks that the given bucket only contains buckets with the +// passed keys, and optionally also key-value pairs. If the errorOnKeyValues +// parameter is set to true, the function will error if it finds key-value pairs +// in the bucket. +func verifyBktKeys(bkt *bbolt.Bucket, errorOnKeyValues bool, + keys ...[]byte) error { + + return bkt.ForEach(func(key, v []byte) error { + if v != nil { + // If we allow key-values, then we can just continue + // to the next key. Else we need to error out, as we + // only expect buckets under the passed bucket. + if errorOnKeyValues { + return fmt.Errorf("unexpected key-value pair "+ + "found: key=%s, value=%x", key, v) + } + + return nil + } + + for _, expectedKey := range keys { + if bytes.Equal(key, expectedKey) { + // If this is an expected key, we can continue + // to the next key. + return nil + } + } + + return fmt.Errorf("unexpected key found: %s", key) + }) +} diff --git a/firewalldb/sql_migration_test.go b/firewalldb/sql_migration_test.go new file mode 100644 index 000000000..1298e3e53 --- /dev/null +++ b/firewalldb/sql_migration_test.go @@ -0,0 +1,664 @@ +package firewalldb + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/session" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb" + "github.com/stretchr/testify/require" + "golang.org/x/exp/rand" +) + +const ( + testRuleName = "test-rule" + testRuleName2 = "test-rule-2" + testFeatureName = "test-feature" + testFeatureName2 = "test-feature-2" + testEntryKey = "test-entry-key" + testEntryKey2 = "test-entry-key-2" + testEntryKey3 = "test-entry-key-3" + testEntryKey4 = "test-entry-key-4" +) + +var ( + testEntryValue = []byte{1, 2, 3} +) + +// TestFirewallDBMigration tests the migration of firewalldb from a bolt +// backend to a SQL database. Note that this test does not attempt to be a +// complete migration test. +// This test only tests the migration of the KV stores currently, but will +// be extended in the future to also test the migration of the privacy mapper +// and the actions store in the future. +func TestFirewallDBMigration(t *testing.T) { + t.Parallel() + + ctx := context.Background() + clock := clock.NewTestClock(time.Now()) + + // When using build tags that creates a kvdb store for NewTestDB, we + // skip this test as it is only applicable for postgres and sqlite tags. + store := NewTestDB(t, clock) + if _, ok := store.(*BoltDB); ok { + t.Skipf("Skipping Firewall DB migration test for kvdb build") + } + + makeSQLDB := func(t *testing.T, sessionsStore session.Store) (*SQLDB, + *db.TransactionExecutor[SQLQueries]) { + + testDBStore := NewTestDBWithSessions(t, sessionsStore, clock) + + store, ok := testDBStore.(*SQLDB) + require.True(t, ok) + + baseDB := store.BaseDB + + genericExecutor := db.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return baseDB.WithTx(tx) + }, + ) + + return store, genericExecutor + } + + // The assertMigrationResults function will currently assert that + // the migrated kv stores entries in the SQLDB match the original kv + // stores entries in the BoltDB. + assertMigrationResults := func(t *testing.T, sqlStore *SQLDB, + kvEntries []*kvEntry) { + + var ( + ruleIDs = make(map[string]int64) + groupIDs = make(map[string]int64) + featureIDs = make(map[string]int64) + err error + ) + + getRuleID := func(ruleName string) int64 { + ruleID, ok := ruleIDs[ruleName] + if !ok { + ruleID, err = sqlStore.GetRuleID(ctx, ruleName) + require.NoError(t, err) + + ruleIDs[ruleName] = ruleID + } + + return ruleID + } + + getGroupID := func(groupAlias []byte) int64 { + groupID, ok := groupIDs[string(groupAlias)] + if !ok { + groupID, err = sqlStore.GetSessionIDByAlias( + ctx, groupAlias, + ) + require.NoError(t, err) + + groupIDs[string(groupAlias)] = groupID + } + + return groupID + } + + getFeatureID := func(featureName string) int64 { + featureID, ok := featureIDs[featureName] + if !ok { + featureID, err = sqlStore.GetFeatureID( + ctx, featureName, + ) + require.NoError(t, err) + + featureIDs[featureName] = featureID + } + + return featureID + } + + // First we extract all migrated kv entries from the SQLDB, + // in order to be able to compare them to the original kv + // entries, to ensure that the migration was successful. + sqlKvEntries, err := sqlStore.ListAllKVStoresRecords(ctx) + require.NoError(t, err) + require.Equal(t, len(kvEntries), len(sqlKvEntries)) + + // We then iterate over the original kv entries that were + // migrated from the BoltDB to the SQLDB, and assert that they + // match the migrated SQL kv entries. + // NOTE: when fetching kv entries that were inserted into the + // sql store with the entry value []byte{}, a nil value is + // returned. Therefore, require.Equal would error on such cases, + // while bytes.Equal would not. Therefore, the comparison below + // uses bytes.Equal to compare the values. + for _, entry := range kvEntries { + ruleID := getRuleID(entry.ruleName) + + if entry.groupAlias.IsNone() { + sqlVal, err := sqlStore.GetGlobalKVStoreRecord( + ctx, + sqlc.GetGlobalKVStoreRecordParams{ + Key: entry.key, + Perm: entry.perm, + RuleID: ruleID, + }, + ) + require.NoError(t, err) + // See docs for the loop above on why + // bytes.Equal is used here. + require.True( + t, bytes.Equal(entry.value, sqlVal), + ) + } else if entry.featureName.IsNone() { + groupAlias := entry.groupAlias.UnwrapOrFail(t) + groupID := getGroupID(groupAlias[:]) + + v, err := sqlStore.GetGroupKVStoreRecord( + ctx, + sqlc.GetGroupKVStoreRecordParams{ + Key: entry.key, + Perm: entry.perm, + RuleID: ruleID, + GroupID: sql.NullInt64{ + Int64: groupID, + Valid: true, + }, + }, + ) + require.NoError(t, err) + // See docs for the loop above on why + // bytes.Equal is used here. + require.True( + t, bytes.Equal(entry.value, v), + ) + } else { + groupAlias := entry.groupAlias.UnwrapOrFail(t) + groupID := getGroupID(groupAlias[:]) + featureID := getFeatureID( + entry.featureName.UnwrapOrFail(t), + ) + + sqlVal, err := sqlStore.GetFeatureKVStoreRecord( + ctx, + sqlc.GetFeatureKVStoreRecordParams{ + Key: entry.key, + Perm: entry.perm, + RuleID: ruleID, + GroupID: sql.NullInt64{ + Int64: groupID, + Valid: true, + }, + FeatureID: sql.NullInt64{ + Int64: featureID, + Valid: true, + }, + }, + ) + require.NoError(t, err) + // See docs for the loop above on why + // bytes.Equal is used here. + require.True( + t, bytes.Equal(entry.value, sqlVal), + ) + } + } + } + + // The tests slice contains all the tests that we will run for the + // migration of the firewalldb from a BoltDB to a SQLDB. + // Note that the tests currently only test the migration of the KV + // stores, but will be extended in the future to also test the migration + // of the privacy mapper and the actions store. + tests := []struct { + name string + populateDB func(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []*kvEntry + }{ + { + name: "empty", + populateDB: func(t *testing.T, ctx context.Context, + boltDB *BoltDB, + sessionStore session.Store) []*kvEntry { + + // Don't populate the DB. + return make([]*kvEntry, 0) + }, + }, + { + name: "global entries", + populateDB: globalEntries, + }, + { + name: "session specific entries", + populateDB: sessionSpecificEntries, + }, + { + name: "feature specific entries", + populateDB: featureSpecificEntries, + }, + { + name: "all entry combinations", + populateDB: allEntryCombinations, + }, + { + name: "random entries", + populateDB: randomKVEntries, + }, + } + + for _, test := range tests { + tc := test + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // First let's create a sessions store to link to in + // the kvstores DB. In order to create the sessions + // store though, we also need to create an accounts + // store, that we link to the sessions store. + // Note that both of these stores will be sql stores due + // to the build tags enabled when running this test, + // which means we can also pass the sessions store to + // the sql version of the kv stores that we'll create + // in test, without also needing to migrate it. + accountStore := accounts.NewTestDB(t, clock) + sessionsStore := session.NewTestDBWithAccounts( + t, clock, accountStore, + ) + + // Create a new firewall store to populate with test + // data. + firewallStore, err := NewBoltDB( + t.TempDir(), DBFilename, sessionsStore, + accountStore, clock, + ) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, firewallStore.Close()) + }) + + // Populate the kv store. + entries := test.populateDB( + t, ctx, firewallStore, sessionsStore, + ) + + // Create the SQL store that we will migrate the data + // to. + sqlStore, txEx := makeSQLDB(t, sessionsStore) + + // Perform the migration. + var opts sqldb.MigrationTxOptions + err = txEx.ExecTx(ctx, &opts, + func(tx SQLQueries) error { + return MigrateFirewallDBToSQL( + ctx, firewallStore.DB, tx, + ) + }, + ) + require.NoError(t, err) + + // Assert migration results. + assertMigrationResults(t, sqlStore, entries) + }) + } +} + +// globalEntries populates the kv store with one global entry for the temp +// store, and one for the perm store. +func globalEntries(t *testing.T, ctx context.Context, boltDB *BoltDB, + _ session.Store) []*kvEntry { + + return insertTempAndPermEntry( + t, ctx, boltDB, testRuleName, fn.None[[]byte](), + fn.None[string](), testEntryKey, testEntryValue, + ) +} + +// sessionSpecificEntries populates the kv store with one session specific +// entry for the local temp store, and one session specific entry for the perm +// local store. +func sessionSpecificEntries(t *testing.T, ctx context.Context, boltDB *BoltDB, + sessionStore session.Store) []*kvEntry { + + groupAlias := getNewSessionAlias(t, ctx, sessionStore) + + return insertTempAndPermEntry( + t, ctx, boltDB, testRuleName, groupAlias, fn.None[string](), + testEntryKey, testEntryValue, + ) +} + +// featureSpecificEntries populates the kv store with one feature specific +// entry for the local temp store, and one feature specific entry for the perm +// local store. +func featureSpecificEntries(t *testing.T, ctx context.Context, boltDB *BoltDB, + sessionStore session.Store) []*kvEntry { + + groupAlias := getNewSessionAlias(t, ctx, sessionStore) + + return insertTempAndPermEntry( + t, ctx, boltDB, testRuleName, groupAlias, + fn.Some(testFeatureName), testEntryKey, testEntryValue, + ) +} + +// allEntryCombinations adds all types of different entries at all possible +// levels of the kvstores, including multple entries with the same +// ruleName, groupAlias and featureName. The test aims to cover all possible +// combinations of entries in the kvstores, including nil and empty entry +// values. That therefore ensures that the migrations don't overwrite or miss +// any entries when the entry set is more complex than just a single entry at +// each level. +func allEntryCombinations(t *testing.T, ctx context.Context, boltDB *BoltDB, + sessionStore session.Store) []*kvEntry { + + var result []*kvEntry + add := func(entry []*kvEntry) { + result = append(result, entry...) + } + + // First lets create standard entries at all levels, which represents + // the entries added by other tests. + add(globalEntries(t, ctx, boltDB, sessionStore)) + add(sessionSpecificEntries(t, ctx, boltDB, sessionStore)) + add(featureSpecificEntries(t, ctx, boltDB, sessionStore)) + + groupAlias := getNewSessionAlias(t, ctx, sessionStore) + + // Now lets add a few more entries at with different rule names and + // features, just to ensure that we cover entries in different rule and + // feature tables. + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, fn.None[[]byte](), + fn.None[string](), testEntryKey, testEntryValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.None[string](), testEntryKey, testEntryValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.Some(testFeatureName), testEntryKey, testEntryValue, + )) + // Let's also create an entry with a different feature name that's still + // referencing the same group ID as the previous entry. + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.Some(testFeatureName2), testEntryKey, testEntryValue, + )) + + // Finally, lets add a few entries with nil and empty values set for the + // actual key value, at all different levels, to ensure that tests don't + // break if the value is nil or empty. + var ( + nilValue []byte = nil + nilSliceValue = []byte(nil) + emptyValue = []byte{} + ) + + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, fn.None[[]byte](), + fn.None[string](), testEntryKey2, nilValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, fn.None[[]byte](), + fn.None[string](), testEntryKey3, nilSliceValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, fn.None[[]byte](), + fn.None[string](), testEntryKey4, emptyValue, + )) + + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.None[string](), testEntryKey2, nilValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.None[string](), testEntryKey3, nilSliceValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.None[string](), testEntryKey4, emptyValue, + )) + + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.Some(testFeatureName), testEntryKey2, nilValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.Some(testFeatureName), testEntryKey3, nilSliceValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.Some(testFeatureName), testEntryKey4, emptyValue, + )) + + return result +} + +func getNewSessionAlias(t *testing.T, ctx context.Context, + sessionStore session.Store) fn.Option[[]byte] { + + sess, err := sessionStore.NewSession( + ctx, "test", session.TypeAutopilot, + time.Unix(1000, 0), "something", + ) + require.NoError(t, err) + + return fn.Some(sess.GroupID[:]) +} + +// insertTempAndPermEntry populates the kv store with one entry for the temp +// store, and one entry for the perm store. Both of the entries will be inserted +// with the same groupAlias, ruleName, entryKey and entryValue. +func insertTempAndPermEntry(t *testing.T, ctx context.Context, + boltDB *BoltDB, ruleName string, groupAlias fn.Option[[]byte], + featureNameOpt fn.Option[string], entryKey string, + entryValue []byte) []*kvEntry { + + tempKvEntry := &kvEntry{ + ruleName: ruleName, + groupAlias: groupAlias, + featureName: featureNameOpt, + key: entryKey, + value: entryValue, + perm: false, + } + + insertKvEntry(t, ctx, boltDB, tempKvEntry) + + permKvEntry := &kvEntry{ + ruleName: ruleName, + groupAlias: groupAlias, + featureName: featureNameOpt, + key: entryKey, + value: entryValue, + perm: true, + } + + insertKvEntry(t, ctx, boltDB, permKvEntry) + + return []*kvEntry{tempKvEntry, permKvEntry} +} + +// insertKvEntry populates the kv store with passed entry, and asserts that the +// entry is inserted correctly. +func insertKvEntry(t *testing.T, ctx context.Context, + boltDB *BoltDB, entry *kvEntry) { + + if entry.groupAlias.IsNone() && entry.featureName.IsSome() { + t.Fatalf("cannot set both global and feature specific at the " + + "same time") + } + + // We get the kv stores that the entry will be inserted into. Note that + // we set an empty group ID if the entry is global, as the group ID + // will not be used when fetching the actual kv store that's used for + // global entries. + groupID := [4]byte{} + if entry.groupAlias.IsSome() { + copy(groupID[:], entry.groupAlias.UnwrapOrFail(t)) + } + + kvStores := boltDB.GetKVStores( + entry.ruleName, groupID, entry.featureName.UnwrapOr(""), + ) + + err := kvStores.Update(ctx, func(ctx context.Context, + tx KVStoreTx) error { + + store := tx.Global() + + switch { + case entry.groupAlias.IsNone() && !entry.perm: + store = tx.GlobalTemp() + case entry.groupAlias.IsSome() && !entry.perm: + store = tx.LocalTemp() + case entry.groupAlias.IsSome() && entry.perm: + store = tx.Local() + } + + return store.Set(ctx, entry.key, entry.value) + }) + require.NoError(t, err) +} + +// randomKVEntries populates the kv store with random kv entries that span +// across all possible combinations of different levels of entries in the kv +// store. All values and different bucket names are randomly generated. +func randomKVEntries(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []*kvEntry { + + var ( + // We set the number of entries to insert to 1000, as that + // should be enough to cover as many different + // combinations of entries as possible, while still being + // fast enough to run in a reasonable time. + numberOfEntries = 1000 + insertedEntries = make([]*kvEntry, 0) + ruleName = "initial-rule" + groupAlias []byte + featureName = "initial-feature" + ) + + // Create a random session that we can reference for the initial group + // ID. + sess, err := sessionStore.NewSession( + ctx, "initial-session", session.Type(1), time.Unix(1000, 0), + "serverAddr.test", + ) + require.NoError(t, err) + + groupAlias = sess.GroupID[:] + + // Generate random entries. Note that many entries will use the same + // rule name, group ID and feature name, to simulate the real world + // usage of the kv stores as much as possible. + for i := 0; i < numberOfEntries; i++ { + // On average, we will generate a new rule which will be used + // for the kv store entry 10% of the time. + if rand.Intn(10) == 0 { + ruleName = fmt.Sprintf( + "rule-%s-%d", randomString(rand.Intn(30)+1), i, + ) + } + + // On average, we use the global store 25% of the time. + global := rand.Intn(4) == 0 + + // We'll use the perm store 50% of the time. + perm := rand.Intn(2) == 0 + + // For the non-global entries, we will generate a new group + // alias 25% of the time. + if !global && rand.Intn(4) == 0 { + newSess, err := sessionStore.NewSession( + ctx, fmt.Sprintf("session-%d", i), + session.Type(uint8(rand.Intn(5))), + time.Unix(1000, 0), + randomString(rand.Intn(10)+1), + ) + require.NoError(t, err) + + groupAlias = newSess.GroupID[:] + } + + featureNameOpt := fn.None[string]() + + // For 50% of the non-global entries, we insert a feature + // specific entry. The other 50% will be session specific + // entries. + if !global && rand.Intn(2) == 0 { + // 25% of the time, we will generate a new feature name. + if rand.Intn(4) == 0 { + featureName = fmt.Sprintf( + "feature-%s-%d", + randomString(rand.Intn(30)+1), i, + ) + } + + featureNameOpt = fn.Some(featureName) + } + + groupAliasOpt := fn.None[[]byte]() + if !global { + // If the entry is not global, we set the group ID + // to the latest session's group ID. + groupAliasOpt = fn.Some(groupAlias[:]) + } + + entry := &kvEntry{ + ruleName: ruleName, + groupAlias: groupAliasOpt, + featureName: featureNameOpt, + key: fmt.Sprintf("key-%d", i), + perm: perm, + } + + // When setting a value for the entry, 25% of the time, we will + // set a nil or empty value. + if rand.Intn(4) == 0 { + // in 50% of these cases, we will set the value to nil, + // and in the other 50% we will set it to an empty + // value + if rand.Intn(2) == 0 { + entry.value = nil + } else { + entry.value = []byte{} + } + } else { + // Else generate a random value for all entries, + entry.value = []byte(randomString(rand.Intn(100) + 1)) + } + + // Insert the entry into the kv store. + insertKvEntry(t, ctx, boltDB, entry) + + // Add the entry to the list of inserted entries. + insertedEntries = append(insertedEntries, entry) + } + + return insertedEntries +} + +// randomString generates a random string of the passed length n. +func randomString(n int) string { + letterBytes := "abcdefghijklmnopqrstuvwxyz" + + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return string(b) +} diff --git a/firewalldb/test_kvdb.go b/firewalldb/test_kvdb.go index 6f7a49aa3..c3cd4533a 100644 --- a/firewalldb/test_kvdb.go +++ b/firewalldb/test_kvdb.go @@ -6,34 +6,37 @@ import ( "testing" "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/stretchr/testify/require" ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *BoltDB { +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { return NewTestDBFromPath(t, t.TempDir(), clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) *BoltDB { +func NewTestDBFromPath(t *testing.T, dbPath string, + clock clock.Clock) FirewallDBs { + return newDBFromPathWithSessions(t, dbPath, nil, nil, clock) } // NewTestDBWithSessions creates a new test BoltDB Store with access to an // existing sessions DB. -func NewTestDBWithSessions(t *testing.T, sessStore SessionDB, - clock clock.Clock) *BoltDB { +func NewTestDBWithSessions(t *testing.T, sessStore session.Store, + clock clock.Clock) FirewallDBs { return newDBFromPathWithSessions(t, t.TempDir(), sessStore, nil, clock) } // NewTestDBWithSessionsAndAccounts creates a new test BoltDB Store with access // to an existing sessions DB and accounts DB. -func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore SessionDB, - acctStore AccountsDB, clock clock.Clock) *BoltDB { +func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore session.Store, + acctStore AccountsDB, clock clock.Clock) FirewallDBs { return newDBFromPathWithSessions( t, t.TempDir(), sessStore, acctStore, clock, @@ -41,7 +44,8 @@ func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore SessionDB, } func newDBFromPathWithSessions(t *testing.T, dbPath string, - sessStore SessionDB, acctStore AccountsDB, clock clock.Clock) *BoltDB { + sessStore session.Store, acctStore AccountsDB, + clock clock.Clock) FirewallDBs { store, err := NewBoltDB(dbPath, DBFilename, sessStore, acctStore, clock) require.NoError(t, err) diff --git a/firewalldb/test_postgres.go b/firewalldb/test_postgres.go index f5777e4cb..732b19b4a 100644 --- a/firewalldb/test_postgres.go +++ b/firewalldb/test_postgres.go @@ -10,12 +10,12 @@ import ( ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestPostgresDB(t).BaseDB, clock) +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, _ string, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestPostgresDB(t).BaseDB, clock) +func NewTestDBFromPath(t *testing.T, _ string, clock clock.Clock) FirewallDBs { + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } diff --git a/firewalldb/test_sql.go b/firewalldb/test_sql.go index 03dcfbebf..a412441f8 100644 --- a/firewalldb/test_sql.go +++ b/firewalldb/test_sql.go @@ -7,6 +7,7 @@ import ( "time" "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" "github.com/stretchr/testify/require" @@ -15,18 +16,17 @@ import ( // NewTestDBWithSessions creates a new test SQLDB Store with access to an // existing sessions DB. func NewTestDBWithSessions(t *testing.T, sessionStore session.Store, - clock clock.Clock) *SQLDB { - + clock clock.Clock) FirewallDBs { sessions, ok := sessionStore.(*session.SQLStore) require.True(t, ok) - return NewSQLDB(sessions.BaseDB, clock) + return createStore(t, sessions.BaseDB, clock) } // NewTestDBWithSessionsAndAccounts creates a new test SQLDB Store with access // to an existing sessions DB and accounts DB. func NewTestDBWithSessionsAndAccounts(t *testing.T, sessionStore SessionDB, - acctStore AccountsDB, clock clock.Clock) *SQLDB { + acctStore AccountsDB, clock clock.Clock) FirewallDBs { sessions, ok := sessionStore.(*session.SQLStore) require.True(t, ok) @@ -36,7 +36,7 @@ func NewTestDBWithSessionsAndAccounts(t *testing.T, sessionStore SessionDB, require.Equal(t, accounts.BaseDB, sessions.BaseDB) - return NewSQLDB(sessions.BaseDB, clock) + return createStore(t, sessions.BaseDB, clock) } func assertEqualActions(t *testing.T, expected, got *Action) { @@ -52,3 +52,14 @@ func assertEqualActions(t *testing.T, expected, got *Action) { expected.AttemptedAt = expectedAttemptedAt got.AttemptedAt = actualAttemptedAt } + +// createStore is a helper function that creates a new SQLDB and ensure that +// it is closed when during the test cleanup. +func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLDB { + store := NewSQLDB(sqlDB, clock) + t.Cleanup(func() { + require.NoError(t, store.Close()) + }) + + return store +} diff --git a/firewalldb/test_sqlite.go b/firewalldb/test_sqlite.go index 5496cb205..49b956d7d 100644 --- a/firewalldb/test_sqlite.go +++ b/firewalldb/test_sqlite.go @@ -10,14 +10,14 @@ import ( ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestSqliteDB(t).BaseDB, clock) +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { + return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) *SQLDB { - return NewSQLDB( - db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, +func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) FirewallDBs { + return createStore( + t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, ) }