Skip to content

sweepbatcher: batch change outputs #976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 81 additions & 19 deletions sweepbatcher/presigned.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ func ensurePresigned(ctx context.Context, newSweeps []*sweep,
outpoint: s.outpoint,
value: s.value,
presigned: s.presigned,
change: s.change,
}
}

Expand All @@ -66,14 +67,20 @@ func ensurePresigned(ctx context.Context, newSweeps []*sweep,
return fmt.Errorf("failed to find destination address: %w", err)
}

// Get the change outputs for each sweep group.
changeOutputs, err := getChangeOutputs(sweeps, chainParams)
if err != nil {
return fmt.Errorf("failed to get change outputs: %w", err)
}

// Set LockTime to 0. It is not critical.
const currentHeight = 0

// Check if we can sign with minimum fee rate.
const feeRate = chainfee.FeePerKwFloor

tx, _, _, _, err := constructUnsignedTx(
sweeps, destAddr, currentHeight, feeRate,
sweeps, destAddr, changeOutputs, currentHeight, feeRate,
)
if err != nil {
return fmt.Errorf("failed to construct unsigned tx "+
Expand Down Expand Up @@ -257,7 +264,7 @@ func (b *batch) presign(ctx context.Context, newSweeps []*sweep) error {

err = presign(
ctx, b.cfg.presignedHelper, destAddr, primarySweepID,
sweeps, nextBlockFeeRate,
sweeps, nextBlockFeeRate, b.cfg.chainParams,
)
if err != nil {
return fmt.Errorf("failed to presign a transaction "+
Expand Down Expand Up @@ -299,7 +306,8 @@ type presigner interface {
// 10x of the current next block feerate.
func presign(ctx context.Context, presigner presigner, destAddr btcutil.Address,
primarySweepID wire.OutPoint, sweeps []sweep,
nextBlockFeeRate chainfee.SatPerKWeight) error {
nextBlockFeeRate chainfee.SatPerKWeight,
chainParams *chaincfg.Params) error {

if presigner == nil {
return fmt.Errorf("presigner is not installed")
Expand Down Expand Up @@ -328,6 +336,12 @@ func presign(ctx context.Context, presigner presigner, destAddr btcutil.Address,
return fmt.Errorf("timeout is invalid: %d", timeout)
}

// Get the change outputs of each sweep group.
changeOutputs, err := getChangeOutputs(sweeps, chainParams)
if err != nil {
return fmt.Errorf("failed to get change outputs: %w", err)
}

// Go from the floor (1.01 sat/vbyte) to 2k sat/vbyte with step of 1.2x.
const (
start = chainfee.FeePerKwFloor
Expand All @@ -353,7 +367,7 @@ func presign(ctx context.Context, presigner presigner, destAddr btcutil.Address,
for fr := start; fr <= stop; fr = (fr * factorPPM) / 1_000_000 {
// Construct an unsigned transaction for this fee rate.
tx, _, feeForWeight, fee, err := constructUnsignedTx(
sweeps, destAddr, currentHeight, fr,
sweeps, destAddr, changeOutputs, currentHeight, fr,
)
if err != nil {
return fmt.Errorf("failed to construct unsigned tx "+
Expand Down Expand Up @@ -438,9 +452,15 @@ func (b *batch) publishPresigned(ctx context.Context) (btcutil.Amount, error,
err), false
}

changeOutputs, err := getChangeOutputs(sweeps, b.cfg.chainParams)
if err != nil {
return 0, fmt.Errorf("failed to get change outputs: %w", err),
false
}

// Construct unsigned batch transaction.
tx, weight, _, fee, err := constructUnsignedTx(
sweeps, address, currentHeight, feeRate,
sweeps, address, changeOutputs, currentHeight, feeRate,
)
if err != nil {
return 0, fmt.Errorf("failed to construct tx: %w", err),
Expand Down Expand Up @@ -493,10 +513,12 @@ func (b *batch) publishPresigned(ctx context.Context) (btcutil.Amount, error,
signedFeeRate := chainfee.NewSatPerKWeight(fee, realWeight)

numSweeps := len(tx.TxIn)
numChange := len(tx.TxOut) - 1
b.Infof("attempting to publish custom signed tx=%v, desiredFeerate=%v,"+
" signedFeeRate=%v, weight=%v, fee=%v, sweeps=%d, destAddr=%s",
" signedFeeRate=%v, weight=%v, fee=%v, sweeps=%d, "+
"changeOutputs=%d, destAddr=%s",
txHash, feeRate, signedFeeRate, realWeight, fee, numSweeps,
address)
numChange, address)
b.debugLogTx("serialized batch", tx)

// Publish the transaction.
Expand Down Expand Up @@ -557,6 +579,46 @@ func getPresignedSweepsDestAddr(ctx context.Context, helper destPkScripter,
return address, nil
}

// getChangeOutputs retrieves the change output references of each sweep and
// de-duplicates them. The function must be used in presigned mode only.
func getChangeOutputs(sweeps []sweep, chainParams *chaincfg.Params) (
map[*wire.TxOut]btcutil.Address, error) {

changeOutputs := make(map[*wire.TxOut]btcutil.Address)
for _, sweep := range sweeps {
// If the sweep has a change output, add it to the changeOutputs
// map to avoid duplicates.
if sweep.change != nil {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flatten depth by if == nil { continue }

// If the change output is already in the map, skip it.
if _, exists := changeOutputs[sweep.change]; exists {
continue
}

// Convert the change output's pkScript to an
// address.
changePkScript, err := txscript.ParsePkScript(
sweep.change.PkScript,
)
if err != nil {
return nil, fmt.Errorf("failed to parse "+
"change output pkScript: %w", err)
}

address, err := changePkScript.Address(chainParams)
if err != nil {
return nil, fmt.Errorf("pkScript.Address "+
"failed for pkScript %x returned for "+
"change output: %w",
sweep.change.PkScript, err)
}

changeOutputs[sweep.change] = address
}
}

return changeOutputs, nil
}

// CheckSignedTx makes sure that signedTx matches the unsignedTx. It checks
// according to criteria specified in the description of PresignedHelper.SignTx.
func CheckSignedTx(unsignedTx, signedTx *wire.MsgTx, inputAmt btcutil.Amount,
Expand Down Expand Up @@ -593,23 +655,23 @@ func CheckSignedTx(unsignedTx, signedTx *wire.MsgTx, inputAmt btcutil.Amount,
}

// Compare outputs.
if len(unsignedTx.TxOut) != 1 {
return fmt.Errorf("unsigned tx has %d outputs, want 1",
len(unsignedTx.TxOut))
}
if len(signedTx.TxOut) != 1 {
return fmt.Errorf("the signed tx has %d outputs, want 1",
if len(unsignedTx.TxOut) != len(signedTx.TxOut) {
return fmt.Errorf("unsigned tx has %d outputs, signed tx has "+
"%d outputs, should be equal", len(unsignedTx.TxOut),
len(signedTx.TxOut))
}
unsignedOut := unsignedTx.TxOut[0]
signedOut := signedTx.TxOut[0]
if !bytes.Equal(unsignedOut.PkScript, signedOut.PkScript) {
return fmt.Errorf("mismatch of output pkScript: %x, %x",
unsignedOut.PkScript, signedOut.PkScript)
for i, o := range unsignedTx.TxOut {
if !bytes.Equal(o.PkScript, signedTx.TxOut[i].PkScript) {
return fmt.Errorf("mismatch of output pkScript: %x, %x",
o.PkScript, signedTx.TxOut[i].PkScript)
}
}

// The first output is always the batch output.
batchOutput := signedTx.TxOut[0]

// Find the feerate of signedTx.
fee := inputAmt - btcutil.Amount(signedOut.Value)
fee := inputAmt - btcutil.Amount(batchOutput.Value)
weight := lntypes.WeightUnit(
blockchain.GetTransactionWeight(btcutil.NewTx(signedTx)),
)
Expand Down
128 changes: 126 additions & 2 deletions sweepbatcher/presigned_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,7 @@ func TestPresign(t *testing.T) {
ctx, tc.presigner, tc.destAddr,
tc.primarySweepID, tc.sweeps,
tc.nextBlockFeeRate,
&chaincfg.RegressionNetParams,
)
if tc.wantErr != "" {
require.Error(t, err)
Expand Down Expand Up @@ -1460,7 +1461,8 @@ func TestCheckSignedTx(t *testing.T) {
},
inputAmt: 3_000_000,
minRelayFee: 253,
wantErr: "unsigned tx has 2 outputs, want 1",
wantErr: "unsigned tx has 2 outputs, signed tx " +
"has 1 outputs, should be equal",
},

{
Expand Down Expand Up @@ -1517,7 +1519,8 @@ func TestCheckSignedTx(t *testing.T) {
},
inputAmt: 3_000_000,
minRelayFee: 253,
wantErr: "the signed tx has 2 outputs, want 1",
wantErr: "unsigned tx has 1 outputs, signed tx " +
"has 2 outputs, should be equal",
},

{
Expand Down Expand Up @@ -1642,3 +1645,124 @@ func TestCheckSignedTx(t *testing.T) {
})
}
}

// TestGetChangeOutputs tests that the change aggregation across sweeps works as
// intended. Each sweep of a sweep group should have a pointer to the same
// change output which is aggregated in getChangeOutput.
func TestGetChangeOutputs(t *testing.T) {
// Prepare the necessary data for test cases.
op1 := wire.OutPoint{
Hash: chainhash.Hash{1, 1, 1},
Index: 1,
}
op2 := wire.OutPoint{
Hash: chainhash.Hash{2, 2, 2},
Index: 2,
}
op3 := wire.OutPoint{
Hash: chainhash.Hash{3, 3, 3},
Index: 3,
}

batchPkScript, err := txscript.PayToAddrScript(destAddr)
require.NoError(t, err)

changeOutput1 := &wire.TxOut{
Value: 100_000,
PkScript: batchPkScript,
}
changeOutput2 := &wire.TxOut{
Value: 200_000,
PkScript: batchPkScript,
}

cases := []struct {
name string
sweeps []sweep
wantOutputs map[*wire.TxOut]btcutil.Address
wantErr string
}{
{
name: "no change",
sweeps: []sweep{
{
outpoint: op1,
value: 1_000_000,
change: nil,
},
},
wantOutputs: map[*wire.TxOut]btcutil.Address{},
},
{
name: "single sweep, single change",
sweeps: []sweep{
{
outpoint: op1,
value: 1_000_000,
change: changeOutput1,
},
},
wantOutputs: map[*wire.TxOut]btcutil.Address{
changeOutput1: destAddr,
},
},
{
name: "double sweep, single change",
sweeps: []sweep{
{
outpoint: op1,
value: 1_000_000,
change: changeOutput1,
},
{
outpoint: op2,
value: 1_000_000,
change: changeOutput1,
},
},
wantOutputs: map[*wire.TxOut]btcutil.Address{
changeOutput1: destAddr,
},
},
{
name: "double sweep, double change",
sweeps: []sweep{
{
outpoint: op1,
value: 1_000_000,
change: changeOutput1,
},
{
outpoint: op2,
value: 1_000_000,
change: changeOutput1,
},
{
outpoint: op3,
value: 1_000_000,
change: changeOutput2,
},
},
wantOutputs: map[*wire.TxOut]btcutil.Address{
changeOutput1: destAddr,
changeOutput2: destAddr,
},
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
changeOutputs, err := getChangeOutputs(
tc.sweeps, &chaincfg.RegressionNetParams,
)
if tc.wantErr != "" {
require.Error(t, err)
require.ErrorContains(t, err, tc.wantErr)
} else {
require.NoError(t, err)
}

require.Equal(t, tc.wantOutputs, changeOutputs)
})
}
}
Loading