Skip to content

Commit 28781b5

Browse files
committed
psbt: Add sighash types to PSBT when not DEFAULT or ALL
When an atypical sighash type is specified by the user, add it to the PSBT so that further signing can enforce sighash type matching.
1 parent 15ce1bd commit 28781b5

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

src/psbt.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,13 @@ PSBTError SignPSBTInput(const SigningProvider& provider, PartiallySignedTransact
424424
if (input.sighash_type && input.sighash_type != sighash) {
425425
return PSBTError::SIGHASH_MISMATCH;
426426
}
427+
// Set the PSBT sighash field when sighash is not DEFAULT or ALL
428+
// DEFAULT is allowed for non-taproot inputs since DEFAULT may be passed for them (e.g. the psbt being signed also has taproot inputs)
429+
// Note that signing already aliases DEFAULT to ALL for non-taproot inputs.
430+
if (utxo.scriptPubKey.IsPayToTaproot() ? sighash != SIGHASH_DEFAULT :
431+
(sighash != SIGHASH_DEFAULT && sighash != SIGHASH_ALL)) {
432+
input.sighash_type = sighash;
433+
}
427434

428435
// Check all existing signatures use the sighash type
429436
if (sighash == SIGHASH_DEFAULT) {
@@ -522,7 +529,8 @@ bool FinalizePSBT(PartiallySignedTransaction& psbtx)
522529
bool complete = true;
523530
const PrecomputedTransactionData txdata = PrecomputePSBTData(psbtx);
524531
for (unsigned int i = 0; i < psbtx.tx->vin.size(); ++i) {
525-
complete &= (SignPSBTInput(DUMMY_SIGNING_PROVIDER, psbtx, i, &txdata, std::nullopt, nullptr, true) == PSBTError::OK);
532+
PSBTInput& input = psbtx.inputs.at(i);
533+
complete &= (SignPSBTInput(DUMMY_SIGNING_PROVIDER, psbtx, i, &txdata, input.sighash_type, nullptr, true) == PSBTError::OK);
526534
}
527535

528536
return complete;

test/functional/rpc_psbt.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,53 @@ def test_sighash_mismatch(self):
301301

302302
wallet.unloadwallet()
303303

304+
def test_sighash_adding(self):
305+
self.log.info("Test adding of sighash type field")
306+
self.nodes[0].createwallet("sighash_adding")
307+
wallet = self.nodes[0].get_wallet_rpc("sighash_adding")
308+
def_wallet = self.nodes[0].get_wallet_rpc(self.default_wallet_name)
309+
310+
outputs = [{wallet.getnewaddress(address_type="bech32"): 1}]
311+
outputs.append({wallet.getnewaddress(address_type="bech32m"): 1})
312+
descs = wallet.listdescriptors(True)["descriptors"]
313+
def_wallet.send(outputs)
314+
self.generate(self.nodes[0], 6)
315+
utxos = wallet.listunspent()
316+
317+
# Make a PSBT
318+
psbt = wallet.walletcreatefundedpsbt(utxos, [{def_wallet.getnewaddress(): 0.5}])["psbt"]
319+
320+
# Process the PSBT with the wallet
321+
wallet_psbt = wallet.walletprocesspsbt(psbt=psbt, sighashtype="ALL|ANYONECANPAY", finalize=False)["psbt"]
322+
323+
# Separately process the PSBT with descriptors
324+
desc_psbt = self.nodes[0].descriptorprocesspsbt(psbt=psbt, descriptors=descs, sighashtype="ALL|ANYONECANPAY", finalize=False)["psbt"]
325+
326+
for psbt in [wallet_psbt, desc_psbt]:
327+
# Check that the PSBT has a sighash field on all inputs
328+
dec_psbt = self.nodes[0].decodepsbt(psbt)
329+
for input in dec_psbt["inputs"]:
330+
assert_equal(input["sighash"], "ALL|ANYONECANPAY")
331+
332+
# Make sure we can still finalize the transaction
333+
fin_res = self.nodes[0].finalizepsbt(psbt)
334+
assert_equal(fin_res["complete"], True)
335+
fin_hex = fin_res["hex"]
336+
assert_equal(self.nodes[0].testmempoolaccept([fin_hex])[0]["allowed"], True)
337+
338+
# Change the sighash field to a different value and make sure we can no longer finalize
339+
mod_psbt = PSBT.from_base64(psbt)
340+
mod_psbt.i[0].map[PSBT_IN_SIGHASH_TYPE] = (SIGHASH_ALL).to_bytes(4, byteorder="little")
341+
mod_psbt.i[1].map[PSBT_IN_SIGHASH_TYPE] = (SIGHASH_ALL).to_bytes(4, byteorder="little")
342+
psbt = mod_psbt.to_base64()
343+
fin_res = self.nodes[0].finalizepsbt(psbt)
344+
assert_equal(fin_res["complete"], False)
345+
346+
self.nodes[0].sendrawtransaction(fin_hex)
347+
self.generate(self.nodes[0], 1)
348+
349+
wallet.unloadwallet()
350+
304351
def assert_change_type(self, psbtx, expected_type):
305352
"""Assert that the given PSBT has a change output with the given type."""
306353

@@ -1139,6 +1186,7 @@ def test_psbt_input_keys(psbt_input, keys):
11391186
assert_raises_rpc_error(-8, "'all' is not a valid sighash parameter.", self.nodes[2].descriptorprocesspsbt, psbt, [descriptor], sighashtype="all")
11401187

11411188
self.test_sighash_mismatch()
1189+
self.test_sighash_adding()
11421190

11431191
if __name__ == '__main__':
11441192
PSBTTest(__file__).main()

0 commit comments

Comments
 (0)