Skip to content

Commit e31163f

Browse files
committed
feat: add new protocols to router at runtime
1 parent 1859de3 commit e31163f

File tree

1 file changed

+189
-20
lines changed

1 file changed

+189
-20
lines changed

iroh/src/protocol.rs

Lines changed: 189 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ use n0_future::{
4040
task::{self, AbortOnDropHandle, JoinSet},
4141
};
4242
use snafu::{Backtrace, Snafu};
43-
use tokio::sync::Mutex;
43+
use tokio::sync::{mpsc, oneshot, Mutex};
4444
use tokio_util::sync::CancellationToken;
4545
use tracing::{error, info_span, trace, warn, Instrument};
4646

@@ -85,6 +85,19 @@ pub struct Router {
8585
// `Router` needs to be `Clone + Send`, and we need to `task.await` in its `shutdown()` impl.
8686
task: Arc<Mutex<Option<AbortOnDropHandle<()>>>>,
8787
cancel_token: CancellationToken,
88+
tx: mpsc::Sender<ToRouterTask>,
89+
}
90+
91+
enum ToRouterTask {
92+
Accept {
93+
alpn: Vec<u8>,
94+
handler: Arc<dyn DynProtocolHandler>,
95+
reply: oneshot::Sender<()>,
96+
},
97+
StopAccepting {
98+
alpn: Vec<u8>,
99+
reply: oneshot::Sender<()>,
100+
},
88101
}
89102

90103
/// Builder for creating a [`Router`] for accepting protocols.
@@ -116,6 +129,14 @@ pub enum AcceptError {
116129
},
117130
}
118131

