Skip to content

Commit 502685f

Browse files
committed
feat(net): add reunite and tests
1 parent 24023fd commit 502685f

File tree

3 files changed

+135
-6
lines changed

3 files changed

+135
-6
lines changed

compio-io/src/split.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pub fn split<T: AsyncRead + AsyncWrite>(stream: T) -> (ReadHalf<T>, WriteHalf<T>
1616
#[derive(Debug)]
1717
pub struct ReadHalf<T>(Arc<Mutex<T>>);
1818

19-
impl<T> ReadHalf<T> {
19+
impl<T: Unpin> ReadHalf<T> {
2020
/// Reunites with a previously split [`WriteHalf`].
2121
///
2222
/// # Panics
@@ -26,10 +26,7 @@ impl<T> ReadHalf<T> {
2626
/// This can be checked ahead of time by comparing the stored pointer
2727
/// of the two halves.
2828
#[track_caller]
29-
pub fn unsplit(self, w: WriteHalf<T>) -> T
30-
where
31-
T: Unpin,
32-
{
29+
pub fn unsplit(self, w: WriteHalf<T>) -> T {
3330
if Arc::ptr_eq(&self.0, &w.0) {
3431
drop(w);
3532
let inner = Arc::try_unwrap(self.0).expect("`Arc::try_unwrap` failed");

compio-net/src/split.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{io, ops::Deref, sync::Arc};
1+
use std::{error::Error, fmt, io, ops::Deref, sync::Arc};
22

33
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
44
use compio_io::{AsyncRead, AsyncWrite};
@@ -64,6 +64,22 @@ where
6464
#[derive(Debug)]
6565
pub struct OwnedReadHalf<T>(Arc<T>);
6666

67+
impl<T: Unpin> OwnedReadHalf<T> {
68+
/// Attempts to put the two halves of a `TcpStream` back together and
69+
/// recover the original socket. Succeeds only if the two halves
70+
/// originated from the same call to `into_split`.
71+
pub fn reunite(self, w: OwnedWriteHalf<T>) -> Result<T, ReuniteError<T>> {
72+
if Arc::ptr_eq(&self.0, &w.0) {
73+
drop(w);
74+
Ok(Arc::try_unwrap(self.0)
75+
.ok()
76+
.expect("`Arc::try_unwrap` failed"))
77+
} else {
78+
Err(ReuniteError(self, w))
79+
}
80+
}
81+
}
82+
6783
impl<T> AsyncRead for OwnedReadHalf<T>
6884
where
6985
for<'a> &'a T: AsyncRead,
@@ -101,3 +117,19 @@ where
101117
self.0.deref().shutdown().await
102118
}
103119
}
120+
121+
/// Error indicating that two halves were not from the same socket, and thus
122+
/// could not be reunited.
123+
#[derive(Debug)]
124+
pub struct ReuniteError<T>(pub OwnedReadHalf<T>, pub OwnedWriteHalf<T>);
125+
126+
impl<T> fmt::Display for ReuniteError<T> {
127+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128+
write!(
129+
f,
130+
"tried to reunite halves that are not from the same socket"
131+
)
132+
}
133+
}
134+
135+
impl<T: fmt::Debug> Error for ReuniteError<T> {}

compio-net/tests/split.rs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
use std::io::{Read, Write};
2+
3+
use compio_buf::BufResult;
4+
use compio_io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5+
use compio_net::{TcpStream, UnixListener, UnixStream};
6+
7+
#[compio_macros::test]
8+
async fn tcp_split() {
9+
const MSG: &[u8] = b"split";
10+
11+
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
12+
let addr = listener.local_addr().unwrap();
13+
14+
let handle = compio_runtime::spawn_blocking(move || {
15+
let (mut stream, _) = listener.accept().unwrap();
16+
stream.write_all(MSG).unwrap();
17+
18+
let mut read_buf = [0u8; 32];
19+
let read_len = stream.read(&mut read_buf).unwrap();
20+
assert_eq!(&read_buf[..read_len], MSG);
21+
});
22+
23+
let stream = TcpStream::connect(&addr).await.unwrap();
24+
let (mut read_half, mut write_half) = stream.into_split();
25+
26+
let read_buf = [0u8; 32];
27+
let (read_res, buf) = read_half.read(read_buf).await.unwrap();
28+
assert_eq!(read_res, MSG.len());
29+
assert_eq!(&buf[..MSG.len()], MSG);
30+
31+
write_half.write_all(MSG).await.unwrap();
32+
handle.await;
33+
}
34+
35+
#[compio_macros::test]
36+
async fn tcp_unsplit() {
37+
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
38+
let addr = listener.local_addr().unwrap();
39+
40+
let handle = compio_runtime::spawn_blocking(move || {
41+
drop(listener.accept().unwrap());
42+
drop(listener.accept().unwrap());
43+
});
44+
45+
let stream1 = TcpStream::connect(&addr).await.unwrap();
46+
let (read1, write1) = stream1.into_split();
47+
48+
let stream2 = TcpStream::connect(&addr).await.unwrap();
49+
let (_, write2) = stream2.into_split();
50+
51+
let read1 = match read1.reunite(write2) {
52+
Ok(_) => panic!("Reunite should not succeed"),
53+
Err(err) => err.0,
54+
};
55+
56+
read1.reunite(write1).expect("Reunite should succeed");
57+
58+
handle.await;
59+
}
60+
61+
#[compio_macros::test]
62+
async fn unix_split() {
63+
let dir = tempfile::Builder::new()
64+
.prefix("compio-uds-split-tests")
65+
.tempdir()
66+
.unwrap();
67+
let sock_path = dir.path().join("connect.sock");
68+
69+
let listener = UnixListener::bind(&sock_path).unwrap();
70+
71+
let client = UnixStream::connect(&sock_path).unwrap();
72+
let (server, _) = listener.accept().await.unwrap();
73+
74+
let (mut a_read, mut a_write) = server.into_split();
75+
let (mut b_read, mut b_write) = client.into_split();
76+
77+
let (a_response, b_response) = futures_util::future::try_join(
78+
send_recv_all(&mut a_read, &mut a_write, b"A"),
79+
send_recv_all(&mut b_read, &mut b_write, b"B"),
80+
)
81+
.await
82+
.unwrap();
83+
84+
assert_eq!(a_response, b"B");
85+
assert_eq!(b_response, b"A");
86+
}
87+
88+
async fn send_recv_all<R: AsyncRead, W: AsyncWrite>(
89+
read: &mut R,
90+
write: &mut W,
91+
input: &'static [u8],
92+
) -> std::io::Result<Vec<u8>> {
93+
write.write_all(input).await.0?;
94+
write.shutdown().await?;
95+
96+
let output = Vec::with_capacity(2);
97+
let BufResult(res, buf) = read.read_exact(output).await;
98+
assert_eq!(res.unwrap_err().kind(), std::io::ErrorKind::UnexpectedEof);
99+
Ok(buf)
100+
}

0 commit comments

Comments
 (0)