Skip to content

Commit 3a153ac

Browse files
Forward vectored writes (#45)
* Migrate early-data test to rustls * Replace `match` with `if let` on `TlsState::EarlyData` * Extract client early data handling * Forward vectored writes
1 parent 925a87f commit 3a153ac

File tree

8 files changed

+328
-168
lines changed

8 files changed

+328
-168
lines changed

src/client.rs

Lines changed: 120 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use std::os::unix::io::{AsRawFd, RawFd};
44
#[cfg(windows)]
55
use std::os::windows::io::{AsRawSocket, RawSocket};
66
use std::pin::Pin;
7+
#[cfg(feature = "early-data")]
8+
use std::task::Waker;
79
use std::task::{Context, Poll};
810

911
use rustls::ClientConnection;
@@ -20,7 +22,7 @@ pub struct TlsStream<IO> {
2022
pub(crate) state: TlsState,
2123

2224
#[cfg(feature = "early-data")]
23-
pub(crate) early_waker: Option<std::task::Waker>,
25+
pub(crate) early_waker: Option<Waker>,
2426
}
2527

2628
impl<IO> TlsStream<IO> {
@@ -152,78 +154,70 @@ where
152154
let mut stream =
153155
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
154156

155-
#[allow(clippy::match_single_binding)]
156-
match this.state {
157-
#[cfg(feature = "early-data")]
158-
TlsState::EarlyData(ref mut pos, ref mut data) => {
159-
use std::io::Write;
160-
161-
// write early data
162-
if let Some(mut early_data) = stream.session.early_data() {
163-
let len = match early_data.write(buf) {
164-
Ok(n) => n,
165-
Err(err) => return Poll::Ready(Err(err)),
166-
};
167-
if len != 0 {
168-
data.extend_from_slice(&buf[..len]);
169-
return Poll::Ready(Ok(len));
170-
}
171-
}
172-
173-
// complete handshake
174-
while stream.session.is_handshaking() {
175-
ready!(stream.handshake(cx))?;
176-
}
177-
178-
// write early data (fallback)
179-
if !stream.session.is_early_data_accepted() {
180-
while *pos < data.len() {
181-
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
182-
*pos += len;
183-
}
184-
}
185-
186-
// end
187-
this.state = TlsState::Stream;
188-
189-
if let Some(waker) = this.early_waker.take() {
190-
waker.wake();
191-
}
192-
193-
stream.as_mut_pin().poll_write(cx, buf)
157+
#[cfg(feature = "early-data")]
158+
{
159+
let bufs = [io::IoSlice::new(buf)];
160+
let written = ready!(poll_handle_early_data(
161+
&mut this.state,
162+
&mut stream,
163+
&mut this.early_waker,
164+
cx,
165+
&bufs
166+
))?;
167+
if written != 0 {
168+
return Poll::Ready(Ok(written));
194169
}
195-
_ => stream.as_mut_pin().poll_write(cx, buf),
196170
}
171+
172+
stream.as_mut_pin().poll_write(cx, buf)
197173
}
198174

199-
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
175+
/// Note: that it does not guarantee the final data to be sent.
176+
/// To be cautious, you must manually call `flush`.
177+
fn poll_write_vectored(
178+
self: Pin<&mut Self>,
179+
cx: &mut Context<'_>,
180+
bufs: &[io::IoSlice<'_>],
181+
) -> Poll<io::Result<usize>> {
200182
let this = self.get_mut();
201183
let mut stream =
202184
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
203185

204186
#[cfg(feature = "early-data")]
205187
{
206-
if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
207-
// complete handshake
208-
while stream.session.is_handshaking() {
209-
ready!(stream.handshake(cx))?;
210-
}
188+
let written = ready!(poll_handle_early_data(
189+
&mut this.state,
190+
&mut stream,
191+
&mut this.early_waker,
192+
cx,
193+
bufs
194+
))?;
195+
if written != 0 {
196+
return Poll::Ready(Ok(written));
197+
}
198+
}
211199

212-
// write early data (fallback)
213-
if !stream.session.is_early_data_accepted() {
214-
while *pos < data.len() {
215-
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
216-
*pos += len;
217-
}
218-
}
200+
stream.as_mut_pin().poll_write_vectored(cx, bufs)
201+
}
219202

220-
this.state = TlsState::Stream;
203+
#[inline]
204+
fn is_write_vectored(&self) -> bool {
205+
true
206+
}
221207

222-
if let Some(waker) = this.early_waker.take() {
223-
waker.wake();
224-
}
225-
}
226-
}
208+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
209+
let this = self.get_mut();
210+
let mut stream =
211+
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
212+
213+
#[cfg(feature = "early-data")]
214+
ready!(poll_handle_early_data(
215+
&mut this.state,
216+
&mut stream,
217+
&mut this.early_waker,
218+
cx,
219+
&[]
220+
))?;
227221

228222
stream.as_mut_pin().poll_flush(cx)
229223
}
@@ -248,3 +242,69 @@ where
248242
stream.as_mut_pin().poll_shutdown(cx)
249243
}
250244
}
245+
246+
#[cfg(feature = "early-data")]
247+
fn poll_handle_early_data<IO>(
248+
state: &mut TlsState,
249+
stream: &mut Stream<IO, ClientConnection>,
250+
early_waker: &mut Option<Waker>,
251+
cx: &mut Context<'_>,
252+
bufs: &[io::IoSlice<'_>],
253+
) -> Poll<io::Result<usize>>
254+
where
255+
IO: AsyncRead + AsyncWrite + Unpin,
256+
{
257+
if let TlsState::EarlyData(pos, data) = state {
258+
use std::io::Write;
259+
260+
// write early data
261+
if let Some(mut early_data) = stream.session.early_data() {
262+
let mut written = 0;
263+
264+
for buf in bufs {
265+
if buf.is_empty() {
266+
continue;
267+
}
268+
269+
let len = match early_data.write(buf) {
270+
Ok(0) => break,
271+
Ok(n) => n,
272+
Err(err) => return Poll::Ready(Err(err)),
273+
};
274+
275+
written += len;
276+
data.extend_from_slice(&buf[..len]);
277+
278+
if len < buf.len() {
279+
break;
280+
}
281+
}
282+
283+
if written != 0 {
284+
return Poll::Ready(Ok(written));
285+
}
286+
}
287+
288+
// complete handshake
289+
while stream.session.is_handshaking() {
290+
ready!(stream.handshake(cx))?;
291+
}
292+
293+
// write early data (fallback)
294+
if !stream.session.is_early_data_accepted() {
295+
while *pos < data.len() {
296+
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
297+
*pos += len;
298+
}
299+
}
300+
301+
// end
302+
*state = TlsState::Stream;
303+
304+
if let Some(waker) = early_waker.take() {
305+
waker.wake();
306+
}
307+
}
308+
309+
Poll::Ready(Ok(0))
310+
}

src/common/mod.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,43 @@ where
289289
Poll::Ready(Ok(pos))
290290
}
291291

