Skip to content

Commit 5eec94c

Browse files
committed
feat(io): compat async stream
1 parent 84ed6e6 commit 5eec94c

File tree

4 files changed

+253
-2
lines changed

4 files changed

+253
-2
lines changed

compio-io/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ repository = { workspace = true }
1414
compio-buf = { workspace = true, features = ["arrayvec"] }
1515
futures-util = { workspace = true }
1616
paste = { workspace = true }
17+
pin-project = { version = "1.1.3", optional = true }
1718

1819
[dev-dependencies]
1920
compio-runtime = { workspace = true }
@@ -22,7 +23,7 @@ tokio = { workspace = true, features = ["macros", "rt"] }
2223

2324
[features]
2425
default = []
25-
compat = []
26+
compat = ["futures-util/io", "dep:pin-project"]
2627

2728
# Nightly features
2829
allocator_api = ["compio-buf/allocator_api"]

compio-io/src/buffer.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,20 @@ impl Buffer {
162162
///
163163
/// https://github.com/compio-rs/compio/issues/209
164164
pub async fn flush_to(&mut self, writer: &mut impl AsyncWrite) -> IoResult<usize> {
165+
if self.slice().is_empty() {
166+
return Ok(0);
167+
}
165168
let mut total = 0;
166169
loop {
167170
let written = self
168171
.with(|inner| async { writer.write(inner.into_slice()).await.into_inner() })
169172
.await?;
173+
if written == 0 {
174+
return Err(std::io::Error::new(
175+
std::io::ErrorKind::UnexpectedEof,
176+
"cannot flush all buffer data",
177+
));
178+
}
170179
total += written;
171180
if self.advance(written) {
172181
break;

compio-io/src/compat.rs

Lines changed: 184 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
//! Compat wrappers for interop with other crates.
22
3-
use std::io::{self, BufRead, Read, Write};
3+
use std::{
4+
future::Future,
5+
io::{self, BufRead, Read, Write},
6+
pin::Pin,
7+
task::{Context, Poll},
8+
};
49

510
use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, SetBufInit};
11+
use pin_project::pin_project;
612

713
use crate::{buffer::Buffer, util::DEFAULT_BUF_SIZE};
814

@@ -158,3 +164,180 @@ impl<S: crate::AsyncWrite> SyncStream<S> {
158164
Ok(len)
159165
}
160166
}
167+
168+
type PinBoxFuture<T> = Pin<Box<dyn Future<Output = T>>>;
169+
170+
/// A stream wrapper for [`futures_util::io`] traits.
171+
#[pin_project]
172+
pub struct AsyncStream<S> {
173+
#[pin]
174+
inner: SyncStream<S>,
175+
read_future: Option<PinBoxFuture<io::Result<usize>>>,
176+
write_future: Option<PinBoxFuture<io::Result<usize>>>,
177+
shutdown_future: Option<PinBoxFuture<io::Result<()>>>,
178+
}
179+
180+
impl<S> AsyncStream<S> {
181+
/// Create [`AsyncStream`] with the stream and default buffer size.
182+
pub fn new(stream: S) -> Self {
183+
Self::new_impl(SyncStream::new(stream))
184+
}
185+
186+
/// Create [`AsyncStream`] with the stream and buffer size.
187+
pub fn with_capacity(cap: usize, stream: S) -> Self {
188+
Self::new_impl(SyncStream::with_capacity(cap, stream))
189+
}
190+
191+
fn new_impl(inner: SyncStream<S>) -> Self {
192+
Self {
193+
inner,
194+
read_future: None,
195+
write_future: None,
196+
shutdown_future: None,
197+
}
198+
}
199+
200+
/// Get the reference of the inner stream.
201+
pub fn get_ref(&self) -> &S {
202+
self.inner.get_ref()
203+
}
204+
}
205+
206+
macro_rules! poll_future {
207+
($f:expr, $cx:expr, $e:expr) => {{
208+
let mut future = match $f.take() {
209+
Some(f) => f,
210+
None => Box::pin($e),
211+
};
212+
let f = future.as_mut();
213+
match f.poll($cx) {
214+
Poll::Pending => {
215+
$f.replace(future);
216+
return Poll::Pending;
217+
}
218+
Poll::Ready(res) => res,
219+
}
220+
}};
221+
}
222+
223+
macro_rules! poll_future_would_block {
224+
($f:expr, $cx:expr, $e:expr, $io:expr) => {{
225+
if let Some(mut f) = $f.take() {
226+
if f.as_mut().poll($cx).is_pending() {
227+
$f.replace(f);
228+
return Poll::Pending;
229+
}
230+
}
231+
232+
match $io {
233+
Ok(len) => Poll::Ready(Ok(len)),
234+
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
235+
$f.replace(Box::pin($e));
236+
$cx.waker().wake_by_ref();
237+
Poll::Pending
238+
}
239+
Err(e) => Poll::Ready(Err(e)),
240+
}
241+
}};
242+
}
243+
244+
impl<S: crate::AsyncRead + 'static> futures_util::AsyncRead for AsyncStream<S> {
245+
fn poll_read(
246+
self: Pin<&mut Self>,
247+
cx: &mut Context<'_>,
248+
buf: &mut [u8],
249+
) -> Poll<io::Result<usize>> {
250+
let this = self.project();
251+
// Safety:
252+
// - The futures won't live longer than the stream.
253+
// - `self` is pinned.
254+
// - The inner stream won't be moved.
255+
let inner: &'static mut SyncStream<S> =
256+
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
257+
258+
poll_future_would_block!(
259+
this.read_future,
260+
cx,
261+
inner.fill_read_buf(),
262+
io::Read::read(inner, buf)
263+
)
264+
}
265+
}
266+
267+
impl<S: crate::AsyncRead + 'static> futures_util::AsyncBufRead for AsyncStream<S> {
268+
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
269+
let this = self.project();
270+
271+
let inner: &'static mut SyncStream<S> =
272+
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
273+
poll_future_would_block!(
274+
this.read_future,
275+
cx,
276+
inner.fill_read_buf(),
277+
// Safety: anyway the slice won't be used after free.
278+
io::BufRead::fill_buf(inner).map(|slice| unsafe { &*(slice as *const _) })
279+
)
280+
}
281+
282+
fn consume(self: Pin<&mut Self>, amt: usize) {
283+
let this = self.project();
284+
285+
let inner: &'static mut SyncStream<S> =
286+
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
287+
inner.consume(amt)
288+
}
289+
}
290+
291+
impl<S: crate::AsyncWrite + 'static> futures_util::AsyncWrite for AsyncStream<S> {
292+
fn poll_write(
293+
self: Pin<&mut Self>,
294+
cx: &mut Context<'_>,
295+
buf: &[u8],
296+
) -> Poll<io::Result<usize>> {
297+
let this = self.project();
298+
299+
if this.shutdown_future.is_some() {
300+
debug_assert!(this.write_future.is_none());
301+
return Poll::Pending;
302+
}
303+
304+
let inner: &'static mut SyncStream<S> =
305+
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
306+
poll_future_would_block!(
307+
this.write_future,
308+
cx,
309+
inner.flush_write_buf(),
310+
io::Write::write(inner, buf)
311+
)
312+
}
313+
314+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
315+
let this = self.project();
316+
317+
if this.shutdown_future.is_some() {
318+
debug_assert!(this.write_future.is_none());
319+
return Poll::Pending;
320+
}
321+
322+
let inner: &'static mut SyncStream<S> =
323+
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
324+
let res = poll_future!(this.write_future, cx, inner.flush_write_buf());
325+
Poll::Ready(res.map(|_| ()))
326+
}
327+
328+
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
329+
let this = self.project();
330+
331+
// Avoid shutdown on flush because the inner buffer might be passed to the
332+
// driver.
333+
if this.write_future.is_some() {
334+
debug_assert!(this.shutdown_future.is_none());
335+
return Poll::Pending;
336+
}
337+
338+
let inner: &'static mut SyncStream<S> =
339+
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
340+
let res = poll_future!(this.shutdown_future, cx, inner.get_mut().shutdown());
341+
Poll::Ready(res)
342+
}
343+
}

