Skip to content

Commit 0181e86

Browse files
fix(mdns): move IO off main task
Resolves: #2591. Pull-Request: #4623.
1 parent d26e04a commit 0181e86

File tree

4 files changed

+171
-93
lines changed

4 files changed

+171
-93
lines changed

protocols/mdns/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
## 0.45.0 - unreleased
22

3+
- Don't perform IO in `Behaviour::poll`.
4+
See [PR 4623](https://github.com/libp2p/rust-libp2p/pull/4623).
35

46
## 0.44.0
57

protocols/mdns/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ keywords = ["peer-to-peer", "libp2p", "networking"]
1111
categories = ["network-programming", "asynchronous"]
1212

1313
[dependencies]
14+
async-std = { version = "1.12.0", optional = true }
1415
async-io = { version = "1.13.0", optional = true }
1516
data-encoding = "2.4.0"
1617
futures = "0.3.28"
@@ -28,7 +29,7 @@ void = "1.0.2"
2829

2930
[features]
3031
tokio = ["dep:tokio", "if-watch/tokio"]
31-
async-io = ["dep:async-io", "if-watch/smol"]
32+
async-io = ["dep:async-io", "dep:async-std", "if-watch/smol"]
3233

3334
[dev-dependencies]
3435
async-std = { version = "1.9.0", features = ["attributes"] }

protocols/mdns/src/behaviour.rs

Lines changed: 88 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ mod timer;
2525
use self::iface::InterfaceState;
2626
use crate::behaviour::{socket::AsyncSocket, timer::Builder};
2727
use crate::Config;
28-
use futures::Stream;
28+
use futures::channel::mpsc;
29+
use futures::{Stream, StreamExt};
2930
use if_watch::IfEvent;
3031
use libp2p_core::{Endpoint, Multiaddr};
3132
use libp2p_identity::PeerId;
@@ -36,6 +37,8 @@ use libp2p_swarm::{
3637
};
3738
use smallvec::SmallVec;
3839
use std::collections::hash_map::{Entry, HashMap};
40+
use std::future::Future;
41+
use std::sync::{Arc, RwLock};
3942
use std::{cmp, fmt, io, net::IpAddr, pin::Pin, task::Context, task::Poll, time::Instant};
4043

4144
/// An abstraction to allow for compatibility with various async runtimes.
@@ -47,16 +50,27 @@ pub trait Provider: 'static {
4750
/// The IfWatcher type.
4851
type Watcher: Stream<Item = std::io::Result<IfEvent>> + fmt::Debug + Unpin;
4952

53+
type TaskHandle: Abort;
54+
5055
/// Create a new instance of the `IfWatcher` type.
5156
fn new_watcher() -> Result<Self::Watcher, std::io::Error>;
57+
58+
fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle;
59+
}
60+
61+
#[allow(unreachable_pub)] // Not re-exported.
62+
pub trait Abort {
63+
fn abort(self);
5264
}
5365

5466
/// The type of a [`Behaviour`] using the `async-io` implementation.
5567
#[cfg(feature = "async-io")]
5668
pub mod async_io {
5769
use super::Provider;
58-
use crate::behaviour::{socket::asio::AsyncUdpSocket, timer::asio::AsyncTimer};
70+
use crate::behaviour::{socket::asio::AsyncUdpSocket, timer::asio::AsyncTimer, Abort};
71+
use async_std::task::JoinHandle;
5972
use if_watch::smol::IfWatcher;
73+
use std::future::Future;
6074

6175
#[doc(hidden)]
6276
pub enum AsyncIo {}
@@ -65,10 +79,21 @@ pub mod async_io {
6579
type Socket = AsyncUdpSocket;
6680
type Timer = AsyncTimer;
6781
type Watcher = IfWatcher;
82+
type TaskHandle = JoinHandle<()>;
6883

6984
fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
7085
IfWatcher::new()
7186
}
87+
88+
fn spawn(task: impl Future<Output = ()> + Send + 'static) -> JoinHandle<()> {
89+
async_std::task::spawn(task)
90+
}
91+
}
92+
93+
impl Abort for JoinHandle<()> {
94+
fn abort(self) {
95+
async_std::task::spawn(self.cancel());
96+
}
7297
}
7398

