Skip to content

Commit 4874081

Browse files
committed
sweepbatcher: consider change when presigning
Presigning sweeps takes change outputs into account. Each sweep belonging to the same sweep group points to the same change output, if existent. sweepbatcher.presign scans all passed sweeps for change outputs and passes them to constructUnsignedTx.
1 parent 23ee318 commit 4874081

File tree

2 files changed

+207
-21
lines changed

2 files changed

+207
-21
lines changed

sweepbatcher/presigned.go

Lines changed: 81 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ func ensurePresigned(ctx context.Context, newSweeps []*sweep,
5151
outpoint: s.outpoint,
5252
value: s.value,
5353
presigned: s.presigned,
54+
change: s.change,
5455
}
5556
}
5657

@@ -66,14 +67,20 @@ func ensurePresigned(ctx context.Context, newSweeps []*sweep,
6667
return fmt.Errorf("failed to find destination address: %w", err)
6768
}
6869

70+
// Get the change outputs for each sweep group.
71+
changeOutputs, err := getChangeOutputs(sweeps, chainParams)
72+
if err != nil {
73+
return fmt.Errorf("failed to get change outputs: %w", err)
74+
}
75+
6976
// Set LockTime to 0. It is not critical.
7077
const currentHeight = 0
7178

7279
// Check if we can sign with minimum fee rate.
7380
const feeRate = chainfee.FeePerKwFloor
7481