compio-io/tests/compat.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
use std::io::Cursor;
2+
3+
use compio_io::compat::AsyncStream;
4+
use futures_util::{AsyncReadExt, AsyncWriteExt};
5+
6+
#[tokio::test]
7+
async fn async_compat_read() {
8+
let src = &[1u8, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0][..];
9+
let mut stream = AsyncStream::new(src);
10+
11+
let mut buf = [0; 6];
12+
let len = stream.read(&mut buf).await.unwrap();
13+
14+
assert_eq!(len, 6);
15+
assert_eq!(buf, [1, 1, 4, 5, 1, 4]);
16+
17+
let mut buf = [0; 20];
18+
let len = stream.read(&mut buf).await.unwrap();
19+
assert_eq!(len, 7);
20+
assert_eq!(&buf[..7], [1, 9, 1, 9, 8, 1, 0]);
21+
}
22+
23+
#[tokio::test]
24+
async fn async_compat_write() {
25+
let dst = Cursor::new([0u8; 10]);
26+
let mut stream = AsyncStream::new(dst);
27+
28+
let len = stream.write(&[1, 1, 4, 5, 1, 4]).await.unwrap();
29+
stream.flush().await.unwrap();
30+
31+
assert_eq!(len, 6);
32+
assert_eq!(stream.get_ref().position(), 6);
33+
assert_eq!(stream.get_ref().get_ref(), &[1, 1, 4, 5, 1, 4, 0, 0, 0, 0]);
34+
35+
let dst = Cursor::new([0u8; 10]);
36+
let mut stream = AsyncStream::with_capacity(10, dst);
37+
let len = stream
38+
.write(&[1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0])
39+
.await
40+
.unwrap();
41+
assert_eq!(len, 10);
42+
43+
stream.flush().await.unwrap();
44+
assert_eq!(stream.get_ref().get_ref(), &[1, 1, 4, 5, 1, 4, 1, 9, 1, 9]);
45+
}
46+
47+
#[tokio::test]
48+
async fn async_compat_flush_fail() {
49+
let dst = Cursor::new([0u8; 10]);
50+
let mut stream = AsyncStream::new(dst);
51+
let len = stream
52+
.write(&[1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0])
53+
.await
54+
.unwrap();
55+
assert_eq!(len, 13);
56+
let err = stream.flush().await.unwrap_err();
57+
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
58+
}

0 commit comments

Comments
 (0)