7499
pub type Behaviour = super::Behaviour<AsyncIo>;
@@ -78,8 +103,10 @@ pub mod async_io {
78103
#[cfg(feature = "tokio")]
79104
pub mod tokio {
80105
use super::Provider;
81-
use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer};
106+
use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer, Abort};
82107
use if_watch::tokio::IfWatcher;
108+
use std::future::Future;
109+
use tokio::task::JoinHandle;
83110

84111
#[doc(hidden)]
85112
pub enum Tokio {}
@@ -88,10 +115,21 @@ pub mod tokio {
88115
type Socket = TokioUdpSocket;
89116
type Timer = TokioTimer;
90117
type Watcher = IfWatcher;
118+
type TaskHandle = JoinHandle<()>;
91119

92120
fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
93121
IfWatcher::new()
94122
}
123+
124+
fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle {
125+
tokio::spawn(task)
126+
}
127+
}
128+
129+
impl Abort for JoinHandle<()> {
130+
fn abort(self) {
131+
JoinHandle::abort(&self)
132+
}
95133
}
96134

97135
pub type Behaviour = super::Behaviour<Tokio>;
@@ -110,8 +148,11 @@ where
110148
/// Iface watcher.
111149
if_watch: P::Watcher,
112150

113-
/// Mdns interface states.
114-
iface_states: HashMap<IpAddr, InterfaceState<P::Socket, P::Timer>>,
151+
/// Handles to tasks running the mDNS queries.
152+
if_tasks: HashMap<IpAddr, P::TaskHandle>,
153+
154+
query_response_receiver: mpsc::Receiver<(PeerId, Multiaddr, Instant)>,
155+
query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>,
115156

116157
/// List of nodes that we have discovered, the address, and when their TTL expires.
117158
///
@@ -124,7 +165,11 @@ where
124165
/// `None` if `discovered_nodes` is empty.
125166
closest_expiration: Option<P::Timer>,
126167

127-
listen_addresses: ListenAddresses,
168+
/// The current set of listen addresses.
169+
///
170+
/// This is shared across all interface tasks using an [`RwLock`].
171+
/// The [`Behaviour`] updates this upon new [`FromSwarm`] events where as [`InterfaceState`]s read from it to answer inbound mDNS queries.
172+
listen_addresses: Arc<RwLock<ListenAddresses>>,
128173