7582
tx, _, _, _, err := constructUnsignedTx(
76-
sweeps, destAddr, currentHeight, feeRate,
83+
sweeps, destAddr, changeOutputs, currentHeight, feeRate,
7784
)
7885
if err != nil {
7986
return fmt.Errorf("failed to construct unsigned tx "+
@@ -257,7 +264,7 @@ func (b *batch) presign(ctx context.Context, newSweeps []*sweep) error {
257264

258265
err = presign(
259266
ctx, b.cfg.presignedHelper, destAddr, primarySweepID,
260-
sweeps, nextBlockFeeRate,
267+
sweeps, nextBlockFeeRate, b.cfg.chainParams,
261268
)
262269
if err != nil {
263270
return fmt.Errorf("failed to presign a transaction "+
@@ -299,7 +306,8 @@ type presigner interface {
299306
// 10x of the current next block feerate.
300307
func presign(ctx context.Context, presigner presigner, destAddr btcutil.Address,
301308
primarySweepID wire.OutPoint, sweeps []sweep,
302-
nextBlockFeeRate chainfee.SatPerKWeight) error {
309+
nextBlockFeeRate chainfee.SatPerKWeight,
310+
chainParams *chaincfg.Params) error {
303311

304312
if presigner == nil {
305313
return fmt.Errorf("presigner is not installed")
@@ -328,6 +336,12 @@ func presign(ctx context.Context, presigner presigner, destAddr btcutil.Address,
328336
return fmt.Errorf("timeout is invalid: %d", timeout)
329337
}
330338

339+
// Get the change outputs of each sweep group.
340+
changeOutputs, err := getChangeOutputs(sweeps, chainParams)
341+
if err != nil {
342+
return fmt.Errorf("failed to get change outputs: %w", err)
343+
}
344+
331345
// Go from the floor (1.01 sat/vbyte) to 2k sat/vbyte with step of 1.2x.
332346
const (
333347
start = chainfee.FeePerKwFloor
@@ -353,7 +367,7 @@ func presign(ctx context.Context, presigner presigner, destAddr btcutil.Address,
353367
for fr := start; fr <= stop; fr = (fr * factorPPM) / 1_000_000 {
354368
// Construct an unsigned transaction for this fee rate.
355369
tx, _, feeForWeight, fee, err := constructUnsignedTx(
356-
sweeps, destAddr, currentHeight, fr,
370+
sweeps, destAddr, changeOutputs, currentHeight, fr,
357371
)
358372
if err != nil {
359373
return fmt.Errorf("failed to construct unsigned tx "+
@@ -438,9 +452,15 @@ func (b *batch) publishPresigned(ctx context.Context) (btcutil.Amount, error,
438452
err), false
439453
}
440454

455+
changeOutputs, err := getChangeOutputs(sweeps, b.cfg.chainParams)
456+
if err != nil {
457+
return 0, fmt.Errorf("failed to get change outputs: %w", err),
458+
false
459+
}
460+
441461
// Construct unsigned batch transaction.
442462
tx, weight, _, fee, err := constructUnsignedTx(
443-
sweeps, address, currentHeight, feeRate,
463+
sweeps, address, changeOutputs, currentHeight, feeRate,
444464
)
445465
if err != nil {
446466
return 0, fmt.Errorf("failed to construct tx: %w", err),
@@ -493,10 +513,12 @@ func (b *batch) publishPresigned(ctx context.Context) (btcutil.Amount, error,
493513
signedFeeRate := chainfee.NewSatPerKWeight(fee, realWeight)
494514

495515
numSweeps := len(tx.TxIn)
516+
numChange := len(tx.TxOut) - 1
496517
b.Infof("attempting to publish custom signed tx=%v, desiredFeerate=%v,"+
497-
" signedFeeRate=%v, weight=%v, fee=%v, sweeps=%d, destAddr=%s",
518+
" signedFeeRate=%v, weight=%v, fee=%v, sweeps=%d, "+
519+
"changeOutputs=%d, destAddr=%s",
498520
txHash, feeRate, signedFeeRate, realWeight, fee, numSweeps,
499-
address)
521+
numChange, address)
500522
b.debugLogTx("serialized batch", tx)
501523

502524
// Publish the transaction.
@@ -557,6 +579,46 @@ func getPresignedSweepsDestAddr(ctx context.Context, helper destPkScripter,
557579
return address, nil
558580
}
559581

582+
// getChangeOutputs retrieves the change output references of each sweep and
583+
// de-duplicates them. The function must be used in presigned mode only.
584+
func getChangeOutputs(sweeps []sweep, chainParams *chaincfg.Params) (
585+
map[*wire.TxOut]btcutil.Address, error) {
586+
587+
changeOutputs := make(map[*wire.TxOut]btcutil.Address)
588+
for _, sweep := range sweeps {
589+
// If the sweep has a change output, add it to the changeOutputs
590+
// map to avoid duplicates.
591+
if sweep.change != nil {
592+
// If the change output is already in the map, skip it.
593+
if _, exists := changeOutputs[sweep.change]; exists {
594+
continue
595+
}
596+
597+
// Convert the change output's pkScript to an
598+
// address.
599+
changePkScript, err := txscript.ParsePkScript(
600+
sweep.change.PkScript,
601+
)
602+
if err != nil {
603+
return nil, fmt.Errorf("failed to parse "+
604+
"change output pkScript: %w", err)
605+
}
606+
607+
address, err := changePkScript.Address(chainParams)
608+
if err != nil {
609+
return nil, fmt.Errorf("pkScript.Address "+
610+
"failed for pkScript %x returned for "+
611+
"change output: %w",
612+
sweep.change.PkScript, err)
613+
}
614+
615+
changeOutputs[sweep.change] = address
616+
}
617+
}
618+
619+
return changeOutputs, nil
620+
}
621+
560622
// CheckSignedTx makes sure that signedTx matches the unsignedTx. It checks
561623
// according to criteria specified in the description of PresignedHelper.SignTx.
562624
func CheckSignedTx(unsignedTx, signedTx *wire.MsgTx, inputAmt btcutil.Amount,
@@ -593,23 +655,23 @@ func CheckSignedTx(unsignedTx, signedTx *wire.MsgTx, inputAmt btcutil.Amount,
593655
}
594656

595657
// Compare outputs.
596-
if len(unsignedTx.TxOut) != 1 {
597-
return fmt.Errorf("unsigned tx has %d outputs, want 1",
598-
len(unsignedTx.TxOut))
599-
}
600-
if len(signedTx.TxOut) != 1 {
601-
return fmt.Errorf("the signed tx has %d outputs, want 1",
658+
if len(unsignedTx.TxOut) != len(signedTx.TxOut) {
659+
return fmt.Errorf("unsigned tx has %d outputs, signed tx has "+
660+
"%d outputs, should be equal", len(unsignedTx.TxOut),
602661
len(signedTx.TxOut))
603662
}
604-
unsignedOut := unsignedTx.TxOut[0]
605-
signedOut := signedTx.TxOut[0]
606-
if !bytes.Equal(unsignedOut.PkScript, signedOut.PkScript) {
607-
return fmt.Errorf("mismatch of output pkScript: %x, %x",
608-
unsignedOut.PkScript, signedOut.PkScript)
663+
for i, o := range unsignedTx.TxOut {
664+
if !bytes.Equal(o.PkScript, signedTx.TxOut[i].PkScript) {
665+
return fmt.Errorf("mismatch of output pkScript: %x, %x",
666+
o.PkScript, signedTx.TxOut[i].PkScript)
667+
}
609668
}
610669

670+
// The first output is always the batch output.
671+
batchOutput := signedTx.TxOut[0]
672+
611673
// Find the feerate of signedTx.
612-
fee := inputAmt - btcutil.Amount(signedOut.Value)
674+
fee := inputAmt - btcutil.Amount(batchOutput.Value)
613675
weight := lntypes.WeightUnit(
614676
blockchain.GetTransactionWeight(btcutil.NewTx(signedTx)),
615677
)

sweepbatcher/presigned_test.go

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,7 @@ func TestPresign(t *testing.T) {
10111011
ctx, tc.presigner, tc.destAddr,
10121012
tc.primarySweepID, tc.sweeps,
10131013
tc.nextBlockFeeRate,
1014+
&chaincfg.RegressionNetParams,
10141015
)
10151016
if tc.wantErr != "" {
10161017
require.Error(t, err)
@@ -1460,7 +1461,8 @@ func TestCheckSignedTx(t *testing.T) {
14601461
},
14611462
inputAmt: 3_000_000,
14621463
minRelayFee: 253,
1463-
wantErr: "unsigned tx has 2 outputs, want 1",
1464+
wantErr: "unsigned tx has 2 outputs, signed tx " +
1465+
"has 1 outputs, should be equal",
14641466
},
14651467

14661468
{
@@ -1517,7 +1519,8 @@ func TestCheckSignedTx(t *testing.T) {
15171519
},
15181520
inputAmt: 3_000_000,
15191521
minRelayFee: 253,
1520-
wantErr: "the signed tx has 2 outputs, want 1",
1522+
wantErr: "unsigned tx has 1 outputs, signed tx " +
1523+
"has 2 outputs, should be equal",
15211524
},
15221525

15231526
{
@@ -1642,3 +1645,124 @@ func TestCheckSignedTx(t *testing.T) {
16421645
})
16431646
}
16441647
}
1648+
1649+
// TestGetChangeOutputs tests that the change aggregation across sweeps works as
1650+
// intended. Each sweep of a sweep group should have a pointer to the same
1651+
// change output which is aggregated in getChangeOutput.
1652+
func TestGetChangeOutputs(t *testing.T) {
1653+
// Prepare the necessary data for test cases.
1654+
op1 := wire.OutPoint{
1655+
Hash: chainhash.Hash{1, 1, 1},
1656+
Index: 1,
1657+
}
1658+
op2 := wire.OutPoint{
1659+
Hash: chainhash.Hash{2, 2, 2},
1660+
Index: 2,
1661+
}
1662+
op3 := wire.OutPoint{
1663+
Hash: chainhash.Hash{3, 3, 3},
1664+
Index: 3,
1665+
}
1666+
1667+
batchPkScript, err := txscript.PayToAddrScript(destAddr)
1668+
require.NoError(t, err)
1669+
1670+
changeOutput1 := &wire.TxOut{
1671+
Value: 100_000,
1672+
PkScript: batchPkScript,
1673+
}
1674+
changeOutput2 := &wire.TxOut{
1675+
Value: 200_000,
1676+
PkScript: batchPkScript,
1677+
}
1678+
1679+
cases := []struct {
1680+
name string
1681+
sweeps []sweep
1682+
wantOutputs map[*wire.TxOut]btcutil.Address
1683+
wantErr string
1684+
}{
1685+
{
1686+
name: "no change",
1687+
sweeps: []sweep{
1688+
{
1689+
outpoint: op1,
1690+
value: 1_000_000,
1691+
change: nil,
1692+
},
1693+
},
1694+
wantOutputs: map[*wire.TxOut]btcutil.Address{},
1695+
},
1696+
{
1697+
name: "single sweep, single change",
1698+
sweeps: []sweep{
1699+
{
1700+
outpoint: op1,
1701+
value: 1_000_000,
1702+
change: changeOutput1,
1703+
},
1704+
},
1705+
wantOutputs: map[*wire.TxOut]btcutil.Address{
1706+
changeOutput1: destAddr,
1707+
},
1708+
},
1709+
{
1710+
name: "double sweep, single change",
1711+
sweeps: []sweep{
1712+
{
1713+
outpoint: op1,
1714+
value: 1_000_000,
1715+
change: changeOutput1,
1716+
},
1717+
{
1718+
outpoint: op2,
1719+
value: 1_000_000,
1720+
change: changeOutput1,
1721+
},
1722+
},
1723+
wantOutputs: map[*wire.TxOut]btcutil.Address{
1724+
changeOutput1: destAddr,
1725+
},
1726+
},
1727+
{
1728+
name: "double sweep, double change",
1729+
sweeps: []sweep{
1730+
{
1731+
outpoint: op1,
1732+
value: 1_000_000,
1733+
change: changeOutput1,
1734+
},
1735+
{
1736+
outpoint: op2,
1737+
value: 1_000_000,
1738+
change: changeOutput1,
1739+
},
1740+
{
1741+
outpoint: op3,
1742+
value: 1_000_000,
1743+
change: changeOutput2,
1744+
},
1745+
},
1746+
wantOutputs: map[*wire.TxOut]btcutil.Address{
1747+
changeOutput1: destAddr,
1748+
changeOutput2: destAddr,
1749+
},
1750+
},
1751+
}
1752+
1753+
for _, tc := range cases {
1754+
t.Run(tc.name, func(t *testing.T) {
1755+
changeOutputs, err := getChangeOutputs(
1756+
tc.sweeps, &chaincfg.RegressionNetParams,
1757+
)
1758+
if tc.wantErr != "" {
1759+
require.Error(t, err)
1760+
require.ErrorContains(t, err, tc.wantErr)
1761+
} else {
1762+
require.NoError(t, err)
1763+
}
1764+
1765+
require.Equal(t, tc.wantOutputs, changeOutputs)
1766+
})
1767+
}
1768+
}

0 commit comments

Comments
 (0)