diff --git a/iroh/src/protocol.rs b/iroh/src/protocol.rs index 529ad11870f..9396b56c9e9 100644 --- a/iroh/src/protocol.rs +++ b/iroh/src/protocol.rs @@ -32,17 +32,18 @@ //! } //! } //! ``` -use std::{collections::BTreeMap, future::Future, pin::Pin, sync::Arc}; +use std::{collections::BTreeMap, future::Future, pin::Pin, sync::Arc, time::Duration}; use iroh_base::NodeId; use n0_future::{ join_all, task::{self, AbortOnDropHandle, JoinSet}, + time, }; use snafu::{Backtrace, Snafu}; -use tokio::sync::Mutex; +use tokio::sync::{mpsc, oneshot, Mutex}; use tokio_util::sync::CancellationToken; -use tracing::{error, info_span, trace, warn, Instrument}; +use tracing::{debug, error, info_span, trace, warn, Instrument}; use crate::{ endpoint::{Connecting, Connection, RemoteNodeIdError}, @@ -85,6 +86,19 @@ pub struct Router { // `Router` needs to be `Clone + Send`, and we need to `task.await` in its `shutdown()` impl. task: Arc>>>, cancel_token: CancellationToken, + tx: mpsc::Sender, +} + +enum ToRouterTask { + Accept { + alpn: Vec, + handler: Arc, + reply: oneshot::Sender, + }, + StopAccepting { + alpn: Vec, + reply: oneshot::Sender>, + }, } /// Builder for creating a [`Router`] for accepting protocols. @@ -137,6 +151,37 @@ impl From for AcceptError { } } +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum RouterError { + #[snafu(display("Endpoint closed"))] + Closed {}, +} + +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[snafu(module)] +#[non_exhaustive] +pub enum StopAcceptingError { + #[snafu(display("Endpoint closed"))] + Closed {}, + #[snafu(display("The ALPN requested to be removed is not registered"))] + UnknownAlpn {}, +} + +/// Returned from [`Router::accept`] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum AddProtocolOutcome { + /// The protocol handler has been newly inserted. + Inserted, + /// The protocol handler replaced a previously registered protocol handler. + Replaced, +} + +/// Timeout applied to [`ProtocolHandler::shutdown] futures. +const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30); + /// Handler for incoming connections. /// /// A router accepts connections for arbitrary ALPN protocols. @@ -265,33 +310,58 @@ impl DynProtocolHandler for P { } } +async fn shutdown_timeout(alpn: Vec, handler: Arc) -> Option<()> { + if let Err(_elapsed) = time::timeout(SHUTDOWN_TIMEOUT, handler.shutdown()).await { + debug!( + alpn = String::from_utf8_lossy(&alpn).to_string(), + "Protocol handler exceeded the shutdown timeout and was aborted" + ); + None + } else { + Some(()) + } +} + /// A typed map of protocol handlers, mapping them from ALPNs. #[derive(Debug, Default)] -pub(crate) struct ProtocolMap(BTreeMap, Box>); +pub(crate) struct ProtocolMap(std::sync::RwLock, Arc>>); impl ProtocolMap { /// Returns the registered protocol handler for an ALPN as a [`Arc`]. - pub(crate) fn get(&self, alpn: &[u8]) -> Option<&dyn DynProtocolHandler> { - self.0.get(alpn).map(|p| &**p) + pub(crate) fn get(&self, alpn: &[u8]) -> Option> { + self.0.read().expect("poisoned").get(alpn).cloned() } /// Inserts a protocol handler. - pub(crate) fn insert(&mut self, alpn: Vec, handler: impl ProtocolHandler) { - let handler = Box::new(handler); - self.0.insert(alpn, handler); + pub(crate) fn insert( + &self, + alpn: Vec, + handler: Arc, + ) -> Option> { + self.0.write().expect("poisoned").insert(alpn, handler) + } + + pub(crate) fn remove(&self, alpn: &[u8]) -> Option> { + self.0.write().expect("poisoned").remove(alpn) } /// Returns an iterator of all registered ALPN protocol identifiers. - pub(crate) fn alpns(&self) -> impl Iterator> { - self.0.keys() + pub(crate) fn alpns(&self) -> Vec> { + self.0.read().expect("poisoned").keys().cloned().collect() } /// Shuts down all protocol handlers. /// /// Calls and awaits [`ProtocolHandler::shutdown`] for all registered handlers concurrently. pub(crate) async fn shutdown(&self) { - let handlers = self.0.values().map(|p| p.shutdown()); - join_all(handlers).await; + let mut futures = Vec::new(); + { + let mut inner = self.0.write().expect("poisoned"); + while let Some((alpn, handler)) = inner.pop_first() { + futures.push(shutdown_timeout(alpn, handler)); + } + } + join_all(futures).await; } } @@ -311,7 +381,55 @@ impl Router { self.cancel_token.is_cancelled() } - /// Shuts down the accept loop cleanly. + /// Accepts incoming connections with this `alpn` via [`ProtocolHandler`]. + /// + /// After this function returns, new connections with this `alpn` will be handled + /// by the passed `handler`. + /// + /// If a protocol handler was already registered for `alpn`, the previous handler will be + /// shutdown. Existing connections will not be aborted by the router, but some protocol + /// handlers may abort existing connections in their [`Router::shutdown`] implementation. + /// Consult the documentation of the protocol handler to see if that is the case. + pub async fn accept( + &self, + alpn: impl AsRef<[u8]>, + handler: impl ProtocolHandler, + ) -> Result { + let (reply, reply_rx) = oneshot::channel(); + self.tx + .send(ToRouterTask::Accept { + alpn: alpn.as_ref().to_vec(), + handler: Arc::new(handler), + reply, + }) + .await + .map_err(|_| RouterError::Closed {})?; + reply_rx.await.map_err(|_| RouterError::Closed {}) + } + + /// Stops accepting connections with this `alpn`. + /// + /// After this function returns, new connections with `alpn` will no longer be accepted. + /// + /// If a protocol handler was registered for `alpn`, the handler will be + /// shutdown. Existing connections will not be aborted by the router, but some protocol + /// handlers may abort existing connections in their [`Router::shutdown`] implementation. + /// Consult the documentation of the protocol handler to see if that is the case. + /// + /// Returns an error if the router has been shutdown or no protocol is registered for `alpn`. + pub async fn stop_accepting(&self, alpn: impl AsRef<[u8]>) -> Result<(), StopAcceptingError> { + let (reply, reply_rx) = oneshot::channel(); + self.tx + .send(ToRouterTask::StopAccepting { + alpn: alpn.as_ref().to_vec(), + reply, + }) + .await + .map_err(|_| StopAcceptingError::Closed {})?; + reply_rx.await.map_err(|_| StopAcceptingError::Closed {})? + } + + /// Shuts down the accept loop and endpoint cleanly. /// /// When this function returns, all [`ProtocolHandler`]s will be shutdown and /// `Endpoint::close` will have been called. @@ -346,10 +464,10 @@ impl RouterBuilder { } } - /// Configures the router to accept the [`ProtocolHandler`] when receiving a connection - /// with this `alpn`. - pub fn accept(mut self, alpn: impl AsRef<[u8]>, handler: impl ProtocolHandler) -> Self { - self.protocols.insert(alpn.as_ref().to_vec(), handler); + /// Configures the router to accept incoming connections with this `alpn` via [`ProtocolHandler`]. + pub fn accept(self, alpn: impl AsRef<[u8]>, handler: impl ProtocolHandler) -> Self { + self.protocols + .insert(alpn.as_ref().to_vec(), Arc::new(handler)); self } @@ -361,14 +479,9 @@ impl RouterBuilder { /// Spawns an accept loop and returns a handle to it encapsulated as the [`Router`]. pub fn spawn(self) -> Router { // Update the endpoint with our alpns. - let alpns = self - .protocols - .alpns() - .map(|alpn| alpn.to_vec()) - .collect::>(); + self.endpoint.set_alpns(self.protocols.alpns()); let protocols = Arc::new(self.protocols); - self.endpoint.set_alpns(alpns); let mut join_set = JoinSet::new(); let endpoint = self.endpoint.clone(); @@ -377,6 +490,8 @@ impl RouterBuilder { let cancel = CancellationToken::new(); let cancel_token = cancel.clone(); + let (tx, mut rx) = mpsc::channel(8); + let run_loop_fut = async move { // Make sure to cancel the token, if this future ever exits. let _cancel_guard = cancel_token.clone().drop_guard(); @@ -390,6 +505,29 @@ impl RouterBuilder { _ = cancel_token.cancelled() => { break; }, + Some(msg) = rx.recv() => { + match msg { + ToRouterTask::Accept { alpn, handler, reply } => { + let outcome = if let Some(previous) = protocols.insert(alpn.clone(), handler) { + join_set.spawn(shutdown_timeout(alpn, previous)); + AddProtocolOutcome::Replaced + } else { + AddProtocolOutcome::Inserted + }; + endpoint.set_alpns(protocols.alpns()); + reply.send(outcome).ok(); + } + ToRouterTask::StopAccepting { alpn, reply } => { + if let Some(handler) = protocols.remove(&alpn) { + join_set.spawn(shutdown_timeout(alpn, handler)); + endpoint.set_alpns(protocols.alpns()); + reply.send(Ok(())).ok(); + } else { + reply.send(Err(StopAcceptingError::UnknownAlpn {})).ok(); + } + } + } + } // handle task terminations and quit on panics. Some(res) = join_set.join_next() => { match res { @@ -436,7 +574,7 @@ impl RouterBuilder { endpoint.close().await; // Finally, we abort the remaining accept tasks. This should be a noop because we already cancelled // the futures above. - tracing::debug!("Shutting down remaining tasks"); + debug!("Shutting down remaining tasks"); join_set.abort_all(); while let Some(res) = join_set.join_next().await { match res { @@ -452,6 +590,7 @@ impl RouterBuilder { endpoint: self.endpoint, task: Arc::new(Mutex::new(Some(task))), cancel_token: cancel, + tx, } } } @@ -542,12 +681,16 @@ impl ProtocolHandler for AccessLimit