129174
local_peer_id: PeerId,
130175
}
@@ -135,10 +180,14 @@ where
135180
{
136181
/// Builds a new `Mdns` behaviour.
137182
pub fn new(config: Config, local_peer_id: PeerId) -> io::Result<Self> {
183+
let (tx, rx) = mpsc::channel(10); // Chosen arbitrarily.
184+
138185
Ok(Self {
139186
config,
140187
if_watch: P::new_watcher()?,
141-
iface_states: Default::default(),
188+
if_tasks: Default::default(),
189+
query_response_receiver: rx,
190+
query_response_sender: tx,
142191
discovered_nodes: Default::default(),
143192
closest_expiration: Default::default(),
144193
listen_addresses: Default::default(),
@@ -147,6 +196,7 @@ where
147196
}
148197

149198
/// Returns true if the given `PeerId` is in the list of nodes discovered through mDNS.
199+
#[deprecated(note = "Use `discovered_nodes` iterator instead.")]
150200
pub fn has_node(&self, peer_id: &PeerId) -> bool {
151201
self.discovered_nodes().any(|p| p == peer_id)
152202
}
@@ -157,6 +207,7 @@ where
157207
}
158208

159209
/// Expires a node before the ttl.
210+
#[deprecated(note = "Unused API. Will be removed in the next release.")]
160211
pub fn expire_node(&mut self, peer_id: &PeerId) {
161212
let now = Instant::now();
162213
for (peer, _addr, expires) in &mut self.discovered_nodes {
@@ -225,28 +276,10 @@ where
225276
}
226277

227278
fn on_swarm_event(&mut self, event: FromSwarm<Self::ConnectionHandler>) {
228-
self.listen_addresses.on_swarm_event(&event);
229-
230-
match event {
231-
FromSwarm::NewListener(_) => {
232-
log::trace!("waking interface state because listening address changed");
233-
for iface in self.iface_states.values_mut() {
234-
iface.fire_timer();
235-
}
236-
}
237-
FromSwarm::ConnectionClosed(_)
238-
| FromSwarm::ConnectionEstablished(_)
239-
| FromSwarm::DialFailure(_)
240-
| FromSwarm::AddressChange(_)
241-
| FromSwarm::ListenFailure(_)
242-
| FromSwarm::NewListenAddr(_)
243-
| FromSwarm::ExpiredListenAddr(_)
244-
| FromSwarm::ListenerError(_)
245-
| FromSwarm::ListenerClosed(_)
246-
| FromSwarm::NewExternalAddrCandidate(_)
247-
| FromSwarm::ExternalAddrExpired(_)
248-
| FromSwarm::ExternalAddrConfirmed(_) => {}
249-
}
279+
self.listen_addresses
280+
.write()
281+
.unwrap_or_else(|e| e.into_inner())
282+
.on_swarm_event(&event);
250283
}
251284

252285
fn poll(
@@ -267,43 +300,50 @@ where
267300
{
268301
continue;
269302
}
270-
if let Entry::Vacant(e) = self.iface_states.entry(addr) {
271-
match InterfaceState::new(addr, self.config.clone(), self.local_peer_id) {
303+
if let Entry::Vacant(e) = self.if_tasks.entry(addr) {
304+
match InterfaceState::<P::Socket, P::Timer>::new(
305+
addr,
306+
self.config.clone(),
307+
self.local_peer_id,
308+
self.listen_addresses.clone(),
309+
self.query_response_sender.clone(),
310+
) {
272311
Ok(iface_state) => {
273-
e.insert(iface_state);
312+
e.insert(P::spawn(iface_state));
274313
}
275314
Err(err) => log::error!("failed to create `InterfaceState`: {}", err),
276315
}
277316
}
278317
}
279318
Ok(IfEvent::Down(inet)) => {
280-
if self.iface_states.contains_key(&inet.addr()) {
319+
if let Some(handle) = self.if_tasks.remove(&inet.addr()) {
281320
log::info!("dropping instance {}", inet.addr());
282-
self.iface_states.remove(&inet.addr());
321+
322+
handle.abort();
283323
}
284324
}
285325
Err(err) => log::error!("if watch returned an error: {}", err),
286326
}
287327
}
288328
// Emit discovered event.
289329
let mut discovered = Vec::new();
290-
for iface_state in self.iface_states.values_mut() {
291-
while let Poll::Ready((peer, addr, expiration)) =
292-
iface_state.poll(cx, &self.listen_addresses)
330+
331+
while let Poll::Ready(Some((peer, addr, expiration))) =
332+
self.query_response_receiver.poll_next_unpin(cx)
333+
{
334+
if let Some((_, _, cur_expires)) = self
335+
.discovered_nodes
336+
.iter_mut()
337+
.find(|(p, a, _)| *p == peer && *a == addr)
293338
{
294-
if let Some((_, _, cur_expires)) = self
295-
.discovered_nodes
296-
.iter_mut()
297-
.find(|(p, a, _)| *p == peer && *a == addr)
298-
{
299-
*cur_expires = cmp::max(*cur_expires, expiration);
300-
} else {
301-
log::info!("discovered: {} {}", peer, addr);
302-
self.discovered_nodes.push((peer, addr.clone(), expiration));
303-
discovered.push((peer, addr));
304-
}
339+
*cur_expires = cmp::max(*cur_expires, expiration);
340+
} else {
341+
log::info!("discovered: {} {}", peer, addr);
342+
self.discovered_nodes.push((peer, addr.clone(), expiration));
343+
discovered.push((peer, addr));
305344
}
306345
}
346+
307347
if !discovered.is_empty() {
308348
let event = Event::Discovered(discovered);
309349
return Poll::Ready(ToSwarm::GenerateEvent(event));

0 commit comments

Comments
 (0)