Skip to content

Commit 12ad42b

Browse files
committed
refactor: cleanup and return types
1 parent 7116a98 commit 12ad42b

File tree

1 file changed

+110
-45
lines changed

1 file changed

+110
-45
lines changed

iroh/src/protocol.rs

Lines changed: 110 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,13 @@
3232
//! }
3333
//! }
3434
//! ```
35-
use std::{collections::BTreeMap, future::Future, pin::Pin, sync::Arc};
35+
use std::{collections::BTreeMap, future::Future, pin::Pin, sync::Arc, time::Duration};
3636

3737
use iroh_base::NodeId;
3838
use n0_future::{
3939
join_all,
4040
task::{self, AbortOnDropHandle, JoinSet},
41+
time,
4142
};
4243
use snafu::{Backtrace, Snafu};
4344
use tokio::sync::{mpsc, oneshot, Mutex};
@@ -92,11 +93,11 @@ enum ToRouterTask {
9293
Accept {
9394
alpn: Vec<u8>,
9495
handler: Arc<dyn DynProtocolHandler>,
95-
reply: oneshot::Sender<()>,
96+
reply: oneshot::Sender<AddProtocolOutcome>,
9697
},
9798
StopAccepting {
9899
alpn: Vec<u8>,
99-
reply: oneshot::Sender<()>,
100+
reply: oneshot::Sender<Result<(), StopAcceptingError>>,
100101
},
101102
}
102103

@@ -129,14 +130,6 @@ pub enum AcceptError {
129130
},
130131
}
131132

132-
#[allow(missing_docs)]
133-
#[derive(Debug, Snafu)]
134-
#[non_exhaustive]
135-
pub enum RouterError {
136-
#[snafu(display("The router actor closed"))]
137-
Closed {},
138-
}
139-
140133
impl AcceptError {
141134
/// Creates a new user error from an arbitrary error type.
142135
pub fn from_err<T: std::error::Error + Send + Sync + 'static>(value: T) -> Self {
@@ -158,6 +151,37 @@ impl From<quinn::ClosedStream> for AcceptError {
158151
}
159152
}
160153

154+
#[allow(missing_docs)]
155+
#[derive(Debug, Snafu)]
156+
#[non_exhaustive]
157+
pub enum RouterError {
158+
#[snafu(display("The router actor closed"))]
159+
Closed {},
160+
}
161+
162+
#[allow(missing_docs)]
163+
#[derive(Debug, Snafu)]
164+
#[snafu(module)]
165+
#[non_exhaustive]
166+
pub enum StopAcceptingError {
167+
#[snafu(display("The router actor closed"))]
168+
Closed {},
169+
#[snafu(display("The ALPN requested to be removed is not registered"))]
170+
UnknownAlpn {},
171+
}
172+
173+
/// Returned from [`Router::accept`]
174+
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
175+
pub enum AddProtocolOutcome {
176+
/// The protocol handler has been newly inserted.
177+
Inserted,
178+
/// The protocol handler replaced a previously registered protocol handler.
179+
Replaced,
180+
}
181+
182+
/// Timeout applied to [`ProtocolHandler::shutdown] futures.
183+
const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
184+
161185
/// Handler for incoming connections.
162186
///
163187
/// A router accepts connections for arbitrary ALPN protocols.
@@ -286,40 +310,50 @@ impl<P: ProtocolHandler> DynProtocolHandler for P {
286310
}
287311
}
288312

313+
async fn shutdown_timeout(handler: Arc<dyn DynProtocolHandler>) -> Option<()> {
314+
time::timeout(SHUTDOWN_TIMEOUT, handler.shutdown())
315+
.await
316+
.ok()
317+
}
318+
289319
/// A typed map of protocol handlers, mapping them from ALPNs.
290320
#[derive(Debug, Default)]
291-
pub(crate) struct ProtocolMap(std::sync::Mutex<BTreeMap<Vec<u8>, Arc<dyn DynProtocolHandler>>>);
321+
pub(crate) struct ProtocolMap(std::sync::RwLock<BTreeMap<Vec<u8>, Arc<dyn DynProtocolHandler>>>);
292322

