Skip to content

Move persist into async part of the sweeper #3819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 52 additions & 37 deletions lightning/src/util/sweep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ where
output_spender: O, change_destination_source: D, kv_store: K, logger: L,
) -> Self {
let outputs = Vec::new();
let sweeper_state = Mutex::new(SweeperState { outputs, best_block });
let sweeper_state =
Mutex::new(SweeperState { persistent: PersistentSweeperState { outputs, best_block } });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Persistent is a bit clunky terminology, IMO, but nbd.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline, couldn't come up with something that is clearly better. Leaving as is.

Self {
sweeper_state,
pending_sweep: AtomicBool::new(false),
Expand Down Expand Up @@ -437,27 +438,27 @@ where
},
};

if state_lock.outputs.iter().find(|o| o.descriptor == output_info.descriptor).is_some()
{
let mut outputs = state_lock.persistent.outputs.iter();
if outputs.find(|o| o.descriptor == output_info.descriptor).is_some() {
continue;
}

state_lock.outputs.push(output_info);
state_lock.persistent.outputs.push(output_info);
}
self.persist_state(&*state_lock).map_err(|e| {
self.persist_state(&state_lock.persistent).map_err(|e| {
log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
})
}

/// Returns a list of the currently tracked spendable outputs.
pub fn tracked_spendable_outputs(&self) -> Vec<TrackedSpendableOutput> {
self.sweeper_state.lock().unwrap().outputs.clone()
self.sweeper_state.lock().unwrap().persistent.outputs.clone()
}

/// Gets the latest best block which was connected either via the [`Listen`] or
/// [`Confirm`] interfaces.
pub fn current_best_block(&self) -> BestBlock {
self.sweeper_state.lock().unwrap().best_block
self.sweeper_state.lock().unwrap().persistent.best_block
}

/// Regenerates and broadcasts the spending transaction for any outputs that are pending. This method will be a
Expand Down Expand Up @@ -505,8 +506,9 @@ where
{
let sweeper_state = self.sweeper_state.lock().unwrap();

let cur_height = sweeper_state.best_block.height;
let has_respends = sweeper_state.outputs.iter().any(|o| filter_fn(o, cur_height));
let cur_height = sweeper_state.persistent.best_block.height;
let has_respends =
sweeper_state.persistent.outputs.iter().any(|o| filter_fn(o, cur_height));
if !has_respends {
return Ok(());
}
Expand All @@ -520,10 +522,11 @@ where
{
let mut sweeper_state = self.sweeper_state.lock().unwrap();

let cur_height = sweeper_state.best_block.height;
let cur_hash = sweeper_state.best_block.block_hash;
let cur_height = sweeper_state.persistent.best_block.height;
let cur_hash = sweeper_state.persistent.best_block.block_hash;

let respend_descriptors: Vec<&SpendableOutputDescriptor> = sweeper_state
.persistent
.outputs
.iter()
.filter(|o| filter_fn(*o, cur_height))
Expand All @@ -536,7 +539,11 @@ where
}

let spending_tx = self
.spend_outputs(&sweeper_state, &respend_descriptors, change_destination_script)
.spend_outputs(
&sweeper_state.persistent,
&respend_descriptors,
change_destination_script,
)
.map_err(|e| {
log_error!(self.logger, "Error spending outputs: {:?}", e);
})?;
Expand All @@ -550,7 +557,7 @@ where
// As we didn't modify the state so far, the same filter_fn yields the same elements as
// above.
let respend_outputs =
sweeper_state.outputs.iter_mut().filter(|o| filter_fn(&**o, cur_height));
sweeper_state.persistent.outputs.iter_mut().filter(|o| filter_fn(&**o, cur_height));
for output_info in respend_outputs {
if let Some(filter) = self.chain_data_source.as_ref() {
let watched_output = output_info.to_watched_output(cur_hash);
Expand All @@ -560,7 +567,7 @@ where
output_info.status.broadcast(cur_hash, cur_height, spending_tx.clone());
}

self.persist_state(&sweeper_state).map_err(|e| {
self.persist_state(&sweeper_state.persistent).map_err(|e| {
log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
})?;

Expand All @@ -571,10 +578,10 @@ where
}

fn prune_confirmed_outputs(&self, sweeper_state: &mut SweeperState) {
let cur_height = sweeper_state.best_block.height;
let cur_height = sweeper_state.persistent.best_block.height;

// Prune all outputs that have sufficient depth by now.
sweeper_state.outputs.retain(|o| {
sweeper_state.persistent.outputs.retain(|o| {
if let Some(confirmation_height) = o.status.confirmation_height() {
// We wait at least `PRUNE_DELAY_BLOCKS` as before that
// `Event::SpendableOutputs` from lingering monitors might get replayed.
Expand All @@ -590,7 +597,7 @@ where
});
}

fn persist_state(&self, sweeper_state: &SweeperState) -> Result<(), io::Error> {
fn persist_state(&self, sweeper_state: &PersistentSweeperState) -> Result<(), io::Error> {
self.kv_store
.write(
OUTPUT_SWEEPER_PERSISTENCE_PRIMARY_NAMESPACE,
Expand All @@ -612,7 +619,7 @@ where
}

fn spend_outputs(
&self, sweeper_state: &SweeperState, descriptors: &[&SpendableOutputDescriptor],
&self, sweeper_state: &PersistentSweeperState, descriptors: &[&SpendableOutputDescriptor],
change_destination_script: ScriptBuf,
) -> Result<Transaction, ()> {
let tx_feerate =
Expand All @@ -635,7 +642,7 @@ where
) {
let confirmation_hash = header.block_hash();
for (_, tx) in txdata {
for output_info in sweeper_state.outputs.iter_mut() {
for output_info in sweeper_state.persistent.outputs.iter_mut() {
if output_info.is_spent_in(*tx) {
output_info.status.confirmed(confirmation_hash, height, (*tx).clone())
}
Expand All @@ -646,7 +653,7 @@ where
fn best_block_updated_internal(
&self, sweeper_state: &mut SweeperState, header: &Header, height: u32,
) {
sweeper_state.best_block = BestBlock::new(header.block_hash(), height);
sweeper_state.persistent.best_block = BestBlock::new(header.block_hash(), height);
self.prune_confirmed_outputs(sweeper_state);
}
}
Expand All @@ -666,15 +673,15 @@ where
&self, header: &Header, txdata: &chain::transaction::TransactionData, height: u32,
) {
let mut state_lock = self.sweeper_state.lock().unwrap();
assert_eq!(state_lock.best_block.block_hash, header.prev_blockhash,
assert_eq!(state_lock.persistent.best_block.block_hash, header.prev_blockhash,
"Blocks must be connected in chain-order - the connected header must build on the last connected header");
assert_eq!(state_lock.best_block.height, height - 1,
assert_eq!(state_lock.persistent.best_block.height, height - 1,
"Blocks must be connected in chain-order - the connected block height must be one greater than the previous height");

self.transactions_confirmed_internal(&mut *state_lock, header, txdata, height);
self.best_block_updated_internal(&mut *state_lock, header, height);

let _ = self.persist_state(&*state_lock).map_err(|e| {
let _ = self.persist_state(&state_lock.persistent).map_err(|e| {
log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
});
}
Expand All @@ -685,20 +692,20 @@ where
let new_height = height - 1;
let block_hash = header.block_hash();

assert_eq!(state_lock.best_block.block_hash, block_hash,
assert_eq!(state_lock.persistent.best_block.block_hash, block_hash,
"Blocks must be disconnected in chain-order - the disconnected header must be the last connected header");
assert_eq!(state_lock.best_block.height, height,
assert_eq!(state_lock.persistent.best_block.height, height,
"Blocks must be disconnected in chain-order - the disconnected block must have the correct height");
state_lock.best_block = BestBlock::new(header.prev_blockhash, new_height);
state_lock.persistent.best_block = BestBlock::new(header.prev_blockhash, new_height);

for output_info in state_lock.outputs.iter_mut() {
for output_info in state_lock.persistent.outputs.iter_mut() {
if output_info.status.confirmation_hash() == Some(block_hash) {
debug_assert_eq!(output_info.status.confirmation_height(), Some(height));
output_info.status.unconfirmed();
}
}

self.persist_state(&*state_lock).unwrap_or_else(|e| {
self.persist_state(&state_lock.persistent).unwrap_or_else(|e| {
log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
});
}
Expand All @@ -720,7 +727,7 @@ where
) {
let mut state_lock = self.sweeper_state.lock().unwrap();
self.transactions_confirmed_internal(&mut *state_lock, header, txdata, height);
self.persist_state(&*state_lock).unwrap_or_else(|e| {
self.persist_state(&state_lock.persistent).unwrap_or_else(|e| {
log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
});
}
Expand All @@ -730,6 +737,7 @@ where

// Get what height was unconfirmed.
let unconf_height = state_lock
.persistent
.outputs
.iter()
.find(|o| o.status.latest_spending_tx().map(|tx| tx.compute_txid()) == Some(*txid))
Expand All @@ -738,12 +746,13 @@ where
if let Some(unconf_height) = unconf_height {
// Unconfirm all >= this height.
state_lock
.persistent
.outputs
.iter_mut()
.filter(|o| o.status.confirmation_height() >= Some(unconf_height))
.for_each(|o| o.status.unconfirmed());

self.persist_state(&*state_lock).unwrap_or_else(|e| {
self.persist_state(&state_lock.persistent).unwrap_or_else(|e| {
log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
});
}
Expand All @@ -752,14 +761,15 @@ where
fn best_block_updated(&self, header: &Header, height: u32) {
let mut state_lock = self.sweeper_state.lock().unwrap();
self.best_block_updated_internal(&mut *state_lock, header, height);
let _ = self.persist_state(&*state_lock).map_err(|e| {
let _ = self.persist_state(&state_lock.persistent).map_err(|e| {
log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
});
}

fn get_relevant_txids(&self) -> Vec<(Txid, u32, Option<BlockHash>)> {
let state_lock = self.sweeper_state.lock().unwrap();
state_lock
.persistent
.outputs
.iter()
.filter_map(|o| match o.status {
Expand All @@ -779,13 +789,18 @@ where
}
}

#[derive(Debug, Clone)]
#[derive(Debug)]
struct SweeperState {
persistent: PersistentSweeperState,
}

#[derive(Debug, Clone)]
struct PersistentSweeperState {
outputs: Vec<TrackedSpendableOutput>,
best_block: BestBlock,
}

impl_writeable_tlv_based!(SweeperState, {
impl_writeable_tlv_based!(PersistentSweeperState, {
(0, outputs, required_vec),
(2, best_block, required),
});
Expand Down Expand Up @@ -831,7 +846,7 @@ where
kv_store,
logger,
) = args;
let state = SweeperState::read(reader)?;
let state = PersistentSweeperState::read(reader)?;
let best_block = state.best_block;

if let Some(filter) = chain_data_source.as_ref() {
Expand All @@ -841,7 +856,7 @@ where
}
}

let sweeper_state = Mutex::new(state);
let sweeper_state = Mutex::new(SweeperState { persistent: state });
Ok(Self {
sweeper_state,
pending_sweep: AtomicBool::new(false),
Expand Down Expand Up @@ -880,7 +895,7 @@ where
kv_store,
logger,
) = args;
let state = SweeperState::read(reader)?;
let state = PersistentSweeperState::read(reader)?;
let best_block = state.best_block;

if let Some(filter) = chain_data_source.as_ref() {
Expand All @@ -890,7 +905,7 @@ where
}
}

let sweeper_state = Mutex::new(state);
let sweeper_state = Mutex::new(SweeperState { persistent: state });
Ok((
best_block,
OutputSweeper {
Expand Down