@@ -40,7 +40,7 @@ use n0_future::{
40
40
task:: { self , AbortOnDropHandle , JoinSet } ,
41
41
} ;
42
42
use snafu:: { Backtrace , Snafu } ;
43
- use tokio:: sync:: Mutex ;
43
+ use tokio:: sync:: { mpsc , oneshot , Mutex } ;
44
44
use tokio_util:: sync:: CancellationToken ;
45
45
use tracing:: { error, info_span, trace, warn, Instrument } ;
46
46
@@ -85,6 +85,19 @@ pub struct Router {
85
85
// `Router` needs to be `Clone + Send`, and we need to `task.await` in its `shutdown()` impl.
86
86
task : Arc < Mutex < Option < AbortOnDropHandle < ( ) > > > > ,
87
87
cancel_token : CancellationToken ,
88
+ tx : mpsc:: Sender < ToRouterTask > ,
89
+ }
90
+
91
+ enum ToRouterTask {
92
+ Accept {
93
+ alpn : Vec < u8 > ,
94
+ handler : Arc < dyn DynProtocolHandler > ,
95
+ reply : oneshot:: Sender < ( ) > ,
96
+ } ,
97
+ StopAccepting {
98
+ alpn : Vec < u8 > ,
99
+ reply : oneshot:: Sender < ( ) > ,
100
+ } ,
88
101
}
89
102
90
103
/// Builder for creating a [`Router`] for accepting protocols.
@@ -116,6 +129,14 @@ pub enum AcceptError {
116
129
} ,
117
130
}
118
131
132
+ #[ allow( missing_docs) ]
133
+ #[ derive( Debug , Snafu ) ]
134
+ #[ non_exhaustive]
135
+ pub enum RouterError {
136
+ #[ snafu( display( "The router actor closed" ) ) ]
137
+ Closed { } ,
138
+ }
139
+
119
140
impl AcceptError {
120
141
/// Creates a new user error from an arbitrary error type.
121
142
pub fn from_err < T : std:: error:: Error + Send + Sync + ' static > ( value : T ) -> Self {
@@ -267,30 +288,40 @@ impl<P: ProtocolHandler> DynProtocolHandler for P {
267
288
268
289
/// A typed map of protocol handlers, mapping them from ALPNs.
269
290
#[ derive( Debug , Default ) ]
270
- pub ( crate ) struct ProtocolMap ( BTreeMap < Vec < u8 > , Box < dyn DynProtocolHandler > > ) ;
291
+ pub ( crate ) struct ProtocolMap ( std :: sync :: Mutex < BTreeMap < Vec < u8 > , Arc < dyn DynProtocolHandler > > > ) ;
271
292
272
293
impl ProtocolMap {
273
294
/// Returns the registered protocol handler for an ALPN as a [`Arc<dyn ProtocolHandler>`].
274
- pub ( crate ) fn get ( & self , alpn : & [ u8 ] ) -> Option < & dyn DynProtocolHandler > {
275
- self . 0 . get ( alpn) . map ( |p| & * * p )
295
+ pub ( crate ) fn get ( & self , alpn : & [ u8 ] ) -> Option < Arc < dyn DynProtocolHandler > > {
296
+ self . 0 . lock ( ) . expect ( "poisoned" ) . get ( alpn) . cloned ( )
276
297
}
277
298
278
299
/// Inserts a protocol handler.
279
- pub ( crate ) fn insert ( & mut self , alpn : Vec < u8 > , handler : impl ProtocolHandler ) {
280
- let handler = Box :: new ( handler) ;
281
- self . 0 . insert ( alpn, handler) ;
300
+ pub ( crate ) fn insert ( & self , alpn : Vec < u8 > , handler : Arc < dyn DynProtocolHandler > ) {
301
+ self . 0 . lock ( ) . expect ( "poisoned" ) . insert ( alpn, handler) ;
302
+ }
303
+
304
+ pub ( crate ) fn remove ( & self , alpn : & [ u8 ] ) {
305
+ self . 0 . lock ( ) . expect ( "poisoned" ) . remove ( alpn) ;
282
306
}
283
307
284
308
/// Returns an iterator of all registered ALPN protocol identifiers.
285
- pub ( crate ) fn alpns ( & self ) -> impl Iterator < Item = & Vec < u8 > > {
286
- self . 0 . keys ( )
309
+ pub ( crate ) fn alpns ( & self ) -> Vec < Vec < u8 > > {
310
+ self . 0 . lock ( ) . expect ( "poisoned" ) . keys ( ) . cloned ( ) . collect ( )
287
311
}
288
312
289
313
/// Shuts down all protocol handlers.
290
314
///
291
315
/// Calls and awaits [`ProtocolHandler::shutdown`] for all registered handlers concurrently.
292
316
pub ( crate ) async fn shutdown ( & self ) {
293
- let handlers = self . 0 . values ( ) . map ( |p| p. shutdown ( ) ) ;
317
+ let handlers: Vec < _ > = {
318
+ let inner = self . 0 . lock ( ) . expect ( "poisoned" ) ;
319
+ inner
320
+ . values ( )
321
+ . cloned ( )
322
+ . map ( |p| async move { p. shutdown ( ) . await } )
323
+ . collect ( )
324
+ } ;
294
325
join_all ( handlers) . await ;
295
326
}
296
327
}
@@ -311,6 +342,47 @@ impl Router {
311
342
self . cancel_token . is_cancelled ( )
312
343
}
313
344
345
+ /// Add a protocol to the list of accepted protocols.
346
+ ///
347
+ /// Configures the router to accept the [`ProtocolHandler`] when receiving a connection
348
+ /// with this `alpn`.
349
+ ///
350
+ /// Once the function yields, new connections with this `alpn` will be handled.
351
+ pub async fn accept (
352
+ & self ,
353
+ alpn : impl AsRef < [ u8 ] > ,
354
+ handler : impl ProtocolHandler ,
355
+ ) -> Result < ( ) , RouterError > {
356
+ let ( reply, reply_rx) = oneshot:: channel ( ) ;
357
+ self . tx
358
+ . send ( ToRouterTask :: Accept {
359
+ alpn : alpn. as_ref ( ) . to_vec ( ) ,
360
+ handler : Arc :: new ( handler) ,
361
+ reply,
362
+ } )
363
+ . await
364
+ . map_err ( |_| RouterError :: Closed { } ) ?;
365
+ reply_rx. await . map_err ( |_| RouterError :: Closed { } ) ?;
366
+ Ok ( ( ) )
367
+ }
368
+
369
+ /// Stops accepting a protocol.
370
+ ///
371
+ /// Note that this has only an effect on new connections. Existing connections that were
372
+ /// accepted with `alpn` won't be closed when calling [`Router::stop_accepting`].
373
+ pub async fn stop_accepting ( & self , alpn : impl AsRef < [ u8 ] > ) -> Result < ( ) , RouterError > {
374
+ let ( reply, reply_rx) = oneshot:: channel ( ) ;
375
+ self . tx
376
+ . send ( ToRouterTask :: StopAccepting {
377
+ alpn : alpn. as_ref ( ) . to_vec ( ) ,
378
+ reply,
379
+ } )
380
+ . await
381
+ . map_err ( |_| RouterError :: Closed { } ) ?;
382
+ reply_rx. await . map_err ( |_| RouterError :: Closed { } ) ?;
383
+ Ok ( ( ) )
384
+ }
385
+
314
386
/// Shuts down the accept loop cleanly.
315
387
///
316
388
/// When this function returns, all [`ProtocolHandler`]s will be shutdown and
@@ -348,8 +420,9 @@ impl RouterBuilder {
348
420
349
421
/// Configures the router to accept the [`ProtocolHandler`] when receiving a connection
350
422
/// with this `alpn`.
351
- pub fn accept ( mut self , alpn : impl AsRef < [ u8 ] > , handler : impl ProtocolHandler ) -> Self {
352
- self . protocols . insert ( alpn. as_ref ( ) . to_vec ( ) , handler) ;
423
+ pub fn accept ( self , alpn : impl AsRef < [ u8 ] > , handler : impl ProtocolHandler ) -> Self {
424
+ self . protocols
425
+ . insert ( alpn. as_ref ( ) . to_vec ( ) , Arc :: new ( handler) ) ;
353
426
self
354
427
}
355
428
@@ -361,14 +434,9 @@ impl RouterBuilder {
361
434
/// Spawns an accept loop and returns a handle to it encapsulated as the [`Router`].
362
435
pub fn spawn ( self ) -> Router {
363
436
// Update the endpoint with our alpns.
364
- let alpns = self
365
- . protocols
366
- . alpns ( )
367
- . map ( |alpn| alpn. to_vec ( ) )
368
- . collect :: < Vec < _ > > ( ) ;
437
+ self . endpoint . set_alpns ( self . protocols . alpns ( ) ) ;
369
438
370
439
let protocols = Arc :: new ( self . protocols ) ;
371
- self . endpoint . set_alpns ( alpns) ;
372
440
373
441
let mut join_set = JoinSet :: new ( ) ;
374
442
let endpoint = self . endpoint . clone ( ) ;
@@ -377,6 +445,8 @@ impl RouterBuilder {
377
445
let cancel = CancellationToken :: new ( ) ;
378
446
let cancel_token = cancel. clone ( ) ;
379
447
448
+ let ( tx, mut rx) = mpsc:: channel ( 8 ) ;
449
+
380
450
let run_loop_fut = async move {
381
451
// Make sure to cancel the token, if this future ever exits.
382
452
let _cancel_guard = cancel_token. clone ( ) . drop_guard ( ) ;
@@ -390,6 +460,20 @@ impl RouterBuilder {
390
460
_ = cancel_token. cancelled( ) => {
391
461
break ;
392
462
} ,
463
+ Some ( msg) = rx. recv( ) => {
464
+ match msg {
465
+ ToRouterTask :: Accept { alpn, handler, reply } => {
466
+ protocols. insert( alpn, handler) ;
467
+ endpoint. set_alpns( protocols. alpns( ) ) ;
468
+ reply. send( ( ) ) . ok( ) ;
469
+ }
470
+ ToRouterTask :: StopAccepting { alpn, reply } => {
471
+ protocols. remove( & alpn) ;
472
+ endpoint. set_alpns( protocols. alpns( ) ) ;
473
+ reply. send( ( ) ) . ok( ) ;
474
+ }
475
+ }
476
+ }
393
477
// handle task terminations and quit on panics.
394
478
Some ( res) = join_set. join_next( ) => {
395
479
match res {
@@ -452,6 +536,7 @@ impl RouterBuilder {
452
536
endpoint : self . endpoint ,
453
537
task : Arc :: new ( Mutex :: new ( Some ( task) ) ) ,
454
538
cancel_token : cancel,
539
+ tx,
455
540
}
456
541
}
457
542
}
@@ -542,12 +627,16 @@ impl<P: ProtocolHandler + Clone> ProtocolHandler for AccessLimit<P> {
542
627
mod tests {
543
628
use std:: { sync:: Mutex , time:: Duration } ;
544
629
630
+ use iroh_base:: NodeAddr ;
545
631
use n0_snafu:: { Result , ResultExt } ;
546
632
use n0_watcher:: Watcher ;
547
- use quinn:: ApplicationClose ;
633
+ use quinn:: { ApplicationClose , TransportErrorCode } ;
548
634
549
635
use super :: * ;
550
- use crate :: { endpoint:: ConnectionError , RelayMode } ;
636
+ use crate :: {
637
+ endpoint:: { ConnectError , ConnectionError } ,
638
+ RelayMode ,
639
+ } ;
551
640
552
641
#[ tokio:: test]
553
642
async fn test_shutdown ( ) -> Result {
@@ -674,4 +763,84 @@ mod tests {
674
763
) ;
675
764
Ok ( ( ) )
676
765
}
766
+
767
+ #[ tokio:: test]
768
+ async fn test_add_and_remove_protocol ( ) -> Result {
769
+ async fn connect_assert_ok ( endpoint : & Endpoint , addr : & NodeAddr , alpn : & [ u8 ] ) {
770
+ let conn = endpoint
771
+ . connect ( addr. clone ( ) , alpn)
772
+ . await
773
+ . expect ( "expected connection to succeed" ) ;
774
+ let reason = conn. closed ( ) . await ;
775
+ assert ! ( matches!( reason,
776
+ ConnectionError :: ApplicationClosed ( ApplicationClose { error_code, .. } ) if error_code == 42u32 . into( )
777
+ ) ) ;
778
+ }
779
+
780
+ async fn connect_assert_fail ( endpoint : & Endpoint , addr : & NodeAddr , alpn : & [ u8 ] ) {
781
+ let conn = endpoint. connect ( addr. clone ( ) , alpn) . await ;
782
+ assert ! ( matches!(
783
+ & conn,
784
+ Err ( ConnectError :: Connection { source, .. } )
785
+ if matches!(
786
+ source. as_ref( ) ,
787
+ ConnectionError :: ConnectionClosed ( frame)
788
+ if frame. error_code == TransportErrorCode :: crypto( rustls:: AlertDescription :: NoApplicationProtocol . into( ) )
789
+ )
790
+ ) ) ;
791
+ }
792
+
793
+ #[ derive( Debug , Clone , Default ) ]
794
+ struct TestProtocol ;
795
+
796
+ const ALPN_1 : & [ u8 ] = b"/iroh/test/1" ;
797
+ const ALPN_2 : & [ u8 ] = b"/iroh/test/2" ;
798
+
799
+ impl ProtocolHandler for TestProtocol {
800
+ async fn accept ( & self , connection : Connection ) -> Result < ( ) , AcceptError > {
801
+ connection. close ( 42u32 . into ( ) , b"bye" ) ;
802
+ Ok ( ( ) )
803
+ }
804
+ }
805
+
806
+ let server = Endpoint :: builder ( )
807
+ . relay_mode ( RelayMode :: Disabled )
808
+ . bind ( )
809
+ . await ?;
810
+ let router = Router :: builder ( server)
811
+ . accept ( ALPN_1 , TestProtocol :: default ( ) )
812
+ . spawn ( ) ;
813
+
814
+ let addr = router. endpoint ( ) . node_addr ( ) . initialized ( ) . await ?;
815
+
816
+ let client = Endpoint :: builder ( )
817
+ . relay_mode ( RelayMode :: Disabled )
818
+ . bind ( )
819
+ . await ?;
820
+
821
+ connect_assert_ok ( & client, & addr, ALPN_1 ) . await ;
822
+ connect_assert_fail ( & client, & addr, ALPN_2 ) . await ;
823
+
824
+ router. stop_accepting ( ALPN_1 ) . await ?;
825
+ connect_assert_fail ( & client, & addr, ALPN_1 ) . await ;
826
+ connect_assert_fail ( & client, & addr, ALPN_2 ) . await ;
827
+
828
+ router. accept ( ALPN_2 , TestProtocol ) . await ?;
829
+ connect_assert_fail ( & client, & addr, ALPN_1 ) . await ;
830
+ connect_assert_ok ( & client, & addr, ALPN_2 ) . await ;
831
+
832
+ router. accept ( ALPN_1 , TestProtocol ) . await ?;
833
+ connect_assert_ok ( & client, & addr, ALPN_1 ) . await ;
834
+ connect_assert_ok ( & client, & addr, ALPN_2 ) . await ;
835
+
836
+ router. stop_accepting ( ALPN_2 ) . await ?;
837
+ connect_assert_ok ( & client, & addr, ALPN_1 ) . await ;
838
+ connect_assert_fail ( & client, & addr, ALPN_2 ) . await ;
839
+
840
+ router. stop_accepting ( ALPN_1 ) . await ?;
841
+ connect_assert_fail ( & client, & addr, ALPN_1 ) . await ;
842
+ connect_assert_fail ( & client, & addr, ALPN_2 ) . await ;
843
+
844
+ Ok ( ( ) )
845
+ }
677
846
}
0 commit comments