Skip to content

Commit eac7d69

Browse files
committed
fix: improve connection handling in iroh-gossip
1 parent 2f7d42d commit eac7d69

File tree

4 files changed

+253
-79
lines changed

4 files changed

+253
-79
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

iroh-gossip/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ iroh-base = { version = "0.16.0", path = "../iroh-base" }
3232

3333
# net dependencies (optional)
3434
futures-lite = { version = "2.3", optional = true }
35+
futures-util = { version = "0.3.30", optional = true }
3536
iroh-net = { path = "../iroh-net", version = "0.16.0", optional = true, default-features = false, features = ["test-utils"] }
3637
tokio = { version = "1", optional = true, features = ["io-util", "sync", "rt", "macros", "net", "fs"] }
3738
tokio-util = { version = "0.7.8", optional = true, features = ["codec"] }
@@ -46,7 +47,7 @@ url = "2.4.0"
4647

4748
[features]
4849
default = ["net"]
49-
net = ["dep:futures-lite", "dep:iroh-net", "dep:tokio", "dep:tokio-util"]
50+
net = ["dep:futures-lite", "dep:futures-util", "dep:iroh-net", "dep:tokio", "dep:tokio-util"]
5051

5152
[[example]]
5253
name = "chat"

iroh-gossip/src/net.rs

Lines changed: 106 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
33
use anyhow::{anyhow, Context};
44
use bytes::{Bytes, BytesMut};
5-
use futures_lite::stream::Stream;
5+
use futures_lite::{stream::Stream, StreamExt};
6+
use futures_util::future::FutureExt;
67
use genawaiter::sync::{Co, Gen};
78
use iroh_net::{
8-
dialer::Dialer,
9-
endpoint::{get_remote_node_id, Connection},
9+
dialer::{ConnDirection, ConnManager, NewConnection},
10+
endpoint::Connection,
1011
key::PublicKey,
1112
AddrInfo, Endpoint, NodeAddr,
1213
};
@@ -15,7 +16,7 @@ use rand_core::SeedableRng;
1516
use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc, task::Poll, time::Instant};
1617
use tokio::{
1718
sync::{broadcast, mpsc, oneshot},
18-
task::JoinHandle,
19+
task::{JoinHandle, JoinSet},
1920
};
2021
use tracing::{debug, error_span, trace, warn, Instrument};
2122