132+
#[allow(missing_docs)]
133+
#[derive(Debug, Snafu)]
134+
#[non_exhaustive]
135+
pub enum RouterError {
136+
#[snafu(display("The router actor closed"))]
137+
Closed {},
138+
}
139+
119140
impl AcceptError {
120141
/// Creates a new user error from an arbitrary error type.
121142
pub fn from_err<T: std::error::Error + Send + Sync + 'static>(value: T) -> Self {
@@ -267,30 +288,40 @@ impl<P: ProtocolHandler> DynProtocolHandler for P {
267288

268289
/// A typed map of protocol handlers, mapping them from ALPNs.
269290
#[derive(Debug, Default)]
270-
pub(crate) struct ProtocolMap(BTreeMap<Vec<u8>, Box<dyn DynProtocolHandler>>);
291+
pub(crate) struct ProtocolMap(std::sync::Mutex<BTreeMap<Vec<u8>, Arc<dyn DynProtocolHandler>>>);
271292

272293
impl ProtocolMap {
273294
/// Returns the registered protocol handler for an ALPN as a [`Arc<dyn ProtocolHandler>`].
274-
pub(crate) fn get(&self, alpn: &[u8]) -> Option<&dyn DynProtocolHandler> {
275-
self.0.get(alpn).map(|p| &**p)
295+
pub(crate) fn get(&self, alpn: &[u8]) -> Option<Arc<dyn DynProtocolHandler>> {
296+
self.0.lock().expect("poisoned").get(alpn).cloned()
276297
}
277298

278299
/// Inserts a protocol handler.
279-
pub(crate) fn insert(&mut self, alpn: Vec<u8>, handler: impl ProtocolHandler) {
280-
let handler = Box::new(handler);
281-
self.0.insert(alpn, handler);
300+
pub(crate) fn insert(&self, alpn: Vec<u8>, handler: Arc<dyn DynProtocolHandler>) {
301+
self.0.lock().expect("poisoned").insert(alpn, handler);
302+
}
303+
304+
pub(crate) fn remove(&self, alpn: &[u8]) {
305+
self.0.lock().expect("poisoned").remove(alpn);
282306
}
283307

284308
/// Returns an iterator of all registered ALPN protocol identifiers.
285-
pub(crate) fn alpns(&self) -> impl Iterator<Item = &Vec<u8>> {
286-
self.0.keys()
309+
pub(crate) fn alpns(&self) -> Vec<Vec<u8>> {
310+
self.0.lock().expect("poisoned").keys().cloned().collect()
287311
}
288312

289313
/// Shuts down all protocol handlers.
290314
///
291315
/// Calls and awaits [`ProtocolHandler::shutdown`] for all registered handlers concurrently.
292316
pub(crate) async fn shutdown(&self) {
293-
let handlers = self.0.values().map(|p| p.shutdown());
317+
let handlers: Vec<_> = {
318+
let inner = self.0.lock().expect("poisoned");
319+
inner
320+
.values()
321+
.cloned()
322+
.map(|p| async move { p.shutdown().await })
323+
.collect()
324+
};
294325
join_all(handlers).await;
295326
}
296327
}
@@ -311,6 +342,47 @@ impl Router {
311342
self.cancel_token.is_cancelled()
312343
}
313344

345+
/// Add a protocol to the list of accepted protocols.
346+
///
347+
/// Configures the router to accept the [`ProtocolHandler`] when receiving a connection
348+
/// with this `alpn`.
349+
///
350+
/// Once the function yields, new connections with this `alpn` will be handled.
351+
pub async fn accept(
352+
&self,
353+
alpn: impl AsRef<[u8]>,
354+
handler: impl ProtocolHandler,
355+
) -> Result<(), RouterError> {
356+
let (reply, reply_rx) = oneshot::channel();
357+
self.tx
358+
.send(ToRouterTask::Accept {
359+
alpn: alpn.as_ref().to_vec(),
360+
handler: Arc::new(handler),
361+
reply,
362+
})
363+
.await
364+
.map_err(|_| RouterError::Closed {})?;
365+
reply_rx.await.map_err(|_| RouterError::Closed {})?;
366+
Ok(())
367+
}
368+
369+
/// Stops accepting a protocol.
370+
///
371+
/// Note that this has only an effect on new connections. Existing connections that were
372+
/// accepted with `alpn` won't be closed when calling [`Router::stop_accepting`].
373+
pub async fn stop_accepting(&self, alpn: impl AsRef<[u8]>) -> Result<(), RouterError> {
374+
let (reply, reply_rx) = oneshot::channel();
375+
self.tx
376+
.send(ToRouterTask::StopAccepting {
377+
alpn: alpn.as_ref().to_vec(),
378+
reply,
379+
})
380+
.await
381+
.map_err(|_| RouterError::Closed {})?;
382+
reply_rx.await.map_err(|_| RouterError::Closed {})?;
383+
Ok(())
384+
}
385+
314386
/// Shuts down the accept loop cleanly.
315387
///
316388
/// When this function returns, all [`ProtocolHandler`]s will be shutdown and
@@ -348,8 +420,9 @@ impl RouterBuilder {
348420

349421
/// Configures the router to accept the [`ProtocolHandler`] when receiving a connection
350422
/// with this `alpn`.
351-
pub fn accept(mut self, alpn: impl AsRef<[u8]>, handler: impl ProtocolHandler) -> Self {
352-
self.protocols.insert(alpn.as_ref().to_vec(), handler);
423+
pub fn accept(self, alpn: impl AsRef<[u8]>, handler: impl ProtocolHandler) -> Self {
424+
self.protocols
425+
.insert(alpn.as_ref().to_vec(), Arc::new(handler));
353426
self
354427
}
355428

@@ -361,14 +434,9 @@ impl RouterBuilder {
361434
/// Spawns an accept loop and returns a handle to it encapsulated as the [`Router`].
362435
pub fn spawn(self) -> Router {
363436
// Update the endpoint with our alpns.
364-
let alpns = self
365-
.protocols
366-
.alpns()
367-
.map(|alpn| alpn.to_vec())
368-
.collect::<Vec<_>>();
437+
self.endpoint.set_alpns(self.protocols.alpns());
369438

370439
let protocols = Arc::new(self.protocols);
371-
self.endpoint.set_alpns(alpns);
372440

373441
let mut join_set = JoinSet::new();
374442
let endpoint = self.endpoint.clone();
@@ -377,6 +445,8 @@ impl RouterBuilder {
377445
let cancel = CancellationToken::new();
378446
let cancel_token = cancel.clone();
379447

448+
let (tx, mut rx) = mpsc::channel(8);
449+
380450
let run_loop_fut = async move {
381451
// Make sure to cancel the token, if this future ever exits.
382452
let _cancel_guard = cancel_token.clone().drop_guard();
@@ -390,6 +460,20 @@ impl RouterBuilder {
390460
_ = cancel_token.cancelled() => {
391461
break;
392462
},
463+
Some(msg) = rx.recv() => {
464+
match msg {
465+
ToRouterTask::Accept { alpn, handler, reply } => {
466+
protocols.insert(alpn, handler);
467+
endpoint.set_alpns(protocols.alpns());
468+
reply.send(()).ok();
469+
}
470+
ToRouterTask::StopAccepting { alpn, reply } => {
471+
protocols.remove(&alpn);
472+
endpoint.set_alpns(protocols.alpns());
473+
reply.send(()).ok();
474+
}
475+
}
476+
}
393477
// handle task terminations and quit on panics.
394478
Some(res) = join_set.join_next() => {
395479
match res {
@@ -452,6 +536,7 @@ impl RouterBuilder {
452536
endpoint: self.endpoint,
453537
task: Arc::new(Mutex::new(Some(task))),
454538
cancel_token: cancel,
539+
tx,
455540
}
456541
}
457542
}
@@ -542,12 +627,16 @@ impl<P: ProtocolHandler + Clone> ProtocolHandler for AccessLimit<P> {
542627
mod tests {
543628
use std::{sync::Mutex, time::Duration};
544629

630+
use iroh_base::NodeAddr;
545631
use n0_snafu::{Result, ResultExt};
546632
use n0_watcher::Watcher;
547-
use quinn::ApplicationClose;
633+
use quinn::{ApplicationClose, TransportErrorCode};
548634

549635
use super::*;
550-
use crate::{endpoint::ConnectionError, RelayMode};
636+
use crate::{
637+
endpoint::{ConnectError, ConnectionError},
638+
RelayMode,
639+
};
551640

552641
#[tokio::test]
553642
async fn test_shutdown() -> Result {
@@ -674,4 +763,84 @@ mod tests {
674763
);
675764
Ok(())
676765
}
766+
767+
#[tokio::test]
768+
async fn test_add_and_remove_protocol() -> Result {
769+
async fn connect_assert_ok(endpoint: &Endpoint, addr: &NodeAddr, alpn: &[u8]) {
770+
let conn = endpoint
771+
.connect(addr.clone(), alpn)
772+
.await
773+
.expect("expected connection to succeed");
774+
let reason = conn.closed().await;
775+
assert!(matches!(reason,
776+
ConnectionError::ApplicationClosed(ApplicationClose { error_code, .. }) if error_code == 42u32.into()
777+
));
778+
}
779+
780+
async fn connect_assert_fail(endpoint: &Endpoint, addr: &NodeAddr, alpn: &[u8]) {
781+
let conn = endpoint.connect(addr.clone(), alpn).await;
782+
assert!(matches!(
783+
&conn,
784+
Err(ConnectError::Connection { source, .. })
785+
if matches!(
786+
source.as_ref(),
787+
ConnectionError::ConnectionClosed(frame)
788+
if frame.error_code == TransportErrorCode::crypto(rustls::AlertDescription::NoApplicationProtocol.into())
789+
)
790+
));
791+
}
792+
793+
#[derive(Debug, Clone, Default)]
794+
struct TestProtocol;
795+
796+
const ALPN_1: &[u8] = b"/iroh/test/1";
797+
const ALPN_2: &[u8] = b"/iroh/test/2";
798+
799+
impl ProtocolHandler for TestProtocol {
800+
async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
801+
connection.close(42u32.into(), b"bye");
802+
Ok(())
803+
}
804+
}
805+
806+
let server = Endpoint::builder()
807+
.relay_mode(RelayMode::Disabled)
808+
.bind()
809+
.await?;
810+
let router = Router::builder(server)
811+
.accept(ALPN_1, TestProtocol::default())
812+
.spawn();
813+
814+
let addr = router.endpoint().node_addr().initialized().await?;
815+
816+
let client = Endpoint::builder()
817+
.relay_mode(RelayMode::Disabled)
818+
.bind()
819+
.await?;
820+
821+
connect_assert_ok(&client, &addr, ALPN_1).await;
822+
connect_assert_fail(&client, &addr, ALPN_2).await;
823+
824+
router.stop_accepting(ALPN_1).await?;
825+
connect_assert_fail(&client, &addr, ALPN_1).await;
826+
connect_assert_fail(&client, &addr, ALPN_2).await;
827+
828+
router.accept(ALPN_2, TestProtocol).await?;
829+
connect_assert_fail(&client, &addr, ALPN_1).await;
830+
connect_assert_ok(&client, &addr, ALPN_2).await;
831+
832+
router.accept(ALPN_1, TestProtocol).await?;
833+
connect_assert_ok(&client, &addr, ALPN_1).await;
834+
connect_assert_ok(&client, &addr, ALPN_2).await;
835+
836+
router.stop_accepting(ALPN_2).await?;
837+
connect_assert_ok(&client, &addr, ALPN_1).await;
838+
connect_assert_fail(&client, &addr, ALPN_2).await;
839+
840+
router.stop_accepting(ALPN_1).await?;
841+
connect_assert_fail(&client, &addr, ALPN_1).await;
842+
connect_assert_fail(&client, &addr, ALPN_2).await;
843+
844+
Ok(())
845+
}
677846
}

0 commit comments

Comments
 (0)