Skip to content

Commit e5f39dd

Browse files
committed
sweep: refactor storeRecord to updateRecord
To make it clear we are only updating fields, which will be handy for the following commit where we start tracking for spending notifications.
1 parent 7eea7a7 commit e5f39dd

File tree

2 files changed

+100
-51
lines changed

2 files changed

+100
-51
lines changed

sweep/fee_bumper.go

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -441,19 +441,19 @@ func (t *TxPublisher) storeInitialRecord(req *BumpRequest) *monitorRecord {
441441
return record
442442
}
443443

444-
// storeRecord stores the given record in the records map.
445-
func (t *TxPublisher) storeRecord(requestID uint64, sweepCtx *sweepTxCtx,
446-
req *BumpRequest, f FeeFunction) {
444+
// updateRecord updates the given record's tx and fee, and saves it in the
445+
// records map.
446+
func (t *TxPublisher) updateRecord(r *monitorRecord,
447+
sweepCtx *sweepTxCtx) *monitorRecord {
448+
449+
r.tx = sweepCtx.tx
450+
r.fee = sweepCtx.fee
451+
r.outpointToTxIndex = sweepCtx.outpointToTxIndex
447452

448453
// Register the record.
449-
t.records.Store(requestID, &monitorRecord{
450-
requestID: requestID,
451-
tx: sweepCtx.tx,
452-
req: req,
453-
feeFunction: f,
454-
fee: sweepCtx.fee,
455-
outpointToTxIndex: sweepCtx.outpointToTxIndex,
456-
})
454+
t.records.Store(r.requestID, r)
455+
456+
return r
457457
}
458458

459459
// NOTE: part of the `chainio.Consumer` interface.
@@ -463,11 +463,11 @@ func (t *TxPublisher) Name() string {
463463

464464
// initializeTx initializes a fee function and creates an RBF-compliant tx. If
465465
// succeeded, the initial tx is stored in the records map.
466-
func (t *TxPublisher) initializeTx(r *monitorRecord) error {
466+
func (t *TxPublisher) initializeTx(r *monitorRecord) (*monitorRecord, error) {
467467
// Create a fee bumping algorithm to be used for future RBF.
468468
feeAlgo, err := t.initializeFeeFunction(r.req)
469469
if err != nil {
470-
return fmt.Errorf("init fee function: %w", err)
470+
return nil, fmt.Errorf("init fee function: %w", err)
471471
}
472472

473473
// Attach the newly created fee function.
@@ -481,12 +481,12 @@ func (t *TxPublisher) initializeTx(r *monitorRecord) error {
481481

482482
// Create the initial tx to be broadcasted. This tx is guaranteed to
483483
// comply with the RBF restrictions.
484-
err = t.createRBFCompliantTx(r)
484+
record, err := t.createRBFCompliantTx(r)
485485
if err != nil {
486-
return fmt.Errorf("create RBF-compliant tx: %w", err)
486+
return nil, fmt.Errorf("create RBF-compliant tx: %w", err)
487487
}
488488

489-
return nil
489+
return record, nil
490490
}
491491

492492
// initializeFeeFunction initializes a fee function to be used for this request
@@ -522,7 +522,9 @@ func (t *TxPublisher) initializeFeeFunction(
522522
// so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee
523523
// and redo the process until the tx is valid, or return an error when non-RBF
524524
// related errors occur or the budget has been used up.
525-
func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
525+
func (t *TxPublisher) createRBFCompliantTx(
526+
r *monitorRecord) (*monitorRecord, error) {
527+
526528
f := r.feeFunction
527529

528530
for {
@@ -533,15 +535,15 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
533535
switch {
534536
case err == nil:
535537
// The tx is valid, store it.
536-
t.storeRecord(r.requestID, sweepCtx, r.req, f)
538+
record := t.updateRecord(r, sweepCtx)
537539

538540
log.Infof("Created initial sweep tx=%v for %v inputs: "+
539541
"feerate=%v, fee=%v, inputs:\n%v",
540542
sweepCtx.tx.TxHash(), len(r.req.Inputs),
541543
f.FeeRate(), sweepCtx.fee,
542544
inputTypeSummary(r.req.Inputs))
543545

544-
return nil
546+
return record, nil
545547

546548
// If the error indicates the fees paid is not enough, we will
547549
// ask the fee function to increase the fee rate and retry.
@@ -572,7 +574,7 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
572574
// cluster these inputs differetly.
573575
increased, err = f.Increment()
574576
if err != nil {
575-
return err
577+
return nil, err
576578
}
577579
}
578580

@@ -582,7 +584,7 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
582584
// mempool acceptance.
583585
default:
584586
log.Debugf("Failed to create RBF-compliant tx: %v", err)
585-
return err
587+
return nil, err
586588
}
587589
}
588590
}
@@ -645,13 +647,7 @@ func (t *TxPublisher) createAndCheckTx(req *BumpRequest,
645647
// the event channel to the record. Any broadcast-related errors will not be
646648
// returned here, instead, they will be put inside the `BumpResult` and
647649
// returned to the caller.
648-
func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) {
649-
// Get the record being monitored.
650-
record, ok := t.records.Load(requestID)
651-
if !ok {
652-
return nil, fmt.Errorf("tx record %v not found", requestID)
653-
}
654-
650+
func (t *TxPublisher) broadcast(record *monitorRecord) (*BumpResult, error) {
655651
txid := record.tx.TxHash()
656652

657653
tx := record.tx
@@ -698,7 +694,7 @@ func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) {
698694
Fee: record.fee,
699695
FeeRate: record.feeFunction.FeeRate(),
700696
Err: err,
701-
requestID: requestID,
697+
requestID: record.requestID,
702698
}
703699

704700
return result, nil
@@ -1043,7 +1039,7 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord) {
10431039
// RBF rules.
10441040
//
10451041
// Create the initial tx to be broadcasted.
1046-
err = t.initializeTx(r)
1042+
record, err := t.initializeTx(r)
10471043
if err != nil {
10481044
log.Errorf("Initial broadcast failed: %v", err)
10491045

@@ -1054,7 +1050,7 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord) {
10541050
}
10551051

10561052
// Successfully created the first tx, now broadcast it.
1057-
result, err = t.broadcast(r.requestID)
1053+
result, err = t.broadcast(record)
10581054
if err != nil {
10591055
// The broadcast failed, which can only happen if the tx record
10601056
// cannot be found or the aux sweeper returns an error. In
@@ -1199,10 +1195,10 @@ func (t *TxPublisher) createAndPublishTx(
11991195

12001196
// The tx has been created without any errors, we now register a new
12011197
// record by overwriting the same requestID.
1202-
t.storeRecord(r.requestID, sweepCtx, r.req, r.feeFunction)
1198+
record := t.updateRecord(r, sweepCtx)
12031199

12041200
// Attempt to broadcast this new tx.
1205-
result, err := t.broadcast(r.requestID)
1201+
result, err := t.broadcast(record)
12061202
if err != nil {
12071203
log.Infof("Failed to broadcast replacement tx %v: %v",
12081204
sweepCtx.tx.TxHash(), err)

sweep/fee_bumper_test.go

Lines changed: 71 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -313,9 +313,9 @@ func TestInitializeFeeFunction(t *testing.T) {
313313
require.Equal(t, feerate, f.FeeRate())
314314
}
315315

316-
// TestStoreRecord correctly increases the request counter and saves the
316+
// TestUpdateRecord correctly updates the fields fee and tx, and saves the
317317
// record.
318-
func TestStoreRecord(t *testing.T) {
318+
func TestUpdateRecord(t *testing.T) {
319319
t.Parallel()
320320

321321
// Create a test input.
@@ -358,8 +358,15 @@ func TestStoreRecord(t *testing.T) {
358358
outpointToTxIndex: utxoIndex,
359359
}
360360

361+
// Create a test record.
362+
record := &monitorRecord{
363+
requestID: initialCounter,
364+
req: req,
365+
feeFunction: feeFunc,
366+
}
367+
361368
// Call the method under test.
362-
tp.storeRecord(initialCounter, sweepCtx, req, feeFunc)
369+
tp.updateRecord(record, sweepCtx)
363370

364371
// Read the saved record and compare.
365372
record, ok := tp.records.Load(initialCounter)
@@ -676,10 +683,19 @@ func TestCreateRBFCompliantTx(t *testing.T) {
676683
tc.setupMock()
677684

678685
// Call the method under test.
679-
err := tp.createRBFCompliantTx(record)
686+
rec, err := tp.createRBFCompliantTx(record)
680687

681688
// Check the result is as expected.
682689
require.ErrorIs(t, err, tc.expectedErr)
690+
691+
if tc.expectedErr != nil {
692+
return
693+
}
694+
695+
// Assert the returned record has the following fields
696+
// populated.
697+
require.NotEmpty(t, rec.tx)
698+
require.NotEmpty(t, rec.fee)
683699
})
684700
}
685701
}
@@ -721,13 +737,13 @@ func TestTxPublisherBroadcast(t *testing.T) {
721737
outpointToTxIndex: utxoIndex,
722738
}
723739

724-
tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
725-
726-
// Quickly check when the requestID cannot be found, an error is
727-
// returned.
728-
result, err := tp.broadcast(uint64(1000))
729-
require.Error(t, err)
730-
require.Nil(t, result)
740+
// Create a test record.
741+
record := &monitorRecord{
742+
requestID: requestID,
743+
req: req,
744+
feeFunction: m.feeFunc,
745+
}
746+
rec := tp.updateRecord(record, sweepCtx)
731747

732748
testCases := []struct {
733749
name string
@@ -782,7 +798,7 @@ func TestTxPublisherBroadcast(t *testing.T) {
782798
tc.setupMock()
783799

784800
// Call the method under test.
785-
result, err := tp.broadcast(requestID)
801+
result, err := tp.broadcast(rec)
786802

787803
// Check the result is as expected.
788804
require.ErrorIs(t, err, tc.expectedErr)
@@ -838,7 +854,15 @@ func TestRemoveResult(t *testing.T) {
838854
name: "remove on TxConfirmed",
839855
setupRecord: func() uint64 {
840856
rid := requestCounter.Add(1)
841-
tp.storeRecord(rid, sweepCtx, req, m.feeFunc)
857+
858+
// Create a test record.
859+
record := &monitorRecord{
860+
requestID: rid,
861+
req: req,
862+
feeFunction: m.feeFunc,
863+
}
864+
865+
tp.updateRecord(record, sweepCtx)
842866
tp.subscriberChans.Store(rid, nil)
843867

844868
return rid
@@ -854,7 +878,15 @@ func TestRemoveResult(t *testing.T) {
854878
name: "remove on TxFailed",
855879
setupRecord: func() uint64 {
856880
rid := requestCounter.Add(1)
857-
tp.storeRecord(rid, sweepCtx, req, m.feeFunc)
881+
882+
// Create a test record.
883+
record := &monitorRecord{
884+
requestID: rid,
885+
req: req,
886+
feeFunction: m.feeFunc,
887+
}
888+
889+
tp.updateRecord(record, sweepCtx)
858890
tp.subscriberChans.Store(rid, nil)
859891

860892
return rid
@@ -871,7 +903,15 @@ func TestRemoveResult(t *testing.T) {
871903
name: "noop when tx is not confirmed or failed",
872904
setupRecord: func() uint64 {
873905
rid := requestCounter.Add(1)
874-
tp.storeRecord(rid, sweepCtx, req, m.feeFunc)
906+
907+
// Create a test record.
908+
record := &monitorRecord{
909+
requestID: rid,
910+
req: req,
911+
feeFunction: m.feeFunc,
912+
}
913+
914+
tp.updateRecord(record, sweepCtx)
875915
tp.subscriberChans.Store(rid, nil)
876916

877917
return rid
@@ -937,8 +977,14 @@ func TestNotifyResult(t *testing.T) {
937977
fee: fee,
938978
outpointToTxIndex: utxoIndex,
939979
}
980+
// Create a test record.
981+
record := &monitorRecord{
982+
requestID: requestID,
983+
req: req,
984+
feeFunction: m.feeFunc,
985+
}
940986

941-
tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
987+
tp.updateRecord(record, sweepCtx)
942988

943989
// Create a subscription to the event.
944990
subscriber := make(chan *BumpResult, 1)
@@ -1250,7 +1296,14 @@ func TestHandleTxConfirmed(t *testing.T) {
12501296
outpointToTxIndex: utxoIndex,
12511297
}
12521298

1253-
tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
1299+
// Create a test record.
1300+
record := &monitorRecord{
1301+
requestID: requestID,
1302+
req: req,
1303+
feeFunction: m.feeFunc,
1304+
}
1305+
1306+
tp.updateRecord(record, sweepCtx)
12541307
record, ok := tp.records.Load(requestID)
12551308
require.True(t, ok)
12561309

@@ -1340,7 +1393,7 @@ func TestHandleFeeBumpTx(t *testing.T) {
13401393
outpointToTxIndex: utxoIndex,
13411394
}
13421395

1343-
tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
1396+
tp.updateRecord(record, sweepCtx)
13441397

13451398
// Create a subscription to the event.
13461399
subscriber := make(chan *BumpResult, 1)

0 commit comments

Comments
 (0)