Skip to content

Commit 4cabfec

Browse files
authored
Merge pull request #206 from Berrysoft/dev/split
feat(io,net): add split
2 parents a26dbe5 + 502685f commit 4cabfec

File tree

9 files changed

+372
-3
lines changed

9 files changed

+372
-3
lines changed

compio-io/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ repository = { workspace = true }
1212

1313
[dependencies]
1414
compio-buf = { workspace = true, features = ["arrayvec"] }
15+
futures-util = { workspace = true }
1516
paste = { workspace = true }
1617

1718
[dev-dependencies]

compio-io/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,13 @@ mod buffer;
107107
#[cfg(feature = "compat")]
108108
pub mod compat;
109109
mod read;
110+
mod split;
110111
pub mod util;
111112
mod write;
112113

113114
pub(crate) type IoResult<T> = std::io::Result<T>;
114115

115116
pub use read::*;
117+
pub use split::*;
116118
pub use util::{copy, null, repeat};
117119
pub use write::*;

compio-io/src/split.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
use std::sync::Arc;
2+
3+
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
4+
use futures_util::lock::Mutex;
5+
6+
use crate::{AsyncRead, AsyncWrite, IoResult};
7+
8+
/// Splits a single value implementing `AsyncRead + AsyncWrite` into separate
9+
/// [`AsyncRead`] and [`AsyncWrite`] handles.
10+
pub fn split<T: AsyncRead + AsyncWrite>(stream: T) -> (ReadHalf<T>, WriteHalf<T>) {
11+
let stream = Arc::new(Mutex::new(stream));
12+
(ReadHalf(stream.clone()), WriteHalf(stream))
13+
}
14+
15+
/// The readable half of a value returned from [`split`].
16+
#[derive(Debug)]
17+
pub struct ReadHalf<T>(Arc<Mutex<T>>);
18+
19+
impl<T: Unpin> ReadHalf<T> {
20+
/// Reunites with a previously split [`WriteHalf`].
21+
///
22+
/// # Panics
23+
///
24+
/// If this [`ReadHalf`] and the given [`WriteHalf`] do not originate from
25+
/// the same [`split`] operation this method will panic.
26+
/// This can be checked ahead of time by comparing the stored pointer
27+
/// of the two halves.
28+
#[track_caller]
29+
pub fn unsplit(self, w: WriteHalf<T>) -> T {
30+
if Arc::ptr_eq(&self.0, &w.0) {
31+
drop(w);
32+
let inner = Arc::try_unwrap(self.0).expect("`Arc::try_unwrap` failed");
33+
inner.into_inner()
34+
} else {
35+
#[cold]
36+
fn panic_unrelated() -> ! {
37+
panic!("Unrelated `WriteHalf` passed to `ReadHalf::unsplit`.")
38+
}
39+
40+
panic_unrelated()
41+
}
42+
}
43+
}
44+
45+
impl<T: AsyncRead> AsyncRead for ReadHalf<T> {
46+
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
47+
self.0.lock().await.read(buf).await
48+
}
49+
50+
async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
51+
self.0.lock().await.read_vectored(buf).await
52+
}
53+
}
54+
55+
/// The writable half of a value returned from [`split`].
56+
#[derive(Debug)]
57+
pub struct WriteHalf<T>(Arc<Mutex<T>>);
58+
59+
impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> {
60+
async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
61+
self.0.lock().await.write(buf).await
62+
}
63+
64+
async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
65+
self.0.lock().await.write_vectored(buf).await
66+
}
67+
68+
async fn flush(&mut self) -> IoResult<()> {
69+
self.0.lock().await.flush().await
70+
}
71+
72+
async fn shutdown(&mut self) -> IoResult<()> {
73+
self.0.lock().await.shutdown().await
74+
}
75+
}

compio-io/tests/io.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::io::Cursor;
22

33
use compio_buf::{arrayvec::ArrayVec, BufResult, IoBuf, IoBufMut};
44
use compio_io::{
5-
AsyncRead, AsyncReadAt, AsyncReadAtExt, AsyncReadExt, AsyncWrite, AsyncWriteAt,
5+
split, AsyncRead, AsyncReadAt, AsyncReadAtExt, AsyncReadExt, AsyncWrite, AsyncWriteAt,
66
AsyncWriteAtExt, AsyncWriteExt,
77
};
88

