diff --git a/Cargo.toml b/Cargo.toml index fbf3f11..e6422d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,8 @@ proptest-derive = "0.5" rand = "0.8.5" tokio = { version = "1.24.2", default-features = false, features = ["io-util", "macros", "rt-multi-thread", "io-std"] } tokio-util = { version = "0.7", default-features = false, features = ["io"] } +tracing = "0.1.40" +tracing-subscriber = "0.3.18" [[test]] name = "brotli" diff --git a/src/tokio/write/generic/encoder.rs b/src/tokio/write/generic/encoder.rs index f5a83aa..c26dc24 100644 --- a/src/tokio/write/generic/encoder.rs +++ b/src/tokio/write/generic/encoder.rs @@ -13,20 +13,13 @@ use futures_core::ready; use pin_project_lite::pin_project; use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; -#[derive(Debug)] -enum State { - Encoding, - Finishing, - Done, -} - pin_project! { #[derive(Debug)] pub struct Encoder { #[pin] writer: BufWriter, encoder: E, - state: State, + finished: bool } } @@ -35,7 +28,7 @@ impl Encoder { Self { writer: BufWriter::new(writer), encoder, - state: State::Encoding, + finished: false, } } } @@ -62,97 +55,6 @@ impl Encoder { } } -impl Encoder { - fn do_poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - input: &mut PartialBuffer<&[u8]>, - ) -> Poll> { - let mut this = self.project(); - - loop { - let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = PartialBuffer::new(output); - - *this.state = match this.state { - State::Encoding => { - this.encoder.encode(input, &mut output)?; - State::Encoding - } - - State::Finishing | State::Done => { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::Other, - "Write after shutdown", - ))) - } - }; - - let produced = output.written().len(); - this.writer.as_mut().produce(produced); - - if input.unwritten().is_empty() { - return Poll::Ready(Ok(())); - } - } - } - - fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - loop { - let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = PartialBuffer::new(output); - - let done = match this.state { - State::Encoding => this.encoder.flush(&mut output)?, - - State::Finishing | State::Done => { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::Other, - "Flush after shutdown", - ))) - } - }; - - let produced = output.written().len(); - this.writer.as_mut().produce(produced); - - if done { - return Poll::Ready(Ok(())); - } - } - } - - fn do_poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - loop { - let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = PartialBuffer::new(output); - - *this.state = match this.state { - State::Encoding | State::Finishing => { - if this.encoder.finish(&mut output)? { - State::Done - } else { - State::Finishing - } - } - - State::Done => State::Done, - }; - - let produced = output.written().len(); - this.writer.as_mut().produce(produced); - - if let State::Done = this.state { - return Poll::Ready(Ok(())); - } - } - } -} - impl AsyncWrite for Encoder { fn poll_write( self: Pin<&mut Self>, @@ -163,24 +65,49 @@ impl AsyncWrite for Encoder { return Poll::Ready(Ok(0)); } - let mut input = PartialBuffer::new(buf); + let mut this = self.project(); + + let mut encodeme = PartialBuffer::new(buf); - match self.do_poll_write(cx, &mut input)? { - Poll::Pending if input.written().is_empty() => Poll::Pending, - _ => Poll::Ready(Ok(input.written().len())), + loop { + let mut space = + PartialBuffer::new(ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?); + this.encoder.encode(&mut encodeme, &mut space)?; + let bytes_encoded = space.written().len(); + this.writer.as_mut().produce(bytes_encoded); + if encodeme.unwritten().is_empty() { + break; + } } + + Poll::Ready(Ok(encodeme.written().len())) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().do_poll_flush(cx))?; - ready!(self.project().writer.as_mut().poll_flush(cx))?; + let mut this = self.project(); + loop { + let mut space = + PartialBuffer::new(ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?); + let flushed = this.encoder.flush(&mut space)?; + let bytes_encoded = space.written().len(); + this.writer.as_mut().produce(bytes_encoded); + if flushed { + break; + } + } Poll::Ready(Ok(())) } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().do_poll_shutdown(cx))?; - ready!(self.project().writer.as_mut().poll_shutdown(cx))?; - Poll::Ready(Ok(())) + let mut this = self.project(); + while !*this.finished { + let mut space = + PartialBuffer::new(ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?); + *this.finished = this.encoder.finish(&mut space)?; + let bytes_encoded = space.written().len(); + this.writer.as_mut().produce(bytes_encoded); + } + this.writer.poll_shutdown(cx) } } diff --git a/tests/issues.rs b/tests/issues.rs new file mode 100644 index 0000000..3d52cd4 --- /dev/null +++ b/tests/issues.rs @@ -0,0 +1,153 @@ +#![cfg(all(feature = "tokio", feature = "zstd"))] + +use std::{ + io, + pin::Pin, + task::{ready, Context, Poll}, +}; + +use async_compression::tokio::write::ZstdEncoder; +use tokio::io::{AsyncWrite, AsyncWriteExt as _}; +use tracing_subscriber::fmt::format::FmtSpan; + +/// This issue covers our state machine being invalid when using adapters +/// like [`tokio_util::codec`]. +/// +/// After the first [`poll_shutdown`] call, +/// we must expect any number of [`poll_flush`] and [`poll_shutdown`] calls, +/// until [`poll_shutdown`] returns [`Poll::Ready`], +/// according to the documentation on [`AsyncWrite`]. +/// +/// +/// +/// [`tokio_util::codec`](https://docs.rs/tokio-util/latest/tokio_util/codec) +/// [`poll_shutdown`](AsyncWrite::poll_shutdown) +/// [`poll_flush`](AsyncWrite::poll_flush) +#[test] +fn issue_246() { + tracing_subscriber::fmt() + .without_time() + .with_ansi(false) + .with_level(false) + .with_test_writer() + .with_target(false) + .with_span_events(FmtSpan::NEW) + .init(); + let mut zstd_encoder = Wrapper::new(Trace::new(ZstdEncoder::new(DelayedShutdown::default()))); + futures::executor::block_on(zstd_encoder.shutdown()).unwrap(); +} + +pin_project_lite::pin_project! { + /// A simple wrapper struct that follows the [`AsyncWrite`] protocol. + /// This is a stand-in for combinators like `tokio_util::codec`s + struct Wrapper { + #[pin] inner: T + } +} + +impl Wrapper { + fn new(inner: T) -> Self { + Self { inner } + } +} + +impl AsyncWrite for Wrapper { + #[tracing::instrument(name = "Wrapper::poll_write", skip_all, ret)] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().inner.poll_write(cx, buf) + } + + #[tracing::instrument(name = "Wrapper::poll_flush", skip_all, ret)] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + /// To quote the [`AsyncWrite`] docs: + /// > Invocation of a shutdown implies an invocation of flush. + /// > Once this method returns Ready it implies that a flush successfully happened before the shutdown happened. + /// > That is, callers don't need to call flush before calling shutdown. + /// > They can rely that by calling shutdown any pending buffered data will be written out. + #[tracing::instrument(name = "Wrapper::poll_shutdown", skip_all, ret)] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + ready!(this.inner.as_mut().poll_flush(cx))?; + this.inner.poll_shutdown(cx) + } +} + +pin_project_lite::pin_project! { + /// Yields [`Poll::Pending`] the first time [`AsyncWrite::poll_shutdown`] is called. + #[derive(Default)] + struct DelayedShutdown { + contents: Vec, + num_times_shutdown_called: u8, + } +} + +impl AsyncWrite for DelayedShutdown { + #[tracing::instrument(name = "DelayedShutdown::poll_write", skip_all, ret)] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let _ = cx; + self.project().contents.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + + #[tracing::instrument(name = "DelayedShutdown::poll_flush", skip_all, ret)] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let _ = cx; + Poll::Ready(Ok(())) + } + + #[tracing::instrument(name = "DelayedShutdown::poll_shutdown", skip_all, ret)] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project().num_times_shutdown_called { + it @ 0 => { + *it += 1; + cx.waker().wake_by_ref(); + Poll::Pending + } + _ => Poll::Ready(Ok(())), + } + } +} + +pin_project_lite::pin_project! { + /// A wrapper which traces all calls + struct Trace { + #[pin] inner: T + } +} + +impl Trace { + fn new(inner: T) -> Self { + Self { inner } + } +} + +impl AsyncWrite for Trace { + #[tracing::instrument(name = "Trace::poll_write", skip_all, ret)] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().inner.poll_write(cx, buf) + } + #[tracing::instrument(name = "Trace::poll_flush", skip_all, ret)] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + #[tracing::instrument(name = "Trace::poll_shutdown", skip_all, ret)] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_shutdown(cx) + } +}