Skip to content

Commit 52c725b

Browse files
fix: avoid race condition between pending frames and closing stream (#156)
Currently, we have a `garbage_collect` function that checks whether any of our streams have been dropped. This can cause a race condition where the channel between a `Stream` and the `Connection` still has pending frames for a stream but dropping a stream causes us to already send a `FIN` flag for the stream. We fix this by maintaining a single channel for each stream. When a stream gets dropped, the `Receiver` becomes disconnected. We use this information to queue the correct frame (`FIN` vs `RST`) into the buffer. At this point, all previous frames have already been processed and the race condition is thus not present. Additionally, this also allows us to implement `Stream::poll_flush` by forwarding to the underlying `Sender`. Note that at present day, this only checks whether there is _space_ in the channel, not whether the items have been emitted by the `Receiver`. We have a PR upstream that might fix this: rust-lang/futures-rs#2746 Fixes: #117.
1 parent 88ed4df commit 52c725b

File tree

8 files changed

+212
-167
lines changed

8 files changed

+212
-167
lines changed

test-harness/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,3 @@ log = "0.4.17"
1616
[dev-dependencies]
1717
env_logger = "0.10"
1818
constrained-connection = "0.1"
19-

yamux/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ nohash-hasher = "0.2"
1616
parking_lot = "0.12"
1717
rand = "0.8.3"
1818
static_assertions = "1"
19+
pin-project = "1.1.0"
1920

2021
[dev-dependencies]
2122
anyhow = "1"
@@ -26,6 +27,7 @@ quickcheck = "1.0"
2627
tokio = { version = "1.0", features = ["net", "rt-multi-thread", "macros", "time"] }
2728
tokio-util = { version = "0.7", features = ["compat"] }
2829
constrained-connection = "0.1"
30+
futures_ringbuf = "0.3.1"
2931

3032
[[bench]]
3133
name = "concurrent"

yamux/src/connection.rs

Lines changed: 108 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,18 @@ use crate::{
9696
error::ConnectionError,
9797
frame::header::{self, Data, GoAway, Header, Ping, StreamId, Tag, WindowUpdate, CONNECTION_ID},
9898
frame::{self, Frame},
99-
Config, WindowUpdateMode, DEFAULT_CREDIT, MAX_COMMAND_BACKLOG,
99+
Config, WindowUpdateMode, DEFAULT_CREDIT,
100100
};
101101
use cleanup::Cleanup;
102102
use closing::Closing;
103+
use futures::stream::SelectAll;
103104
use futures::{channel::mpsc, future::Either, prelude::*, sink::SinkExt, stream::Fuse};
104105
use nohash_hasher::IntMap;
105106
use std::collections::VecDeque;
106-
use std::task::Context;
107+
use std::task::{Context, Waker};
107108
use std::{fmt, sync::Arc, task::Poll};
108109

110+
use crate::tagged_stream::TaggedStream;
109111
pub use stream::{Packet, State, Stream};
110112

111113
/// How the connection is used.
@@ -347,10 +349,11 @@ struct Active<T> {
347349
config: Arc<Config>,
348350
socket: Fuse<frame::Io<T>>,
349351
next_id: u32,
352+
350353
streams: IntMap<StreamId, Stream>,
351-
stream_sender: mpsc::Sender<StreamCommand>,
352-
stream_receiver: mpsc::Receiver<StreamCommand>,
353-
dropped_streams: Vec<StreamId>,
354+
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
355+
no_streams_waker: Option<Waker>,
356+
354357
pending_frames: VecDeque<Frame<()>>,
355358
}
356359

@@ -360,7 +363,7 @@ pub(crate) enum StreamCommand {
360363
/// A new frame should be sent to the remote.
361364
SendFrame(Frame<Either<Data, WindowUpdate>>),
362365
/// Close a stream.
363-
CloseStream { id: StreamId, ack: bool },
366+
CloseStream { ack: bool },
364367
}
365368

366369
/// Possible actions as a result of incoming frame handling.
@@ -408,28 +411,26 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
408411
fn new(socket: T, cfg: Config, mode: Mode) -> Self {
409412
let id = Id::random();
410413
log::debug!("new connection: {} ({:?})", id, mode);
411-
let (stream_sender, stream_receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
412414
let socket = frame::Io::new(id, socket, cfg.max_buffer_size).fuse();
413415
Active {
414416
id,
415417
mode,
416418
config: Arc::new(cfg),
417419
socket,
418420
streams: IntMap::default(),
419-
stream_sender,
420-
stream_receiver,
421+
stream_receivers: SelectAll::default(),
422+
no_streams_waker: None,
421423
next_id: match mode {
422424
Mode::Client => 1,
423425
Mode::Server => 2,
424426
},
425-
dropped_streams: Vec::new(),
426427
pending_frames: VecDeque::default(),
427428
}
428429
}
429430

430431
/// Gracefully close the connection to the remote.
431432
fn close(self) -> Closing<T> {
432-
Closing::new(self.stream_receiver, self.pending_frames, self.socket)
433+
Closing::new(self.stream_receivers, self.pending_frames, self.socket)
433434
}
434435

435436
/// Cleanup all our resources.
@@ -438,13 +439,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
438439
fn cleanup(mut self, error: ConnectionError) -> Cleanup {
439440
self.drop_all_streams();
440441

441-
Cleanup::new(self.stream_receiver, error)
442+
Cleanup::new(self.stream_receivers, error)
442443
}
443444

444445
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream>> {
445446
loop {
446-
self.garbage_collect();
447-
448447
if self.socket.poll_ready_unpin(cx).is_ready() {
449448
if let Some(frame) = self.pending_frames.pop_front() {
450449
self.socket.start_send_unpin(frame)?;
@@ -457,17 +456,21 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
457456
Poll::Pending => {}
458457
}
459458

460-
match self.stream_receiver.poll_next_unpin(cx) {
461-
Poll::Ready(Some(StreamCommand::SendFrame(frame))) => {
462-
self.on_send_frame(frame);
459+
match self.stream_receivers.poll_next_unpin(cx) {
460+
Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => {
461+
self.on_send_frame(frame.into());
463462
continue;
464463
}
465-
Poll::Ready(Some(StreamCommand::CloseStream { id, ack })) => {
464+
Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => {
466465
self.on_close_stream(id, ack);
467466
continue;
468467
}
468+
Poll::Ready(Some((id, None))) => {
469+
self.on_drop_stream(id);
470+
continue;
471+
}
469472
Poll::Ready(None) => {
470-
debug_assert!(false, "Only closed during shutdown")
473+
self.no_streams_waker = Some(cx.waker().clone());
471474
}
472475
Poll::Pending => {}
473476
}
@@ -508,16 +511,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
508511
self.pending_frames.push_back(frame.into());
509512
}
510513

511-
let stream = {
512-
let config = self.config.clone();
513-
let sender = self.stream_sender.clone();
514-
let window = self.config.receive_window;
515-
let mut stream = Stream::new(id, self.id, config, window, DEFAULT_CREDIT, sender);
516-
if extra_credit == 0 {
517-
stream.set_flag(stream::Flag::Syn)
518-
}
519-
stream
520-
};
514+
let mut stream = self.make_new_stream(id, self.config.receive_window, DEFAULT_CREDIT);
515+
516+
if extra_credit == 0 {
517+
stream.set_flag(stream::Flag::Syn)
518+
}
521519

522520
log::debug!("{}: new outbound {} of {}", self.id, stream, self);
523521
self.streams.insert(id, stream.clone());
@@ -541,6 +539,69 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
541539
.push_back(Frame::close_stream(id, ack).into());
542540
}
543541

542+
fn on_drop_stream(&mut self, id: StreamId) {
543+
let stream = self.streams.remove(&id).expect("stream not found");
544+
545+
log::trace!("{}: removing dropped {}", self.id, stream);
546+
let stream_id = stream.id();
547+
let frame = {
548+
let mut shared = stream.shared();
549+
let frame = match shared.update_state(self.id, stream_id, State::Closed) {
550+
// The stream was dropped without calling `poll_close`.
551+
// We reset the stream to inform the remote of the closure.
552+
State::Open => {
553+
let mut header = Header::data(stream_id, 0);
554+
header.rst();
555+
Some(Frame::new(header))
556+
}
557+
// The stream was dropped without calling `poll_close`.
558+
// We have already received a FIN from remote and send one
559+
// back which closes the stream for good.
560+
State::RecvClosed => {
561+
let mut header = Header::data(stream_id, 0);
562+
header.fin();
563+
Some(Frame::new(header))
564+
}
565+
// The stream was properly closed. We already sent our FIN frame.
566+
// The remote may be out of credit though and blocked on
567+
// writing more data. We may need to reset the stream.
568+
State::SendClosed => {
569+
if self.config.window_update_mode == WindowUpdateMode::OnRead
570+
&& shared.window == 0
571+
{
572+
// The remote may be waiting for a window update
573+
// which we will never send, so reset the stream now.
574+
let mut header = Header::data(stream_id, 0);
575+
header.rst();
576+
Some(Frame::new(header))
577+
} else {
578+
// The remote has either still credit or will be given more
579+
// (due to an enqueued window update or because the update
580+
// mode is `OnReceive`) or we already have inbound frames in
581+
// the socket buffer which will be processed later. In any
582+
// case we will reply with an RST in `Connection::on_data`
583+
// because the stream will no longer be known.
584+
None
585+
}
586+
}
587+
// The stream was properly closed. We already have sent our FIN frame. The
588+
// remote end has already done so in the past.
589+
State::Closed => None,
590+
};
591+
if let Some(w) = shared.reader.take() {
592+
w.wake()
593+
}
594+
if let Some(w) = shared.writer.take() {
595+
w.wake()
596+
}
597+
frame
598+
};
599+
if let Some(f) = frame {
600+
log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header());
601+
self.pending_frames.push_back(f.into());
602+
}
603+
}
604+
544605
/// Process the result of reading from the socket.
545606
///
546607
/// Unless `frame` is `Ok(Some(_))` we will assume the connection got closed
@@ -628,12 +689,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
628689
log::error!("{}: maximum number of streams reached", self.id);
629690
return Action::Terminate(Frame::internal_error());
630691
}
631-
let mut stream = {
632-
let config = self.config.clone();
633-
let credit = DEFAULT_CREDIT;
634-
let sender = self.stream_sender.clone();
635-
Stream::new(stream_id, self.id, config, credit, credit, sender)
636-
};
692+
let mut stream = self.make_new_stream(stream_id, DEFAULT_CREDIT, DEFAULT_CREDIT);
637693
let mut window_update = None;
638694
{
639695
let mut shared = stream.shared();
@@ -748,15 +804,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
748804
log::error!("{}: maximum number of streams reached", self.id);
749805
return Action::Terminate(Frame::protocol_error());
750806
}
751-
let stream = {
752-
let credit = frame.header().credit() + DEFAULT_CREDIT;
753-
let config = self.config.clone();
754-
let sender = self.stream_sender.clone();
755-
let mut stream =
756-
Stream::new(stream_id, self.id, config, DEFAULT_CREDIT, credit, sender);
757-
stream.set_flag(stream::Flag::Ack);
758-
stream
759-
};
807+
808+
let credit = frame.header().credit() + DEFAULT_CREDIT;
809+
let mut stream = self.make_new_stream(stream_id, DEFAULT_CREDIT, credit);
810+
stream.set_flag(stream::Flag::Ack);
811+
760812
if is_finish {
761813
stream
762814
.shared()
@@ -821,6 +873,18 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
821873
Action::None
822874
}
823875

876+
fn make_new_stream(&mut self, id: StreamId, window: u32, credit: u32) -> Stream {
877+
let config = self.config.clone();
878+
879+
let (sender, receiver) = mpsc::channel(10); // 10 is an arbitrary number.
880+
self.stream_receivers.push(TaggedStream::new(id, receiver));
881+
if let Some(waker) = self.no_streams_waker.take() {
882+
waker.wake();
883+
}
884+
885+
Stream::new(id, self.id, config, window, credit, sender)
886+
}
887+
824888
fn next_stream_id(&mut self) -> Result<StreamId> {
825889
let proposed = StreamId::new(self.next_id);
826890
self.next_id = self
@@ -844,79 +908,6 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
844908
Mode::Server => id.is_client(),
845909
}
846910
}
847-
848-
/// Remove stale streams and create necessary messages to be sent to the remote.
849-
fn garbage_collect(&mut self) {
850-
let conn_id = self.id;
851-
let win_update_mode = self.config.window_update_mode;
852-
for stream in self.streams.values_mut() {
853-
if stream.strong_count() > 1 {
854-
continue;
855-
}
856-
log::trace!("{}: removing dropped {}", conn_id, stream);
857-
let stream_id = stream.id();
858-
let frame = {
859-
let mut shared = stream.shared();
860-
let frame = match shared.update_state(conn_id, stream_id, State::Closed) {
861-
// The stream was dropped without calling `poll_close`.
862-
// We reset the stream to inform the remote of the closure.
863-
State::Open => {
864-
let mut header = Header::data(stream_id, 0);
865-
header.rst();
866-
Some(Frame::new(header))
867-
}
868-
// The stream was dropped without calling `poll_close`.
869-
// We have already received a FIN from remote and send one
870-
// back which closes the stream for good.
871-
State::RecvClosed => {
872-
let mut header = Header::data(stream_id, 0);
873-
header.fin();
874-
Some(Frame::new(header))
875-
}
876-
// The stream was properly closed. We either already have
877-
// or will at some later point send our FIN frame.
878-
// The remote may be out of credit though and blocked on
879-
// writing more data. We may need to reset the stream.
880-
State::SendClosed => {
881-
if win_update_mode == WindowUpdateMode::OnRead && shared.window == 0 {
882-
// The remote may be waiting for a window update
883-
// which we will never send, so reset the stream now.
884-
let mut header = Header::data(stream_id, 0);
885-
header.rst();
886-
Some(Frame::new(header))
887-
} else {
888-
// The remote has either still credit or will be given more
889-
// (due to an enqueued window update or because the update
890-
// mode is `OnReceive`) or we already have inbound frames in
891-
// the socket buffer which will be processed later. In any
892-
// case we will reply with an RST in `Connection::on_data`
893-
// because the stream will no longer be known.
894-
None
895-
}
896-
}
897-
// The stream was properly closed. We either already have
898-
// or will at some later point send our FIN frame. The
899-
// remote end has already done so in the past.
900-
State::Closed => None,
901-
};
902-
if let Some(w) = shared.reader.take() {
903-
w.wake()
904-
}
905-
if let Some(w) = shared.writer.take() {
906-
w.wake()
907-
}
908-
frame
909-
};
910-
if let Some(f) = frame {
911-
log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header());
912-
self.pending_frames.push_back(f.into());
913-
}
914-
self.dropped_streams.push(stream_id)
915-
}
916-
for id in self.dropped_streams.drain(..) {
917-
self.streams.remove(&id);
918-
}
919-
}
920911
}
921912

922913
impl<T> Active<T> {

0 commit comments

Comments
 (0)