@@ -355,3 +355,19 @@ async fn read_to_end_at() {
355355
assert_eq!(len, 4);
356356
assert_eq!(buf, [4, 5, 1, 4]);
357357
}
358+
359+
#[tokio::test]
360+
async fn split_unsplit() {
361+
let src = Cursor::new([1, 1, 4, 5, 1, 4]);
362+
let (mut read, mut write) = split(src);
363+
364+
let (len, buf) = read.read([0, 0, 0]).await.unwrap();
365+
assert_eq!(len, 3);
366+
assert_eq!(buf, [1, 1, 4]);
367+
368+
let (len, _) = write.write([2, 2, 2]).await.unwrap();
369+
assert_eq!(len, 3);
370+
371+
let src = read.unsplit(write);
372+
assert_eq!(src.into_inner(), [1, 1, 4, 2, 2, 2]);
373+
}

compio-net/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77

88
mod resolve;
99
mod socket;
10+
pub(crate) mod split;
1011
mod tcp;
1112
mod udp;
1213
mod unix;
1314

1415
pub use resolve::ToSocketAddrsAsync;
1516
pub(crate) use resolve::{each_addr, first_addr_buf};
1617
pub(crate) use socket::*;
18+
pub use split::*;
1719
pub use tcp::*;
1820
pub use udp::*;
1921
pub use unix::*;

compio-net/src/split.rs

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
use std::{error::Error, fmt, io, ops::Deref, sync::Arc};
2+
3+
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
4+
use compio_io::{AsyncRead, AsyncWrite};
5+
6+
pub(crate) fn split<T>(stream: &T) -> (ReadHalf<T>, WriteHalf<T>)
7+
where
8+
for<'a> &'a T: AsyncRead + AsyncWrite,
9+
{
10+
(ReadHalf(stream), WriteHalf(stream))
11+
}
12+
13+
/// Borrowed read half.
14+
#[derive(Debug)]
15+
pub struct ReadHalf<'a, T>(&'a T);
16+
17+
impl<T> AsyncRead for ReadHalf<'_, T>
18+
where
19+
for<'a> &'a T: AsyncRead,
20+
{
21+
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
22+
self.0.read(buf).await
23+
}
24+
25+
async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
26+
self.0.read_vectored(buf).await
27+
}
28+
}
29+
30+
/// Borrowed write half.
31+
#[derive(Debug)]
32+
pub struct WriteHalf<'a, T>(&'a T);
33+
34+
impl<T> AsyncWrite for WriteHalf<'_, T>
35+
where
36+
for<'a> &'a T: AsyncWrite,
37+
{
38+
async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
39+
self.0.write(buf).await
40+
}
41+
42+
async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
43+
self.0.write_vectored(buf).await
44+
}
45+
46+
async fn flush(&mut self) -> io::Result<()> {
47+
self.0.flush().await
48+
}
49+
50+
async fn shutdown(&mut self) -> io::Result<()> {
51+
self.0.shutdown().await
52+
}
53+
}
54+
55+
pub(crate) fn into_split<T>(stream: T) -> (OwnedReadHalf<T>, OwnedWriteHalf<T>)
56+
where
57+
for<'a> &'a T: AsyncRead + AsyncWrite,
58+
{
59+
let stream = Arc::new(stream);
60+
(OwnedReadHalf(stream.clone()), OwnedWriteHalf(stream))
61+
}
62+
63+
/// Owned read half.
64+
#[derive(Debug)]
65+
pub struct OwnedReadHalf<T>(Arc<T>);
66+
67+
impl<T: Unpin> OwnedReadHalf<T> {
68+
/// Attempts to put the two halves of a `TcpStream` back together and
69+
/// recover the original socket. Succeeds only if the two halves
70+
/// originated from the same call to `into_split`.
71+
pub fn reunite(self, w: OwnedWriteHalf<T>) -> Result<T, ReuniteError<T>> {
72+
if Arc::ptr_eq(&self.0, &w.0) {
73+
drop(w);
74+
Ok(Arc::try_unwrap(self.0)
75+
.ok()
76+
.expect("`Arc::try_unwrap` failed"))
77+
} else {
78+
Err(ReuniteError(self, w))
79+
}
80+
}
81+
}
82+
83+
impl<T> AsyncRead for OwnedReadHalf<T>
84+
where
85+
for<'a> &'a T: AsyncRead,
86+
{
87+
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
88+
self.0.deref().read(buf).await
89+
}
90+
91+
async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
92+
self.0.deref().read_vectored(buf).await
93+
}
94+
}
95+
96+
/// Owned write half.
97+
#[derive(Debug)]
98+
pub struct OwnedWriteHalf<T>(Arc<T>);
99+
100+
impl<T> AsyncWrite for OwnedWriteHalf<T>
101+
where
102+
for<'a> &'a T: AsyncWrite,
103+
{
104+
async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
105+
self.0.deref().write(buf).await
106+
}
107+
108+
async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
109+
self.0.deref().write_vectored(buf).await
110+
}
111+
112+
async fn flush(&mut self) -> io::Result<()> {
113+
self.0.deref().flush().await
114+
}
115+
116+
async fn shutdown(&mut self) -> io::Result<()> {
117+
self.0.deref().shutdown().await
118+
}
119+
}
120+
121+
/// Error indicating that two halves were not from the same socket, and thus
122+
/// could not be reunited.
123+
#[derive(Debug)]
124+
pub struct ReuniteError<T>(pub OwnedReadHalf<T>, pub OwnedWriteHalf<T>);
125+
126+
impl<T> fmt::Display for ReuniteError<T> {
127+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128+
write!(
129+
f,
130+
"tried to reunite halves that are not from the same socket"
131+
)
132+
}
133+
}
134+
135+
impl<T: fmt::Debug> Error for ReuniteError<T> {}