292+
fn poll_write_vectored(
293+
mut self: Pin<&mut Self>,
294+
cx: &mut Context<'_>,
295+
bufs: &[IoSlice<'_>],
296+
) -> Poll<io::Result<usize>> {
297+
if bufs.iter().all(|buf| buf.is_empty()) {
298+
return Poll::Ready(Ok(0));
299+
}
300+
301+
loop {
302+
let mut would_block = false;
303+
let written = self.session.writer().write_vectored(bufs)?;
304+
305+
while self.session.wants_write() {
306+
match self.write_io(cx) {
307+
Poll::Ready(Ok(0)) | Poll::Pending => {
308+
would_block = true;
309+
break;
310+
}
311+
Poll::Ready(Ok(_)) => (),
312+
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
313+
}
314+
}
315+
316+
return match (written, would_block) {
317+
(0, true) => Poll::Pending,
318+
(0, false) => continue,
319+
(n, _) => Poll::Ready(Ok(n)),
320+
};
321+
}
322+
}
323+
324+
#[inline]
325+
fn is_write_vectored(&self) -> bool {
326+
true
327+
}
328+
292329
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
293330
self.session.writer().flush()?;
294331
while self.session.wants_write() {

src/common/test_stream.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,15 @@ impl AsyncWrite for Expected {
122122

123123
#[tokio::test]
124124
async fn stream_good() -> io::Result<()> {
125+
stream_good_impl(false).await
126+
}
127+
128+
#[tokio::test]
129+
async fn stream_good_vectored() -> io::Result<()> {
130+
stream_good_impl(true).await
131+
}
132+
133+
async fn stream_good_impl(vectored: bool) -> io::Result<()> {
125134
const FILE: &[u8] = include_bytes!("../../README.md");
126135

127136
let (server, mut client) = make_pair();
@@ -139,7 +148,7 @@ async fn stream_good() -> io::Result<()> {
139148
dbg!(stream.read_to_end(&mut buf).await)?;
140149
assert_eq!(buf, FILE);
141150

142-
dbg!(stream.write_all(b"Hello World!").await)?;
151+
dbg!(utils::write(&mut stream, b"Hello World!", vectored).await)?;
143152
stream.session.send_close_notify();
144153

145154
dbg!(stream.shutdown().await)?;

src/lib.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,26 @@ where
564564
}
565565
}
566566

567+
#[inline]
568+
fn poll_write_vectored(
569+
self: Pin<&mut Self>,
570+
cx: &mut Context<'_>,
571+
bufs: &[io::IoSlice<'_>],
572+
) -> Poll<io::Result<usize>> {
573+
match self.get_mut() {
574+
TlsStream::Client(x) => Pin::new(x).poll_write_vectored(cx, bufs),
575+
TlsStream::Server(x) => Pin::new(x).poll_write_vectored(cx, bufs),
576+
}
577+
}
578+
579+
#[inline]
580+
fn is_write_vectored(&self) -> bool {
581+
match self {
582+
TlsStream::Client(x) => x.is_write_vectored(),
583+
TlsStream::Server(x) => x.is_write_vectored(),
584+
}
585+
}
586+
567587
#[inline]
568588
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
569589
match self.get_mut() {

src/server.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,24 @@ where
113113
stream.as_mut_pin().poll_write(cx, buf)
114114
}
115115

116+
/// Note: that it does not guarantee the final data to be sent.
117+
/// To be cautious, you must manually call `flush`.
118+
fn poll_write_vectored(
119+
self: Pin<&mut Self>,
120+
cx: &mut Context<'_>,
121+
bufs: &[io::IoSlice<'_>],
122+
) -> Poll<io::Result<usize>> {
123+
let this = self.get_mut();
124+
let mut stream =
125+
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
126+
stream.as_mut_pin().poll_write_vectored(cx, bufs)
127+
}
128+
129+
#[inline]
130+
fn is_write_vectored(&self) -> bool {
131+
true
132+
}
133+
116134
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
117135
let this = self.get_mut();
118136
let mut stream =

0 commit comments

Comments
 (0)