@@ -82,7 +83,7 @@ impl Gossip {
8283
/// Spawn a gossip actor and get a handle for it
8384
pub fn from_endpoint(endpoint: Endpoint, config: proto::Config, my_addr: &AddrInfo) -> Self {
8485
let peer_id = endpoint.node_id();
85-
let dialer = Dialer::new(endpoint.clone());
86+
let conn_manager = ConnManager::new(endpoint.clone(), GOSSIP_ALPN);
8687
let state = proto::State::new(
8788
peer_id,
8889
encode_peer_data(my_addr).unwrap(),
@@ -97,12 +98,12 @@ impl Gossip {
9798
let actor = Actor {
9899
endpoint,
99100
state,
100-
dialer,
101+
conn_manager,
102+
conn_tasks: Default::default(),
101103
to_actor_rx,
102104
in_event_rx,
103105
in_event_tx,
104106
on_endpoints_rx,
105-
conns: Default::default(),
106107
conn_send_tx: Default::default(),
107108
pending_sends: Default::default(),
108109
timers: Timers::new(),
@@ -231,9 +232,7 @@ impl Gossip {
231232
///
232233
/// Make sure to check the ALPN protocol yourself before passing the connection.
233234
pub async fn handle_connection(&self, conn: Connection) -> anyhow::Result<()> {
234-
let peer_id = get_remote_node_id(&conn)?;
235-
self.send(ToActor::ConnIncoming(peer_id, ConnOrigin::Accept, conn))
236-
.await?;
235+
self.send(ToActor::AcceptConn(conn)).await?;
237236
Ok(())
238237
}
239238

@@ -283,19 +282,13 @@ impl Future for JoinTopicFut {
283282
}
284283
}
285284

286-
/// Whether a connection is initiated by us (Dial) or by the remote peer (Accept)
287-
#[derive(Debug)]
288-
enum ConnOrigin {
289-
Accept,
290-
Dial,
291-
}
292-
293285
/// Input messages for the gossip [`Actor`].
294286
#[derive(derive_more::Debug)]
295287
enum ToActor {
296288
/// Handle a new QUIC connection, either from accept (external to the actor) or from connect
297289
/// (happens internally in the actor).
298-
ConnIncoming(PublicKey, ConnOrigin, #[debug(skip)] Connection),
290+
// ConnIncoming(NewConnection),
291+
AcceptConn(iroh_net::endpoint::Connection),
299292
/// Join a topic with a list of peers. Reply with oneshot once at least one peer joined.
300293
Join(
301294
TopicId,
@@ -330,7 +323,7 @@ struct Actor {
330323
state: proto::State<PublicKey, StdRng>,
331324
endpoint: Endpoint,
332325
/// Dial machine to connect to peers
333-
dialer: Dialer,
326+
conn_manager: ConnManager,
334327
/// Input messages to the actor
335328
to_actor_rx: mpsc::Receiver<ToActor>,
336329
/// Sender for the state input (cloned into the connection loops)
@@ -341,10 +334,9 @@ struct Actor {
341334
on_endpoints_rx: mpsc::Receiver<Vec<iroh_net::config::Endpoint>>,
342335
/// Queued timers
343336
timers: Timers<Timer>,
344-
/// Currently opened quinn connections to peers
345-
conns: HashMap<PublicKey, Connection>,
346337
/// Channels to send outbound messages into the connection loops
347338
conn_send_tx: HashMap<PublicKey, mpsc::Sender<ProtoMessage>>,
339+
conn_tasks: JoinSet<(PublicKey, anyhow::Result<()>)>,
348340
/// Queued messages that were to be sent before a dial completed
349341
pending_sends: HashMap<PublicKey, Vec<ProtoMessage>>,
350342
/// Broadcast senders for active topic subscriptions from the application
@@ -353,6 +345,12 @@ struct Actor {
353345
subscribers_all: Option<broadcast::Sender<(TopicId, Event)>>,
354346
}
355347

348+
impl Drop for Actor {
349+
fn drop(&mut self) {
350+
self.conn_tasks.abort_all();
351+
}
352+
}
353+
356354
impl Actor {
357355
pub async fn run(mut self) -> anyhow::Result<()> {
358356
let mut i = 0;
@@ -384,15 +382,31 @@ impl Actor {
384382
}
385383
}
386384
}
387-
(peer_id, res) = self.dialer.next_conn() => {
388-
trace!(?i, "tick: dialer");
385+
Some(new_conn) = self.conn_manager.next() => {
386+
trace!(?i, "tick: conn_manager");
387+
let node_id = new_conn.node_id;
388+
if let Err(err) = self.handle_new_connection(new_conn).await {
389+
warn!(peer=%node_id.fmt_short(), ?err, "failed to handle new connection");
390+
self.conn_manager.remove(&node_id);
391+
self.conn_send_tx.remove(&node_id);
392+
}
393+
}
394+
Some(res) = self.conn_tasks.join_next(), if !self.conn_tasks.is_empty() => {
389395
match res {
390-
Ok(conn) => {
391-
debug!(peer = ?peer_id, "dial successful");
392-
self.handle_to_actor_msg(ToActor::ConnIncoming(peer_id, ConnOrigin::Dial, conn), Instant::now()).await.context("dialer.next -> conn -> handle_to_actor_msg")?;
393-
}
394-
Err(err) => {
395-
warn!(peer = ?peer_id, "dial failed: {err}");
396+
Err(err) if !err.is_cancelled() => warn!(?err, "connection loop panicked"),
397+
Err(_err) => {},
398+
Ok((node_id, result)) => {
399+
self.conn_manager.remove(&node_id);
400+
self.conn_send_tx.remove(&node_id);
401+
self.handle_in_event(InEvent::PeerDisconnected(node_id), Instant::now()).await ?;
402+
match result {
403+
Ok(()) => {
404+
debug!(peer=%node_id.fmt_short(), "connection closed without error");
405+
}
406+
Err(err) => {
407+
debug!(peer=%node_id.fmt_short(), "connection closed with error {err:?}");
408+
}
409+
}
396410
}
397411
}
398412
}
@@ -421,38 +435,9 @@ impl Actor {
421435
async fn handle_to_actor_msg(&mut self, msg: ToActor, now: Instant) -> anyhow::Result<()> {
422436
trace!("handle to_actor {msg:?}");
423437
match msg {
424-
ToActor::ConnIncoming(peer_id, origin, conn) => {
425-
self.conns.insert(peer_id, conn.clone());
426-
self.dialer.abort_dial(&peer_id);
427-
let (send_tx, send_rx) = mpsc::channel(SEND_QUEUE_CAP);
428-
self.conn_send_tx.insert(peer_id, send_tx.clone());
429-
430-
// Spawn a task for this connection
431-
let in_event_tx = self.in_event_tx.clone();
432-
tokio::spawn(
433-
async move {
434-
debug!("connection established");
435-
match connection_loop(peer_id, conn, origin, send_rx, &in_event_tx).await {
436-
Ok(()) => {
437-
debug!("connection closed without error")
438-
}
439-
Err(err) => {
440-
debug!("connection closed with error {err:?}")
441-
}
442-
}
443-
in_event_tx
444-
.send(InEvent::PeerDisconnected(peer_id))
445-
.await
446-
.ok();
447-
}
448-
.instrument(error_span!("gossip_conn", peer = %peer_id.fmt_short())),
449-
);
450-
451-
// Forward queued pending sends
452-
if let Some(send_queue) = self.pending_sends.remove(&peer_id) {
453-
for msg in send_queue {
454-
send_tx.send(msg).await?;
455-
}
438+
ToActor::AcceptConn(conn) => {
439+
if let Err(err) = self.conn_manager.push_accept(conn) {
440+
warn!(?err, "failed to accept connection");
456441
}
457442
}
458443
ToActor::Join(topic_id, peers, reply) => {
@@ -502,9 +487,6 @@ impl Actor {
502487
} else {
503488
debug!("handle in_event {event:?}");
504489
};
505-
if let InEvent::PeerDisconnected(peer) = &event {
506-
self.conn_send_tx.remove(peer);
507-
}
508490
let out = self.state.handle(event, now);
509491
for event in out {
510492
if matches!(event, OutEvent::ScheduleTimer(_, _)) {
@@ -518,10 +500,13 @@ impl Actor {
518500
if let Err(_err) = send.send(message).await {
519501
warn!("conn receiver for {peer_id:?} dropped");
520502
self.conn_send_tx.remove(&peer_id);
503+
self.conn_manager.remove(&peer_id);
521504
}
522505
} else {
523-
debug!(peer = ?peer_id, "dial");
524-
self.dialer.queue_dial(peer_id, GOSSIP_ALPN);
506+
if !self.conn_manager.is_dialing(&peer_id) {
507+
debug!(peer = ?peer_id, "dial");
508+
self.conn_manager.dial(peer_id);
509+
}
525510
// TODO: Enforce max length
526511
self.pending_sends.entry(peer_id).or_default().push(message);
527512
}
@@ -544,12 +529,11 @@ impl Actor {
544529
self.timers.insert(now + delay, timer);
545530
}
546531
OutEvent::DisconnectPeer(peer) => {
547-
if let Some(conn) = self.conns.remove(&peer) {
548-
conn.close(0u8.into(), b"close from disconnect");
549-
}
550532
self.conn_send_tx.remove(&peer);
551533
self.pending_sends.remove(&peer);
552-
self.dialer.abort_dial(&peer);
534+
if let Some(conn) = self.conn_manager.remove(&peer) {
535+
conn.close(0u8.into(), b"close from disconnect");
536+
}
553537
}
554538
OutEvent::PeerData(node_id, data) => match decode_peer_data(&data) {
555539
Err(err) => warn!("Failed to decode {data:?} from {node_id}: {err}"),
@@ -566,6 +550,42 @@ impl Actor {
566550
Ok(())
567551
}
568552

553+
async fn handle_new_connection(&mut self, new_conn: NewConnection) -> anyhow::Result<()> {
554+
let NewConnection {
555+
conn,
556+
node_id: peer_id,
557+
direction,
558+
} = new_conn;
559+
match conn {
560+
Ok(conn) => {
561+
let (send_tx, send_rx) = mpsc::channel(SEND_QUEUE_CAP);
562+
self.conn_send_tx.insert(peer_id, send_tx.clone());
563+
564+
// Spawn a task for this connection
565+
let pending_sends = self.pending_sends.remove(&peer_id);
566+
let in_event_tx = self.in_event_tx.clone();
567+
debug!(peer=%peer_id.fmt_short(), ?direction, "connection established");
568+
self.conn_tasks.spawn(
569+
connection_loop(
570+
peer_id,
571+
conn,
572+
direction,
573+
send_rx,
574+
in_event_tx,
575+
pending_sends,
576+
)
577+
.map(move |r| (peer_id, r))
578+
.instrument(error_span!("gossip_conn", peer = %peer_id.fmt_short())),
579+
);
580+
}
581+
Err(err) => {
582+
warn!(peer=%peer_id.fmt_short(), "connecting to node failed: {err:?}");
583+
}
584+
}
585+
586+
Ok(())
587+
}
588+
569589
fn subscribe_all(&mut self) -> broadcast::Receiver<(TopicId, Event)> {
570590
if let Some(tx) = self.subscribers_all.as_mut() {
571591
tx.subscribe()
@@ -602,20 +622,30 @@ async fn wait_for_neighbor_up(mut sub: broadcast::Receiver<Event>) -> anyhow::Re
602622
async fn connection_loop(
603623
from: PublicKey,
604624
conn: Connection,
605-
origin: ConnOrigin,
625+
direction: ConnDirection,
606626
mut send_rx: mpsc::Receiver<ProtoMessage>,
607-
in_event_tx: &mpsc::Sender<InEvent>,
627+
in_event_tx: mpsc::Sender<InEvent>,
628+
mut pending_sends: Option<Vec<ProtoMessage>>,
608629
) -> anyhow::Result<()> {
609-
let (mut send, mut recv) = match origin {
610-
ConnOrigin::Accept => conn.accept_bi().await?,
611-
ConnOrigin::Dial => conn.open_bi().await?,
630+
let (mut send, mut recv) = match direction {
631+
ConnDirection::Accept => conn.accept_bi().await?,
632+
ConnDirection::Dial => conn.open_bi().await?,
612633
};
613634
let mut send_buf = BytesMut::new();
614635
let mut recv_buf = BytesMut::new();
636+
637+
// Forward queued pending sends
638+
if let Some(mut send_queue) = pending_sends.take() {
639+
for msg in send_queue.drain(..) {
640+
write_message(&mut send, &mut send_buf, &msg).await?;
641+
}
642+
}
643+
644+
// loop over sending and receiving messages
615645
loop {
616646
tokio::select! {
617647
biased;
618-
msg = send_rx.recv() => {
648+
msg = send_rx.recv(), if !send_rx.is_closed() => {
619649
match msg {
620650
None => break,
621651
Some(msg) => write_message(&mut send, &mut send_buf, &msg).await?,

0 commit comments

Comments
 (0)