{ mod tests { use std::{sync::Mutex, time::Duration}; + use iroh_base::NodeAddr; use n0_snafu::{Result, ResultExt}; use n0_watcher::Watcher; - use quinn::ApplicationClose; + use quinn::{ApplicationClose, TransportErrorCode}; use super::*; - use crate::{endpoint::ConnectionError, RelayMode}; + use crate::{ + endpoint::{ConnectError, ConnectionError}, + RelayMode, + }; #[tokio::test] async fn test_shutdown() -> Result { @@ -674,4 +817,104 @@ mod tests { ); Ok(()) } + + #[tokio::test] + async fn test_add_and_remove_protocol() -> Result { + async fn connect_assert_ok( + endpoint: &Endpoint, + addr: &NodeAddr, + alpn: &[u8], + expected_code: u32, + ) { + let conn = endpoint + .connect(addr.clone(), alpn) + .await + .expect("expected connection to succeed"); + let reason = conn.closed().await; + assert!(matches!(reason, + ConnectionError::ApplicationClosed(ApplicationClose { error_code, .. }) if error_code == expected_code.into() + )); + } + + async fn connect_assert_fail(endpoint: &Endpoint, addr: &NodeAddr, alpn: &[u8]) { + let conn = endpoint.connect(addr.clone(), alpn).await; + assert!(matches!( + &conn, + Err(ConnectError::Connection { source, .. }) + if matches!( + source.as_ref(), + ConnectionError::ConnectionClosed(frame) + if frame.error_code == TransportErrorCode::crypto(rustls::AlertDescription::NoApplicationProtocol.into()) + ) + )); + } + + #[derive(Debug, Clone, Default)] + struct TestProtocol(u32); + + const ALPN_1: &[u8] = b"/iroh/test/1"; + const ALPN_2: &[u8] = b"/iroh/test/2"; + + impl ProtocolHandler for TestProtocol { + async fn accept(&self, connection: Connection) -> Result<(), AcceptError> { + connection.close(self.0.into(), b"bye"); + Ok(()) + } + } + + let server = Endpoint::builder() + .relay_mode(RelayMode::Disabled) + .bind() + .await?; + let router = Router::builder(server) + .accept(ALPN_1, TestProtocol(1)) + .spawn(); + + let addr = router.endpoint().node_addr().initialized().await?; + + let client = Endpoint::builder() + .relay_mode(RelayMode::Disabled) + .bind() + .await?; + + connect_assert_ok(&client, &addr, ALPN_1, 1).await; + connect_assert_fail(&client, &addr, ALPN_2).await; + + router.stop_accepting(ALPN_1).await?; + connect_assert_fail(&client, &addr, ALPN_1).await; + connect_assert_fail(&client, &addr, ALPN_2).await; + + let outcome = router.accept(ALPN_2, TestProtocol(2)).await?; + assert_eq!(outcome, AddProtocolOutcome::Inserted); + connect_assert_fail(&client, &addr, ALPN_1).await; + connect_assert_ok(&client, &addr, ALPN_2, 2).await; + + let outcome = router.accept(ALPN_1, TestProtocol(3)).await?; + assert_eq!(outcome, AddProtocolOutcome::Inserted); + connect_assert_ok(&client, &addr, ALPN_1, 3).await; + connect_assert_ok(&client, &addr, ALPN_2, 2).await; + + let outcome = router.accept(ALPN_1, TestProtocol(4)).await?; + assert_eq!(outcome, AddProtocolOutcome::Replaced); + connect_assert_ok(&client, &addr, ALPN_1, 4).await; + + router.stop_accepting(ALPN_2).await?; + connect_assert_ok(&client, &addr, ALPN_1, 4).await; + connect_assert_fail(&client, &addr, ALPN_2).await; + + router.stop_accepting(ALPN_1).await?; + connect_assert_fail(&client, &addr, ALPN_1).await; + connect_assert_fail(&client, &addr, ALPN_2).await; + + assert!(matches!( + router.stop_accepting(ALPN_1).await, + Err(StopAcceptingError::UnknownAlpn {}) + )); + assert!(matches!( + router.stop_accepting(ALPN_2).await, + Err(StopAcceptingError::UnknownAlpn {}) + )); + + Ok(()) + } }