Skip to content

Commit bff2f24

Browse files
authored
Merge pull request #9873 from ellemouton/sqldbHelpers
sqldb: re-usable TxOptions and NoOpReset
2 parents 8e96bd0 + 9cbc1f8 commit bff2f24

File tree

10 files changed

+98
-149
lines changed

10 files changed

+98
-149
lines changed

batch/batch.go

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,6 @@ var errSolo = errors.New(
1414
"batch function returned an error and should be re-run solo",
1515
)
1616

17-
// txOpts implements the sqldb.TxOptions interface. It is used to indicate that
18-
// the transaction can be read-only or not transaction.
19-
type txOpts struct {
20-
readOnly bool
21-
}
22-
23-
// ReadOnly returns true if the transaction should be read only.
24-
//
25-
// NOTE: This is part of the sqldb.TxOptions interface.
26-
func (t *txOpts) ReadOnly() bool {
27-
return t.readOnly
28-
}
29-
3017
type request[Q any] struct {
3118
*Request[Q]
3219
errChan chan error
@@ -38,7 +25,7 @@ type batch[Q any] struct {
3825
reqs []*request[Q]
3926
clear func(b *batch[Q])
4027
locker sync.Locker
41-
txOpts txOpts
28+
txOpts sqldb.TxOptions
4229
}
4330

4431
// trigger is the entry point for the batch and ensures that run is started at
@@ -68,7 +55,7 @@ func (b *batch[Q]) run(ctx context.Context) {
6855
// that fail will be retried individually.
6956
for len(b.reqs) > 0 {
7057
var failIdx = -1
71-
err := b.db.ExecTx(ctx, &b.txOpts, func(tx Q) error {
58+
err := b.db.ExecTx(ctx, b.txOpts, func(tx Q) error {
7259
for i, req := range b.reqs {
7360
err := req.Do(tx)
7461
if err != nil {

batch/batch_test.go

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ func benchmarkSQLBatching(b *testing.B, sqlite bool) {
550550
}
551551

552552
ctx := context.Background()
553-
var opts txOpts
553+
opts := sqldb.WriteTxOpt()
554554

555555
// writeRecord is a helper that adds a single new invoice to the
556556
// database. It uses the 'i' argument to create a unique hash for the
@@ -578,13 +578,12 @@ func benchmarkSQLBatching(b *testing.B, sqlite bool) {
578578
var hash [8]byte
579579
binary.BigEndian.PutUint64(hash[:], uint64(N-1))
580580

581-
err := tx.ExecTx(
582-
ctx, &txOpts{}, func(queries *sqlc.Queries) error {
583-
_, err := queries.GetInvoiceByHash(ctx, hash[:])
584-
require.NoError(b, err)
581+
err := tx.ExecTx(ctx, opts, func(queries *sqlc.Queries) error {
582+
_, err := queries.GetInvoiceByHash(ctx, hash[:])
583+
require.NoError(b, err)
585584

586-
return nil
587-
}, func() {},
585+
return nil
586+
}, func() {},
588587
)
589588
require.NoError(b, err)
590589
}
@@ -602,7 +601,7 @@ func benchmarkSQLBatching(b *testing.B, sqlite bool) {
602601
defer wg.Done()
603602

604603
err := db.ExecTx(
605-
ctx, &opts,
604+
ctx, opts,
606605
func(tx *sqlc.Queries) error {
607606
writeRecord(b, tx, int64(j))
608607
return nil
@@ -624,7 +623,7 @@ func benchmarkSQLBatching(b *testing.B, sqlite bool) {
624623
b.ResetTimer()
625624

626625
err := db.ExecTx(
627-
ctx, &opts,
626+
ctx, opts,
628627
func(tx *sqlc.Queries) error {
629628
for i := 0; i < b.N; i++ {
630629
writeRecord(b, tx, int64(i))

batch/scheduler.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,7 @@ func (s *TimeScheduler[Q]) Execute(ctx context.Context, r *Request[Q]) error {
6565
// By default, we assume that the batch is read-only,
6666
// and we only upgrade it to read-write if a request
6767
// is added that is not read-only.
68-
txOpts: txOpts{
69-
readOnly: true,
70-
},
68+
txOpts: sqldb.ReadTxOpt(),
7169
}
7270
trigger := s.b.trigger
7371
time.AfterFunc(s.duration, func() {
@@ -78,8 +76,8 @@ func (s *TimeScheduler[Q]) Execute(ctx context.Context, r *Request[Q]) error {
7876

7977
// We only upgrade the batch to read-write if the new request is not
8078
// read-only. If it is already read-write, we don't need to do anything.
81-
if s.b.txOpts.readOnly && !r.Opts.ReadOnly {
82-
s.b.txOpts.readOnly = false
79+
if s.b.txOpts.ReadOnly() && !r.Opts.ReadOnly {
80+
s.b.txOpts = sqldb.WriteTxOpt()
8381
}
8482

8583
// If this is a non-lazy request, we'll execute the batch immediately.
@@ -109,7 +107,7 @@ func (s *TimeScheduler[Q]) Execute(ctx context.Context, r *Request[Q]) error {
109107
}
110108

111109
// Otherwise, run the request on its own.
112-
commitErr := s.db.ExecTx(ctx, &txOpts, func(tx Q) error {
110+
commitErr := s.db.ExecTx(ctx, txOpts, func(tx Q) error {
113111
return req.Do(tx)
114112
}, func() {
115113
if req.Reset != nil {

graph/db/sql_store.go

Lines changed: 21 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -139,27 +139,6 @@ func NewSQLStore(db BatchedSQLQueries, kvStore *KVStore,
139139
return s, nil
140140
}
141141

142-
// TxOptions defines the set of db txn options the SQLQueries
143-
// understands.
144-
type TxOptions struct {
145-
// readOnly governs if a read only transaction is needed or not.
146-
readOnly bool
147-
}
148-
149-
// ReadOnly returns true if the transaction should be read only.
150-
//
151-
// NOTE: This implements the TxOptions.
152-
func (a *TxOptions) ReadOnly() bool {
153-
return a.readOnly
154-
}
155-
156-
// NewReadTx creates a new read transaction option set.
157-
func NewReadTx() *TxOptions {
158-
return &TxOptions{
159-
readOnly: true,
160-
}
161-
}
162-
163142
// AddLightningNode adds a vertex/node to the graph database. If the node is not
164143
// in the database from before, this will add a new, unconnected one to the
165144
// graph. If it is present from before, this will update that node's
@@ -192,16 +171,13 @@ func (s *SQLStore) FetchLightningNode(pubKey route.Vertex) (
192171

193172
ctx := context.TODO()
194173

195-
var (
196-
readTx = NewReadTx()
197-
node *models.LightningNode
198-
)
199-
err := s.db.ExecTx(ctx, readTx, func(db SQLQueries) error {
174+
var node *models.LightningNode
175+
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
200176
var err error
201177
_, node, err = getNodeByPubKey(ctx, db, pubKey)
202178

203179
return err
204-
}, func() {})
180+
}, sqldb.NoOpReset)
205181
if err != nil {
206182
return nil, fmt.Errorf("unable to fetch node: %w", err)
207183
}
@@ -222,11 +198,10 @@ func (s *SQLStore) HasLightningNode(pubKey [33]byte) (time.Time, bool,
222198
ctx := context.TODO()
223199

224200
var (
225-
readTx = NewReadTx()
226201
exists bool
227202
lastUpdate time.Time
228203
)
229-
err := s.db.ExecTx(ctx, readTx, func(db SQLQueries) error {
204+
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
230205
dbNode, err := db.GetNodeByPubKey(
231206
ctx, sqlc.GetNodeByPubKeyParams{
232207
Version: int16(ProtocolV1),
@@ -246,7 +221,7 @@ func (s *SQLStore) HasLightningNode(pubKey [33]byte) (time.Time, bool,
246221
}
247222

248223
return nil
249-
}, func() {})
224+
}, sqldb.NoOpReset)
250225
if err != nil {
251226
return time.Time{}, false,
252227
fmt.Errorf("unable to fetch node: %w", err)
@@ -266,11 +241,10 @@ func (s *SQLStore) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr,
266241
ctx := context.TODO()
267242

268243
var (
269-
readTx = NewReadTx()
270244
addresses []net.Addr
271245
known bool
272246
)
273-
err := s.db.ExecTx(ctx, readTx, func(db SQLQueries) error {
247+
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
274248
var err error
275249
known, addresses, err = getNodeAddresses(
276250
ctx, db, nodePub.SerializeCompressed(),
@@ -281,7 +255,7 @@ func (s *SQLStore) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr,
281255
}
282256

283257
return nil
284-
}, func() {})
258+
}, sqldb.NoOpReset)
285259
if err != nil {
286260
return false, nil, fmt.Errorf("unable to get addresses for "+
287261
"node(%x): %w", nodePub.SerializeCompressed(), err)
@@ -297,8 +271,7 @@ func (s *SQLStore) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr,
297271
func (s *SQLStore) DeleteLightningNode(pubKey route.Vertex) error {
298272
ctx := context.TODO()
299273

300-
var writeTxOpts TxOptions
301-
err := s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error {
274+
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
302275
res, err := db.DeleteNodeByPubKey(
303276
ctx, sqlc.DeleteNodeByPubKeyParams{
304277
Version: int16(ProtocolV1),
@@ -321,7 +294,7 @@ func (s *SQLStore) DeleteLightningNode(pubKey route.Vertex) error {
321294
}
322295

323296
return err
324-
}, func() {})
297+
}, sqldb.NoOpReset)
325298
if err != nil {
326299
return fmt.Errorf("unable to delete node: %w", err)
327300
}
@@ -346,11 +319,10 @@ func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
346319
// NOTE: part of the V1Store interface.
347320
func (s *SQLStore) LookupAlias(pub *btcec.PublicKey) (string, error) {
348321
var (
349-
ctx = context.TODO()
350-
readTx = NewReadTx()
351-
alias string
322+
ctx = context.TODO()
323+
alias string
352324
)
353-
err := s.db.ExecTx(ctx, readTx, func(db SQLQueries) error {
325+
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
354326
dbNode, err := db.GetNodeByPubKey(
355327
ctx, sqlc.GetNodeByPubKeyParams{
356328
Version: int16(ProtocolV1),
@@ -370,7 +342,7 @@ func (s *SQLStore) LookupAlias(pub *btcec.PublicKey) (string, error) {
370342
alias = dbNode.Alias.String
371343

372344
return nil
373-
}, func() {})
345+
}, sqldb.NoOpReset)
374346
if err != nil {
375347
return "", fmt.Errorf("unable to look up alias: %w", err)
376348
}
@@ -387,11 +359,8 @@ func (s *SQLStore) LookupAlias(pub *btcec.PublicKey) (string, error) {
387359
func (s *SQLStore) SourceNode() (*models.LightningNode, error) {
388360
ctx := context.TODO()
389361

390-
var (
391-
readTx = NewReadTx()
392-
node *models.LightningNode
393-
)
394-
err := s.db.ExecTx(ctx, readTx, func(db SQLQueries) error {
362+
var node *models.LightningNode
363+
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
395364
_, nodePub, err := getSourceNode(ctx, db, ProtocolV1)
396365
if err != nil {
397366
return fmt.Errorf("unable to fetch V1 source node: %w",
@@ -401,7 +370,7 @@ func (s *SQLStore) SourceNode() (*models.LightningNode, error) {
401370
_, node, err = getNodeByPubKey(ctx, db, nodePub)
402371

403372
return err
404-
}, func() {})
373+
}, sqldb.NoOpReset)
405374
if err != nil {
406375
return nil, fmt.Errorf("unable to fetch source node: %w", err)
407376
}
@@ -416,9 +385,8 @@ func (s *SQLStore) SourceNode() (*models.LightningNode, error) {
416385
// NOTE: part of the V1Store interface.
417386
func (s *SQLStore) SetSourceNode(node *models.LightningNode) error {
418387
ctx := context.TODO()
419-
var writeTxOpts TxOptions
420388

421-
return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error {
389+
return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
422390
id, err := upsertNode(ctx, db, node)
423391
if err != nil {
424392
return fmt.Errorf("unable to upsert source node: %w",
@@ -442,7 +410,7 @@ func (s *SQLStore) SetSourceNode(node *models.LightningNode) error {
442410
}
443411

444412
return db.AddSourceNode(ctx, id)
445-
}, func() {})
413+
}, sqldb.NoOpReset)
446414
}
447415

448416
// NodeUpdatesInHorizon returns all the known lightning node which have an
@@ -456,11 +424,8 @@ func (s *SQLStore) NodeUpdatesInHorizon(startTime,
456424

457425
ctx := context.TODO()
458426

459-
var (
460-
readTx = NewReadTx()
461-
nodes []models.LightningNode
462-
)
463-
err := s.db.ExecTx(ctx, readTx, func(db SQLQueries) error {
427+
var nodes []models.LightningNode
428+
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
464429
dbNodes, err := db.GetNodesByLastUpdateRange(
465430
ctx, sqlc.GetNodesByLastUpdateRangeParams{
466431
StartTime: sqldb.SQLInt64(startTime.Unix()),
@@ -482,7 +447,7 @@ func (s *SQLStore) NodeUpdatesInHorizon(startTime,
482447
}
483448

484449
return nil
485-
}, func() {})
450+
}, sqldb.NoOpReset)
486451
if err != nil {
487452
return nil, fmt.Errorf("unable to fetch nodes: %w", err)
488453
}

invoices/kv_sql_migration_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,13 @@ func TestMigrationWithChannelDB(t *testing.T) {
7070
ctxb := context.Background()
7171

7272
const batchSize = 11
73-
var opts sqldb.MigrationTxOptions
7473
err := sqlStore.ExecTx(
75-
ctxb, &opts, func(tx *sqlc.Queries) error {
74+
ctxb, sqldb.WriteTxOpt(), func(tx *sqlc.Queries) error {
7675
return invpkg.MigrateInvoicesToSQL(
7776
ctxb, kvStore.Backend, kvStore, tx,
7877
batchSize,
7978
)
80-
}, func() {},
79+
}, sqldb.NoOpReset,
8180
)
8281
require.NoError(t, err)
8382

invoices/sql_migration_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,15 +317,15 @@ func testMigrateSingleInvoiceRapid(t *rapid.T, store *SQLStore, mpp bool,
317317
invoices[hash] = invoice
318318
}
319319

320-
var ops SQLInvoiceQueriesTxOptions
321-
err := store.db.ExecTx(ctxb, &ops, func(tx SQLInvoiceQueries) error {
320+
ops := sqldb.WriteTxOpt()
321+
err := store.db.ExecTx(ctxb, ops, func(tx SQLInvoiceQueries) error {
322322
for hash, invoice := range invoices {
323323
err := MigrateSingleInvoice(ctxb, tx, invoice, hash)
324324
require.NoError(t, err)
325325
}
326326

327327
return nil
328-
}, func() {})
328+
}, sqldb.NoOpReset)
329329
require.NoError(t, err)
330330

331331
// Fetch and compare each migrated invoice from the store with the

0 commit comments

Comments
 (0)