Skip to content

Commit b463817

Browse files
authored
Merge pull request #240 from Berrysoft/dev/async-stream
feat(io): compat async stream for futures
2 parents cb93865 + 273dd44 commit b463817

File tree

4 files changed

+270
-2
lines changed

4 files changed

+270
-2
lines changed

compio-io/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ repository = { workspace = true }
1414
compio-buf = { workspace = true, features = ["arrayvec"] }
1515
futures-util = { workspace = true }
1616
paste = { workspace = true }
17+
pin-project-lite = { version = "0.2.14", optional = true }
1718

1819
[dev-dependencies]
1920
compio-runtime = { workspace = true }
@@ -22,7 +23,7 @@ tokio = { workspace = true, features = ["macros", "rt"] }
2223

2324
[features]
2425
default = []
25-
compat = []
26+
compat = ["futures-util/io", "dep:pin-project-lite"]
2627

2728
# Nightly features
2829
allocator_api = ["compio-buf/allocator_api"]

compio-io/src/buffer.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,20 @@ impl Buffer {
162162
///
163163
/// https://github.com/compio-rs/compio/issues/209
164164
pub async fn flush_to(&mut self, writer: &mut impl AsyncWrite) -> IoResult<usize> {
165+
if self.slice().is_empty() {
166+
return Ok(0);
167+
}
165168
let mut total = 0;
166169
loop {
167170
let written = self
168171
.with(|inner| async { writer.write(inner.into_slice()).await.into_inner() })
169172
.await?;
173+
if written == 0 {
174+
return Err(std::io::Error::new(
175+
std::io::ErrorKind::UnexpectedEof,
176+
"cannot flush all buffer data",
177+
));
178+
}
170179
total += written;
171180
if self.advance(written) {
172181
break;

compio-io/src/compat.rs

Lines changed: 185 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
//! Compat wrappers for interop with other crates.
22
3-
use std::io::{self, BufRead, Read, Write};
3+
use std::{
4+
future::Future,
5+
io::{self, BufRead, Read, Write},
6+
pin::Pin,
7+
task::{Context, Poll},
8+
};
49

510
use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, SetBufInit};
11+
use pin_project_lite::pin_project;
612

713
use crate::{buffer::Buffer, util::DEFAULT_BUF_SIZE};
814

@@ -158,3 +164,181 @@ impl<S: crate::AsyncWrite> SyncStream<S> {
158164
Ok(len)
159165
}
160166
}
167+
168+
type PinBoxFuture<T> = Pin<Box<dyn Future<Output = T>>>;
169+
170+
pin_project! {
171+
/// A stream wrapper for [`futures_util::io`] traits.
172+
pub struct AsyncStream<S> {
173+
#[pin]
174+
inner: SyncStream<S>,
175+
read_future: Option<PinBoxFuture<io::Result<usize>>>,
176+
write_future: Option<PinBoxFuture<io::Result<usize>>>,
177+
shutdown_future: Option<PinBoxFuture<io::Result<()>>>,
178+
}
179+
}
180+
181+
impl<S> AsyncStream<S> {
182+
/// Create [`AsyncStream`] with the stream and default buffer size.
183+
pub fn new(stream: S) -> Self {
184+
Self::new_impl(SyncStream::new(stream))
185+
}
186+
187+
/// Create [`AsyncStream`] with the stream and buffer size.
188+
pub fn with_capacity(cap: usize, stream: S) -> Self {
189+
Self::new_impl(SyncStream::with_capacity(cap, stream))
190+
}
191+
192+
fn new_impl(inner: SyncStream<S>) -> Self {
193+
Self {
194+
inner,
195+
read_future: None,
196+
write_future: None,
197+
shutdown_future: None,
198+
}
199+
}
200+
201+
/// Get the reference of the inner stream.
202+
pub fn get_ref(&self) -> &S {
203+
self.inner.get_ref()
204+
}
205+
}
206+
207+
macro_rules! poll_future {
208+
($f:expr, $cx:expr, $e:expr) => {{
209+
let mut future = match $f.take() {
210+
Some(f) => f,
211+
None => Box::pin($e),
212+
};
213+
let f = future.as_mut();
214+
match f.poll($cx) {
215+
Poll::Pending => {
216+
$f.replace(future);
217+
return Poll::Pending;
218+
}
219+
Poll::Ready(res) => res,
220+
}
221+
}};
222+
}
223+
224+
macro_rules! poll_future_would_block {
225+
($f:expr, $cx:expr, $e:expr, $io:expr) => {{
226+
if let Some(mut f) = $f.take() {
227+
if f.as_mut().poll($cx).is_pending() {
228+
$f.replace(f);
229+
return Poll::Pending;
230+
}
231+
}
232+
233+
match $io {
234+
Ok(len) => Poll::Ready(Ok(len)),
235+
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
236+
$f.replace(Box::pin($e));
237+
$cx.waker().wake_by_ref();
238+
Poll::Pending
239+
}
240+
Err(e) => Poll::Ready(Err(e)),
241+
}
242+
}};
243+
}
244+
245+
impl<S: crate::AsyncRead + 'static> futures_util::AsyncRead for AsyncStream<S> {
246+
fn poll_read(
247+
self: Pin<&mut Self>,
248+
cx: &mut Context<'_>,
249+
buf: &mut [u8],
250+
) -> Poll<io::Result<usize>> {
251+
let this = self.project();
252+
// Safety:
253+
// - The futures won't live longer than the stream.
254+
// - `self` is pinned.
255+
// - The inner stream won't be moved.
256+
let inner: &'static mut SyncStream<S> =
257+
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
258+
259+
poll_future_would_block!(
260+
this.read_future,
261+
cx,
262+
inner.fill_read_buf(),
263+
io::Read::read(inner, buf)
264+
)
265+
}
266+
}
267+
268+
impl<S: crate::AsyncRead + 'static> futures_util::AsyncBufRead for AsyncStream<S> {
269+
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
270+
let this = self.project();
271+
272+
let inner: &'static mut SyncStream<S> =
273+
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
274+
poll_future_would_block!(
275+
this.read_future,
276+
cx,
277+
inner.fill_read_buf(),
278+
// Safety: anyway the slice won't be used after free.
279+
io::BufRead::fill_buf(inner).map(|slice| unsafe { &*(slice as *const _) })
280+
)
281+
}
282+
283+
fn consume(self: Pin<&mut Self>, amt: usize) {
284+
let this = self.project();
285+
286+
let inner: &'static mut SyncStream<S> =
287+
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
288+
inner.consume(amt)
289+
}
290+
}
291+
292+
impl<S: crate::AsyncWrite + 'static> futures_util::AsyncWrite for AsyncStream<S> {
293+
fn poll_write(
294+
self: Pin<&mut Self>,
295+
cx: &mut Context<'_>,
296+
buf: &[u8],
297+
) -> Poll<io::Result<usize>> {
298+
let this = self.project();
299+
300+
if this.shutdown_future.is_some() {
301+
debug_assert!(this.write_future.is_none());
302+
return Poll::Pending;
303+
}
304+
305+
let inner: &'static mut SyncStream<S> =
306+
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
307+
poll_future_would_block!(
308+
this.write_future,
309+
cx,
310+
inner.flush_write_buf(),
311+
io::Write::write(inner, buf)
312+
)
313+
}
314+
315+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
316+
let this = self.project();
317+
318+
if this.shutdown_future.is_some() {
319+
debug_assert!(this.write_future.is_none());
320+
return Poll::Pending;
321+
}
322+
323+
let inner: &'static mut SyncStream<S> =
324+
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
325+
let res = poll_future!(this.write_future, cx, inner.flush_write_buf());
326+
Poll::Ready(res.map(|_| ()))
327+
}
328+
329+
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
330+
let this = self.project();
331+
332+
// Avoid shutdown on flush because the inner buffer might be passed to the
333+
// driver.
334+
if this.write_future.is_some() {
335+
debug_assert!(this.shutdown_future.is_none());
336+
return Poll::Pending;
337+
}
338+
339+
let inner: &'static mut SyncStream<S> =
340+
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
341+
let res = poll_future!(this.shutdown_future, cx, inner.get_mut().shutdown());
342+
Poll::Ready(res)
343+
}
344+
}