compio-net/src/tcp.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use compio_io::{AsyncRead, AsyncWrite};
55
use compio_runtime::{impl_attachable, impl_try_as_raw_fd};
66
use socket2::{Protocol, SockAddr, Type};
77

8-
use crate::{Socket, ToSocketAddrsAsync};
8+
use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, ToSocketAddrsAsync, WriteHalf};
99

1010
/// A TCP socket server, listening for connections.
1111
///
@@ -203,6 +203,25 @@ impl TcpStream {
203203
.local_addr()
204204
.map(|addr| addr.as_socket().expect("should be SocketAddr"))
205205
}
206+
207+
/// Splits a [`TcpStream`] into a read half and a write half, which can be
208+
/// used to read and write the stream concurrently.
209+
///
210+
/// This method is more efficient than
211+
/// [`into_split`](TcpStream::into_split), but the halves cannot
212+
/// be moved into independently spawned tasks.
213+
pub fn split(&self) -> (ReadHalf<Self>, WriteHalf<Self>) {
214+
crate::split(self)
215+
}
216+
217+
/// Splits a [`TcpStream`] into a read half and a write half, which can be
218+
/// used to read and write the stream concurrently.
219+
///
220+
/// Unlike [`split`](TcpStream::split), the owned halves can be moved to
221+
/// separate tasks, however this comes at the cost of a heap allocation.
222+
pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
223+
crate::into_split(self)
224+
}
206225
}
207226

208227
impl AsyncRead for TcpStream {

compio-net/src/unix.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use compio_io::{AsyncRead, AsyncWrite};
55
use compio_runtime::{impl_attachable, impl_try_as_raw_fd};
66
use socket2::{Domain, SockAddr, Type};
77

8-
use crate::Socket;
8+
use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, WriteHalf};
99

1010
/// A Unix socket server, listening for connections.
1111
///
@@ -159,6 +159,25 @@ impl UnixStream {
159159
pub fn local_addr(&self) -> io::Result<SockAddr> {
160160
self.inner.local_addr()
161161
}
162+
163+
/// Splits a [`UnixStream`] into a read half and a write half, which can be
164+
/// used to read and write the stream concurrently.
165+
///
166+
/// This method is more efficient than
167+
/// [`into_split`](UnixStream::into_split), but the halves cannot
168+
/// be moved into independently spawned tasks.
169+
pub fn split(&self) -> (ReadHalf<Self>, WriteHalf<Self>) {
170+
crate::split(self)
171+
}
172+
173+
/// Splits a [`UnixStream`] into a read half and a write half, which can be
174+
/// used to read and write the stream concurrently.
175+
///
176+
/// Unlike [`split`](UnixStream::split), the owned halves can be moved to
177+
/// separate tasks, however this comes at the cost of a heap allocation.
178+
pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
179+
crate::into_split(self)
180+
}
162181
}
163182

164183
impl AsyncRead for UnixStream {

0 commit comments

Comments
 (0)