2
2
3
3
use anyhow:: { anyhow, Context } ;
4
4
use bytes:: { Bytes , BytesMut } ;
5
- use futures_lite:: stream:: Stream ;
5
+ use futures_lite:: { stream:: Stream , StreamExt } ;
6
+ use futures_util:: future:: FutureExt ;
6
7
use genawaiter:: sync:: { Co , Gen } ;
7
8
use iroh_net:: {
8
- dialer:: Dialer ,
9
- endpoint:: { get_remote_node_id , Connection } ,
9
+ dialer:: { ConnDirection , ConnManager , NewConnection } ,
10
+ endpoint:: Connection ,
10
11
key:: PublicKey ,
11
12
AddrInfo , Endpoint , NodeAddr ,
12
13
} ;
@@ -15,7 +16,7 @@ use rand_core::SeedableRng;
15
16
use std:: { collections:: HashMap , future:: Future , pin:: Pin , sync:: Arc , task:: Poll , time:: Instant } ;
16
17
use tokio:: {
17
18
sync:: { broadcast, mpsc, oneshot} ,
18
- task:: JoinHandle ,
19
+ task:: { JoinHandle , JoinSet } ,
19
20
} ;
20
21
use tracing:: { debug, error_span, trace, warn, Instrument } ;
21
22
@@ -82,7 +83,7 @@ impl Gossip {
82
83
/// Spawn a gossip actor and get a handle for it
83
84
pub fn from_endpoint ( endpoint : Endpoint , config : proto:: Config , my_addr : & AddrInfo ) -> Self {
84
85
let peer_id = endpoint. node_id ( ) ;
85
- let dialer = Dialer :: new ( endpoint. clone ( ) ) ;
86
+ let conn_manager = ConnManager :: new ( endpoint. clone ( ) , GOSSIP_ALPN ) ;
86
87
let state = proto:: State :: new (
87
88
peer_id,
88
89
encode_peer_data ( my_addr) . unwrap ( ) ,
@@ -97,12 +98,12 @@ impl Gossip {
97
98
let actor = Actor {
98
99
endpoint,
99
100
state,
100
- dialer,
101
+ conn_manager,
102
+ conn_tasks : Default :: default ( ) ,
101
103
to_actor_rx,
102
104
in_event_rx,
103
105
in_event_tx,
104
106
on_endpoints_rx,
105
- conns : Default :: default ( ) ,
106
107
conn_send_tx : Default :: default ( ) ,
107
108
pending_sends : Default :: default ( ) ,
108
109
timers : Timers :: new ( ) ,
@@ -231,9 +232,7 @@ impl Gossip {
231
232
///
232
233
/// Make sure to check the ALPN protocol yourself before passing the connection.
233
234
pub async fn handle_connection ( & self , conn : Connection ) -> anyhow:: Result < ( ) > {
234
- let peer_id = get_remote_node_id ( & conn) ?;
235
- self . send ( ToActor :: ConnIncoming ( peer_id, ConnOrigin :: Accept , conn) )
236
- . await ?;
235
+ self . send ( ToActor :: AcceptConn ( conn) ) . await ?;
237
236
Ok ( ( ) )
238
237
}
239
238
@@ -283,19 +282,13 @@ impl Future for JoinTopicFut {
283
282
}
284
283
}
285
284
286
- /// Whether a connection is initiated by us (Dial) or by the remote peer (Accept)
287
- #[ derive( Debug ) ]
288
- enum ConnOrigin {
289
- Accept ,
290
- Dial ,
291
- }
292
-
293
285
/// Input messages for the gossip [`Actor`].
294
286
#[ derive( derive_more:: Debug ) ]
295
287
enum ToActor {
296
288
/// Handle a new QUIC connection, either from accept (external to the actor) or from connect
297
289
/// (happens internally in the actor).
298
- ConnIncoming ( PublicKey , ConnOrigin , #[ debug( skip) ] Connection ) ,
290
+ // ConnIncoming(NewConnection),
291
+ AcceptConn ( iroh_net:: endpoint:: Connection ) ,
299
292
/// Join a topic with a list of peers. Reply with oneshot once at least one peer joined.
300
293
Join (
301
294
TopicId ,
@@ -330,7 +323,7 @@ struct Actor {
330
323
state : proto:: State < PublicKey , StdRng > ,
331
324
endpoint : Endpoint ,
332
325
/// Dial machine to connect to peers
333
- dialer : Dialer ,
326
+ conn_manager : ConnManager ,
334
327
/// Input messages to the actor
335
328
to_actor_rx : mpsc:: Receiver < ToActor > ,
336
329
/// Sender for the state input (cloned into the connection loops)
@@ -341,10 +334,9 @@ struct Actor {
341
334
on_endpoints_rx : mpsc:: Receiver < Vec < iroh_net:: config:: Endpoint > > ,
342
335
/// Queued timers
343
336
timers : Timers < Timer > ,
344
- /// Currently opened quinn connections to peers
345
- conns : HashMap < PublicKey , Connection > ,
346
337
/// Channels to send outbound messages into the connection loops
347
338
conn_send_tx : HashMap < PublicKey , mpsc:: Sender < ProtoMessage > > ,
339
+ conn_tasks : JoinSet < ( PublicKey , anyhow:: Result < ( ) > ) > ,
348
340
/// Queued messages that were to be sent before a dial completed
349
341
pending_sends : HashMap < PublicKey , Vec < ProtoMessage > > ,
350
342
/// Broadcast senders for active topic subscriptions from the application
@@ -353,6 +345,12 @@ struct Actor {
353
345
subscribers_all : Option < broadcast:: Sender < ( TopicId , Event ) > > ,
354
346
}
355
347
348
+ impl Drop for Actor {
349
+ fn drop ( & mut self ) {
350
+ self . conn_tasks . abort_all ( ) ;
351
+ }
352
+ }
353
+
356
354
impl Actor {
357
355
pub async fn run ( mut self ) -> anyhow:: Result < ( ) > {
358
356
let mut i = 0 ;
@@ -384,15 +382,31 @@ impl Actor {
384
382
}
385
383
}
386
384
}
387
- ( peer_id, res) = self . dialer. next_conn( ) => {
388
- trace!( ?i, "tick: dialer" ) ;
385
+ Some ( new_conn) = self . conn_manager. next( ) => {
386
+ trace!( ?i, "tick: conn_manager" ) ;
387
+ let node_id = new_conn. node_id;
388
+ if let Err ( err) = self . handle_new_connection( new_conn) . await {
389
+ warn!( peer=%node_id. fmt_short( ) , ?err, "failed to handle new connection" ) ;
390
+ self . conn_manager. remove( & node_id) ;
391
+ self . conn_send_tx. remove( & node_id) ;
392
+ }
393
+ }
394
+ Some ( res) = self . conn_tasks. join_next( ) , if !self . conn_tasks. is_empty( ) => {
389
395
match res {
390
- Ok ( conn) => {
391
- debug!( peer = ?peer_id, "dial successful" ) ;
392
- self . handle_to_actor_msg( ToActor :: ConnIncoming ( peer_id, ConnOrigin :: Dial , conn) , Instant :: now( ) ) . await . context( "dialer.next -> conn -> handle_to_actor_msg" ) ?;
393
- }
394
- Err ( err) => {
395
- warn!( peer = ?peer_id, "dial failed: {err}" ) ;
396
+ Err ( err) if !err. is_cancelled( ) => warn!( ?err, "connection loop panicked" ) ,
397
+ Err ( _err) => { } ,
398
+ Ok ( ( node_id, result) ) => {
399
+ self . conn_manager. remove( & node_id) ;
400
+ self . conn_send_tx. remove( & node_id) ;
401
+ self . handle_in_event( InEvent :: PeerDisconnected ( node_id) , Instant :: now( ) ) . await ?;
402
+ match result {
403
+ Ok ( ( ) ) => {
404
+ debug!( peer=%node_id. fmt_short( ) , "connection closed without error" ) ;
405
+ }
406
+ Err ( err) => {
407
+ debug!( peer=%node_id. fmt_short( ) , "connection closed with error {err:?}" ) ;
408
+ }
409
+ }
396
410
}
397
411
}
398
412
}
@@ -421,38 +435,9 @@ impl Actor {
421
435
async fn handle_to_actor_msg ( & mut self , msg : ToActor , now : Instant ) -> anyhow:: Result < ( ) > {
422
436
trace ! ( "handle to_actor {msg:?}" ) ;
423
437
match msg {
424
- ToActor :: ConnIncoming ( peer_id, origin, conn) => {
425
- self . conns . insert ( peer_id, conn. clone ( ) ) ;
426
- self . dialer . abort_dial ( & peer_id) ;
427
- let ( send_tx, send_rx) = mpsc:: channel ( SEND_QUEUE_CAP ) ;
428
- self . conn_send_tx . insert ( peer_id, send_tx. clone ( ) ) ;
429
-
430
- // Spawn a task for this connection
431
- let in_event_tx = self . in_event_tx . clone ( ) ;
432
- tokio:: spawn (
433
- async move {
434
- debug ! ( "connection established" ) ;
435
- match connection_loop ( peer_id, conn, origin, send_rx, & in_event_tx) . await {
436
- Ok ( ( ) ) => {
437
- debug ! ( "connection closed without error" )
438
- }
439
- Err ( err) => {
440
- debug ! ( "connection closed with error {err:?}" )
441
- }
442
- }
443
- in_event_tx
444
- . send ( InEvent :: PeerDisconnected ( peer_id) )
445
- . await
446
- . ok ( ) ;
447
- }
448
- . instrument ( error_span ! ( "gossip_conn" , peer = %peer_id. fmt_short( ) ) ) ,
449
- ) ;
450
-
451
- // Forward queued pending sends
452
- if let Some ( send_queue) = self . pending_sends . remove ( & peer_id) {
453
- for msg in send_queue {
454
- send_tx. send ( msg) . await ?;
455
- }
438
+ ToActor :: AcceptConn ( conn) => {
439
+ if let Err ( err) = self . conn_manager . push_accept ( conn) {
440
+ warn ! ( ?err, "failed to accept connection" ) ;
456
441
}
457
442
}
458
443
ToActor :: Join ( topic_id, peers, reply) => {
@@ -502,9 +487,6 @@ impl Actor {
502
487
} else {
503
488
debug ! ( "handle in_event {event:?}" ) ;
504
489
} ;
505
- if let InEvent :: PeerDisconnected ( peer) = & event {
506
- self . conn_send_tx . remove ( peer) ;
507
- }
508
490
let out = self . state . handle ( event, now) ;
509
491
for event in out {
510
492
if matches ! ( event, OutEvent :: ScheduleTimer ( _, _) ) {
@@ -518,10 +500,13 @@ impl Actor {
518
500
if let Err ( _err) = send. send ( message) . await {
519
501
warn ! ( "conn receiver for {peer_id:?} dropped" ) ;
520
502
self . conn_send_tx . remove ( & peer_id) ;
503
+ self . conn_manager . remove ( & peer_id) ;
521
504
}
522
505
} else {
523
- debug ! ( peer = ?peer_id, "dial" ) ;
524
- self . dialer . queue_dial ( peer_id, GOSSIP_ALPN ) ;
506
+ if !self . conn_manager . is_dialing ( & peer_id) {
507
+ debug ! ( peer = ?peer_id, "dial" ) ;
508
+ self . conn_manager . dial ( peer_id) ;
509
+ }
525
510
// TODO: Enforce max length
526
511
self . pending_sends . entry ( peer_id) . or_default ( ) . push ( message) ;
527
512
}
@@ -544,12 +529,11 @@ impl Actor {
544
529
self . timers . insert ( now + delay, timer) ;
545
530
}
546
531
OutEvent :: DisconnectPeer ( peer) => {
547
- if let Some ( conn) = self . conns . remove ( & peer) {
548
- conn. close ( 0u8 . into ( ) , b"close from disconnect" ) ;
549
- }
550
532
self . conn_send_tx . remove ( & peer) ;
551
533
self . pending_sends . remove ( & peer) ;
552
- self . dialer . abort_dial ( & peer) ;
534
+ if let Some ( conn) = self . conn_manager . remove ( & peer) {
535
+ conn. close ( 0u8 . into ( ) , b"close from disconnect" ) ;
536
+ }
553
537
}
554
538
OutEvent :: PeerData ( node_id, data) => match decode_peer_data ( & data) {
555
539
Err ( err) => warn ! ( "Failed to decode {data:?} from {node_id}: {err}" ) ,
@@ -566,6 +550,42 @@ impl Actor {
566
550
Ok ( ( ) )
567
551
}
568
552
553
+ async fn handle_new_connection ( & mut self , new_conn : NewConnection ) -> anyhow:: Result < ( ) > {
554
+ let NewConnection {
555
+ conn,
556
+ node_id : peer_id,
557
+ direction,
558
+ } = new_conn;
559
+ match conn {
560
+ Ok ( conn) => {
561
+ let ( send_tx, send_rx) = mpsc:: channel ( SEND_QUEUE_CAP ) ;
562
+ self . conn_send_tx . insert ( peer_id, send_tx. clone ( ) ) ;
563
+
564
+ // Spawn a task for this connection
565
+ let pending_sends = self . pending_sends . remove ( & peer_id) ;
566
+ let in_event_tx = self . in_event_tx . clone ( ) ;
567
+ debug ! ( peer=%peer_id. fmt_short( ) , ?direction, "connection established" ) ;
568
+ self . conn_tasks . spawn (
569
+ connection_loop (
570
+ peer_id,
571
+ conn,
572
+ direction,
573
+ send_rx,
574
+ in_event_tx,
575
+ pending_sends,
576
+ )
577
+ . map ( move |r| ( peer_id, r) )
578
+ . instrument ( error_span ! ( "gossip_conn" , peer = %peer_id. fmt_short( ) ) ) ,
579
+ ) ;
580
+ }
581
+ Err ( err) => {
582
+ warn ! ( peer=%peer_id. fmt_short( ) , "connecting to node failed: {err:?}" ) ;
583
+ }
584
+ }
585
+
586
+ Ok ( ( ) )
587
+ }
588
+
569
589
fn subscribe_all ( & mut self ) -> broadcast:: Receiver < ( TopicId , Event ) > {
570
590
if let Some ( tx) = self . subscribers_all . as_mut ( ) {
571
591
tx. subscribe ( )
@@ -602,20 +622,30 @@ async fn wait_for_neighbor_up(mut sub: broadcast::Receiver<Event>) -> anyhow::Re
602
622
async fn connection_loop (
603
623
from : PublicKey ,
604
624
conn : Connection ,
605
- origin : ConnOrigin ,
625
+ direction : ConnDirection ,
606
626
mut send_rx : mpsc:: Receiver < ProtoMessage > ,
607
- in_event_tx : & mpsc:: Sender < InEvent > ,
627
+ in_event_tx : mpsc:: Sender < InEvent > ,
628
+ mut pending_sends : Option < Vec < ProtoMessage > > ,
608
629
) -> anyhow:: Result < ( ) > {
609
- let ( mut send, mut recv) = match origin {
610
- ConnOrigin :: Accept => conn. accept_bi ( ) . await ?,
611
- ConnOrigin :: Dial => conn. open_bi ( ) . await ?,
630
+ let ( mut send, mut recv) = match direction {
631
+ ConnDirection :: Accept => conn. accept_bi ( ) . await ?,
632
+ ConnDirection :: Dial => conn. open_bi ( ) . await ?,
612
633
} ;
613
634
let mut send_buf = BytesMut :: new ( ) ;
614
635
let mut recv_buf = BytesMut :: new ( ) ;
636
+
637
+ // Forward queued pending sends
638
+ if let Some ( mut send_queue) = pending_sends. take ( ) {
639
+ for msg in send_queue. drain ( ..) {
640
+ write_message ( & mut send, & mut send_buf, & msg) . await ?;
641
+ }
642
+ }
643
+
644
+ // loop over sending and receiving messages
615
645
loop {
616
646
tokio:: select! {
617
647
biased;
618
- msg = send_rx. recv( ) => {
648
+ msg = send_rx. recv( ) , if !send_rx . is_closed ( ) => {
619
649
match msg {
620
650
None => break ,
621
651
Some ( msg) => write_message( & mut send, & mut send_buf, & msg) . await ?,
0 commit comments