Skip to content

Commit 8c9ba32

Browse files
committed
sweep: add method getSpentInputs
To track the input and its spending tx, which will be used later to detect unknown spends.
1 parent 0e87863 commit 8c9ba32

File tree

2 files changed

+175
-22
lines changed

2 files changed

+175
-22
lines changed

sweep/fee_bumper.go

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,7 @@ func (t *TxPublisher) processRecords() {
908908
// Check whether the inputs has been spent by a third party.
909909
//
910910
// NOTE: this check is only done for neutrino backend.
911-
if t.isThirdPartySpent(r.tx.TxHash(), r.req.Inputs) {
911+
if t.isThirdPartySpent(r) {
912912
failedRecords[requestID] = r
913913

914914
// Move to the next record.
@@ -1253,26 +1253,59 @@ func (t *TxPublisher) isConfirmed(txid chainhash.Hash) bool {
12531253
//
12541254
// NOTE: this check is only performed for neutrino backend as it has no
12551255
// reliable way to tell a tx has been replaced.
1256-
func (t *TxPublisher) isThirdPartySpent(txid chainhash.Hash,
1257-
inputs []input.Input) bool {
1258-
1256+
func (t *TxPublisher) isThirdPartySpent(r *monitorRecord) bool {
12591257
// Skip this check for if this is not neutrino backend.
12601258
if !t.isNeutrinoBackend() {
12611259
return false
12621260
}
12631261

1262+
txid := r.tx.TxHash()
1263+
spends := t.getSpentInputs(r)
1264+
1265+
// Iterate all the spending txns and check if they match the sweeping
1266+
// tx.
1267+
for op, spendingTx := range spends {
1268+
spendingTxID := spendingTx.TxHash()
1269+
1270+
// If the spending tx is the same as the sweeping tx
1271+
// then we are good.
1272+
if spendingTxID == txid {
1273+
continue
1274+
}
1275+
1276+
log.Warnf("Detected third party spent of output=%v "+
1277+
"in tx=%v", op, spendingTx.TxHash())
1278+
1279+
return true
1280+
}
1281+
1282+
return false
1283+
}
1284+
1285+
// getSpentInputs performs a non-blocking read on the spending subscriptions to
1286+
// see whether any of the monitored inputs has been spent. A map of inputs with
1287+
// their spending txns are returned if found.
1288+
func (t *TxPublisher) getSpentInputs(
1289+
r *monitorRecord) map[wire.OutPoint]*wire.MsgTx {
1290+
1291+
// Create a slice to record the inputs spent.
1292+
spentInputs := make(map[wire.OutPoint]*wire.MsgTx, len(r.req.Inputs))
1293+
12641294
// Iterate all the inputs and check if they have been spent already.
1265-
for _, inp := range inputs {
1295+
for _, inp := range r.req.Inputs {
12661296
op := inp.OutPoint()
12671297

12681298
// For wallet utxos, the height hint is not set - we don't need
12691299
// to monitor them for third party spend.
1300+
//
1301+
// TODO(yy): We need to properly lock wallet utxos before
1302+
// skipping this check as the same wallet utxo can be used by
1303+
// different sweeping txns.
12701304
heightHint := inp.HeightHint()
12711305
if heightHint == 0 {
1272-
log.Debugf("Skipped third party check for wallet "+
1273-
"input %v", op)
1274-
1275-
continue
1306+
heightHint = uint32(t.currentHeight.Load())
1307+
log.Debugf("Checking wallet input %v using heightHint "+
1308+
"%v", op, heightHint)
12761309
}
12771310

12781311
// If the input has already been spent after the height hint, a
@@ -1283,7 +1316,8 @@ func (t *TxPublisher) isThirdPartySpent(txid chainhash.Hash,
12831316
if err != nil {
12841317
log.Criticalf("Failed to register spend ntfn for "+
12851318
"input=%v: %v", op, err)
1286-
return false
1319+
1320+
return nil
12871321
}
12881322

12891323
// Remove the subscription when exit.
@@ -1294,28 +1328,24 @@ func (t *TxPublisher) isThirdPartySpent(txid chainhash.Hash,
12941328
case spend, ok := <-spendEvent.Spend:
12951329
if !ok {
12961330
log.Debugf("Spend ntfn for %v canceled", op)
1297-
return false
1298-
}
1299-
1300-
spendingTxID := spend.SpendingTx.TxHash()
13011331

1302-
// If the spending tx is the same as the sweeping tx
1303-
// then we are good.
1304-
if spendingTxID == txid {
13051332
continue
13061333
}
13071334

1308-
log.Warnf("Detected third party spent of output=%v "+
1309-
"in tx=%v", op, spend.SpendingTx.TxHash())
1335+
spendingTx := spend.SpendingTx
13101336

1311-
return true
1337+
log.Debugf("Detected spent of input=%v in tx=%v", op,
1338+
spendingTx.TxHash())
1339+
1340+
spentInputs[op] = spendingTx
13121341

13131342
// Move to the next input.
13141343
default:
1344+
log.Tracef("Input %v not spent yet", op)
13151345
}
13161346
}
13171347

1318-
return false
1348+
return spentInputs
13191349
}
13201350

13211351
// calcCurrentConfTarget calculates the current confirmation target based on

sweep/fee_bumper_test.go

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func createTestInput(value int64,
5555
PubKey: testPubKey,
5656
},
5757
},
58-
0,
58+
1,
5959
nil,
6060
)
6161

@@ -1776,3 +1776,126 @@ func TestHandleInitialBroadcastFail(t *testing.T) {
17761776
require.Equal(t, 0, tp.records.Len())
17771777
require.Equal(t, 0, tp.subscriberChans.Len())
17781778
}
1779+
1780+
// TestHasInputsSpent checks the expected outpoint:tx map is returned.
1781+
func TestHasInputsSpent(t *testing.T) {
1782+
t.Parallel()
1783+
1784+
// Create a publisher using the mocks.
1785+
tp, m := createTestPublisher(t)
1786+
1787+
// Create mock inputs.
1788+
op1 := wire.OutPoint{
1789+
Hash: chainhash.Hash{1},
1790+
Index: 1,
1791+
}
1792+
inp1 := &input.MockInput{}
1793+
heightHint1 := uint32(1)
1794+
defer inp1.AssertExpectations(t)
1795+
1796+
op2 := wire.OutPoint{
1797+
Hash: chainhash.Hash{1},
1798+
Index: 2,
1799+
}
1800+
inp2 := &input.MockInput{}
1801+
heightHint2 := uint32(2)
1802+
defer inp2.AssertExpectations(t)
1803+
1804+
op3 := wire.OutPoint{
1805+
Hash: chainhash.Hash{1},
1806+
Index: 3,
1807+
}
1808+
walletInp := &input.MockInput{}
1809+
heightHint3 := uint32(0)
1810+
defer walletInp.AssertExpectations(t)
1811+
1812+
// We expect all the inputs to call OutPoint and HeightHint.
1813+
inp1.On("OutPoint").Return(op1).Once()
1814+
inp2.On("OutPoint").Return(op2).Once()
1815+
walletInp.On("OutPoint").Return(op3).Once()
1816+
inp1.On("HeightHint").Return(heightHint1).Once()
1817+
inp2.On("HeightHint").Return(heightHint2).Once()
1818+
walletInp.On("HeightHint").Return(heightHint3).Once()
1819+
1820+
// We expect the normal inputs to call SignDesc.
1821+
pkScript1 := []byte{1}
1822+
sd1 := &input.SignDescriptor{
1823+
Output: &wire.TxOut{
1824+
PkScript: pkScript1,
1825+
},
1826+
}
1827+
inp1.On("SignDesc").Return(sd1).Once()
1828+
1829+
pkScript2 := []byte{1}
1830+
sd2 := &input.SignDescriptor{
1831+
Output: &wire.TxOut{
1832+
PkScript: pkScript2,
1833+
},
1834+
}
1835+
inp2.On("SignDesc").Return(sd2).Once()
1836+
1837+
pkScript3 := []byte{3}
1838+
sd3 := &input.SignDescriptor{
1839+
Output: &wire.TxOut{
1840+
PkScript: pkScript3,
1841+
},
1842+
}
1843+
walletInp.On("SignDesc").Return(sd3).Once()
1844+
1845+
// Mock RegisterSpendNtfn.
1846+
//
1847+
// spendingTx1 is the tx spending op1.
1848+
spendingTx1 := &wire.MsgTx{}
1849+
se1 := createTestSpendEvent(spendingTx1)
1850+
m.notifier.On("RegisterSpendNtfn",
1851+
&op1, pkScript1, heightHint1).Return(se1, nil).Once()
1852+
1853+
// Create the spending event that doesn't send an event.
1854+
se2 := &chainntnfs.SpendEvent{
1855+
Cancel: func() {},
1856+
}
1857+
m.notifier.On("RegisterSpendNtfn",
1858+
&op2, pkScript2, heightHint2).Return(se2, nil).Once()
1859+
1860+
se3 := &chainntnfs.SpendEvent{
1861+
Cancel: func() {},
1862+
}
1863+
m.notifier.On("RegisterSpendNtfn",
1864+
&op3, pkScript3, heightHint3).Return(se3, nil).Once()
1865+
1866+
// Prepare the test inputs.
1867+
inputs := []input.Input{inp1, inp2, walletInp}
1868+
1869+
// Prepare the test record.
1870+
record := &monitorRecord{
1871+
req: &BumpRequest{
1872+
Inputs: inputs,
1873+
},
1874+
}
1875+
1876+
// Call the method under test.
1877+
result := tp.getSpentInputs(record)
1878+
1879+
// Assert the expected map is created.
1880+
expected := map[wire.OutPoint]*wire.MsgTx{
1881+
op1: spendingTx1,
1882+
}
1883+
require.Equal(t, expected, result)
1884+
}
1885+
1886+
// createTestSpendEvent creates a SpendEvent which places the specified tx in
1887+
// the channel, which can be read by a spending subscriber.
1888+
func createTestSpendEvent(tx *wire.MsgTx) *chainntnfs.SpendEvent {
1889+
// Create a monitor record that's confirmed.
1890+
spendDetails := chainntnfs.SpendDetail{
1891+
SpendingTx: tx,
1892+
}
1893+
spendChan1 := make(chan *chainntnfs.SpendDetail, 1)
1894+
spendChan1 <- &spendDetails
1895+
1896+
// Create the spend events.
1897+
return &chainntnfs.SpendEvent{
1898+
Spend: spendChan1,
1899+
Cancel: func() {},
1900+
}
1901+
}

0 commit comments

Comments
 (0)