Skip to content

Commit 8f7ea0c

Browse files
committed
feat(io): return () in write_all
1 parent 69988b5 commit 8f7ea0c

File tree

5 files changed

+57
-36
lines changed

5 files changed

+57
-36
lines changed

compio-io/src/compat.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,14 @@ impl<S: crate::AsyncWrite> SyncStream<S> {
153153
/// Flush all data in the write buffer.
154154
pub async fn flush_write_buf(&mut self) -> io::Result<usize> {
155155
let stream = &mut self.stream;
156-
let len = self.write_buffer.with(|b| stream.write_all(b)).await?;
156+
let len = self
157+
.write_buffer
158+
.with(|w| async {
159+
let len = w.buf_len();
160+
let BufResult(res, w) = stream.write_all(w).await;
161+
BufResult(res.map(|()| len), w)
162+
})
163+
.await?;
157164
self.write_buffer.reset();
158165
stream.flush().await?;
159166
Ok(len)

compio-io/src/util/mod.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! IO related utilities functions for ease of use.
22
33
mod take;
4+
use compio_buf::{BufResult, IoBuf};
45
pub use take::Take;
56

67
mod null;
@@ -39,13 +40,13 @@ pub async fn copy<'a, R: AsyncRead, W: AsyncWrite>(
3940

4041
// When EOF is reached, we are terminating, so flush before that
4142
if read == 0 || buf.need_flush() {
42-
let written = buf.with(|w| writer.write_all(w)).await?;
43-
if written == 0 {
44-
return Err(std::io::Error::new(
45-
std::io::ErrorKind::WriteZero,
46-
"0 byte was written into the writer",
47-
));
48-
}
43+
let written = buf
44+
.with(|w| async {
45+
let len = w.buf_len();
46+
let BufResult(res, w) = writer.write_all(w).await;
47+
BufResult(res.map(|()| len), w)
48+
})
49+
.await?;
4950
total += written;
5051

5152
if buf.advance(written) {

compio-io/src/write/ext.rs

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,21 @@ macro_rules! write_scalar {
1111
use ::compio_buf::{arrayvec::ArrayVec, BufResult};
1212

1313
const LEN: usize = ::std::mem::size_of::<$t>();
14-
let BufResult(len, _) = self
14+
let BufResult(res, _) = self
1515
.write_all(ArrayVec::<u8, LEN>::from(num.$be()))
1616
.await;
17-
assert_eq!(len?, LEN, "`write_all` returned unexpected length");
18-
Ok(())
17+
res
1918
}
2019

2120
#[doc = concat!("Write a little endian `", stringify!($t), "` into the underlying writer.")]
2221
async fn [< write_ $t _le >](&mut self, num: $t) -> IoResult<()> {
2322
use ::compio_buf::{arrayvec::ArrayVec, BufResult};
2423

2524
const LEN: usize = ::std::mem::size_of::<$t>();
26-
let BufResult(len, _) = self
25+
let BufResult(res, _) = self
2726
.write_all(ArrayVec::<u8, LEN>::from(num.$le()))
2827
.await;
29-
assert_eq!(len?, LEN, "`write_all` returned unexpected length");
30-
Ok(())
28+
res
3129
}
3230
}
3331
};
@@ -57,24 +55,42 @@ macro_rules! loop_write_all {
5755
BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => {
5856
$buf = buf;
5957
}
60-
res => return res,
58+
BufResult(Err(e), buf) => return BufResult(Err(e), buf),
6159
}
6260
}
6361

64-
return BufResult(Ok($needle), $buf);
62+
return BufResult(Ok(()), $buf);
6563
};
6664
}
6765

