diff --git a/Cargo.toml b/Cargo.toml index 19389a5f8..00a5b1687 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,14 +33,6 @@ lightning-background-processor = { git = "https://github.com/lightningdevkit/rus lightning-rapid-gossip-sync = { git = "https://github.com/lightningdevkit/rust-lightning", rev="7b85ebadb64058127350b83fb4b76dcb409ea518" } lightning-transaction-sync = { git = "https://github.com/lightningdevkit/rust-lightning", rev="7b85ebadb64058127350b83fb4b76dcb409ea518", features = ["esplora-async"] } -#lightning = { git = "https://github.com/tnull/rust-lightning", branch="2023-03-expose-impl-writeable-tlv-based-enum-common", features = ["max_level_trace", "std"] } -#lightning-invoice = { git = "https://github.com/tnull/rust-lightning", branch="2023-03-expose-impl-writeable-tlv-based-enum-common" } -#lightning-net-tokio = { git = "https://github.com/tnull/rust-lightning", branch="2023-03-expose-impl-writeable-tlv-based-enum-common" } -#lightning-persister = { git = "https://github.com/tnull/rust-lightning", branch="2023-03-expose-impl-writeable-tlv-based-enum-common" } -#lightning-background-processor = { git = "https://github.com/tnull/rust-lightning", branch="2023-03-expose-impl-writeable-tlv-based-enum-common" } -#lightning-rapid-gossip-sync = { git = "https://github.com/tnull/rust-lightning", branch="2023-03-expose-impl-writeable-tlv-based-enum-common" } -#lightning-transaction-sync = { git = "https://github.com/tnull/rust-lightning", branch="2023-03-expose-impl-writeable-tlv-based-enum-common", features = ["esplora-async"] } - #lightning = { path = "../rust-lightning/lightning", features = ["max_level_trace", "std"] } #lightning-invoice = { path = "../rust-lightning/lightning-invoice" } #lightning-net-tokio = { path = "../rust-lightning/lightning-net-tokio" } @@ -49,7 +41,7 @@ lightning-transaction-sync = { git = "https://github.com/lightningdevkit/rust-li #lightning-rapid-gossip-sync = { path = "../rust-lightning/lightning-rapid-gossip-sync" } #lightning-transaction-sync = { path = "../rust-lightning/lightning-transaction-sync", features = ["esplora-async"] } -bdk = { version = "=0.27.1", default-features = false, features = ["async-interface", "use-esplora-async", "sqlite-bundled"]} +bdk = { version = "0.27.1", default-features = false, features = ["async-interface", "use-esplora-async", "sqlite-bundled"]} reqwest = { version = "0.11", default-features = false, features = ["json", "rustls-tls"] } rusqlite = { version = "0.28.0", features = ["bundled"] } bitcoin = "0.29.2" @@ -59,13 +51,14 @@ chrono = "0.4" futures = "0.3" serde_json = { version = "1.0" } tokio = { version = "1", default-features = false, features = [ "rt-multi-thread", "time", "sync" ] } -esplora-client = { version = "=0.3", default-features = false } +esplora-client = { version = "0.4", default-features = false } +libc = "0.2" [dev-dependencies] electrsd = { version = "0.22.0", features = ["legacy", "esplora_a33e97e1", "bitcoind_23_0"] } electrum-client = "0.12.0" -once_cell = "1.16.0" proptest = "1.0.0" +regex = "1.5.6" [profile.release] panic = "abort" diff --git a/src/event.rs b/src/event.rs index 10f0c364a..c655bc22a 100644 --- a/src/event.rs +++ b/src/event.rs @@ -1,8 +1,12 @@ -use crate::{ - hex_utils, ChannelManager, Config, Error, KeysManager, NetworkGraph, PaymentInfo, - PaymentInfoStorage, PaymentStatus, Wallet, +use crate::{hex_utils, ChannelManager, Config, Error, KeysManager, NetworkGraph, Wallet}; + +use crate::payment_store::{ + PaymentDetails, PaymentDetailsUpdate, PaymentDirection, PaymentStatus, PaymentStore, }; +use crate::io::{ + KVStore, TransactionalWrite, EVENT_QUEUE_PERSISTENCE_KEY, EVENT_QUEUE_PERSISTENCE_NAMESPACE, +}; use crate::logger::{log_error, log_info, Logger}; use lightning::chain::chaininterface::{BroadcasterInterface, ConfirmationTarget, FeeEstimator}; @@ -13,19 +17,15 @@ use lightning::util::errors::APIError; use lightning::util::events::Event as LdkEvent; use lightning::util::events::EventHandler as LdkEventHandler; use lightning::util::events::PaymentPurpose; -use lightning::util::persist::KVStorePersister; use lightning::util::ser::{Readable, ReadableArgs, Writeable, Writer}; use bitcoin::secp256k1::Secp256k1; use rand::{thread_rng, Rng}; -use std::collections::{hash_map, VecDeque}; +use std::collections::VecDeque; use std::ops::Deref; use std::sync::{Arc, Condvar, Mutex}; use std::time::Duration; -/// The event queue will be persisted under this key. -pub(crate) const EVENTS_PERSISTENCE_KEY: &str = "events"; - /// An event emitted by [`Node`], which should be handled by the user. /// /// [`Node`]: [`crate::Node`] @@ -85,30 +85,33 @@ impl_writeable_tlv_based_enum!(Event, }; ); -pub(crate) struct EventQueue +pub struct EventQueue where - K::Target: KVStorePersister, + K::Target: KVStore, + L::Target: Logger, { queue: Mutex>, notifier: Condvar, - persister: K, + kv_store: K, + logger: L, } -impl EventQueue +impl EventQueue where - K::Target: KVStorePersister, + K::Target: KVStore, + L::Target: Logger, { - pub(crate) fn new(persister: K) -> Self { + pub(crate) fn new(kv_store: K, logger: L) -> Self { let queue: Mutex> = Mutex::new(VecDeque::new()); let notifier = Condvar::new(); - Self { queue, notifier, persister } + Self { queue, notifier, kv_store, logger } } pub(crate) fn add_event(&self, event: Event) -> Result<(), Error> { { let mut locked_queue = self.queue.lock().unwrap(); locked_queue.push_back(event); - self.persist_queue(&locked_queue)?; + self.write_queue_and_commit(&locked_queue)?; } self.notifier.notify_one(); @@ -125,32 +128,64 @@ where { let mut locked_queue = self.queue.lock().unwrap(); locked_queue.pop_front(); - self.persist_queue(&locked_queue)?; + self.write_queue_and_commit(&locked_queue)?; } self.notifier.notify_one(); Ok(()) } - fn persist_queue(&self, locked_queue: &VecDeque) -> Result<(), Error> { - self.persister - .persist(EVENTS_PERSISTENCE_KEY, &EventQueueSerWrapper(locked_queue)) - .map_err(|_| Error::PersistenceFailed)?; + fn write_queue_and_commit(&self, locked_queue: &VecDeque) -> Result<(), Error> { + let mut writer = self + .kv_store + .write(EVENT_QUEUE_PERSISTENCE_NAMESPACE, EVENT_QUEUE_PERSISTENCE_KEY) + .map_err(|e| { + log_error!( + self.logger, + "Getting writer for key {}/{} failed due to: {}", + EVENT_QUEUE_PERSISTENCE_NAMESPACE, + EVENT_QUEUE_PERSISTENCE_KEY, + e + ); + Error::PersistenceFailed + })?; + EventQueueSerWrapper(locked_queue).write(&mut writer).map_err(|e| { + log_error!( + self.logger, + "Writing event queue data to key {}/{} failed due to: {}", + EVENT_QUEUE_PERSISTENCE_NAMESPACE, + EVENT_QUEUE_PERSISTENCE_KEY, + e + ); + Error::PersistenceFailed + })?; + writer.commit().map_err(|e| { + log_error!( + self.logger, + "Committing event queue data to key {}/{} failed due to: {}", + EVENT_QUEUE_PERSISTENCE_NAMESPACE, + EVENT_QUEUE_PERSISTENCE_KEY, + e + ); + Error::PersistenceFailed + })?; Ok(()) } } -impl ReadableArgs for EventQueue +impl ReadableArgs<(K, L)> for EventQueue where - K::Target: KVStorePersister, + K::Target: KVStore, + L::Target: Logger, { #[inline] fn read( - reader: &mut R, persister: K, + reader: &mut R, args: (K, L), ) -> Result { + let (kv_store, logger) = args; let read_queue: EventQueueDeserWrapper = Readable::read(reader)?; let queue: Mutex> = Mutex::new(read_queue.0); let notifier = Condvar::new(); - Ok(Self { queue, notifier, persister }) + Ok(Self { queue, notifier, kv_store, logger }) } } @@ -181,34 +216,32 @@ impl Writeable for EventQueueSerWrapper<'_> { } } -pub(crate) struct EventHandler +pub(crate) struct EventHandler where - K::Target: KVStorePersister, + K::Target: KVStore, L::Target: Logger, { wallet: Arc>, - event_queue: Arc>, + event_queue: Arc>, channel_manager: Arc, network_graph: Arc, keys_manager: Arc, - inbound_payments: Arc, - outbound_payments: Arc, + payment_store: Arc>, tokio_runtime: Arc, logger: L, _config: Arc, } -impl EventHandler +impl EventHandler where - K::Target: KVStorePersister, + K::Target: KVStore, L::Target: Logger, { pub fn new( - wallet: Arc>, event_queue: Arc>, + wallet: Arc>, event_queue: Arc>, channel_manager: Arc, network_graph: Arc, - keys_manager: Arc, inbound_payments: Arc, - outbound_payments: Arc, tokio_runtime: Arc, - logger: L, _config: Arc, + keys_manager: Arc, payment_store: Arc>, + tokio_runtime: Arc, logger: L, _config: Arc, ) -> Self { Self { event_queue, @@ -216,8 +249,7 @@ where channel_manager, network_graph, keys_manager, - inbound_payments, - outbound_payments, + payment_store, logger, tokio_runtime, _config, @@ -225,9 +257,9 @@ where } } -impl LdkEventHandler for EventHandler +impl LdkEventHandler for EventHandler where - K::Target: KVStorePersister, + K::Target: KVStore, L::Target: Logger, { fn handle_event(&self, event: LdkEvent) { @@ -298,9 +330,28 @@ where via_channel_id: _, via_user_channel_id: _, } => { + if let Some(info) = self.payment_store.get(&payment_hash) { + if info.status == PaymentStatus::Succeeded { + log_info!( + self.logger, + "Refused duplicate inbound payment from payment hash {} of {}msat", + hex_utils::to_string(&payment_hash.0), + amount_msat, + ); + self.channel_manager.fail_htlc_backwards(&payment_hash); + + let update = PaymentDetailsUpdate { + status: Some(PaymentStatus::Failed), + ..PaymentDetailsUpdate::new(payment_hash) + }; + self.payment_store.update(&update).expect("Failed to access payment store"); + return; + } + } + log_info!( self.logger, - "Received payment from payment hash {} of {} msats", + "Received payment from payment hash {} of {}msat", hex_utils::to_string(&payment_hash.0), amount_msat, ); @@ -326,7 +377,12 @@ where hex_utils::to_string(&payment_hash.0), ); self.channel_manager.fail_htlc_backwards(&payment_hash); - self.inbound_payments.lock().unwrap().remove(&payment_hash); + + let update = PaymentDetailsUpdate { + status: Some(PaymentStatus::Failed), + ..PaymentDetailsUpdate::new(payment_hash) + }; + self.payment_store.update(&update).expect("Failed to access payment store"); } } LdkEvent::PaymentClaimed { @@ -337,59 +393,97 @@ where } => { log_info!( self.logger, - "Claimed payment from payment hash {} of {} msats.", + "Claimed payment from payment hash {} of {}msat.", hex_utils::to_string(&payment_hash.0), amount_msat, ); - let (payment_preimage, payment_secret) = match purpose { + match purpose { PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => { - (payment_preimage, Some(payment_secret)) - } - PaymentPurpose::SpontaneousPayment(preimage) => (Some(preimage), None), - }; - let mut payments = self.inbound_payments.lock().unwrap(); - match payments.entry(payment_hash) { - hash_map::Entry::Occupied(mut e) => { - let payment = e.get_mut(); - payment.status = PaymentStatus::Succeeded; - payment.preimage = payment_preimage; - payment.secret = payment_secret; - payment.amount_msat = Some(amount_msat); + let update = PaymentDetailsUpdate { + preimage: Some(payment_preimage), + secret: Some(Some(payment_secret)), + amount_msat: Some(Some(amount_msat)), + status: Some(PaymentStatus::Succeeded), + ..PaymentDetailsUpdate::new(payment_hash) + }; + match self.payment_store.update(&update) { + Ok(true) => (), + Ok(false) => { + log_error!( + self.logger, + "Payment with hash {} couldn't be found in store", + hex_utils::to_string(&payment_hash.0) + ); + debug_assert!(false); + } + Err(e) => { + log_error!( + self.logger, + "Failed to update payment with hash {}: {}", + hex_utils::to_string(&payment_hash.0), + e + ); + debug_assert!(false); + } + } } - hash_map::Entry::Vacant(e) => { - e.insert(PaymentInfo { - preimage: payment_preimage, - secret: payment_secret, - status: PaymentStatus::Succeeded, + PaymentPurpose::SpontaneousPayment(preimage) => { + let payment = PaymentDetails { + preimage: Some(preimage), + hash: payment_hash, + secret: None, amount_msat: Some(amount_msat), - }); + direction: PaymentDirection::Inbound, + status: PaymentStatus::Succeeded, + }; + + match self.payment_store.insert(payment) { + Ok(false) => (), + Ok(true) => { + log_error!( + self.logger, + "Spontaneous payment with hash {} was previosly known", + hex_utils::to_string(&payment_hash.0) + ); + debug_assert!(false); + } + Err(e) => { + log_error!( + self.logger, + "Failed to insert payment with hash {}: {}", + hex_utils::to_string(&payment_hash.0), + e + ); + debug_assert!(false); + } + } } - } + }; + self.event_queue .add_event(Event::PaymentReceived { payment_hash, amount_msat }) .expect("Failed to push to event queue"); } LdkEvent::PaymentSent { payment_preimage, payment_hash, fee_paid_msat, .. } => { - let mut payments = self.outbound_payments.lock().unwrap(); - for (hash, payment) in payments.iter_mut() { - if *hash == payment_hash { - payment.preimage = Some(payment_preimage); - payment.status = PaymentStatus::Succeeded; - log_info!( - self.logger, - "Successfully sent payment of {} msats{} from \ - payment hash {:?} with preimage {:?}", - payment.amount_msat.unwrap(), - if let Some(fee) = fee_paid_msat { - format!(" (fee {} msats)", fee) - } else { - "".to_string() - }, - hex_utils::to_string(&payment_hash.0), - hex_utils::to_string(&payment_preimage.0) - ); - break; - } + if let Some(mut payment) = self.payment_store.get(&payment_hash) { + payment.preimage = Some(payment_preimage); + payment.status = PaymentStatus::Succeeded; + self.payment_store + .insert(payment.clone()) + .expect("Failed to access payment store"); + log_info!( + self.logger, + "Successfully sent payment of {}msat{} from \ + payment hash {:?} with preimage {:?}", + payment.amount_msat.unwrap(), + if let Some(fee) = fee_paid_msat { + format!(" (fee {} msat)", fee) + } else { + "".to_string() + }, + hex_utils::to_string(&payment_hash.0), + hex_utils::to_string(&payment_preimage.0) + ); } self.event_queue .add_event(Event::PaymentSuccessful { payment_hash }) @@ -402,12 +496,11 @@ where hex_utils::to_string(&payment_hash.0) ); - let mut payments = self.outbound_payments.lock().unwrap(); - if payments.contains_key(&payment_hash) { - let payment = payments.get_mut(&payment_hash).unwrap(); - assert_eq!(payment.status, PaymentStatus::Pending); - payment.status = PaymentStatus::Failed; - } + let update = PaymentDetailsUpdate { + status: Some(PaymentStatus::Failed), + ..PaymentDetailsUpdate::new(payment_hash) + }; + self.payment_store.update(&update).expect("Failed to access payment store"); self.event_queue .add_event(Event::PaymentFailed { payment_hash }) .expect("Failed to push to event queue"); @@ -493,7 +586,7 @@ where if claim_from_onchain_tx { log_info!( self.logger, - "Forwarded payment{}{}, earning {} msats in fees from claiming onchain.", + "Forwarded payment{}{}, earning {}msat in fees from claiming onchain.", from_prev_str, to_next_str, fee_earned, @@ -501,7 +594,7 @@ where } else { log_info!( self.logger, - "Forwarded payment{}{}, earning {} msats in fees.", + "Forwarded payment{}{}, earning {}msat in fees.", from_prev_str, to_next_str, fee_earned, @@ -541,33 +634,36 @@ where #[cfg(test)] mod tests { use super::*; - use crate::tests::test_utils::TestPersister; + use crate::test::utils::{TestLogger, TestStore}; #[test] fn event_queue_persistence() { - let persister = Arc::new(TestPersister::new()); - let event_queue = EventQueue::new(Arc::clone(&persister)); + let store = Arc::new(TestStore::new()); + let logger = Arc::new(TestLogger::new()); + let event_queue = EventQueue::new(Arc::clone(&store), Arc::clone(&logger)); let expected_event = Event::ChannelReady { channel_id: [23u8; 32], user_channel_id: 2323 }; event_queue.add_event(expected_event.clone()).unwrap(); - assert!(persister.get_and_clear_did_persist()); + assert!(store.get_and_clear_did_persist()); // Check we get the expected event and that it is returned until we mark it handled. for _ in 0..5 { assert_eq!(event_queue.next_event(), expected_event); - assert_eq!(false, persister.get_and_clear_did_persist()); + assert_eq!(false, store.get_and_clear_did_persist()); } // Check we can read back what we persisted. - let persisted_bytes = persister.get_persisted_bytes(EVENTS_PERSISTENCE_KEY).unwrap(); + let persisted_bytes = store + .get_persisted_bytes(EVENT_QUEUE_PERSISTENCE_NAMESPACE, EVENT_QUEUE_PERSISTENCE_KEY) + .unwrap(); let deser_event_queue = - EventQueue::read(&mut &persisted_bytes[..], Arc::clone(&persister)).unwrap(); + EventQueue::read(&mut &persisted_bytes[..], (Arc::clone(&store), logger)).unwrap(); assert_eq!(deser_event_queue.next_event(), expected_event); - assert!(!persister.get_and_clear_did_persist()); + assert!(!store.get_and_clear_did_persist()); // Check we persisted on `event_handled()` event_queue.event_handled().unwrap(); - assert!(persister.get_and_clear_did_persist()); + assert!(store.get_and_clear_did_persist()); } } diff --git a/src/io/fs_store.rs b/src/io/fs_store.rs new file mode 100644 index 000000000..0925ae301 --- /dev/null +++ b/src/io/fs_store.rs @@ -0,0 +1,343 @@ +#[cfg(target_os = "windows")] +extern crate winapi; + +use super::{KVStore, TransactionalWrite}; + +use std::collections::HashMap; +use std::fs; +use std::io::{BufReader, BufWriter, Read, Write}; +use std::path::{Path, PathBuf}; +use std::str::FromStr; +use std::sync::{Arc, Mutex, RwLock}; + +#[cfg(not(target_os = "windows"))] +use std::os::unix::io::AsRawFd; + +use lightning::util::persist::KVStorePersister; +use lightning::util::ser::Writeable; + +use rand::distributions::Alphanumeric; +use rand::{thread_rng, Rng}; + +#[cfg(target_os = "windows")] +use {std::ffi::OsStr, std::os::windows::ffi::OsStrExt}; + +#[cfg(target_os = "windows")] +macro_rules! call { + ($e: expr) => { + if $e != 0 { + return Ok(()); + } else { + return Err(std::io::Error::last_os_error()); + } + }; +} + +#[cfg(target_os = "windows")] +fn path_to_windows_str>(path: T) -> Vec { + path.as_ref().encode_wide().chain(Some(0)).collect() +} + +pub struct FilesystemStore { + dest_dir: PathBuf, + locks: Mutex>>>, +} + +impl FilesystemStore { + pub fn new(dest_dir: PathBuf) -> Self { + let locks = Mutex::new(HashMap::new()); + Self { dest_dir, locks } + } +} + +impl KVStore for FilesystemStore { + type Reader = FilesystemReader; + type Writer = FilesystemWriter; + + fn read(&self, namespace: &str, key: &str) -> std::io::Result { + let mut outer_lock = self.locks.lock().unwrap(); + let lock_key = (namespace.to_string(), key.to_string()); + let inner_lock_ref = Arc::clone(&outer_lock.entry(lock_key).or_default()); + + let mut dest_file = self.dest_dir.clone(); + dest_file.push(namespace); + dest_file.push(key); + FilesystemReader::new(dest_file, inner_lock_ref) + } + + fn write(&self, namespace: &str, key: &str) -> std::io::Result { + let mut outer_lock = self.locks.lock().unwrap(); + let lock_key = (namespace.to_string(), key.to_string()); + let inner_lock_ref = Arc::clone(&outer_lock.entry(lock_key).or_default()); + + let mut dest_file = self.dest_dir.clone(); + dest_file.push(namespace); + dest_file.push(key); + FilesystemWriter::new(dest_file, inner_lock_ref) + } + + fn remove(&self, namespace: &str, key: &str) -> std::io::Result { + let mut outer_lock = self.locks.lock().unwrap(); + let lock_key = (namespace.to_string(), key.to_string()); + let inner_lock_ref = Arc::clone(&outer_lock.entry(lock_key.clone()).or_default()); + + let _guard = inner_lock_ref.write().unwrap(); + + let mut dest_file = self.dest_dir.clone(); + dest_file.push(namespace); + dest_file.push(key); + + if !dest_file.is_file() { + return Ok(false); + } + + fs::remove_file(&dest_file)?; + #[cfg(not(target_os = "windows"))] + { + let msg = format!("Could not retrieve parent directory of {}.", dest_file.display()); + let parent_directory = dest_file + .parent() + .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidInput, msg))?; + let dir_file = fs::OpenOptions::new().read(true).open(parent_directory)?; + unsafe { + // The above call to `fs::remove_file` corresponds to POSIX `unlink`, whose changes + // to the inode might get cached (and hence possibly lost on crash), depending on + // the target platform and file system. + // + // In order to assert we permanently removed the file in question we therefore + // call `fsync` on the parent directory on platforms that support it, + libc::fsync(dir_file.as_raw_fd()); + } + } + + if dest_file.is_file() { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "Removing key failed")); + } + + if Arc::strong_count(&inner_lock_ref) == 2 { + // It's safe to remove the lock entry if we're the only one left holding a strong + // reference. Checking this is necessary to ensure we continue to distribute references to the + // same lock as long as some Writers/Readers are around. However, we still want to + // clean up the table when possible. + // + // Note that this by itself is still leaky as lock entries will remain when more Readers/Writers are + // around, but is preferable to doing nothing *or* something overly complex such as + // implementing yet another RAII structure just for this pupose. + outer_lock.remove(&lock_key); + } + + // Garbage collect all lock entries that are not referenced anymore. + outer_lock.retain(|_, v| Arc::strong_count(&v) > 1); + + Ok(true) + } + + fn list(&self, namespace: &str) -> std::io::Result> { + let mut prefixed_dest = self.dest_dir.clone(); + prefixed_dest.push(namespace); + + let mut keys = Vec::new(); + + if !Path::new(&prefixed_dest).exists() { + return Ok(Vec::new()); + } + + for entry in fs::read_dir(prefixed_dest.clone())? { + let entry = entry?; + let p = entry.path(); + + if !p.is_file() { + continue; + } + + if let Some(ext) = p.extension() { + if ext == "tmp" { + continue; + } + } + + if let Ok(relative_path) = p.strip_prefix(prefixed_dest.clone()) { + keys.push(relative_path.display().to_string()) + } + } + + Ok(keys) + } +} + +pub struct FilesystemReader { + inner: BufReader, + lock_ref: Arc>, +} + +impl FilesystemReader { + fn new(dest_file: PathBuf, lock_ref: Arc>) -> std::io::Result { + let f = fs::File::open(dest_file.clone())?; + let inner = BufReader::new(f); + Ok(Self { inner, lock_ref }) + } +} + +impl Read for FilesystemReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let _guard = self.lock_ref.read().unwrap(); + self.inner.read(buf) + } +} + +pub struct FilesystemWriter { + dest_file: PathBuf, + parent_directory: PathBuf, + tmp_file: PathBuf, + tmp_writer: BufWriter, + lock_ref: Arc>, +} + +impl FilesystemWriter { + fn new(dest_file: PathBuf, lock_ref: Arc>) -> std::io::Result { + let msg = format!("Could not retrieve parent directory of {}.", dest_file.display()); + let parent_directory = dest_file + .parent() + .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidInput, msg))? + .to_path_buf(); + fs::create_dir_all(parent_directory.clone())?; + + // Do a crazy dance with lots of fsync()s to be overly cautious here... + // We never want to end up in a state where we've lost the old data, or end up using the + // old data on power loss after we've returned. + // The way to atomically write a file on Unix platforms is: + // open(tmpname), write(tmpfile), fsync(tmpfile), close(tmpfile), rename(), fsync(dir) + let mut tmp_file = dest_file.clone(); + let mut rng = thread_rng(); + let rand_str: String = (0..7).map(|_| rng.sample(Alphanumeric) as char).collect(); + let ext = format!("{}.tmp", rand_str); + tmp_file.set_extension(ext); + + let tmp_writer = BufWriter::new(fs::File::create(&tmp_file)?); + + Ok(Self { dest_file, parent_directory, tmp_file, tmp_writer, lock_ref }) + } +} + +impl Write for FilesystemWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + Ok(self.tmp_writer.write(buf)?) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.tmp_writer.flush()?; + self.tmp_writer.get_ref().sync_all()?; + Ok(()) + } +} + +impl TransactionalWrite for FilesystemWriter { + fn commit(&mut self) -> std::io::Result<()> { + self.flush()?; + + let _guard = self.lock_ref.write().unwrap(); + // Fsync the parent directory on Unix. + #[cfg(not(target_os = "windows"))] + { + fs::rename(&self.tmp_file, &self.dest_file)?; + let dir_file = fs::OpenOptions::new().read(true).open(self.parent_directory.clone())?; + unsafe { + libc::fsync(dir_file.as_raw_fd()); + } + } + + #[cfg(target_os = "windows")] + { + if dest_file.exists() { + unsafe { + winapi::um::winbase::ReplaceFileW( + path_to_windows_str(dest_file).as_ptr(), + path_to_windows_str(tmp_file).as_ptr(), + std::ptr::null(), + winapi::um::winbase::REPLACEFILE_IGNORE_MERGE_ERRORS, + std::ptr::null_mut() as *mut winapi::ctypes::c_void, + std::ptr::null_mut() as *mut winapi::ctypes::c_void, + ) + }; + } else { + call!(unsafe { + winapi::um::winbase::MoveFileExW( + path_to_windows_str(tmp_file).as_ptr(), + path_to_windows_str(dest_file).as_ptr(), + winapi::um::winbase::MOVEFILE_WRITE_THROUGH + | winapi::um::winbase::MOVEFILE_REPLACE_EXISTING, + ) + }); + } + } + Ok(()) + } +} + +impl KVStorePersister for FilesystemStore { + fn persist(&self, prefixed_key: &str, object: &W) -> lightning::io::Result<()> { + let msg = format!("Could not persist file for key {}.", prefixed_key); + let dest_file = PathBuf::from_str(prefixed_key).map_err(|_| { + lightning::io::Error::new(lightning::io::ErrorKind::InvalidInput, msg.clone()) + })?; + + let parent_directory = dest_file.parent().ok_or(lightning::io::Error::new( + lightning::io::ErrorKind::InvalidInput, + msg.clone(), + ))?; + let namespace = parent_directory.display().to_string(); + + let dest_without_namespace = dest_file + .strip_prefix(&namespace) + .map_err(|_| lightning::io::Error::new(lightning::io::ErrorKind::InvalidInput, msg))?; + let key = dest_without_namespace.display().to_string(); + let mut writer = self.write(&namespace, &key)?; + object.write(&mut writer)?; + Ok(writer.commit()?) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::utils::random_storage_path; + use lightning::util::persist::KVStorePersister; + use lightning::util::ser::{Readable, Writeable}; + + use proptest::prelude::*; + proptest! { + #[test] + fn read_write_remove_list_persist(data in any::<[u8; 32]>()) { + let rand_dir = random_storage_path(); + + let fs_store = FilesystemStore::new(rand_dir.into()); + let namespace = "testspace"; + let key = "testkey"; + + // Test the basic KVStore operations. + let mut writer = fs_store.write(namespace, key).unwrap(); + data.write(&mut writer).unwrap(); + writer.commit().unwrap(); + + let listed_keys = fs_store.list(namespace).unwrap(); + assert_eq!(listed_keys.len(), 1); + assert_eq!(listed_keys[0], "testkey"); + + let mut reader = fs_store.read(namespace, key).unwrap(); + let read_data: [u8; 32] = Readable::read(&mut reader).unwrap(); + assert_eq!(data, read_data); + + fs_store.remove(namespace, key).unwrap(); + + let listed_keys = fs_store.list(namespace).unwrap(); + assert_eq!(listed_keys.len(), 0); + + // Test KVStorePersister + let prefixed_key = format!("{}/{}", namespace, key); + fs_store.persist(&prefixed_key, &data).unwrap(); + let mut reader = fs_store.read(namespace, key).unwrap(); + let read_data: [u8; 32] = Readable::read(&mut reader).unwrap(); + assert_eq!(data, read_data); + } + } +} diff --git a/src/io/mod.rs b/src/io/mod.rs new file mode 100644 index 000000000..1c3dd3c69 --- /dev/null +++ b/src/io/mod.rs @@ -0,0 +1,73 @@ +pub(crate) mod fs_store; +pub(crate) mod utils; + +use std::io::{Read, Write}; + +// The namespacs and keys LDK uses for persisting +pub(crate) const CHANNEL_MANAGER_PERSISTENCE_NAMESPACE: &str = ""; +pub(crate) const CHANNEL_MANAGER_PERSISTENCE_KEY: &str = "manager"; + +pub(crate) const CHANNEL_MONITOR_PERSISTENCE_NAMESPACE: &str = "monitors"; + +pub(crate) const NETWORK_GRAPH_PERSISTENCE_NAMESPACE: &str = ""; +pub(crate) const NETWORK_GRAPH_PERSISTENCE_KEY: &str = "network_graph"; + +pub(crate) const SCORER_PERSISTENCE_NAMESPACE: &str = ""; +pub(crate) const SCORER_PERSISTENCE_KEY: &str = "scorer"; + +/// The event queue will be persisted under this key. +pub(crate) const EVENT_QUEUE_PERSISTENCE_NAMESPACE: &str = ""; +pub(crate) const EVENT_QUEUE_PERSISTENCE_KEY: &str = "events"; + +/// The peer information will be persisted under this key. +pub(crate) const PEER_INFO_PERSISTENCE_NAMESPACE: &str = ""; +pub(crate) const PEER_INFO_PERSISTENCE_KEY: &str = "peers"; + +/// The payment information will be persisted under this prefix. +pub(crate) const PAYMENT_INFO_PERSISTENCE_NAMESPACE: &str = "payments"; + +/// Provides an interface that allows to store and retrieve persisted values that are associated +/// with given keys. +/// +/// In order to avoid collisions the key space is segmented based on the given `namespace`s. +/// Implementations of this trait are free to handle them in different ways, as long as +/// per-namespace key uniqueness is asserted. +/// +/// Keys and namespaces are required to be valid ASCII strings and the empty namespace (`""`) is +/// assumed to be valid namespace. +pub trait KVStore { + type Reader: Read; + type Writer: TransactionalWrite; + /// Returns a [`Read`] for the given `namespace` and `key` from which [`Readable`]s may be + /// read. + /// + /// Returns an `Err` if the given `key` could not be found in the given `namespace`. + /// + /// [`Readable`]: lightning::util::ser::Readable + fn read(&self, namespace: &str, key: &str) -> std::io::Result; + /// Returns a [`TransactionalWrite`] for the given `key` to which [`Writeable`]s may be written. + /// + /// Will create the given `namespace` if not already present in the store. + /// + /// Note that [`TransactionalWrite::commit`] MUST be called to commit the written data, otherwise + /// the changes won't be persisted. + /// + /// [`Writeable`]: lightning::util::ser::Writeable + fn write(&self, namespace: &str, key: &str) -> std::io::Result; + /// Removes any data that had previously been persisted under the given `key`. + /// + /// Returns `true` if the `key` was present in the given `namespace`, and `false` otherwise. + fn remove(&self, namespace: &str, key: &str) -> std::io::Result; + /// Returns a list of keys that are stored under the given `namespace`. + /// + /// Will return an empty list if the `namespace` is unknown. + fn list(&self, namespace: &str) -> std::io::Result>; +} + +/// A [`Write`] asserting data consistency. +/// +/// Note that any changes need to be `commit`ed for them to take effect, and are lost otherwise. +pub trait TransactionalWrite: Write { + /// Persist the previously made changes. + fn commit(&mut self) -> std::io::Result<()>; +} diff --git a/src/io/utils.rs b/src/io/utils.rs new file mode 100644 index 000000000..00baf77eb --- /dev/null +++ b/src/io/utils.rs @@ -0,0 +1,174 @@ +use super::*; +use crate::WALLET_KEYS_SEED_LEN; + +use crate::peer_store::PeerStore; +use crate::{EventQueue, PaymentDetails}; + +use lightning::chain::channelmonitor::ChannelMonitor; +use lightning::chain::keysinterface::{EntropySource, SignerProvider}; +use lightning::routing::gossip::NetworkGraph; +use lightning::routing::scoring::{ProbabilisticScorer, ProbabilisticScoringParameters}; +use lightning::util::logger::Logger; +use lightning::util::ser::{Readable, ReadableArgs}; + +use bitcoin::hash_types::{BlockHash, Txid}; +use bitcoin::hashes::hex::FromHex; +use rand::{thread_rng, RngCore}; + +use std::fs; +use std::io::Write; +use std::ops::Deref; +use std::path::Path; + +use super::KVStore; + +pub(crate) fn read_or_generate_seed_file(keys_seed_path: &str) -> [u8; WALLET_KEYS_SEED_LEN] { + if Path::new(&keys_seed_path).exists() { + let seed = fs::read(keys_seed_path).expect("Failed to read keys seed file"); + assert_eq!( + seed.len(), + WALLET_KEYS_SEED_LEN, + "Failed to read keys seed file: unexpected length" + ); + let mut key = [0; WALLET_KEYS_SEED_LEN]; + key.copy_from_slice(&seed); + key + } else { + let mut key = [0; WALLET_KEYS_SEED_LEN]; + thread_rng().fill_bytes(&mut key); + + let mut f = fs::File::create(keys_seed_path).expect("Failed to create keys seed file"); + f.write_all(&key).expect("Failed to write node keys seed to disk"); + f.sync_all().expect("Failed to sync node keys seed to disk"); + key + } +} + +/// Read previously persisted [`ChannelMonitor`]s from the store. +pub(crate) fn read_channel_monitors( + kv_store: K, entropy_source: ES, signer_provider: SP, +) -> std::io::Result::Signer>)>> +where + K::Target: KVStore, + ES::Target: EntropySource + Sized, + SP::Target: SignerProvider + Sized, +{ + let mut res = Vec::new(); + + for stored_key in kv_store.list(CHANNEL_MONITOR_PERSISTENCE_NAMESPACE)? { + let txid = Txid::from_hex(stored_key.split_at(64).0).map_err(|_| { + std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid tx ID in stored key") + })?; + + let index: u16 = stored_key.split_at(65).1.parse().map_err(|_| { + std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid tx index in stored key") + })?; + + match <(BlockHash, ChannelMonitor<::Signer>)>::read( + &mut kv_store.read(CHANNEL_MONITOR_PERSISTENCE_NAMESPACE, &stored_key)?, + (&*entropy_source, &*signer_provider), + ) { + Ok((block_hash, channel_monitor)) => { + if channel_monitor.get_funding_txo().0.txid != txid + || channel_monitor.get_funding_txo().0.index != index + { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "ChannelMonitor was stored under the wrong key", + )); + } + res.push((block_hash, channel_monitor)); + } + Err(e) => { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Failed to deserialize ChannelMonitor: {}", e), + )) + } + } + } + Ok(res) +} + +/// Read a previously persisted [`NetworkGraph`] from the store. +pub(crate) fn read_network_graph( + kv_store: K, logger: L, +) -> Result, std::io::Error> +where + K::Target: KVStore, + L::Target: Logger, +{ + let mut reader = + kv_store.read(NETWORK_GRAPH_PERSISTENCE_NAMESPACE, NETWORK_GRAPH_PERSISTENCE_KEY)?; + let graph = NetworkGraph::read(&mut reader, logger).map_err(|_| { + std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to deserialize NetworkGraph") + })?; + Ok(graph) +} + +/// Read a previously persisted [`Scorer`] from the store. +pub(crate) fn read_scorer>, L: Deref>( + kv_store: K, network_graph: G, logger: L, +) -> Result, std::io::Error> +where + K::Target: KVStore, + L::Target: Logger, +{ + let params = ProbabilisticScoringParameters::default(); + let mut reader = kv_store.read(SCORER_PERSISTENCE_NAMESPACE, SCORER_PERSISTENCE_KEY)?; + let args = (params, network_graph, logger); + let scorer = ProbabilisticScorer::read(&mut reader, args).map_err(|_| { + std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to deserialize Scorer") + })?; + Ok(scorer) +} + +/// Read previously persisted events from the store. +pub(crate) fn read_event_queue( + kv_store: K, logger: L, +) -> Result, std::io::Error> +where + K::Target: KVStore, + L::Target: Logger, +{ + let mut reader = + kv_store.read(EVENT_QUEUE_PERSISTENCE_NAMESPACE, EVENT_QUEUE_PERSISTENCE_KEY)?; + let event_queue = EventQueue::read(&mut reader, (kv_store, logger)).map_err(|_| { + std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to deserialize EventQueue") + })?; + Ok(event_queue) +} + +/// Read previously persisted peer info from the store. +pub(crate) fn read_peer_info( + kv_store: K, logger: L, +) -> Result, std::io::Error> +where + K::Target: KVStore, + L::Target: Logger, +{ + let mut reader = kv_store.read(PEER_INFO_PERSISTENCE_NAMESPACE, PEER_INFO_PERSISTENCE_KEY)?; + let peer_info = PeerStore::read(&mut reader, (kv_store, logger)).map_err(|_| { + std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to deserialize PeerStore") + })?; + Ok(peer_info) +} + +/// Read previously persisted payments information from the store. +pub(crate) fn read_payments(kv_store: K) -> Result, std::io::Error> +where + K::Target: KVStore, +{ + let mut res = Vec::new(); + + for stored_key in kv_store.list(PAYMENT_INFO_PERSISTENCE_NAMESPACE)? { + let payment = PaymentDetails::read( + &mut kv_store.read(PAYMENT_INFO_PERSISTENCE_NAMESPACE, &stored_key)?, + ) + .map_err(|_| { + std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to deserialize Payment") + })?; + res.push(payment); + } + Ok(res) +} diff --git a/src/io_utils.rs b/src/io_utils.rs deleted file mode 100644 index 6453aac75..000000000 --- a/src/io_utils.rs +++ /dev/null @@ -1,62 +0,0 @@ -use crate::{Config, FilesystemLogger, NetworkGraph, Scorer, WALLET_KEYS_SEED_LEN}; - -use lightning::routing::scoring::{ProbabilisticScorer, ProbabilisticScoringParameters}; -use lightning::util::ser::ReadableArgs; - -use rand::{thread_rng, RngCore}; - -use std::fs; -use std::io::{BufReader, Write}; -use std::path::Path; -use std::sync::Arc; - -pub(crate) fn read_or_generate_seed_file(keys_seed_path: &str) -> [u8; WALLET_KEYS_SEED_LEN] { - if Path::new(&keys_seed_path).exists() { - let seed = fs::read(keys_seed_path).expect("Failed to read keys seed file"); - assert_eq!( - seed.len(), - WALLET_KEYS_SEED_LEN, - "Failed to read keys seed file: unexpected length" - ); - let mut key = [0; WALLET_KEYS_SEED_LEN]; - key.copy_from_slice(&seed); - key - } else { - let mut key = [0; WALLET_KEYS_SEED_LEN]; - thread_rng().fill_bytes(&mut key); - - let mut f = fs::File::create(keys_seed_path).expect("Failed to create keys seed file"); - f.write_all(&key).expect("Failed to write node keys seed to disk"); - f.sync_all().expect("Failed to sync node keys seed to disk"); - key - } -} - -pub(crate) fn read_network_graph(config: &Config, logger: Arc) -> NetworkGraph { - let ldk_data_dir = format!("{}/ldk", config.storage_dir_path); - let network_graph_path = format!("{}/network_graph", ldk_data_dir); - - if let Ok(file) = fs::File::open(network_graph_path) { - if let Ok(graph) = NetworkGraph::read(&mut BufReader::new(file), Arc::clone(&logger)) { - return graph; - } - } - - NetworkGraph::new(config.network, logger) -} - -pub(crate) fn read_scorer( - config: &Config, network_graph: Arc, logger: Arc, -) -> Scorer { - let ldk_data_dir = format!("{}/ldk", config.storage_dir_path); - let scorer_path = format!("{}/scorer", ldk_data_dir); - - let params = ProbabilisticScoringParameters::default(); - if let Ok(file) = fs::File::open(scorer_path) { - let args = (params.clone(), Arc::clone(&network_graph), Arc::clone(&logger)); - if let Ok(scorer) = ProbabilisticScorer::read(&mut BufReader::new(file), args) { - return scorer; - } - } - ProbabilisticScorer::new(params, network_graph, logger) -} diff --git a/src/lib.rs b/src/lib.rs index 8f8b783ec..99672a038 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -69,11 +69,12 @@ mod error; mod event; mod hex_utils; -mod io_utils; +mod io; mod logger; +mod payment_store; mod peer_store; #[cfg(test)] -mod tests; +mod test; mod types; mod wallet; @@ -84,12 +85,15 @@ pub use lightning_invoice; pub use error::Error; pub use event::Event; use event::{EventHandler, EventQueue}; -use peer_store::{PeerInfo, PeerInfoStorage}; +use io::fs_store::FilesystemStore; +use io::{KVStore, CHANNEL_MANAGER_PERSISTENCE_KEY, CHANNEL_MANAGER_PERSISTENCE_NAMESPACE}; +use payment_store::PaymentStore; +pub use payment_store::{PaymentDetails, PaymentDirection, PaymentStatus}; +use peer_store::{PeerInfo, PeerStore}; use types::{ ChainMonitor, ChannelManager, GossipSync, KeysManager, NetworkGraph, OnionMessenger, - PaymentInfoStorage, PeerManager, Scorer, + PeerManager, Scorer, }; -pub use types::{PaymentInfo, PaymentStatus}; use wallet::Wallet; use logger::{log_error, log_info, FilesystemLogger, Logger}; @@ -103,6 +107,7 @@ use lightning::ln::channelmanager::{ use lightning::ln::peer_handler::{IgnoringMessageHandler, MessageHandler}; use lightning::ln::{PaymentHash, PaymentPreimage}; use lightning::routing::gossip::P2PGossipSync; +use lightning::routing::scoring::{ProbabilisticScorer, ProbabilisticScoringParameters}; use lightning::routing::utxo::UtxoLookup; use lightning::util::config::{ChannelHandshakeConfig, ChannelHandshakeLimits, UserConfig}; @@ -110,7 +115,6 @@ use lightning::util::ser::ReadableArgs; use lightning_background_processor::BackgroundProcessor; use lightning_background_processor::GossipSync as BPGossipSync; -use lightning_persister::FilesystemPersister; use lightning_transaction_sync::EsploraSyncClient; @@ -129,7 +133,6 @@ use bitcoin::BlockHash; use rand::Rng; -use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::default::Default; use std::fs; @@ -286,13 +289,13 @@ impl Builder { match entropy_source { WalletEntropySource::SeedBytes(bytes) => bytes.clone(), WalletEntropySource::SeedFile(seed_path) => { - io_utils::read_or_generate_seed_file(seed_path) + io::utils::read_or_generate_seed_file(seed_path) } } } else { // Default to read or generate from the default location generate a seed file. let seed_path = format!("{}/keys_seed", config.storage_dir_path); - io_utils::read_or_generate_seed_file(&seed_path) + io::utils::read_or_generate_seed_file(&seed_path) }; let xprv = bitcoin::util::bip32::ExtendedPrivKey::new_master(config.network, &seed_bytes) @@ -328,8 +331,7 @@ impl Builder { let wallet = Arc::new(Wallet::new(blockchain, bdk_wallet, Arc::clone(&logger))); - // Initialize Persist - let persister = Arc::new(FilesystemPersister::new(ldk_data_dir.clone())); + let kv_store = Arc::new(FilesystemStore::new(ldk_data_dir.clone().into())); // Initialize the ChainMonitor let chain_monitor: Arc = Arc::new(chainmonitor::ChainMonitor::new( @@ -337,7 +339,7 @@ impl Builder { Arc::clone(&wallet), Arc::clone(&logger), Arc::clone(&wallet), - Arc::clone(&persister), + Arc::clone(&kv_store), )); // Initialize the KeysManager @@ -354,12 +356,38 @@ impl Builder { // Initialize the network graph, scorer, and router let network_graph = - Arc::new(io_utils::read_network_graph(config.as_ref(), Arc::clone(&logger))); - let scorer = Arc::new(Mutex::new(io_utils::read_scorer( - config.as_ref(), + match io::utils::read_network_graph(Arc::clone(&kv_store), Arc::clone(&logger)) { + Ok(graph) => Arc::new(graph), + Err(e) => { + if e.kind() == std::io::ErrorKind::NotFound { + Arc::new(NetworkGraph::new(config.network, Arc::clone(&logger))) + } else { + log_error!(logger, "Failed to read network graph: {}", e.to_string()); + panic!("Failed to read network graph: {}", e.to_string()); + } + } + }; + + let scorer = match io::utils::read_scorer( + Arc::clone(&kv_store), Arc::clone(&network_graph), Arc::clone(&logger), - ))); + ) { + Ok(scorer) => Arc::new(Mutex::new(scorer)), + Err(e) => { + if e.kind() == std::io::ErrorKind::NotFound { + let params = ProbabilisticScoringParameters::default(); + Arc::new(Mutex::new(ProbabilisticScorer::new( + params, + Arc::clone(&network_graph), + Arc::clone(&logger), + ))) + } else { + log_error!(logger, "Failed to read scorer: {}", e.to_string()); + panic!("Failed to read scorer: {}", e.to_string()); + } + } + }; let router = Arc::new(DefaultRouter::new( Arc::clone(&network_graph), @@ -368,16 +396,26 @@ impl Builder { Arc::clone(&scorer), )); - // Read ChannelMonitor state from disk - let mut channel_monitors = persister - .read_channelmonitors(Arc::clone(&keys_manager), Arc::clone(&keys_manager)) - .expect("Failed to read channel monitors from disk"); + // Read ChannelMonitor state from store + let mut channel_monitors = match io::utils::read_channel_monitors( + Arc::clone(&kv_store), + Arc::clone(&keys_manager), + Arc::clone(&keys_manager), + ) { + Ok(monitors) => monitors, + Err(e) => { + log_error!(logger, "Failed to read channel monitors: {}", e.to_string()); + panic!("Failed to read channel monitors: {}", e.to_string()); + } + }; // Initialize the ChannelManager let mut user_config = UserConfig::default(); user_config.channel_handshake_limits.force_announced_channel_preference = false; let channel_manager = { - if let Ok(mut f) = fs::File::open(format!("{}/manager", ldk_data_dir)) { + if let Ok(mut reader) = kv_store + .read(CHANNEL_MANAGER_PERSISTENCE_NAMESPACE, CHANNEL_MANAGER_PERSISTENCE_KEY) + { let channel_monitor_references = channel_monitors.iter_mut().map(|(_, chanmon)| chanmon).collect(); let read_args = ChannelManagerReadArgs::new( @@ -393,8 +431,8 @@ impl Builder { channel_monitor_references, ); let (_hash, channel_manager) = - <(BlockHash, ChannelManager)>::read(&mut f, read_args) - .expect("Failed to read channel manager from disk"); + <(BlockHash, ChannelManager)>::read(&mut reader, read_args) + .expect("Failed to read channel manager from store"); channel_manager } else { // We're starting a fresh node. @@ -464,31 +502,40 @@ impl Builder { )); // Init payment info storage - // TODO: persist payment info to disk - let inbound_payments = Arc::new(Mutex::new(HashMap::new())); - let outbound_payments = Arc::new(Mutex::new(HashMap::new())); - - // Restore event handler from disk or create a new one. - let event_queue = if let Ok(mut f) = - fs::File::open(format!("{}/{}", ldk_data_dir, event::EVENTS_PERSISTENCE_KEY)) - { - Arc::new( - EventQueue::read(&mut f, Arc::clone(&persister)) - .expect("Failed to read event queue from disk."), - ) - } else { - Arc::new(EventQueue::new(Arc::clone(&persister))) + let payment_store = match io::utils::read_payments(Arc::clone(&kv_store)) { + Ok(payments) => { + Arc::new(PaymentStore::new(payments, Arc::clone(&kv_store), Arc::clone(&logger))) + } + Err(e) => { + log_error!(logger, "Failed to read payment information: {}", e.to_string()); + panic!("Failed to read payment information: {}", e.to_string()); + } }; - let peer_store = if let Ok(mut f) = - fs::File::open(format!("{}/{}", ldk_data_dir, peer_store::PEER_INFO_PERSISTENCE_KEY)) + let event_queue = + match io::utils::read_event_queue(Arc::clone(&kv_store), Arc::clone(&logger)) { + Ok(event_queue) => Arc::new(event_queue), + Err(e) => { + if e.kind() == std::io::ErrorKind::NotFound { + Arc::new(EventQueue::new(Arc::clone(&kv_store), Arc::clone(&logger))) + } else { + log_error!(logger, "Failed to read event queue: {}", e.to_string()); + panic!("Failed to read event queue: {}", e.to_string()); + } + } + }; + + let peer_store = match io::utils::read_peer_info(Arc::clone(&kv_store), Arc::clone(&logger)) { - Arc::new( - PeerInfoStorage::read(&mut f, Arc::clone(&persister)) - .expect("Failed to read peer information from disk."), - ) - } else { - Arc::new(PeerInfoStorage::new(Arc::clone(&persister))) + Ok(peer_store) => Arc::new(peer_store), + Err(e) => { + if e.kind() == std::io::ErrorKind::NotFound { + Arc::new(PeerStore::new(Arc::clone(&kv_store), Arc::clone(&logger))) + } else { + log_error!(logger, "Failed to read peer store: {}", e.to_string()); + panic!("Failed to read peer store: {}", e.to_string()); + } + } }; let running = RwLock::new(None); @@ -505,12 +552,11 @@ impl Builder { keys_manager, network_graph, gossip_sync, - persister, + kv_store, logger, scorer, - inbound_payments, - outbound_payments, peer_store, + payment_store, } } } @@ -532,19 +578,18 @@ pub struct Node { config: Arc, wallet: Arc>, tx_sync: Arc>>, - event_queue: Arc>>, + event_queue: Arc, Arc>>, channel_manager: Arc, chain_monitor: Arc, peer_manager: Arc, keys_manager: Arc, network_graph: Arc, gossip_sync: Arc, - persister: Arc, + kv_store: Arc, logger: Arc, scorer: Arc>, - inbound_payments: Arc, - outbound_payments: Arc, - peer_store: Arc>, + peer_store: Arc, Arc>>, + payment_store: Arc, Arc>>, } impl Node { @@ -604,8 +649,7 @@ impl Node { Arc::clone(&self.channel_manager), Arc::clone(&self.network_graph), Arc::clone(&self.keys_manager), - Arc::clone(&self.inbound_payments), - Arc::clone(&self.outbound_payments), + Arc::clone(&self.payment_store), Arc::clone(&tokio_runtime), Arc::clone(&self.logger), Arc::clone(&self.config), @@ -742,7 +786,7 @@ impl Node { // Setup background processing let _background_processor = BackgroundProcessor::start( - Arc::clone(&self.persister), + Arc::clone(&self.kv_store), Arc::clone(&event_handler), Arc::clone(&self.chain_monitor), Arc::clone(&self.channel_manager), @@ -958,9 +1002,13 @@ impl Node { return Err(Error::NotRunning); } - let mut outbound_payments_lock = self.outbound_payments.lock().unwrap(); - let payment_hash = PaymentHash((*invoice.payment_hash()).into_inner()); + + if self.payment_store.contains(&payment_hash) { + log_error!(self.logger, "Payment error: an invoice must not get paid twice."); + return Err(Error::NonUniquePaymentHash); + } + let payment_secret = Some(*invoice.payment_secret()); match lightning_invoice::payment::pay_invoice( @@ -973,15 +1021,15 @@ impl Node { let amt_msat = invoice.amount_milli_satoshis().unwrap(); log_info!(self.logger, "Initiated sending {}msat to {}", amt_msat, payee_pubkey); - outbound_payments_lock.insert( - payment_hash, - PaymentInfo { - preimage: None, - secret: payment_secret, - status: PaymentStatus::Pending, - amount_msat: invoice.amount_milli_satoshis(), - }, - ); + let payment = PaymentDetails { + preimage: None, + hash: payment_hash, + secret: payment_secret, + amount_msat: invoice.amount_milli_satoshis(), + direction: PaymentDirection::Outbound, + status: PaymentStatus::Pending, + }; + self.payment_store.insert(payment)?; Ok(payment_hash) } @@ -992,15 +1040,16 @@ impl Node { Err(payment::PaymentError::Sending(e)) => { log_error!(self.logger, "Failed to send payment: {:?}", e); - outbound_payments_lock.insert( - payment_hash, - PaymentInfo { - preimage: None, - secret: payment_secret, - status: PaymentStatus::Failed, - amount_msat: invoice.amount_milli_satoshis(), - }, - ); + let payment = PaymentDetails { + preimage: None, + hash: payment_hash, + secret: payment_secret, + amount_msat: invoice.amount_milli_satoshis(), + direction: PaymentDirection::Outbound, + status: PaymentStatus::Failed, + }; + self.payment_store.insert(payment)?; + Err(Error::PaymentFailed) } } @@ -1019,8 +1068,6 @@ impl Node { return Err(Error::NotRunning); } - let mut outbound_payments_lock = self.outbound_payments.lock().unwrap(); - if let Some(invoice_amount_msat) = invoice.amount_milli_satoshis() { if amount_msat < invoice_amount_msat { log_error!( @@ -1030,8 +1077,13 @@ impl Node { } } - let payment_id = PaymentId(invoice.payment_hash().into_inner()); let payment_hash = PaymentHash((*invoice.payment_hash()).into_inner()); + if self.payment_store.contains(&payment_hash) { + log_error!(self.logger, "Payment error: an invoice must not get paid twice."); + return Err(Error::NonUniquePaymentHash); + } + + let payment_id = PaymentId(invoice.payment_hash().into_inner()); let payment_secret = Some(*invoice.payment_secret()); let expiry_time = invoice.duration_since_epoch().saturating_add(invoice.expiry_time()); let mut payment_params = PaymentParameters::from_node_id( @@ -1067,15 +1119,15 @@ impl Node { payee_pubkey ); - outbound_payments_lock.insert( - payment_hash, - PaymentInfo { - preimage: None, - secret: payment_secret, - status: PaymentStatus::Pending, - amount_msat: Some(amount_msat), - }, - ); + let payment = PaymentDetails { + hash: payment_hash, + preimage: None, + secret: payment_secret, + amount_msat: Some(amount_msat), + direction: PaymentDirection::Outbound, + status: PaymentStatus::Pending, + }; + self.payment_store.insert(payment)?; Ok(payment_hash) } @@ -1086,15 +1138,16 @@ impl Node { Err(payment::PaymentError::Sending(e)) => { log_error!(self.logger, "Failed to send payment: {:?}", e); - outbound_payments_lock.insert( - payment_hash, - PaymentInfo { - preimage: None, - secret: payment_secret, - status: PaymentStatus::Failed, - amount_msat: Some(amount_msat), - }, - ); + let payment = PaymentDetails { + hash: payment_hash, + preimage: None, + secret: payment_secret, + amount_msat: Some(amount_msat), + direction: PaymentDirection::Outbound, + status: PaymentStatus::Failed, + }; + self.payment_store.insert(payment)?; + Err(Error::PaymentFailed) } } @@ -1108,8 +1161,6 @@ impl Node { return Err(Error::NotRunning); } - let mut outbound_payments_lock = self.outbound_payments.lock().unwrap(); - let pubkey = hex_utils::to_compressed_pubkey(node_id).ok_or(Error::PeerInfoParseFailed)?; let payment_preimage = PaymentPreimage(self.keys_manager.get_secure_random_bytes()); @@ -1131,28 +1182,32 @@ impl Node { ) { Ok(_payment_id) => { log_info!(self.logger, "Initiated sending {}msat to {}.", amount_msat, node_id); - outbound_payments_lock.insert( - payment_hash, - PaymentInfo { - preimage: None, - secret: None, - status: PaymentStatus::Pending, - amount_msat: Some(amount_msat), - }, - ); + + let payment = PaymentDetails { + hash: payment_hash, + preimage: Some(payment_preimage), + secret: None, + status: PaymentStatus::Pending, + direction: PaymentDirection::Outbound, + amount_msat: Some(amount_msat), + }; + self.payment_store.insert(payment)?; + Ok(payment_hash) } Err(e) => { log_error!(self.logger, "Failed to send payment: {:?}", e); - outbound_payments_lock.insert( - payment_hash, - PaymentInfo { - preimage: None, - secret: None, - status: PaymentStatus::Failed, - amount_msat: Some(amount_msat), - }, - ); + + let payment = PaymentDetails { + hash: payment_hash, + preimage: Some(payment_preimage), + secret: None, + status: PaymentStatus::Failed, + direction: PaymentDirection::Outbound, + amount_msat: Some(amount_msat), + }; + self.payment_store.insert(payment)?; + Err(Error::PaymentFailed) } } @@ -1177,8 +1232,6 @@ impl Node { fn receive_payment_inner( &self, amount_msat: Option, description: &str, expiry_secs: u32, ) -> Result { - let mut inbound_payments_lock = self.inbound_payments.lock().unwrap(); - let currency = match self.config.network { bitcoin::Network::Bitcoin => Currency::Bitcoin, bitcoin::Network::Testnet => Currency::BitcoinTestnet, @@ -1207,37 +1260,51 @@ impl Node { }; let payment_hash = PaymentHash((*invoice.payment_hash()).into_inner()); - inbound_payments_lock.insert( - payment_hash, - PaymentInfo { - preimage: None, - secret: Some(*invoice.payment_secret()), - status: PaymentStatus::Pending, - amount_msat, - }, - ); + let payment = PaymentDetails { + hash: payment_hash, + preimage: None, + secret: Some(invoice.payment_secret().clone()), + amount_msat, + direction: PaymentDirection::Inbound, + status: PaymentStatus::Pending, + }; + + self.payment_store.insert(payment)?; + Ok(invoice) } - /// Query for information about the status of a specific payment. - pub fn payment_info(&self, payment_hash: &[u8; 32]) -> Option { - let payment_hash = PaymentHash(*payment_hash); - - { - let outbound_payments_lock = self.outbound_payments.lock().unwrap(); - if let Some(payment_info) = outbound_payments_lock.get(&payment_hash) { - return Some((*payment_info).clone()); - } - } + /// Retrieve the details of a specific payment with the given hash. + /// + /// Returns `Some` if the payment was known and `None` otherwise. + pub fn payment(&self, payment_hash: &PaymentHash) -> Option { + self.payment_store.get(payment_hash) + } - { - let inbound_payments_lock = self.inbound_payments.lock().unwrap(); - if let Some(payment_info) = inbound_payments_lock.get(&payment_hash) { - return Some((*payment_info).clone()); - } - } + /// Remove the payment with the given hash from the store. + /// + /// Returns `true` if the payment was present and `false` otherwise. + pub fn remove_payment(&self, payment_hash: &PaymentHash) -> Result { + self.payment_store.remove(&payment_hash) + } - None + /// Retrieves all payments that match the given predicate. + /// + /// For example, you could retrieve all stored outbound payments as follows: + /// ``` + /// # use ldk_node::{Builder, Config, PaymentDirection}; + /// # use ldk_node::bitcoin::Network; + /// # let mut config = Config::default(); + /// # config.network = Network::Regtest; + /// # config.storage_dir_path = "/tmp/ldk_node_test/".to_string(); + /// # let builder = Builder::from_config(config); + /// # let node = builder.build(); + /// node.list_payments_with_filter(|p| p.direction == PaymentDirection::Outbound); + /// ``` + pub fn list_payments_with_filter bool>( + &self, f: F, + ) -> Vec { + self.payment_store.list_filter(f) } } diff --git a/src/payment_store.rs b/src/payment_store.rs new file mode 100644 index 000000000..80da72c0f --- /dev/null +++ b/src/payment_store.rs @@ -0,0 +1,266 @@ +use crate::hex_utils; +use crate::io::{KVStore, TransactionalWrite, PAYMENT_INFO_PERSISTENCE_NAMESPACE}; +use crate::logger::{log_error, Logger}; +use crate::Error; + +use lightning::ln::{PaymentHash, PaymentPreimage, PaymentSecret}; +use lightning::util::ser::Writeable; +use lightning::{impl_writeable_tlv_based, impl_writeable_tlv_based_enum}; + +use std::collections::HashMap; +use std::iter::FromIterator; +use std::ops::Deref; +use std::sync::Mutex; + +/// Represents a payment. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct PaymentDetails { + /// The payment hash, i.e., the hash of the `preimage`. + pub hash: PaymentHash, + /// The pre-image used by the payment. + pub preimage: Option, + /// The secret used by the payment. + pub secret: Option, + /// The amount transferred. + pub amount_msat: Option, + /// The direction of the payment. + pub direction: PaymentDirection, + /// The status of the payment. + pub status: PaymentStatus, +} + +impl_writeable_tlv_based!(PaymentDetails, { + (0, hash, required), + (1, preimage, required), + (2, secret, required), + (3, amount_msat, required), + (4, direction, required), + (5, status, required) +}); + +/// Represents the direction of a payment. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum PaymentDirection { + /// The payment is inbound. + Inbound, + /// The payment is outbound. + Outbound, +} + +impl_writeable_tlv_based_enum!(PaymentDirection, + (0, Inbound) => {}, + (1, Outbound) => {}; +); + +/// Represents the current status of a payment. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum PaymentStatus { + /// The payment is still pending. + Pending, + /// The payment suceeded. + Succeeded, + /// The payment failed. + Failed, +} + +impl_writeable_tlv_based_enum!(PaymentStatus, + (0, Pending) => {}, + (1, Succeeded) => {}, + (2, Failed) => {}; +); + +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) struct PaymentDetailsUpdate { + pub hash: PaymentHash, + pub preimage: Option>, + pub secret: Option>, + pub amount_msat: Option>, + pub direction: Option, + pub status: Option, +} + +impl PaymentDetailsUpdate { + pub fn new(hash: PaymentHash) -> Self { + Self { + hash, + preimage: None, + secret: None, + amount_msat: None, + direction: None, + status: None, + } + } +} + +pub(crate) struct PaymentStore +where + K::Target: KVStore, + L::Target: Logger, +{ + payments: Mutex>, + kv_store: K, + logger: L, +} + +impl PaymentStore +where + K::Target: KVStore, + L::Target: Logger, +{ + pub(crate) fn new(payments: Vec, kv_store: K, logger: L) -> Self { + let payments = Mutex::new(HashMap::from_iter( + payments.into_iter().map(|payment| (payment.hash, payment)), + )); + Self { payments, kv_store, logger } + } + + pub(crate) fn insert(&self, payment: PaymentDetails) -> Result { + let mut locked_payments = self.payments.lock().unwrap(); + + let hash = payment.hash.clone(); + let updated = locked_payments.insert(hash.clone(), payment.clone()).is_some(); + self.write_info_and_commit(&hash, &payment)?; + Ok(updated) + } + + pub(crate) fn remove(&self, hash: &PaymentHash) -> Result { + let store_key = hex_utils::to_string(&hash.0); + self.kv_store.remove(PAYMENT_INFO_PERSISTENCE_NAMESPACE, &store_key).map_err(|e| { + log_error!( + self.logger, + "Removing payment data for key {}/{} failed due to: {}", + PAYMENT_INFO_PERSISTENCE_NAMESPACE, + store_key, + e + ); + Error::PersistenceFailed + }) + } + + pub(crate) fn get(&self, hash: &PaymentHash) -> Option { + self.payments.lock().unwrap().get(hash).cloned() + } + + pub(crate) fn contains(&self, hash: &PaymentHash) -> bool { + self.payments.lock().unwrap().contains_key(hash) + } + + pub(crate) fn update(&self, update: &PaymentDetailsUpdate) -> Result { + let mut updated = false; + let mut locked_payments = self.payments.lock().unwrap(); + + if let Some(payment) = locked_payments.get_mut(&update.hash) { + if let Some(preimage_opt) = update.preimage { + payment.preimage = preimage_opt; + } + + if let Some(secret_opt) = update.secret { + payment.secret = secret_opt; + } + + if let Some(amount_opt) = update.amount_msat { + payment.amount_msat = amount_opt; + } + + if let Some(status) = update.status { + payment.status = status; + } + + self.write_info_and_commit(&update.hash, payment)?; + updated = true; + } + + Ok(updated) + } + + pub(crate) fn list_filter bool>( + &self, f: F, + ) -> Vec { + self.payments + .lock() + .unwrap() + .iter() + .map(|(_, p)| p) + .filter(f) + .cloned() + .collect::>() + } + + fn write_info_and_commit( + &self, hash: &PaymentHash, payment: &PaymentDetails, + ) -> Result<(), Error> { + let store_key = hex_utils::to_string(&hash.0); + let mut writer = + self.kv_store.write(PAYMENT_INFO_PERSISTENCE_NAMESPACE, &store_key).map_err(|e| { + log_error!( + self.logger, + "Getting writer for key {}/{} failed due to: {}", + PAYMENT_INFO_PERSISTENCE_NAMESPACE, + store_key, + e + ); + Error::PersistenceFailed + })?; + payment.write(&mut writer).map_err(|e| { + log_error!( + self.logger, + "Writing payment data for key {}/{} failed due to: {}", + PAYMENT_INFO_PERSISTENCE_NAMESPACE, + store_key, + e + ); + Error::PersistenceFailed + })?; + writer.commit().map_err(|e| { + log_error!( + self.logger, + "Committing payment data for key {}/{} failed due to: {}", + PAYMENT_INFO_PERSISTENCE_NAMESPACE, + store_key, + e + ); + Error::PersistenceFailed + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::utils::{TestLogger, TestStore}; + use std::sync::Arc; + + #[test] + fn persistence_guard_persists_on_drop() { + let store = Arc::new(TestStore::new()); + let logger = Arc::new(TestLogger::new()); + let payment_store = PaymentStore::new(Vec::new(), Arc::clone(&store), logger); + + let hash = PaymentHash([42u8; 32]); + assert!(!payment_store.contains(&hash)); + + let payment = PaymentDetails { + hash, + preimage: None, + secret: None, + amount_msat: None, + direction: PaymentDirection::Inbound, + status: PaymentStatus::Pending, + }; + + assert!(!store.get_and_clear_did_persist()); + + assert_eq!(Ok(false), payment_store.insert(payment.clone())); + assert!(store.get_and_clear_did_persist()); + + assert_eq!(Ok(true), payment_store.insert(payment)); + assert!(store.get_and_clear_did_persist()); + + let mut update = PaymentDetailsUpdate::new(hash); + update.status = Some(PaymentStatus::Succeeded); + assert_eq!(Ok(true), payment_store.update(&update)); + assert!(store.get_and_clear_did_persist()); + + assert_eq!(PaymentStatus::Succeeded, payment_store.get(&hash).unwrap().status); + } +} diff --git a/src/peer_store.rs b/src/peer_store.rs index 6147dee81..96f32752c 100644 --- a/src/peer_store.rs +++ b/src/peer_store.rs @@ -1,7 +1,10 @@ use crate::hex_utils; +use crate::io::{ + KVStore, TransactionalWrite, PEER_INFO_PERSISTENCE_KEY, PEER_INFO_PERSISTENCE_NAMESPACE, +}; +use crate::logger::{log_error, Logger}; use crate::Error; -use lightning::util::persist::KVStorePersister; use lightning::util::ser::{Readable, ReadableArgs, Writeable, Writer}; use bitcoin::secp256k1::PublicKey; @@ -9,44 +12,41 @@ use bitcoin::secp256k1::PublicKey; use std::collections::HashMap; use std::convert::TryFrom; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs}; -use std::sync::{Arc, RwLock}; - -/// The peer information will be persisted under this key. -pub(crate) const PEER_INFO_PERSISTENCE_KEY: &str = "peers"; - -pub(crate) struct PeerInfoStorage { +use std::ops::Deref; +use std::sync::RwLock; + +pub struct PeerStore +where + K::Target: KVStore, + L::Target: Logger, +{ peers: RwLock>, - persister: Arc, + kv_store: K, + logger: L, } -impl PeerInfoStorage { - pub(crate) fn new(persister: Arc) -> Self { +impl PeerStore +where + K::Target: KVStore, + L::Target: Logger, +{ + pub(crate) fn new(kv_store: K, logger: L) -> Self { let peers = RwLock::new(HashMap::new()); - Self { peers, persister } + Self { peers, kv_store, logger } } pub(crate) fn add_peer(&self, peer_info: PeerInfo) -> Result<(), Error> { let mut locked_peers = self.peers.write().unwrap(); locked_peers.insert(peer_info.pubkey, peer_info); - - self.persister - .persist(PEER_INFO_PERSISTENCE_KEY, &PeerInfoStorageSerWrapper(&*locked_peers)) - .map_err(|_| Error::PersistenceFailed)?; - - Ok(()) + self.write_peers_and_commit(&*locked_peers) } pub(crate) fn remove_peer(&self, peer_pubkey: &PublicKey) -> Result<(), Error> { let mut locked_peers = self.peers.write().unwrap(); locked_peers.remove(peer_pubkey); - - self.persister - .persist(PEER_INFO_PERSISTENCE_KEY, &PeerInfoStorageSerWrapper(&*locked_peers)) - .map_err(|_| Error::PersistenceFailed)?; - - Ok(()) + self.write_peers_and_commit(&*locked_peers) } pub(crate) fn list_peers(&self) -> Vec { @@ -56,23 +56,66 @@ impl PeerInfoStorage { pub(crate) fn get_peer(&self, peer_pubkey: &PublicKey) -> Option { self.peers.read().unwrap().get(peer_pubkey).cloned() } + + fn write_peers_and_commit( + &self, locked_peers: &HashMap, + ) -> Result<(), Error> { + let mut writer = self + .kv_store + .write(PEER_INFO_PERSISTENCE_NAMESPACE, PEER_INFO_PERSISTENCE_KEY) + .map_err(|e| { + log_error!( + self.logger, + "Getting writer for key {}/{} failed due to: {}", + PEER_INFO_PERSISTENCE_NAMESPACE, + PEER_INFO_PERSISTENCE_KEY, + e + ); + Error::PersistenceFailed + })?; + PeerStoreSerWrapper(&*locked_peers).write(&mut writer).map_err(|e| { + log_error!( + self.logger, + "Writing peer data to key {}/{} failed due to: {}", + PEER_INFO_PERSISTENCE_NAMESPACE, + PEER_INFO_PERSISTENCE_KEY, + e + ); + Error::PersistenceFailed + })?; + writer.commit().map_err(|e| { + log_error!( + self.logger, + "Committing peer data to key {}/{} failed due to: {}", + PEER_INFO_PERSISTENCE_NAMESPACE, + PEER_INFO_PERSISTENCE_KEY, + e + ); + Error::PersistenceFailed + }) + } } -impl ReadableArgs> for PeerInfoStorage { +impl ReadableArgs<(K, L)> for PeerStore +where + K::Target: KVStore, + L::Target: Logger, +{ #[inline] fn read( - reader: &mut R, persister: Arc, + reader: &mut R, args: (K, L), ) -> Result { - let read_peers: PeerInfoStorageDeserWrapper = Readable::read(reader)?; + let (kv_store, logger) = args; + let read_peers: PeerStoreDeserWrapper = Readable::read(reader)?; let peers: RwLock> = RwLock::new(read_peers.0); - Ok(Self { peers, persister }) + Ok(Self { peers, kv_store, logger }) } } #[derive(Clone, Debug, PartialEq, Eq)] -pub(crate) struct PeerInfoStorageDeserWrapper(HashMap); +pub(crate) struct PeerStoreDeserWrapper(HashMap); -impl Readable for PeerInfoStorageDeserWrapper { +impl Readable for PeerStoreDeserWrapper { fn read( reader: &mut R, ) -> Result { @@ -87,9 +130,9 @@ impl Readable for PeerInfoStorageDeserWrapper { } } -pub(crate) struct PeerInfoStorageSerWrapper<'a>(&'a HashMap); +pub(crate) struct PeerStoreSerWrapper<'a>(&'a HashMap); -impl Writeable for PeerInfoStorageSerWrapper<'_> { +impl Writeable for PeerStoreSerWrapper<'_> { fn write(&self, writer: &mut W) -> Result<(), lightning::io::Error> { (self.0.len() as u16).write(writer)?; for (k, v) in self.0.iter() { @@ -173,14 +216,16 @@ impl TryFrom for PeerInfo { #[cfg(test)] mod tests { use super::*; - use crate::tests::test_utils::TestPersister; + use crate::test::utils::{TestLogger, TestStore}; use proptest::prelude::*; use std::str::FromStr; + use std::sync::Arc; #[test] fn peer_info_persistence() { - let persister = Arc::new(TestPersister::new()); - let peer_store = PeerInfoStorage::new(Arc::clone(&persister)); + let store = Arc::new(TestStore::new()); + let logger = Arc::new(TestLogger::new()); + let peer_store = PeerStore::new(Arc::clone(&store), Arc::clone(&logger)); let pubkey = PublicKey::from_str( "0276607124ebe6a6c9338517b6f485825b27c2dcc0b9fc2aa6a4c0df91194e5993", @@ -189,18 +234,20 @@ mod tests { let address: SocketAddr = "127.0.0.1:9738".parse().unwrap(); let expected_peer_info = PeerInfo { pubkey, address }; peer_store.add_peer(expected_peer_info.clone()).unwrap(); - assert!(persister.get_and_clear_did_persist()); + assert!(store.get_and_clear_did_persist()); // Check we can read back what we persisted. - let persisted_bytes = persister.get_persisted_bytes(PEER_INFO_PERSISTENCE_KEY).unwrap(); + let persisted_bytes = store + .get_persisted_bytes(PEER_INFO_PERSISTENCE_NAMESPACE, PEER_INFO_PERSISTENCE_KEY) + .unwrap(); let deser_peer_store = - PeerInfoStorage::read(&mut &persisted_bytes[..], Arc::clone(&persister)).unwrap(); + PeerStore::read(&mut &persisted_bytes[..], (Arc::clone(&store), logger)).unwrap(); let peers = deser_peer_store.list_peers(); assert_eq!(peers.len(), 1); assert_eq!(peers[0], expected_peer_info); assert_eq!(deser_peer_store.get_peer(&pubkey), Some(expected_peer_info)); - assert!(!persister.get_and_clear_did_persist()); + assert!(!store.get_and_clear_did_persist()); } #[test] diff --git a/src/tests/functional_tests.rs b/src/test/functional_tests.rs similarity index 52% rename from src/tests/functional_tests.rs rename to src/test/functional_tests.rs index e83741cd8..e2074af57 100644 --- a/src/tests/functional_tests.rs +++ b/src/test/functional_tests.rs @@ -1,179 +1,32 @@ -use crate::tests::test_utils::expect_event; -use crate::{Builder, Config, Error, Event}; +use crate::test::utils::*; +use crate::test::utils::{expect_event, random_config}; +use crate::{Builder, Error, Event, PaymentDirection, PaymentStatus}; -use bitcoin::{Address, Amount, OutPoint, Txid}; -use bitcoind::bitcoincore_rpc::RpcApi; -use electrsd::bitcoind::bitcoincore_rpc::bitcoincore_rpc_json::AddressType; -use electrsd::{bitcoind, bitcoind::BitcoinD, ElectrsD}; -use electrum_client::ElectrumApi; +use bitcoin::Amount; -use once_cell::sync::OnceCell; -use rand::distributions::Alphanumeric; -use rand::{thread_rng, Rng}; - -use std::env; -use std::sync::Mutex; use std::time::Duration; - -static BITCOIND: OnceCell = OnceCell::new(); -static ELECTRSD: OnceCell = OnceCell::new(); -static PREMINE: OnceCell<()> = OnceCell::new(); -static MINER_LOCK: Mutex<()> = Mutex::new(()); - -fn get_bitcoind() -> &'static BitcoinD { - BITCOIND.get_or_init(|| { - let bitcoind_exe = - env::var("BITCOIND_EXE").ok().or_else(|| bitcoind::downloaded_exe_path().ok()).expect( - "you need to provide an env var BITCOIND_EXE or specify a bitcoind version feature", - ); - let mut conf = bitcoind::Conf::default(); - conf.network = "regtest"; - BitcoinD::with_conf(bitcoind_exe, &conf).unwrap() - }) -} - -fn get_electrsd() -> &'static ElectrsD { - ELECTRSD.get_or_init(|| { - let bitcoind = get_bitcoind(); - let electrs_exe = - env::var("ELECTRS_EXE").ok().or_else(electrsd::downloaded_exe_path).expect( - "you need to provide env var ELECTRS_EXE or specify an electrsd version feature", - ); - let mut conf = electrsd::Conf::default(); - conf.http_enabled = true; - conf.network = "regtest"; - ElectrsD::with_conf(electrs_exe, &bitcoind, &conf).unwrap() - }) -} - -fn generate_blocks_and_wait(num: usize) { - let _miner = MINER_LOCK.lock().unwrap(); - let cur_height = get_bitcoind().client.get_block_count().unwrap(); - let address = - get_bitcoind().client.get_new_address(Some("test"), Some(AddressType::Legacy)).unwrap(); - let _block_hashes = get_bitcoind().client.generate_to_address(num as u64, &address).unwrap(); - wait_for_block(cur_height as usize + num); -} - -fn wait_for_block(min_height: usize) { - let mut header = get_electrsd().client.block_headers_subscribe().unwrap(); - loop { - if header.height >= min_height { - break; - } - header = exponential_backoff_poll(|| { - get_electrsd().trigger().unwrap(); - get_electrsd().client.ping().unwrap(); - get_electrsd().client.block_headers_pop().unwrap() - }); - } -} - -fn wait_for_tx(txid: Txid) { - let mut tx_res = get_electrsd().client.transaction_get(&txid); - loop { - if tx_res.is_ok() { - break; - } - tx_res = exponential_backoff_poll(|| { - get_electrsd().trigger().unwrap(); - get_electrsd().client.ping().unwrap(); - Some(get_electrsd().client.transaction_get(&txid)) - }); - } -} - -fn wait_for_outpoint_spend(outpoint: OutPoint) { - let tx = get_electrsd().client.transaction_get(&outpoint.txid).unwrap(); - let txout_script = tx.output.get(outpoint.vout as usize).unwrap().clone().script_pubkey; - let mut is_spent = !get_electrsd().client.script_get_history(&txout_script).unwrap().is_empty(); - loop { - if is_spent { - break; - } - - is_spent = exponential_backoff_poll(|| { - get_electrsd().trigger().unwrap(); - get_electrsd().client.ping().unwrap(); - Some(!get_electrsd().client.script_get_history(&txout_script).unwrap().is_empty()) - }); - } -} - -fn exponential_backoff_poll(mut poll: F) -> T -where - F: FnMut() -> Option, -{ - let mut delay = Duration::from_millis(64); - let mut tries = 0; - loop { - match poll() { - Some(data) => break data, - None if delay.as_millis() < 512 => { - delay = delay.mul_f32(2.0); - } - - None => {} - } - assert!(tries < 10, "Reached max tries."); - tries += 1; - std::thread::sleep(delay); - } -} - -fn premine_and_distribute_funds(addrs: Vec
, amount: Amount) { - PREMINE.get_or_init(|| { - generate_blocks_and_wait(101); - }); - - for addr in addrs { - let txid = get_bitcoind() - .client - .send_to_address(&addr, amount, None, None, None, None, None, None) - .unwrap(); - wait_for_tx(txid); - } - - generate_blocks_and_wait(1); -} - -fn rand_config() -> Config { - let mut config = Config::default(); - - let esplora_url = get_electrsd().esplora_url.as_ref().unwrap(); - - println!("Setting esplora server URL: {}", esplora_url); - config.esplora_server_url = format!("http://{}", esplora_url); - - let mut rng = thread_rng(); - let rand_dir: String = (0..7).map(|_| rng.sample(Alphanumeric) as char).collect(); - let rand_path = format!("/tmp/{}", rand_dir); - println!("Setting random LDK storage dir: {}", rand_dir); - config.storage_dir_path = rand_path; - - let rand_port: u16 = rng.gen_range(5000..8000); - println!("Setting random LDK listening port: {}", rand_port); - let listening_address = format!("127.0.0.1:{}", rand_port); - config.listening_address = Some(listening_address); - - config -} - #[test] fn channel_full_cycle() { + let (bitcoind, electrsd) = setup_bitcoind_and_electrsd(); println!("== Node A =="); - let config_a = rand_config(); + let esplora_url = electrsd.esplora_url.as_ref().unwrap(); + let config_a = random_config(esplora_url); let node_a = Builder::from_config(config_a).build(); node_a.start().unwrap(); let addr_a = node_a.new_funding_address().unwrap(); println!("\n== Node B =="); - let config_b = rand_config(); + let config_b = random_config(esplora_url); let node_b = Builder::from_config(config_b).build(); node_b.start().unwrap(); let addr_b = node_b.new_funding_address().unwrap(); - premine_and_distribute_funds(vec![addr_a, addr_b], Amount::from_sat(100000)); + premine_and_distribute_funds( + &bitcoind, + &electrsd, + vec![addr_a, addr_b], + Amount::from_sat(100000), + ); node_a.sync_wallets().unwrap(); node_b.sync_wallets().unwrap(); assert_eq!(node_a.on_chain_balance().unwrap().get_spendable(), 100000); @@ -193,10 +46,10 @@ fn channel_full_cycle() { } }; - wait_for_tx(funding_txo.txid); + wait_for_tx(&electrsd, funding_txo.txid); println!("\n .. generating blocks, syncing wallets .. "); - generate_blocks_and_wait(6); + generate_blocks_and_wait(&bitcoind, &electrsd, 6); node_a.sync_wallets().unwrap(); node_b.sync_wallets().unwrap(); @@ -219,13 +72,39 @@ fn channel_full_cycle() { }; println!("\nB receive_payment"); - let invoice = node_b.receive_payment(1000000, &"asdf", 9217).unwrap(); + let invoice_amount = 1000000; + let invoice = node_b.receive_payment(invoice_amount, &"asdf", 9217).unwrap(); println!("\nA send_payment"); - node_a.send_payment(invoice).unwrap(); + let payment_hash = node_a.send_payment(invoice.clone()).unwrap(); + + let outbound_payments_a = + node_a.list_payments_with_filter(|p| p.direction == PaymentDirection::Outbound); + assert_eq!(outbound_payments_a.len(), 1); + + let inbound_payments_a = + node_a.list_payments_with_filter(|p| p.direction == PaymentDirection::Inbound); + assert_eq!(inbound_payments_a.len(), 0); + + let outbound_payments_b = + node_b.list_payments_with_filter(|p| p.direction == PaymentDirection::Outbound); + assert_eq!(outbound_payments_b.len(), 0); + + let inbound_payments_b = + node_b.list_payments_with_filter(|p| p.direction == PaymentDirection::Inbound); + assert_eq!(inbound_payments_b.len(), 1); expect_event!(node_a, PaymentSuccessful); expect_event!(node_b, PaymentReceived); + assert_eq!(node_a.payment(&payment_hash).unwrap().status, PaymentStatus::Succeeded); + assert_eq!(node_a.payment(&payment_hash).unwrap().direction, PaymentDirection::Outbound); + assert_eq!(node_a.payment(&payment_hash).unwrap().amount_msat, Some(invoice_amount)); + assert_eq!(node_b.payment(&payment_hash).unwrap().status, PaymentStatus::Succeeded); + assert_eq!(node_b.payment(&payment_hash).unwrap().direction, PaymentDirection::Inbound); + assert_eq!(node_b.payment(&payment_hash).unwrap().amount_msat, Some(invoice_amount)); + + // Assert we fail duplicate outbound payments. + assert_eq!(Err(Error::NonUniquePaymentHash), node_a.send_payment(invoice)); // Test under-/overpayment let invoice_amount = 1000000; @@ -239,7 +118,7 @@ fn channel_full_cycle() { let invoice = node_b.receive_payment(invoice_amount, &"asdf", 9217).unwrap(); let overpaid_amount = invoice_amount + 100; - node_a.send_payment_using_amount(invoice, overpaid_amount).unwrap(); + let payment_hash = node_a.send_payment_using_amount(invoice, overpaid_amount).unwrap(); expect_event!(node_a, PaymentSuccessful); let received_amount = match node_b.next_event() { ref e @ Event::PaymentReceived { amount_msat, .. } => { @@ -252,12 +131,19 @@ fn channel_full_cycle() { } }; assert_eq!(received_amount, overpaid_amount); + assert_eq!(node_a.payment(&payment_hash).unwrap().status, PaymentStatus::Succeeded); + assert_eq!(node_a.payment(&payment_hash).unwrap().direction, PaymentDirection::Outbound); + assert_eq!(node_a.payment(&payment_hash).unwrap().amount_msat, Some(overpaid_amount)); + assert_eq!(node_b.payment(&payment_hash).unwrap().status, PaymentStatus::Succeeded); + assert_eq!(node_b.payment(&payment_hash).unwrap().direction, PaymentDirection::Inbound); + assert_eq!(node_b.payment(&payment_hash).unwrap().amount_msat, Some(overpaid_amount)); // Test "zero-amount" invoice payment let variable_amount_invoice = node_b.receive_variable_amount_payment(&"asdf", 9217).unwrap(); let determined_amount = 1234567; assert_eq!(Err(Error::InvalidInvoice), node_a.send_payment(variable_amount_invoice.clone())); - node_a.send_payment_using_amount(variable_amount_invoice, determined_amount).unwrap(); + let payment_hash = + node_a.send_payment_using_amount(variable_amount_invoice, determined_amount).unwrap(); expect_event!(node_a, PaymentSuccessful); let received_amount = match node_b.next_event() { @@ -271,14 +157,20 @@ fn channel_full_cycle() { } }; assert_eq!(received_amount, determined_amount); + assert_eq!(node_a.payment(&payment_hash).unwrap().status, PaymentStatus::Succeeded); + assert_eq!(node_a.payment(&payment_hash).unwrap().direction, PaymentDirection::Outbound); + assert_eq!(node_a.payment(&payment_hash).unwrap().amount_msat, Some(determined_amount)); + assert_eq!(node_b.payment(&payment_hash).unwrap().status, PaymentStatus::Succeeded); + assert_eq!(node_b.payment(&payment_hash).unwrap().direction, PaymentDirection::Inbound); + assert_eq!(node_b.payment(&payment_hash).unwrap().amount_msat, Some(determined_amount)); node_b.close_channel(&channel_id, &node_a.node_id()).unwrap(); expect_event!(node_a, ChannelClosed); expect_event!(node_b, ChannelClosed); - wait_for_outpoint_spend(funding_txo.into_bitcoin_outpoint()); + wait_for_outpoint_spend(&electrsd, funding_txo.into_bitcoin_outpoint()); - generate_blocks_and_wait(1); + generate_blocks_and_wait(&bitcoind, &electrsd, 1); node_a.sync_wallets().unwrap(); node_b.sync_wallets().unwrap(); @@ -293,19 +185,26 @@ fn channel_full_cycle() { #[test] fn channel_open_fails_when_funds_insufficient() { + let (bitcoind, electrsd) = setup_bitcoind_and_electrsd(); println!("== Node A =="); - let config_a = rand_config(); + let esplora_url = electrsd.esplora_url.as_ref().unwrap(); + let config_a = random_config(&esplora_url); let node_a = Builder::from_config(config_a).build(); node_a.start().unwrap(); let addr_a = node_a.new_funding_address().unwrap(); println!("\n== Node B =="); - let config_b = rand_config(); + let config_b = random_config(&esplora_url); let node_b = Builder::from_config(config_b).build(); node_b.start().unwrap(); let addr_b = node_b.new_funding_address().unwrap(); - premine_and_distribute_funds(vec![addr_a, addr_b], Amount::from_sat(100000)); + premine_and_distribute_funds( + &bitcoind, + &electrsd, + vec![addr_a, addr_b], + Amount::from_sat(100000), + ); node_a.sync_wallets().unwrap(); node_b.sync_wallets().unwrap(); assert_eq!(node_a.on_chain_balance().unwrap().get_spendable(), 100000); @@ -321,7 +220,9 @@ fn channel_open_fails_when_funds_insufficient() { #[test] fn connect_to_public_testnet_esplora() { - let mut config = rand_config(); + let (_bitcoind, electrsd) = setup_bitcoind_and_electrsd(); + let esplora_url = electrsd.esplora_url.as_ref().unwrap(); + let mut config = random_config(&esplora_url); config.esplora_server_url = "https://blockstream.info/testnet/api".to_string(); config.network = bitcoin::Network::Testnet; let node = Builder::from_config(config).build(); @@ -332,14 +233,16 @@ fn connect_to_public_testnet_esplora() { #[test] fn start_stop_reinit() { - let config = rand_config(); + let (bitcoind, electrsd) = setup_bitcoind_and_electrsd(); + let esplora_url = electrsd.esplora_url.as_ref().unwrap(); + let config = random_config(&esplora_url); let node = Builder::from_config(config.clone()).build(); let expected_node_id = node.node_id(); let funding_address = node.new_funding_address().unwrap(); let expected_amount = Amount::from_sat(100000); - premine_and_distribute_funds(vec![funding_address], expected_amount); + premine_and_distribute_funds(&bitcoind, &electrsd, vec![funding_address], expected_amount); assert_eq!(node.on_chain_balance().unwrap().get_total(), 0); node.start().unwrap(); diff --git a/src/tests/mod.rs b/src/test/mod.rs similarity index 56% rename from src/tests/mod.rs rename to src/test/mod.rs index 5c32fa2af..f856f4878 100644 --- a/src/tests/mod.rs +++ b/src/test/mod.rs @@ -1,2 +1,2 @@ pub mod functional_tests; -pub mod test_utils; +pub mod utils; diff --git a/src/test/utils.rs b/src/test/utils.rs new file mode 100644 index 000000000..618afaa93 --- /dev/null +++ b/src/test/utils.rs @@ -0,0 +1,423 @@ +use crate::io::{KVStore, TransactionalWrite}; +use crate::Config; +use lightning::util::logger::{Level, Logger, Record}; +use lightning::util::persist::KVStorePersister; +use lightning::util::ser::Writeable; + +use bitcoin::{Address, Amount, OutPoint, Txid}; + +use bitcoind::bitcoincore_rpc::RpcApi; +use electrsd::bitcoind::bitcoincore_rpc::bitcoincore_rpc_json::AddressType; +use electrsd::{bitcoind, bitcoind::BitcoinD, ElectrsD}; +use electrum_client::ElectrumApi; + +use regex; + +use rand::distributions::Alphanumeric; +use rand::{thread_rng, Rng}; +use std::collections::hash_map; +use std::collections::HashMap; +use std::env; +use std::io::{BufWriter, Cursor, Read, Write}; +use std::path::PathBuf; +use std::str::FromStr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex, RwLock}; +use std::time::Duration; + +macro_rules! expect_event { + ($node: expr, $event_type: ident) => {{ + match $node.next_event() { + ref e @ Event::$event_type { .. } => { + println!("{} got event {:?}", std::stringify!($node), e); + $node.event_handled(); + } + ref e => { + panic!("{} got unexpected event!: {:?}", std::stringify!($node), e); + } + } + }}; +} + +pub(crate) use expect_event; + +pub(crate) struct TestStore { + persisted_bytes: RwLock>>>>>, + did_persist: Arc, +} + +impl TestStore { + pub fn new() -> Self { + let persisted_bytes = RwLock::new(HashMap::new()); + let did_persist = Arc::new(AtomicBool::new(false)); + Self { persisted_bytes, did_persist } + } + + pub fn get_persisted_bytes(&self, namespace: &str, key: &str) -> Option> { + if let Some(outer_ref) = self.persisted_bytes.read().unwrap().get(namespace) { + if let Some(inner_ref) = outer_ref.get(key) { + let locked = inner_ref.read().unwrap(); + return Some((*locked).clone()); + } + } + None + } + + pub fn get_and_clear_did_persist(&self) -> bool { + self.did_persist.swap(false, Ordering::Relaxed) + } +} + +impl KVStore for TestStore { + type Reader = TestReader; + type Writer = TestWriter; + + fn read(&self, namespace: &str, key: &str) -> std::io::Result { + if let Some(outer_ref) = self.persisted_bytes.read().unwrap().get(namespace) { + if let Some(inner_ref) = outer_ref.get(key) { + Ok(TestReader::new(Arc::clone(inner_ref))) + } else { + let msg = format!("Key not found: {}", key); + Err(std::io::Error::new(std::io::ErrorKind::NotFound, msg)) + } + } else { + let msg = format!("Namespace not found: {}", namespace); + Err(std::io::Error::new(std::io::ErrorKind::NotFound, msg)) + } + } + + fn write(&self, namespace: &str, key: &str) -> std::io::Result { + let mut guard = self.persisted_bytes.write().unwrap(); + let outer_e = guard.entry(namespace.to_string()).or_insert(HashMap::new()); + let inner_e = outer_e.entry(key.to_string()).or_insert(Arc::new(RwLock::new(Vec::new()))); + Ok(TestWriter::new(Arc::clone(&inner_e), Arc::clone(&self.did_persist))) + } + + fn remove(&self, namespace: &str, key: &str) -> std::io::Result { + match self.persisted_bytes.write().unwrap().entry(namespace.to_string()) { + hash_map::Entry::Occupied(mut e) => { + self.did_persist.store(true, Ordering::SeqCst); + Ok(e.get_mut().remove(&key.to_string()).is_some()) + } + hash_map::Entry::Vacant(_) => Ok(false), + } + } + + fn list(&self, namespace: &str) -> std::io::Result> { + match self.persisted_bytes.write().unwrap().entry(namespace.to_string()) { + hash_map::Entry::Occupied(e) => Ok(e.get().keys().cloned().collect()), + hash_map::Entry::Vacant(_) => Ok(Vec::new()), + } + } +} + +impl KVStorePersister for TestStore { + fn persist(&self, prefixed_key: &str, object: &W) -> std::io::Result<()> { + let msg = format!("Could not persist file for key {}.", prefixed_key); + let dest_file = PathBuf::from_str(prefixed_key).map_err(|_| { + lightning::io::Error::new(lightning::io::ErrorKind::InvalidInput, msg.clone()) + })?; + + let parent_directory = dest_file.parent().ok_or(lightning::io::Error::new( + lightning::io::ErrorKind::InvalidInput, + msg.clone(), + ))?; + let namespace = parent_directory.display().to_string(); + + let dest_without_namespace = dest_file + .strip_prefix(&namespace) + .map_err(|_| lightning::io::Error::new(lightning::io::ErrorKind::InvalidInput, msg))?; + let key = dest_without_namespace.display().to_string(); + let mut writer = self.write(&namespace, &key)?; + object.write(&mut writer)?; + writer.commit()?; + Ok(()) + } +} + +pub struct TestReader { + entry_ref: Arc>>, +} + +impl TestReader { + pub fn new(entry_ref: Arc>>) -> Self { + Self { entry_ref } + } +} + +impl Read for TestReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let bytes = self.entry_ref.read().unwrap().clone(); + let mut reader = Cursor::new(bytes); + reader.read(buf) + } +} + +pub struct TestWriter { + tmp_inner: BufWriter>, + entry_ref: Arc>>, + did_persist: Arc, +} + +impl TestWriter { + pub fn new(entry_ref: Arc>>, did_persist: Arc) -> Self { + let tmp_inner = BufWriter::new(Vec::new()); + Self { tmp_inner, entry_ref, did_persist } + } +} + +impl Write for TestWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.tmp_inner.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.tmp_inner.flush() + } +} + +impl TransactionalWrite for TestWriter { + fn commit(&mut self) -> std::io::Result<()> { + self.flush()?; + let bytes_ref = self.tmp_inner.get_ref(); + let mut guard = self.entry_ref.write().unwrap(); + guard.clone_from(bytes_ref); + self.did_persist.store(true, Ordering::SeqCst); + Ok(()) + } +} + +// Copied over from upstream LDK +#[allow(dead_code)] +pub struct TestLogger { + level: Level, + pub(crate) id: String, + pub lines: Mutex>, +} + +impl TestLogger { + #[allow(dead_code)] + pub fn new() -> TestLogger { + Self::with_id("".to_owned()) + } + + #[allow(dead_code)] + pub fn with_id(id: String) -> TestLogger { + TestLogger { level: Level::Trace, id, lines: Mutex::new(HashMap::new()) } + } + + #[allow(dead_code)] + pub fn enable(&mut self, level: Level) { + self.level = level; + } + + #[allow(dead_code)] + pub fn assert_log(&self, module: String, line: String, count: usize) { + let log_entries = self.lines.lock().unwrap(); + assert_eq!(log_entries.get(&(module, line)), Some(&count)); + } + + /// Search for the number of occurrence of the logged lines which + /// 1. belongs to the specified module and + /// 2. contains `line` in it. + /// And asserts if the number of occurrences is the same with the given `count` + #[allow(dead_code)] + pub fn assert_log_contains(&self, module: &str, line: &str, count: usize) { + let log_entries = self.lines.lock().unwrap(); + let l: usize = log_entries + .iter() + .filter(|&(&(ref m, ref l), _c)| m == module && l.contains(line)) + .map(|(_, c)| c) + .sum(); + assert_eq!(l, count) + } + + /// Search for the number of occurrences of logged lines which + /// 1. belong to the specified module and + /// 2. match the given regex pattern. + /// Assert that the number of occurrences equals the given `count` + #[allow(dead_code)] + pub fn assert_log_regex(&self, module: &str, pattern: regex::Regex, count: usize) { + let log_entries = self.lines.lock().unwrap(); + let l: usize = log_entries + .iter() + .filter(|&(&(ref m, ref l), _c)| m == module && pattern.is_match(&l)) + .map(|(_, c)| c) + .sum(); + assert_eq!(l, count) + } +} + +impl Logger for TestLogger { + fn log(&self, record: &Record) { + *self + .lines + .lock() + .unwrap() + .entry((record.module_path.to_string(), format!("{}", record.args))) + .or_insert(0) += 1; + if record.level >= self.level { + #[cfg(feature = "std")] + println!( + "{:<5} {} [{} : {}, {}] {}", + record.level.to_string(), + self.id, + record.module_path, + record.file, + record.line, + record.args + ); + } + } +} + +pub fn random_storage_path() -> String { + let mut rng = thread_rng(); + let rand_dir: String = (0..7).map(|_| rng.sample(Alphanumeric) as char).collect(); + format!("/tmp/{}", rand_dir) +} + +pub fn random_port() -> u16 { + let mut rng = thread_rng(); + rng.gen_range(5000..65535) +} + +pub fn random_config(esplora_url: &str) -> Config { + let mut config = Config::default(); + + println!("Setting esplora server URL: {}", esplora_url); + config.esplora_server_url = format!("http://{}", esplora_url); + + let rand_dir = random_storage_path(); + println!("Setting random LDK storage dir: {}", rand_dir); + config.storage_dir_path = rand_dir; + + let rand_port = random_port(); + println!("Setting random LDK listening port: {}", rand_port); + let listening_address = format!("127.0.0.1:{}", rand_port); + config.listening_address = Some(listening_address); + + config +} + +pub fn setup_bitcoind_and_electrsd() -> (BitcoinD, ElectrsD) { + let bitcoind_exe = + env::var("BITCOIND_EXE").ok().or_else(|| bitcoind::downloaded_exe_path().ok()).expect( + "you need to provide an env var BITCOIND_EXE or specify a bitcoind version feature", + ); + let mut bitcoind_conf = bitcoind::Conf::default(); + bitcoind_conf.network = "regtest"; + let bitcoind = BitcoinD::with_conf(bitcoind_exe, &bitcoind_conf).unwrap(); + + let electrs_exe = env::var("ELECTRS_EXE") + .ok() + .or_else(electrsd::downloaded_exe_path) + .expect("you need to provide env var ELECTRS_EXE or specify an electrsd version feature"); + let mut electrsd_conf = electrsd::Conf::default(); + electrsd_conf.http_enabled = true; + electrsd_conf.network = "regtest"; + let electrsd = ElectrsD::with_conf(electrs_exe, &bitcoind, &electrsd_conf).unwrap(); + (bitcoind, electrsd) +} + +pub fn generate_blocks_and_wait(bitcoind: &BitcoinD, electrsd: &ElectrsD, num: usize) { + let cur_height = bitcoind.client.get_block_count().expect("failed to get current block height"); + let address = bitcoind + .client + .get_new_address(Some("test"), Some(AddressType::Legacy)) + .expect("failed to get new address"); + // TODO: expect this Result once the WouldBlock issue is resolved upstream. + let _block_hashes_res = bitcoind.client.generate_to_address(num as u64, &address); + wait_for_block(electrsd, cur_height as usize + num); +} + +pub fn wait_for_block(electrsd: &ElectrsD, min_height: usize) { + let mut header = match electrsd.client.block_headers_subscribe() { + Ok(header) => header, + Err(_) => { + // While subscribing should succeed the first time around, we ran into some cases where + // it didn't. Since we can't proceed without subscribing, we try again after a delay + // and panic if it still fails. + std::thread::sleep(Duration::from_secs(1)); + electrsd.client.block_headers_subscribe().expect("failed to subscribe to block headers") + } + }; + loop { + if header.height >= min_height { + break; + } + header = exponential_backoff_poll(|| { + electrsd.trigger().expect("failed to trigger electrsd"); + electrsd.client.ping().expect("failed to ping electrsd"); + electrsd.client.block_headers_pop().expect("failed to pop block header") + }); + } +} + +pub fn wait_for_tx(electrsd: &ElectrsD, txid: Txid) { + let mut tx_res = electrsd.client.transaction_get(&txid); + loop { + if tx_res.is_ok() { + break; + } + tx_res = exponential_backoff_poll(|| { + electrsd.trigger().unwrap(); + electrsd.client.ping().unwrap(); + Some(electrsd.client.transaction_get(&txid)) + }); + } +} + +pub fn wait_for_outpoint_spend(electrsd: &ElectrsD, outpoint: OutPoint) { + let tx = electrsd.client.transaction_get(&outpoint.txid).unwrap(); + let txout_script = tx.output.get(outpoint.vout as usize).unwrap().clone().script_pubkey; + let mut is_spent = !electrsd.client.script_get_history(&txout_script).unwrap().is_empty(); + loop { + if is_spent { + break; + } + + is_spent = exponential_backoff_poll(|| { + electrsd.trigger().unwrap(); + electrsd.client.ping().unwrap(); + Some(!electrsd.client.script_get_history(&txout_script).unwrap().is_empty()) + }); + } +} + +pub fn exponential_backoff_poll(mut poll: F) -> T +where + F: FnMut() -> Option, +{ + let mut delay = Duration::from_millis(64); + let mut tries = 0; + loop { + match poll() { + Some(data) => break data, + None if delay.as_millis() < 512 => { + delay = delay.mul_f32(2.0); + } + + None => {} + } + assert!(tries < 10, "Reached max tries."); + tries += 1; + std::thread::sleep(delay); + } +} + +pub fn premine_and_distribute_funds( + bitcoind: &BitcoinD, electrsd: &ElectrsD, addrs: Vec
, amount: Amount, +) { + generate_blocks_and_wait(bitcoind, electrsd, 101); + + for addr in addrs { + let txid = bitcoind + .client + .send_to_address(&addr, amount, None, None, None, None, None, None) + .unwrap(); + wait_for_tx(electrsd, txid); + } + + generate_blocks_and_wait(bitcoind, electrsd, 1); +} diff --git a/src/tests/test_utils.rs b/src/tests/test_utils.rs deleted file mode 100644 index b4b832d45..000000000 --- a/src/tests/test_utils.rs +++ /dev/null @@ -1,55 +0,0 @@ -use lightning::util::persist::KVStorePersister; -use lightning::util::ser::Writeable; - -use std::collections::HashMap; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Mutex; - -macro_rules! expect_event { - ($node: expr, $event_type: ident) => {{ - match $node.next_event() { - ref e @ Event::$event_type { .. } => { - println!("{} got event {:?}", std::stringify!($node), e); - $node.event_handled(); - } - ref e => { - panic!("{} got unexpected event!: {:?}", std::stringify!($node), e); - } - } - }}; -} - -pub(crate) use expect_event; - -pub(crate) struct TestPersister { - persisted_bytes: Mutex>>, - did_persist: AtomicBool, -} - -impl TestPersister { - pub fn new() -> Self { - let persisted_bytes = Mutex::new(HashMap::new()); - let did_persist = AtomicBool::new(false); - Self { persisted_bytes, did_persist } - } - - pub fn get_persisted_bytes(&self, key: &str) -> Option> { - let persisted_bytes_lock = self.persisted_bytes.lock().unwrap(); - persisted_bytes_lock.get(key).cloned() - } - - pub fn get_and_clear_did_persist(&self) -> bool { - self.did_persist.swap(false, Ordering::SeqCst) - } -} - -impl KVStorePersister for TestPersister { - fn persist(&self, key: &str, object: &W) -> std::io::Result<()> { - let mut persisted_bytes_lock = self.persisted_bytes.lock().unwrap(); - let mut bytes = Vec::new(); - object.write(&mut bytes)?; - persisted_bytes_lock.insert(key.to_owned(), bytes); - self.did_persist.store(true, Ordering::SeqCst); - Ok(()) - } -} diff --git a/src/types.rs b/src/types.rs index deba8dd3e..457dab04e 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,56 +1,27 @@ +use crate::io::fs_store::FilesystemStore; use crate::logger::FilesystemLogger; use crate::wallet::{Wallet, WalletKeysManager}; use lightning::chain::chainmonitor; use lightning::chain::keysinterface::InMemorySigner; use lightning::ln::peer_handler::IgnoringMessageHandler; -use lightning::ln::{PaymentHash, PaymentPreimage, PaymentSecret}; use lightning::routing::gossip; use lightning::routing::gossip::P2PGossipSync; use lightning::routing::router::DefaultRouter; use lightning::routing::scoring::ProbabilisticScorer; use lightning::routing::utxo::UtxoLookup; use lightning_net_tokio::SocketDescriptor; -use lightning_persister::FilesystemPersister; use lightning_transaction_sync::EsploraSyncClient; -use std::collections::HashMap; use std::sync::{Arc, Mutex}; -// Structs wrapping the particular information which should easily be -// understandable, parseable, and transformable, i.e., we'll try to avoid -// exposing too many technical detail here. -/// Represents a payment. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct PaymentInfo { - /// The pre-image used by the payment. - pub preimage: Option, - /// The secret used by the payment. - pub secret: Option, - /// The status of the payment. - pub status: PaymentStatus, - /// The amount transferred. - pub amount_msat: Option, -} - -/// Represents the current status of a payment. -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum PaymentStatus { - /// The payment is still pending. - Pending, - /// The payment suceeded. - Succeeded, - /// The payment failed. - Failed, -} - pub(crate) type ChainMonitor = chainmonitor::ChainMonitor< InMemorySigner, Arc>>, Arc>, Arc>, Arc, - Arc, + Arc, >; pub(crate) type PeerManager = lightning::ln::peer_handler::PeerManager< @@ -85,8 +56,6 @@ pub(crate) type GossipSync = pub(crate) type NetworkGraph = gossip::NetworkGraph>; -pub(crate) type PaymentInfoStorage = Mutex>; - pub(crate) type OnionMessenger = lightning::onion_message::OnionMessenger< Arc>, Arc>,