293323
impl ProtocolMap {
294324
/// Returns the registered protocol handler for an ALPN as a [`Arc<dyn ProtocolHandler>`].
295325
pub(crate) fn get(&self, alpn: &[u8]) -> Option<Arc<dyn DynProtocolHandler>> {
296-
self.0.lock().expect("poisoned").get(alpn).cloned()
326+
self.0.read().expect("poisoned").get(alpn).cloned()
297327
}
298328

299329
/// Inserts a protocol handler.
300-
pub(crate) fn insert(&self, alpn: Vec<u8>, handler: Arc<dyn DynProtocolHandler>) {
301-
self.0.lock().expect("poisoned").insert(alpn, handler);
330+
pub(crate) fn insert(
331+
&self,
332+
alpn: Vec<u8>,
333+
handler: Arc<dyn DynProtocolHandler>,
334+
) -> Option<Arc<dyn DynProtocolHandler>> {
335+
self.0.write().expect("poisoned").insert(alpn, handler)
302336
}
303337

304-
pub(crate) fn remove(&self, alpn: &[u8]) {
305-
self.0.lock().expect("poisoned").remove(alpn);
338+
pub(crate) fn remove(&self, alpn: &[u8]) -> Option<Arc<dyn DynProtocolHandler>> {
339+
self.0.write().expect("poisoned").remove(alpn)
306340
}
307341

308342
/// Returns an iterator of all registered ALPN protocol identifiers.
309343
pub(crate) fn alpns(&self) -> Vec<Vec<u8>> {
310-
self.0.lock().expect("poisoned").keys().cloned().collect()
344+
self.0.read().expect("poisoned").keys().cloned().collect()
311345
}
312346

313347
/// Shuts down all protocol handlers.
314348
///
315349
/// Calls and awaits [`ProtocolHandler::shutdown`] for all registered handlers concurrently.
316350
pub(crate) async fn shutdown(&self) {
317351
let handlers: Vec<_> = {
318-
let inner = self.0.lock().expect("poisoned");
352+
let inner = self.0.read().expect("poisoned");
319353
inner
320354
.values()
321355
.cloned()
322-
.map(|p| async move { p.shutdown().await })
356+
.map(|handler| shutdown_timeout(handler))
323357
.collect()
324358
};
325359
join_all(handlers).await;
@@ -342,17 +376,21 @@ impl Router {
342376
self.cancel_token.is_cancelled()
343377
}
344378

345-
/// Add a protocol to the list of accepted protocols.
379+
/// Adds a protocol to the list of accepted protocols.
346380
///
347381
/// Configures the router to accept the [`ProtocolHandler`] when receiving a connection
348382
/// with this `alpn`.
349383
///
350384
/// Once the function yields, new connections with this `alpn` will be handled.
385+
///
386+
/// If a protocol handler was already registered for `alpn`, the previous handler will be shutdown.
387+
///
388+
/// Returns `true` if
351389
pub async fn accept(
352390
&self,
353391
alpn: impl AsRef<[u8]>,
354392
handler: impl ProtocolHandler,
355-
) -> Result<(), RouterError> {
393+
) -> Result<AddProtocolOutcome, RouterError> {
356394
let (reply, reply_rx) = oneshot::channel();
357395
self.tx
358396
.send(ToRouterTask::Accept {
@@ -362,25 +400,23 @@ impl Router {
362400
})
363401
.await
364402
.map_err(|_| RouterError::Closed {})?;
365-
reply_rx.await.map_err(|_| RouterError::Closed {})?;
366-
Ok(())
403+
reply_rx.await.map_err(|_| RouterError::Closed {})
367404
}
368405

369406
/// Stops accepting a protocol.
370407
///
371408
/// Note that this has only an effect on new connections. Existing connections that were
372409
/// accepted with `alpn` won't be closed when calling [`Router::stop_accepting`].
373-
pub async fn stop_accepting(&self, alpn: impl AsRef<[u8]>) -> Result<(), RouterError> {
410+
pub async fn stop_accepting(&self, alpn: impl AsRef<[u8]>) -> Result<(), StopAcceptingError> {
374411
let (reply, reply_rx) = oneshot::channel();
375412
self.tx
376413
.send(ToRouterTask::StopAccepting {
377414
alpn: alpn.as_ref().to_vec(),
378415
reply,
379416
})
380417
.await
381-
.map_err(|_| RouterError::Closed {})?;
382-
reply_rx.await.map_err(|_| RouterError::Closed {})?;
383-
Ok(())
418+
.map_err(|_| StopAcceptingError::Closed {})?;
419+
reply_rx.await.map_err(|_| StopAcceptingError::Closed {})?
384420
}
385421

386422
/// Shuts down the accept loop cleanly.
@@ -463,14 +499,23 @@ impl RouterBuilder {
463499
Some(msg) = rx.recv() => {
464500
match msg {
465501
ToRouterTask::Accept { alpn, handler, reply } => {
466-
protocols.insert(alpn, handler);
502+
let outcome = if let Some(previous) = protocols.insert(alpn, handler) {
503+
join_set.spawn(shutdown_timeout(previous));
504+
AddProtocolOutcome::Replaced
505+
} else {
506+
AddProtocolOutcome::Inserted
507+
};
467508
endpoint.set_alpns(protocols.alpns());
468-
reply.send(()).ok();
509+
reply.send(outcome).ok();
469510
}
470511
ToRouterTask::StopAccepting { alpn, reply } => {
471-
protocols.remove(&alpn);
472-
endpoint.set_alpns(protocols.alpns());
473-
reply.send(()).ok();
512+
if let Some(handler) = protocols.remove(&alpn) {
513+
join_set.spawn(shutdown_timeout(handler));
514+
endpoint.set_alpns(protocols.alpns());
515+
reply.send(Ok(())).ok();
516+
} else {
517+
reply.send(Err(StopAcceptingError::UnknownAlpn {})).ok();
518+
}
474519
}
475520
}
476521
}
@@ -766,14 +811,19 @@ mod tests {
766811

767812
#[tokio::test]
768813
async fn test_add_and_remove_protocol() -> Result {
769-
async fn connect_assert_ok(endpoint: &Endpoint, addr: &NodeAddr, alpn: &[u8]) {
814+
async fn connect_assert_ok(
815+
endpoint: &Endpoint,
816+
addr: &NodeAddr,
817+
alpn: &[u8],
818+
expected_code: u32,
819+
) {
770820
let conn = endpoint
771821
.connect(addr.clone(), alpn)
772822
.await
773823
.expect("expected connection to succeed");
774824
let reason = conn.closed().await;
775825
assert!(matches!(reason,
776-
ConnectionError::ApplicationClosed(ApplicationClose { error_code, .. }) if error_code == 42u32.into()
826+
ConnectionError::ApplicationClosed(ApplicationClose { error_code, .. }) if error_code == expected_code.into()
777827
));
778828
}
779829

@@ -791,14 +841,14 @@ mod tests {
791841
}
792842

793843
#[derive(Debug, Clone, Default)]
794-
struct TestProtocol;
844+
struct TestProtocol(u32);
795845

796846
const ALPN_1: &[u8] = b"/iroh/test/1";
797847
const ALPN_2: &[u8] = b"/iroh/test/2";
798848

799849
impl ProtocolHandler for TestProtocol {
800850
async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
801-
connection.close(42u32.into(), b"bye");
851+
connection.close(self.0.into(), b"bye");
802852
Ok(())
803853
}
804854
}
@@ -808,7 +858,7 @@ mod tests {
808858
.bind()
809859
.await?;
810860
let router = Router::builder(server)
811-
.accept(ALPN_1, TestProtocol::default())
861+
.accept(ALPN_1, TestProtocol(1))
812862
.spawn();
813863

814864
let addr = router.endpoint().node_addr().initialized().await?;
@@ -818,29 +868,44 @@ mod tests {
818868
.bind()
819869
.await?;
820870

821-
connect_assert_ok(&client, &addr, ALPN_1).await;
871+
connect_assert_ok(&client, &addr, ALPN_1, 1).await;
822872
connect_assert_fail(&client, &addr, ALPN_2).await;
823873

824874
router.stop_accepting(ALPN_1).await?;
825875
connect_assert_fail(&client, &addr, ALPN_1).await;
826876
connect_assert_fail(&client, &addr, ALPN_2).await;
827877

828-
router.accept(ALPN_2, TestProtocol).await?;
878+
let outcome = router.accept(ALPN_2, TestProtocol(2)).await?;
879+
assert_eq!(outcome, AddProtocolOutcome::Inserted);
829880
connect_assert_fail(&client, &addr, ALPN_1).await;
830-
connect_assert_ok(&client, &addr, ALPN_2).await;
881+
connect_assert_ok(&client, &addr, ALPN_2, 2).await;
831882

832-
router.accept(ALPN_1, TestProtocol).await?;
833-
connect_assert_ok(&client, &addr, ALPN_1).await;
834-
connect_assert_ok(&client, &addr, ALPN_2).await;
883+
let outcome = router.accept(ALPN_1, TestProtocol(3)).await?;
884+
assert_eq!(outcome, AddProtocolOutcome::Inserted);
885+
connect_assert_ok(&client, &addr, ALPN_1, 3).await;
886+
connect_assert_ok(&client, &addr, ALPN_2, 2).await;
887+
888+
let outcome = router.accept(ALPN_1, TestProtocol(4)).await?;
889+
assert_eq!(outcome, AddProtocolOutcome::Replaced);
890+
connect_assert_ok(&client, &addr, ALPN_1, 4).await;
835891

836892
router.stop_accepting(ALPN_2).await?;
837-
connect_assert_ok(&client, &addr, ALPN_1).await;
893+
connect_assert_ok(&client, &addr, ALPN_1, 4).await;
838894
connect_assert_fail(&client, &addr, ALPN_2).await;
839895

840896
router.stop_accepting(ALPN_1).await?;
841897
connect_assert_fail(&client, &addr, ALPN_1).await;
842898
connect_assert_fail(&client, &addr, ALPN_2).await;
843899

900+
assert!(matches!(
901+
router.stop_accepting(ALPN_1).await,
902+
Err(StopAcceptingError::UnknownAlpn {})
903+
));
904+
assert!(matches!(
905+
router.stop_accepting(ALPN_2).await,
906+
Err(StopAcceptingError::UnknownAlpn {})
907+
));
908+
844909
Ok(())
845910
}
846911
}

0 commit comments

Comments
 (0)