32
32
//! }
33
33
//! }
34
34
//! ```
35
- use std:: { collections:: BTreeMap , future:: Future , pin:: Pin , sync:: Arc } ;
35
+ use std:: { collections:: BTreeMap , future:: Future , pin:: Pin , sync:: Arc , time :: Duration } ;
36
36
37
37
use iroh_base:: NodeId ;
38
38
use n0_future:: {
39
39
join_all,
40
40
task:: { self , AbortOnDropHandle , JoinSet } ,
41
+ time,
41
42
} ;
42
43
use snafu:: { Backtrace , Snafu } ;
43
44
use tokio:: sync:: { mpsc, oneshot, Mutex } ;
@@ -92,11 +93,11 @@ enum ToRouterTask {
92
93
Accept {
93
94
alpn : Vec < u8 > ,
94
95
handler : Arc < dyn DynProtocolHandler > ,
95
- reply : oneshot:: Sender < ( ) > ,
96
+ reply : oneshot:: Sender < AddProtocolOutcome > ,
96
97
} ,
97
98
StopAccepting {
98
99
alpn : Vec < u8 > ,
99
- reply : oneshot:: Sender < ( ) > ,
100
+ reply : oneshot:: Sender < Result < ( ) , StopAcceptingError > > ,
100
101
} ,
101
102
}
102
103
@@ -129,14 +130,6 @@ pub enum AcceptError {
129
130
} ,
130
131
}
131
132
132
- #[ allow( missing_docs) ]
133
- #[ derive( Debug , Snafu ) ]
134
- #[ non_exhaustive]
135
- pub enum RouterError {
136
- #[ snafu( display( "The router actor closed" ) ) ]
137
- Closed { } ,
138
- }
139
-
140
133
impl AcceptError {
141
134
/// Creates a new user error from an arbitrary error type.
142
135
pub fn from_err < T : std:: error:: Error + Send + Sync + ' static > ( value : T ) -> Self {
@@ -158,6 +151,37 @@ impl From<quinn::ClosedStream> for AcceptError {
158
151
}
159
152
}
160
153
154
+ #[ allow( missing_docs) ]
155
+ #[ derive( Debug , Snafu ) ]
156
+ #[ non_exhaustive]
157
+ pub enum RouterError {
158
+ #[ snafu( display( "The router actor closed" ) ) ]
159
+ Closed { } ,
160
+ }
161
+
162
+ #[ allow( missing_docs) ]
163
+ #[ derive( Debug , Snafu ) ]
164
+ #[ snafu( module) ]
165
+ #[ non_exhaustive]
166
+ pub enum StopAcceptingError {
167
+ #[ snafu( display( "The router actor closed" ) ) ]
168
+ Closed { } ,
169
+ #[ snafu( display( "The ALPN requested to be removed is not registered" ) ) ]
170
+ UnknownAlpn { } ,
171
+ }
172
+
173
+ /// Returned from [`Router::accept`]
174
+ #[ derive( Debug , Clone , Copy , Eq , PartialEq ) ]
175
+ pub enum AddProtocolOutcome {
176
+ /// The protocol handler has been newly inserted.
177
+ Inserted ,
178
+ /// The protocol handler replaced a previously registered protocol handler.
179
+ Replaced ,
180
+ }
181
+
182
+ /// Timeout applied to [`ProtocolHandler::shutdown] futures.
183
+ const SHUTDOWN_TIMEOUT : Duration = Duration :: from_secs ( 30 ) ;
184
+
161
185
/// Handler for incoming connections.
162
186
///
163
187
/// A router accepts connections for arbitrary ALPN protocols.
@@ -286,40 +310,50 @@ impl<P: ProtocolHandler> DynProtocolHandler for P {
286
310
}
287
311
}
288
312
313
+ async fn shutdown_timeout ( handler : Arc < dyn DynProtocolHandler > ) -> Option < ( ) > {
314
+ time:: timeout ( SHUTDOWN_TIMEOUT , handler. shutdown ( ) )
315
+ . await
316
+ . ok ( )
317
+ }
318
+
289
319
/// A typed map of protocol handlers, mapping them from ALPNs.
290
320
#[ derive( Debug , Default ) ]
291
- pub ( crate ) struct ProtocolMap ( std:: sync:: Mutex < BTreeMap < Vec < u8 > , Arc < dyn DynProtocolHandler > > > ) ;
321
+ pub ( crate ) struct ProtocolMap ( std:: sync:: RwLock < BTreeMap < Vec < u8 > , Arc < dyn DynProtocolHandler > > > ) ;
292
322
293
323
impl ProtocolMap {
294
324
/// Returns the registered protocol handler for an ALPN as a [`Arc<dyn ProtocolHandler>`].
295
325
pub ( crate ) fn get ( & self , alpn : & [ u8 ] ) -> Option < Arc < dyn DynProtocolHandler > > {
296
- self . 0 . lock ( ) . expect ( "poisoned" ) . get ( alpn) . cloned ( )
326
+ self . 0 . read ( ) . expect ( "poisoned" ) . get ( alpn) . cloned ( )
297
327
}
298
328
299
329
/// Inserts a protocol handler.
300
- pub ( crate ) fn insert ( & self , alpn : Vec < u8 > , handler : Arc < dyn DynProtocolHandler > ) {
301
- self . 0 . lock ( ) . expect ( "poisoned" ) . insert ( alpn, handler) ;
330
+ pub ( crate ) fn insert (
331
+ & self ,
332
+ alpn : Vec < u8 > ,
333
+ handler : Arc < dyn DynProtocolHandler > ,
334
+ ) -> Option < Arc < dyn DynProtocolHandler > > {
335
+ self . 0 . write ( ) . expect ( "poisoned" ) . insert ( alpn, handler)
302
336
}
303
337
304
- pub ( crate ) fn remove ( & self , alpn : & [ u8 ] ) {
305
- self . 0 . lock ( ) . expect ( "poisoned" ) . remove ( alpn) ;
338
+ pub ( crate ) fn remove ( & self , alpn : & [ u8 ] ) -> Option < Arc < dyn DynProtocolHandler > > {
339
+ self . 0 . write ( ) . expect ( "poisoned" ) . remove ( alpn)
306
340
}
307
341
308
342
/// Returns an iterator of all registered ALPN protocol identifiers.
309
343
pub ( crate ) fn alpns ( & self ) -> Vec < Vec < u8 > > {
310
- self . 0 . lock ( ) . expect ( "poisoned" ) . keys ( ) . cloned ( ) . collect ( )
344
+ self . 0 . read ( ) . expect ( "poisoned" ) . keys ( ) . cloned ( ) . collect ( )
311
345
}
312
346
313
347
/// Shuts down all protocol handlers.
314
348
///
315
349
/// Calls and awaits [`ProtocolHandler::shutdown`] for all registered handlers concurrently.
316
350
pub ( crate ) async fn shutdown ( & self ) {
317
351
let handlers: Vec < _ > = {
318
- let inner = self . 0 . lock ( ) . expect ( "poisoned" ) ;
352
+ let inner = self . 0 . read ( ) . expect ( "poisoned" ) ;
319
353
inner
320
354
. values ( )
321
355
. cloned ( )
322
- . map ( |p| async move { p . shutdown ( ) . await } )
356
+ . map ( |handler| shutdown_timeout ( handler ) )
323
357
. collect ( )
324
358
} ;
325
359
join_all ( handlers) . await ;
@@ -342,17 +376,21 @@ impl Router {
342
376
self . cancel_token . is_cancelled ( )
343
377
}
344
378
345
- /// Add a protocol to the list of accepted protocols.
379
+ /// Adds a protocol to the list of accepted protocols.
346
380
///
347
381
/// Configures the router to accept the [`ProtocolHandler`] when receiving a connection
348
382
/// with this `alpn`.
349
383
///
350
384
/// Once the function yields, new connections with this `alpn` will be handled.
385
+ ///
386
+ /// If a protocol handler was already registered for `alpn`, the previous handler will be shutdown.
387
+ ///
388
+ /// Returns `true` if
351
389
pub async fn accept (
352
390
& self ,
353
391
alpn : impl AsRef < [ u8 ] > ,
354
392
handler : impl ProtocolHandler ,
355
- ) -> Result < ( ) , RouterError > {
393
+ ) -> Result < AddProtocolOutcome , RouterError > {
356
394
let ( reply, reply_rx) = oneshot:: channel ( ) ;
357
395
self . tx
358
396
. send ( ToRouterTask :: Accept {
@@ -362,25 +400,23 @@ impl Router {
362
400
} )
363
401
. await
364
402
. map_err ( |_| RouterError :: Closed { } ) ?;
365
- reply_rx. await . map_err ( |_| RouterError :: Closed { } ) ?;
366
- Ok ( ( ) )
403
+ reply_rx. await . map_err ( |_| RouterError :: Closed { } )
367
404
}
368
405
369
406
/// Stops accepting a protocol.
370
407
///
371
408
/// Note that this has only an effect on new connections. Existing connections that were
372
409
/// accepted with `alpn` won't be closed when calling [`Router::stop_accepting`].
373
- pub async fn stop_accepting ( & self , alpn : impl AsRef < [ u8 ] > ) -> Result < ( ) , RouterError > {
410
+ pub async fn stop_accepting ( & self , alpn : impl AsRef < [ u8 ] > ) -> Result < ( ) , StopAcceptingError > {
374
411
let ( reply, reply_rx) = oneshot:: channel ( ) ;
375
412
self . tx
376
413
. send ( ToRouterTask :: StopAccepting {
377
414
alpn : alpn. as_ref ( ) . to_vec ( ) ,
378
415
reply,
379
416
} )
380
417
. await
381
- . map_err ( |_| RouterError :: Closed { } ) ?;
382
- reply_rx. await . map_err ( |_| RouterError :: Closed { } ) ?;
383
- Ok ( ( ) )
418
+ . map_err ( |_| StopAcceptingError :: Closed { } ) ?;
419
+ reply_rx. await . map_err ( |_| StopAcceptingError :: Closed { } ) ?
384
420
}
385
421
386
422
/// Shuts down the accept loop cleanly.
@@ -463,14 +499,23 @@ impl RouterBuilder {
463
499
Some ( msg) = rx. recv( ) => {
464
500
match msg {
465
501
ToRouterTask :: Accept { alpn, handler, reply } => {
466
- protocols. insert( alpn, handler) ;
502
+ let outcome = if let Some ( previous) = protocols. insert( alpn, handler) {
503
+ join_set. spawn( shutdown_timeout( previous) ) ;
504
+ AddProtocolOutcome :: Replaced
505
+ } else {
506
+ AddProtocolOutcome :: Inserted
507
+ } ;
467
508
endpoint. set_alpns( protocols. alpns( ) ) ;
468
- reply. send( ( ) ) . ok( ) ;
509
+ reply. send( outcome ) . ok( ) ;
469
510
}
470
511
ToRouterTask :: StopAccepting { alpn, reply } => {
471
- protocols. remove( & alpn) ;
472
- endpoint. set_alpns( protocols. alpns( ) ) ;
473
- reply. send( ( ) ) . ok( ) ;
512
+ if let Some ( handler) = protocols. remove( & alpn) {
513
+ join_set. spawn( shutdown_timeout( handler) ) ;
514
+ endpoint. set_alpns( protocols. alpns( ) ) ;
515
+ reply. send( Ok ( ( ) ) ) . ok( ) ;
516
+ } else {
517
+ reply. send( Err ( StopAcceptingError :: UnknownAlpn { } ) ) . ok( ) ;
518
+ }
474
519
}
475
520
}
476
521
}
@@ -766,14 +811,19 @@ mod tests {
766
811
767
812
#[ tokio:: test]
768
813
async fn test_add_and_remove_protocol ( ) -> Result {
769
- async fn connect_assert_ok ( endpoint : & Endpoint , addr : & NodeAddr , alpn : & [ u8 ] ) {
814
+ async fn connect_assert_ok (
815
+ endpoint : & Endpoint ,
816
+ addr : & NodeAddr ,
817
+ alpn : & [ u8 ] ,
818
+ expected_code : u32 ,
819
+ ) {
770
820
let conn = endpoint
771
821
. connect ( addr. clone ( ) , alpn)
772
822
. await
773
823
. expect ( "expected connection to succeed" ) ;
774
824
let reason = conn. closed ( ) . await ;
775
825
assert ! ( matches!( reason,
776
- ConnectionError :: ApplicationClosed ( ApplicationClose { error_code, .. } ) if error_code == 42u32 . into( )
826
+ ConnectionError :: ApplicationClosed ( ApplicationClose { error_code, .. } ) if error_code == expected_code . into( )
777
827
) ) ;
778
828
}
779
829
@@ -791,14 +841,14 @@ mod tests {
791
841
}
792
842
793
843
#[ derive( Debug , Clone , Default ) ]
794
- struct TestProtocol ;
844
+ struct TestProtocol ( u32 ) ;
795
845
796
846
const ALPN_1 : & [ u8 ] = b"/iroh/test/1" ;
797
847
const ALPN_2 : & [ u8 ] = b"/iroh/test/2" ;
798
848
799
849
impl ProtocolHandler for TestProtocol {
800
850
async fn accept ( & self , connection : Connection ) -> Result < ( ) , AcceptError > {
801
- connection. close ( 42u32 . into ( ) , b"bye" ) ;
851
+ connection. close ( self . 0 . into ( ) , b"bye" ) ;
802
852
Ok ( ( ) )
803
853
}
804
854
}
@@ -808,7 +858,7 @@ mod tests {
808
858
. bind ( )
809
859
. await ?;
810
860
let router = Router :: builder ( server)
811
- . accept ( ALPN_1 , TestProtocol :: default ( ) )
861
+ . accept ( ALPN_1 , TestProtocol ( 1 ) )
812
862
. spawn ( ) ;
813
863
814
864
let addr = router. endpoint ( ) . node_addr ( ) . initialized ( ) . await ?;
@@ -818,29 +868,44 @@ mod tests {
818
868
. bind ( )
819
869
. await ?;
820
870
821
- connect_assert_ok ( & client, & addr, ALPN_1 ) . await ;
871
+ connect_assert_ok ( & client, & addr, ALPN_1 , 1 ) . await ;
822
872
connect_assert_fail ( & client, & addr, ALPN_2 ) . await ;
823
873
824
874
router. stop_accepting ( ALPN_1 ) . await ?;
825
875
connect_assert_fail ( & client, & addr, ALPN_1 ) . await ;
826
876
connect_assert_fail ( & client, & addr, ALPN_2 ) . await ;
827
877
828
- router. accept ( ALPN_2 , TestProtocol ) . await ?;
878
+ let outcome = router. accept ( ALPN_2 , TestProtocol ( 2 ) ) . await ?;
879
+ assert_eq ! ( outcome, AddProtocolOutcome :: Inserted ) ;
829
880
connect_assert_fail ( & client, & addr, ALPN_1 ) . await ;
830
- connect_assert_ok ( & client, & addr, ALPN_2 ) . await ;
881
+ connect_assert_ok ( & client, & addr, ALPN_2 , 2 ) . await ;
831
882
832
- router. accept ( ALPN_1 , TestProtocol ) . await ?;
833
- connect_assert_ok ( & client, & addr, ALPN_1 ) . await ;
834
- connect_assert_ok ( & client, & addr, ALPN_2 ) . await ;
883
+ let outcome = router. accept ( ALPN_1 , TestProtocol ( 3 ) ) . await ?;
884
+ assert_eq ! ( outcome, AddProtocolOutcome :: Inserted ) ;
885
+ connect_assert_ok ( & client, & addr, ALPN_1 , 3 ) . await ;
886
+ connect_assert_ok ( & client, & addr, ALPN_2 , 2 ) . await ;
887
+
888
+ let outcome = router. accept ( ALPN_1 , TestProtocol ( 4 ) ) . await ?;
889
+ assert_eq ! ( outcome, AddProtocolOutcome :: Replaced ) ;
890
+ connect_assert_ok ( & client, & addr, ALPN_1 , 4 ) . await ;
835
891
836
892
router. stop_accepting ( ALPN_2 ) . await ?;
837
- connect_assert_ok ( & client, & addr, ALPN_1 ) . await ;
893
+ connect_assert_ok ( & client, & addr, ALPN_1 , 4 ) . await ;
838
894
connect_assert_fail ( & client, & addr, ALPN_2 ) . await ;
839
895
840
896
router. stop_accepting ( ALPN_1 ) . await ?;
841
897
connect_assert_fail ( & client, & addr, ALPN_1 ) . await ;
842
898
connect_assert_fail ( & client, & addr, ALPN_2 ) . await ;
843
899
900
+ assert ! ( matches!(
901
+ router. stop_accepting( ALPN_1 ) . await ,
902
+ Err ( StopAcceptingError :: UnknownAlpn { } )
903
+ ) ) ;
904
+ assert ! ( matches!(
905
+ router. stop_accepting( ALPN_2 ) . await ,
906
+ Err ( StopAcceptingError :: UnknownAlpn { } )
907
+ ) ) ;
908
+
844
909
Ok ( ( ) )
845
910
}
846
911
}
0 commit comments