diff --git a/Cargo.toml b/Cargo.toml index 5b7eca8e3..e88826aee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,7 +66,6 @@ bip39 = "2.0.0" rand = "0.8.5" chrono = { version = "0.4", default-features = false, features = ["clock"] } -futures = "0.3" tokio = { version = "1", default-features = false, features = [ "rt-multi-thread", "time", "sync" ] } esplora-client = { version = "0.6", default-features = false } libc = "0.2" diff --git a/src/builder.rs b/src/builder.rs index 5edbd55ab..6b4da6b57 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -2,6 +2,7 @@ use crate::config::{ Config, BDK_CLIENT_CONCURRENCY, BDK_CLIENT_STOP_GAP, DEFAULT_ESPLORA_SERVER_URL, WALLET_KEYS_SEED_LEN, }; +use crate::connection::ConnectionManager; use crate::event::EventQueue; use crate::fee_estimator::OnchainFeeEstimator; use crate::gossip::GossipSource; @@ -895,6 +896,9 @@ fn build_with_store_internal( liquidity_source.as_ref().map(|l| l.set_peer_manager(Arc::clone(&peer_manager))); + let connection_manager = + Arc::new(ConnectionManager::new(Arc::clone(&peer_manager), Arc::clone(&logger))); + let output_sweeper = match io::utils::read_output_sweeper( Arc::clone(&tx_broadcaster), Arc::clone(&fee_estimator), @@ -991,6 +995,7 @@ fn build_with_store_internal( chain_monitor, output_sweeper, peer_manager, + connection_manager, keys_manager, network_graph, gossip_source, diff --git a/src/connection.rs b/src/connection.rs new file mode 100644 index 000000000..9d956d6be --- /dev/null +++ b/src/connection.rs @@ -0,0 +1,147 @@ +use crate::logger::{log_error, log_info, Logger}; +use crate::types::PeerManager; +use crate::Error; + +use lightning::ln::msgs::SocketAddress; + +use bitcoin::secp256k1::PublicKey; + +use std::collections::hash_map::{self, HashMap}; +use std::net::ToSocketAddrs; +use std::ops::Deref; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +pub(crate) struct ConnectionManager +where + L::Target: Logger, +{ + pending_connections: + Mutex>>>>, + peer_manager: Arc, + logger: L, +} + +impl ConnectionManager +where + L::Target: Logger, +{ + pub(crate) fn new(peer_manager: Arc, logger: L) -> Self { + let pending_connections = Mutex::new(HashMap::new()); + Self { pending_connections, peer_manager, logger } + } + + pub(crate) async fn connect_peer_if_necessary( + &self, node_id: PublicKey, addr: SocketAddress, + ) -> Result<(), Error> { + if self.peer_manager.peer_by_node_id(&node_id).is_some() { + return Ok(()); + } + + self.do_connect_peer(node_id, addr).await + } + + pub(crate) async fn do_connect_peer( + &self, node_id: PublicKey, addr: SocketAddress, + ) -> Result<(), Error> { + // First, we check if there is already an outbound connection in flight, if so, we just + // await on the corresponding watch channel. The task driving the connection future will + // send us the result.. + let pending_ready_receiver_opt = self.register_or_subscribe_pending_connection(&node_id); + if let Some(pending_connection_ready_receiver) = pending_ready_receiver_opt { + return pending_connection_ready_receiver.await.map_err(|e| { + debug_assert!(false, "Failed to receive connection result: {:?}", e); + log_error!(self.logger, "Failed to receive connection result: {:?}", e); + Error::ConnectionFailed + })?; + } + + log_info!(self.logger, "Connecting to peer: {}@{}", node_id, addr); + + let socket_addr = addr + .to_socket_addrs() + .map_err(|e| { + log_error!(self.logger, "Failed to resolve network address {}: {}", addr, e); + self.propagate_result_to_subscribers(&node_id, Err(Error::InvalidSocketAddress)); + Error::InvalidSocketAddress + })? + .next() + .ok_or_else(|| { + log_error!(self.logger, "Failed to resolve network address {}", addr); + self.propagate_result_to_subscribers(&node_id, Err(Error::InvalidSocketAddress)); + Error::InvalidSocketAddress + })?; + + let connection_future = lightning_net_tokio::connect_outbound( + Arc::clone(&self.peer_manager), + node_id, + socket_addr, + ); + + let res = match connection_future.await { + Some(connection_closed_future) => { + let mut connection_closed_future = Box::pin(connection_closed_future); + loop { + tokio::select! { + _ = &mut connection_closed_future => { + log_info!(self.logger, "Peer connection closed: {}@{}", node_id, addr); + break Err(Error::ConnectionFailed); + }, + _ = tokio::time::sleep(Duration::from_millis(10)) => {}, + }; + + match self.peer_manager.peer_by_node_id(&node_id) { + Some(_) => break Ok(()), + None => continue, + } + } + }, + None => { + log_error!(self.logger, "Failed to connect to peer: {}@{}", node_id, addr); + Err(Error::ConnectionFailed) + }, + }; + + self.propagate_result_to_subscribers(&node_id, res); + + res + } + + fn register_or_subscribe_pending_connection( + &self, node_id: &PublicKey, + ) -> Option>> { + let mut pending_connections_lock = self.pending_connections.lock().unwrap(); + match pending_connections_lock.entry(*node_id) { + hash_map::Entry::Occupied(mut entry) => { + let (tx, rx) = tokio::sync::oneshot::channel(); + entry.get_mut().push(tx); + Some(rx) + }, + hash_map::Entry::Vacant(entry) => { + entry.insert(Vec::new()); + None + }, + } + } + + fn propagate_result_to_subscribers(&self, node_id: &PublicKey, res: Result<(), Error>) { + // Send the result to any other tasks that might be waiting on it by now. + let mut pending_connections_lock = self.pending_connections.lock().unwrap(); + if let Some(connection_ready_senders) = pending_connections_lock.remove(node_id) { + for sender in connection_ready_senders { + let _ = sender.send(res).map_err(|e| { + debug_assert!( + false, + "Failed to send connection result to subscribers: {:?}", + e + ); + log_error!( + self.logger, + "Failed to send connection result to subscribers: {:?}", + e + ); + }); + } + } + } +} diff --git a/src/error.rs b/src/error.rs index 0182b3092..c5234a6d4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ use std::fmt; -#[derive(Debug, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] /// An error that possibly needs to be handled by the user. pub enum Error { /// Returned when trying to start [`crate::Node`] while it is already running. diff --git a/src/event.rs b/src/event.rs index 61dc748d4..cd11e41a8 100644 --- a/src/event.rs +++ b/src/event.rs @@ -291,7 +291,7 @@ impl Future for EventFuture { } } -pub(crate) struct EventHandler +pub(crate) struct EventHandler where L::Target: Logger, { @@ -307,7 +307,7 @@ where config: Arc, } -impl EventHandler +impl EventHandler where L::Target: Logger, { diff --git a/src/lib.rs b/src/lib.rs index 3f240e980..f6082d4d3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -78,6 +78,7 @@ mod balance; mod builder; mod config; +mod connection; mod error; mod event; mod fee_estimator; @@ -124,6 +125,7 @@ use config::{ LDK_PAYMENT_RETRY_TIMEOUT, NODE_ANN_BCAST_INTERVAL, PEER_RECONNECTION_INTERVAL, RGS_SYNC_INTERVAL, WALLET_SYNC_INTERVAL_MINIMUM_SECS, }; +use connection::ConnectionManager; use event::{EventHandler, EventQueue}; use gossip::GossipSource; use liquidity::LiquiditySource; @@ -187,6 +189,7 @@ pub struct Node { chain_monitor: Arc, output_sweeper: Arc, peer_manager: Arc, + connection_manager: Arc>>, keys_manager: Arc, network_graph: Arc, gossip_source: Arc, @@ -498,6 +501,7 @@ impl Node { } // Regularly reconnect to persisted peers. + let connect_cm = Arc::clone(&self.connection_manager); let connect_pm = Arc::clone(&self.peer_manager); let connect_logger = Arc::clone(&self.logger); let connect_peer_store = Arc::clone(&self.peer_store); @@ -518,11 +522,9 @@ impl Node { .collect::>(); for peer_info in connect_peer_store.list_peers().iter().filter(|info| !pm_peers.contains(&info.node_id)) { - let res = do_connect_peer( + let res = connect_cm.do_connect_peer( peer_info.node_id, peer_info.address.clone(), - Arc::clone(&connect_pm), - Arc::clone(&connect_logger), ).await; match res { Ok(_) => { @@ -871,14 +873,13 @@ impl Node { let con_node_id = peer_info.node_id; let con_addr = peer_info.address.clone(); - let con_logger = Arc::clone(&self.logger); - let con_pm = Arc::clone(&self.peer_manager); + let con_cm = Arc::clone(&self.connection_manager); // We need to use our main runtime here as a local runtime might not be around to poll // connection futures going forward. tokio::task::block_in_place(move || { runtime.block_on(async move { - connect_peer_if_necessary(con_node_id, con_addr, con_pm, con_logger).await + con_cm.connect_peer_if_necessary(con_node_id, con_addr).await }) })?; @@ -944,14 +945,13 @@ impl Node { let con_node_id = peer_info.node_id; let con_addr = peer_info.address.clone(); - let con_logger = Arc::clone(&self.logger); - let con_pm = Arc::clone(&self.peer_manager); + let con_cm = Arc::clone(&self.connection_manager); // We need to use our main runtime here as a local runtime might not be around to poll // connection futures going forward. tokio::task::block_in_place(move || { runtime.block_on(async move { - connect_peer_if_necessary(con_node_id, con_addr, con_pm, con_logger).await + con_cm.connect_peer_if_necessary(con_node_id, con_addr).await }) })?; @@ -1601,14 +1601,13 @@ impl Node { let con_node_id = peer_info.node_id; let con_addr = peer_info.address.clone(); - let con_logger = Arc::clone(&self.logger); - let con_pm = Arc::clone(&self.peer_manager); + let con_cm = Arc::clone(&self.connection_manager); // We need to use our main runtime here as a local runtime might not be around to poll // connection futures going forward. tokio::task::block_in_place(move || { runtime.block_on(async move { - connect_peer_if_necessary(con_node_id, con_addr, con_pm, con_logger).await + con_cm.connect_peer_if_necessary(con_node_id, con_addr).await }) })?; @@ -1849,56 +1848,3 @@ pub struct NodeStatus { /// Will be `None` if we have no public channels or we haven't broadcasted since the [`Node`] was initialized. pub latest_node_announcement_broadcast_timestamp: Option, } - -async fn connect_peer_if_necessary( - node_id: PublicKey, addr: SocketAddress, peer_manager: Arc, - logger: Arc, -) -> Result<(), Error> { - if peer_manager.peer_by_node_id(&node_id).is_some() { - return Ok(()); - } - - do_connect_peer(node_id, addr, peer_manager, logger).await -} - -async fn do_connect_peer( - node_id: PublicKey, addr: SocketAddress, peer_manager: Arc, - logger: Arc, -) -> Result<(), Error> { - log_info!(logger, "Connecting to peer: {}@{}", node_id, addr); - - let socket_addr = addr - .to_socket_addrs() - .map_err(|e| { - log_error!(logger, "Failed to resolve network address: {}", e); - Error::InvalidSocketAddress - })? - .next() - .ok_or(Error::ConnectionFailed)?; - - match lightning_net_tokio::connect_outbound(Arc::clone(&peer_manager), node_id, socket_addr) - .await - { - Some(connection_closed_future) => { - let mut connection_closed_future = Box::pin(connection_closed_future); - loop { - match futures::poll!(&mut connection_closed_future) { - std::task::Poll::Ready(_) => { - log_info!(logger, "Peer connection closed: {}@{}", node_id, addr); - return Err(Error::ConnectionFailed); - }, - std::task::Poll::Pending => {}, - } - // Avoid blocking the tokio context by sleeping a bit - match peer_manager.peer_by_node_id(&node_id) { - Some(_) => return Ok(()), - None => tokio::time::sleep(Duration::from_millis(10)).await, - } - } - }, - None => { - log_error!(logger, "Failed to connect to peer: {}@{}", node_id, addr); - Err(Error::ConnectionFailed) - }, - } -} diff --git a/tests/integration_tests_rust.rs b/tests/integration_tests_rust.rs index 71867f8f2..f0e222fd3 100644 --- a/tests/integration_tests_rust.rs +++ b/tests/integration_tests_rust.rs @@ -333,3 +333,33 @@ fn do_connection_restart_behavior(persist: bool) { assert!(node_b.list_peers().is_empty()); } } + +#[test] +fn concurrent_connections_succeed() { + let (_bitcoind, electrsd) = setup_bitcoind_and_electrsd(); + let (node_a, node_b) = setup_two_nodes(&electrsd, false); + + let node_a = Arc::new(node_a); + let node_b = Arc::new(node_b); + + let node_id_b = node_b.node_id(); + let node_addr_b = node_b.listening_addresses().unwrap().first().unwrap().clone(); + + while !node_b.status().is_listening { + std::thread::sleep(std::time::Duration::from_millis(10)); + } + + let mut handles = Vec::new(); + for _ in 0..10 { + let thread_node = Arc::clone(&node_a); + let thread_addr = node_addr_b.clone(); + let handle = std::thread::spawn(move || { + thread_node.connect(node_id_b, thread_addr, false).unwrap(); + }); + handles.push(handle); + } + + for h in handles { + h.join().unwrap(); + } +}