compio-io/tests/compat.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
use std::io::Cursor;
2+
3+
use compio_io::compat::AsyncStream;
4+
use futures_util::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
5+
6+
#[tokio::test]
7+
async fn async_compat_read() {
8+
let src = &[1u8, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0][..];
9+
let mut stream = AsyncStream::new(src);
10+
11+
let mut buf = [0; 6];
12+
let len = stream.read(&mut buf).await.unwrap();
13+
14+
assert_eq!(len, 6);
15+
assert_eq!(buf, [1, 1, 4, 5, 1, 4]);
16+
17+
let mut buf = [0; 20];
18+
let len = stream.read(&mut buf).await.unwrap();
19+
assert_eq!(len, 7);
20+
assert_eq!(&buf[..7], [1, 9, 1, 9, 8, 1, 0]);
21+
}
22+
23+
#[tokio::test]
24+
async fn async_compat_bufread() {
25+
let src = &[1u8, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0][..];
26+
let mut stream = AsyncStream::new(src);
27+
28+
let slice = stream.fill_buf().await.unwrap();
29+
assert_eq!(slice, [1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0]);
30+
stream.consume_unpin(6);
31+
32+
let mut buf = [0; 7];
33+
let len = stream.read(&mut buf).await.unwrap();
34+
35+
assert_eq!(len, 7);
36+
assert_eq!(buf, [1, 9, 1, 9, 8, 1, 0]);
37+
}
38+
39+
#[tokio::test]
40+
async fn async_compat_write() {
41+
let dst = Cursor::new([0u8; 10]);
42+
let mut stream = AsyncStream::new(dst);
43+
44+
let len = stream.write(&[1, 1, 4, 5, 1, 4]).await.unwrap();
45+
stream.flush().await.unwrap();
46+
47+
assert_eq!(len, 6);
48+
assert_eq!(stream.get_ref().position(), 6);
49+
assert_eq!(stream.get_ref().get_ref(), &[1, 1, 4, 5, 1, 4, 0, 0, 0, 0]);
50+
51+
let dst = Cursor::new([0u8; 10]);
52+
let mut stream = AsyncStream::with_capacity(10, dst);
53+
let len = stream
54+
.write(&[1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0])
55+
.await
56+
.unwrap();
57+
assert_eq!(len, 10);
58+
59+
stream.flush().await.unwrap();
60+
assert_eq!(stream.get_ref().get_ref(), &[1, 1, 4, 5, 1, 4, 1, 9, 1, 9]);
61+
}
62+
63+
#[tokio::test]
64+
async fn async_compat_flush_fail() {
65+
let dst = Cursor::new([0u8; 10]);
66+
let mut stream = AsyncStream::new(dst);
67+
let len = stream
68+
.write(&[1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0])
69+
.await
70+
.unwrap();
71+
assert_eq!(len, 13);
72+
let err = stream.flush().await.unwrap_err();
73+
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
74+
}

0 commit comments

Comments
 (0)