6866
macro_rules! loop_write_vectored {
69-
(
70-
$buf:ident,
71-
$tracker:ident :
72-
$tracker_ty:ty,
73-
$iter:ident,loop
74-
$read_expr:expr
75-
) => {
76-
loop_write_vectored!($buf, $tracker: $tracker_ty, res, $iter, loop $read_expr, break None)
77-
};
67+
($buf:ident, $tracker:ident : $tracker_ty:ty, $iter:ident,loop $read_expr:expr) => {{
68+
let mut $iter = match $buf.owned_iter() {
69+
Ok(buf) => buf,
70+
Err(buf) => return BufResult(Ok(()), buf),
71+
};
72+
let mut $tracker: $tracker_ty = 0;
73+
74+
loop {
75+
let len = $iter.buf_len();
76+
if len == 0 {
77+
continue;
78+
}
79+
80+
match $read_expr.await {
81+
BufResult(Ok(()), ret) => {
82+
$iter = ret;
83+
$tracker += len as $tracker_ty;
84+
}
85+
BufResult(Err(e), $iter) => return BufResult(Err(e), $iter.into_inner()),
86+
};
87+
88+
match $iter.next() {
89+
Ok(next) => $iter = next,
90+
Err(buf) => return BufResult(Ok(()), buf),
91+
}
92+
}
93+
}};
7894
(
7995
$buf:ident,
8096
$tracker:ident :
@@ -130,7 +146,7 @@ pub trait AsyncWriteExt: AsyncWrite {
130146
}
131147

132148
/// Write the entire contents of a buffer into this writer.
133-
async fn write_all<T: IoBuf>(&mut self, mut buf: T) -> BufResult<usize, T> {
149+
async fn write_all<T: IoBuf>(&mut self, mut buf: T) -> BufResult<(), T> {
134150
loop_write_all!(
135151
buf,
136152
buf.buf_len(),
@@ -142,8 +158,8 @@ pub trait AsyncWriteExt: AsyncWrite {
142158
/// Write the entire contents of a buffer into this writer. Like
143159
/// [`AsyncWrite::write_vectored`], except that it tries to write the entire
144160
/// contents of the buffer into this writer.
145-
async fn write_vectored_all<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
146-
loop_write_vectored!(buf, total: usize, iter, loop self.write_all(iter))
161+
async fn write_vectored_all<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<(), T> {
162+
loop_write_vectored!(buf, _total: usize, iter, loop self.write_all(iter))
147163
}
148164

149165
write_scalar!(u8, to_be_bytes, to_le_bytes);
@@ -168,7 +184,7 @@ impl<A: AsyncWrite + ?Sized> AsyncWriteExt for A {}
168184
pub trait AsyncWriteAtExt: AsyncWriteAt {
169185
/// Like [`AsyncWriteAt::write_at`], except that it tries to write the
170186
/// entire contents of the buffer into this writer.
171-
async fn write_all_at<T: IoBuf>(&mut self, mut buf: T, pos: u64) -> BufResult<usize, T> {
187+
async fn write_all_at<T: IoBuf>(&mut self, mut buf: T, pos: u64) -> BufResult<(), T> {
172188
loop_write_all!(
173189
buf,
174190
buf.buf_len(),
@@ -183,7 +199,7 @@ pub trait AsyncWriteAtExt: AsyncWriteAt {
183199
&mut self,
184200
buf: T,
185201
pos: u64,
186-
) -> BufResult<usize, T> {
202+
) -> BufResult<(), T> {
187203
loop_write_vectored!(buf, total: u64, iter, loop self.write_all_at(iter, pos + total))
188204
}
189205
}

compio-io/tests/io.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,10 @@ impl AsyncWriteAt for WriteOne {
280280
async fn write_all() {
281281
let mut dst = WriteOne(vec![]);
282282

283-
let (len, _) = dst.write_all([1, 1, 4, 5, 1, 4]).await.unwrap();
284-
assert_eq!(len, 6);
283+
let ((), _) = dst.write_all([1, 1, 4, 5, 1, 4]).await.unwrap();
285284
assert_eq!(dst.0, [1, 1, 4, 5, 1, 4]);
286285

287-
let (len, _) = dst.write_all_at([114, 114, 114], 2).await.unwrap();
288-
assert_eq!(len, 3);
286+
let ((), _) = dst.write_all_at([114, 114, 114], 2).await.unwrap();
289287
assert_eq!(dst.0, [1, 1, 114, 114, 114, 4]);
290288
}
291289

compio-net/tests/unix_stream.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ async fn accept_read_write() -> std::io::Result<()> {
1414
let mut client = UnixStream::connect(&sock_path)?;
1515
let (mut server, _) = listener.accept().await?;
1616

17-
let write_len = client.write_all("hello").await.0?;
18-
assert_eq!(write_len, 5);
17+
client.write_all("hello").await.0?;
1918
drop(client);
2019

2120
let buf = Vec::with_capacity(5);

0 commit comments

Comments
 (0)