Skip to content

Commit ba69687

Browse files
Replace SimStream with DuplexStream (#288)
1 parent a5f642d commit ba69687

File tree

2 files changed

+66
-194
lines changed

2 files changed

+66
-194
lines changed

lambda/src/client.rs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,10 @@ mod endpoint_tests {
6666
use http::{uri::PathAndQuery, HeaderValue, Method, Request, Response, StatusCode, Uri};
6767
use hyper::{server::conn::Http, service::service_fn, Body};
6868
use serde_json::json;
69+
use simulated::DuplexStreamWrapper;
6970
use std::convert::TryFrom;
7071
use tokio::{
71-
io::{AsyncRead, AsyncWrite},
72+
io::{self, AsyncRead, AsyncWrite},
7273
select,
7374
sync::{self, oneshot},
7475
};
@@ -161,14 +162,14 @@ mod endpoint_tests {
161162
#[tokio::test]
162163
async fn test_next_event() -> Result<(), Error> {
163164
let base = Uri::from_static("http://localhost:9001");
164-
let (client, server) = crate::simulated::chan();
165+
let (client, server) = io::duplex(64);
165166

166167
let (tx, rx) = sync::oneshot::channel();
167168
let server = tokio::spawn(async {
168169
handle(server, rx).await.expect("Unable to handle request");
169170
});
170171

171-
let conn = simulated::Connector { inner: client };
172+
let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?;
172173
let client = Client::with(base, conn);
173174

174175
let req = NextEventRequest.into_req()?;
@@ -189,15 +190,15 @@ mod endpoint_tests {
189190

190191
#[tokio::test]
191192
async fn test_ok_response() -> Result<(), Error> {
192-
let (client, server) = crate::simulated::chan();
193+
let (client, server) = io::duplex(64);
193194
let (tx, rx) = sync::oneshot::channel();
194195
let base = Uri::from_static("http://localhost:9001");
195196

196197
let server = tokio::spawn(async {
197198
handle(server, rx).await.expect("Unable to handle request");
198199
});
199200

200-
let conn = simulated::Connector { inner: client };
201+
let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?;
201202
let client = Client::with(base, conn);
202203

203204
let req = EventCompletionRequest {
@@ -220,15 +221,15 @@ mod endpoint_tests {
220221

221222
#[tokio::test]
222223
async fn test_error_response() -> Result<(), Error> {
223-
let (client, server) = crate::simulated::chan();
224+
let (client, server) = io::duplex(200);
224225
let (tx, rx) = sync::oneshot::channel();
225226
let base = Uri::from_static("http://localhost:9001");
226227

227228
let server = tokio::spawn(async {
228229
handle(server, rx).await.expect("Unable to handle request");
229230
});
230231

231-
let conn = simulated::Connector { inner: client };
232+
let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?;
232233
let client = Client::with(base, conn);
233234

234235
let req = EventErrorRequest {
@@ -253,14 +254,14 @@ mod endpoint_tests {
253254

254255
#[tokio::test]
255256
async fn successful_end_to_end_run() -> Result<(), Error> {
256-
let (client, server) = crate::simulated::chan();
257+
let (client, server) = io::duplex(64);
257258
let (tx, rx) = sync::oneshot::channel();
258259
let base = Uri::from_static("http://localhost:9001");
259260

260261
let server = tokio::spawn(async {
261262
handle(server, rx).await.expect("Unable to handle request");
262263
});
263-
let conn = simulated::Connector { inner: client };
264+
let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?;
264265

265266
let runtime = Runtime::builder()
266267
.with_endpoint(base)

lambda/src/simulated.rs

Lines changed: 56 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -1,229 +1,100 @@
11
use http::Uri;
22
use hyper::client::connect::Connection;
33
use std::{
4-
cmp::min,
5-
collections::VecDeque,
4+
collections::HashMap,
65
future::Future,
76
io::Result as IoResult,
87
pin::Pin,
98
sync::{Arc, Mutex},
10-
task::{Context, Poll, Waker},
9+
task::{Context, Poll},
1110
};
12-
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11+
use tokio::io::{AsyncRead, AsyncWrite, DuplexStream, ReadBuf};
1312

14-
/// Creates a pair of `AsyncRead`/`AsyncWrite` data streams, where the write end of each member of the pair
15-
/// is the read end of the other member of the pair. This allows us to emulate the behavior of a TcpStream
16-
/// but in-memory, deterministically, and with full control over failure injection.
17-
pub(crate) fn chan() -> (SimStream, SimStream) {
18-
// Set up two reference-counted, lock-guarded byte VecDeques, one for each direction of the
19-
// connection
20-
let one = Arc::new(Mutex::new(BufferState::new()));
21-
let two = Arc::new(Mutex::new(BufferState::new()));
22-
23-
// Use buf1 for the read-side of left, use buf2 for the write-side of left
24-
let left = SimStream {
25-
read: ReadHalf { buffer: one.clone() },
26-
write: WriteHalf { buffer: two.clone() },
27-
};
28-
29-
// Now swap the buffers for right
30-
let right = SimStream {
31-
read: ReadHalf { buffer: two },
32-
write: WriteHalf { buffer: one },
33-
};
34-
35-
(left, right)
36-
}
13+
use crate::Error;
3714

3815
#[derive(Clone)]
3916
pub struct Connector {
40-
pub inner: SimStream,
41-
}
42-
43-
impl hyper::service::Service<Uri> for Connector {
44-
type Response = SimStream;
45-
type Error = std::io::Error;
46-
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
47-
48-
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
49-
Poll::Ready(Ok(()))
50-
}
51-
52-
fn call(&mut self, _: Uri) -> Self::Future {
53-
let inner = self.inner.clone();
54-
Box::pin(async move { Ok(inner) })
55-
}
17+
inner: Arc<Mutex<HashMap<Uri, DuplexStreamWrapper>>>,
5618
}
5719

58-
impl Connection for SimStream {
59-
fn connected(&self) -> hyper::client::connect::Connected {
60-
hyper::client::connect::Connected::new()
61-
}
62-
}
63-
64-
/// A struct that implements AsyncRead + AsyncWrite (similarly to TcpStream) using in-memory
65-
/// bytes only. Unfortunately tokio does not provide an operation that is the opposite of
66-
/// `tokio::io::split`, as that would negate the need for this struct.
67-
// TODO: Implement the ability to explicitly close a connection
68-
#[derive(Debug, Clone)]
69-
pub struct SimStream {
70-
read: ReadHalf,
71-
write: WriteHalf,
72-
}
73-
74-
/// Delegates to the underlying `write` member's methods
75-
impl AsyncWrite for SimStream {
76-
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
77-
Pin::new(&mut self.write).poll_write(cx, buf)
78-
}
79-
80-
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
81-
Pin::new(&mut self.write).poll_flush(cx)
82-
}
83-
84-
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
85-
Pin::new(&mut self.write).poll_shutdown(cx)
86-
}
87-
}
20+
pub struct DuplexStreamWrapper(DuplexStream);
8821

89-
/// Delegates to the underlying `read` member's methods
90-
impl AsyncRead for SimStream {
91-
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<IoResult<()>> {
92-
Pin::new(&mut self.read).poll_read(cx, buf)
22+
impl DuplexStreamWrapper {
23+
pub(crate) fn new(stream: DuplexStream) -> DuplexStreamWrapper {
24+
DuplexStreamWrapper(stream)
9325
}
9426
}
9527

96-
/// A buffer for use with ReadHalf/WriteHalf that allows bytes to be written at one end of a
97-
/// dequeue and read from the other end. If a `read_waker` is provided, the BufferState will call
98-
/// `wake()` when there is new data to be read.
99-
#[derive(Debug, Clone)]
100-
pub struct BufferState {
101-
buffer: VecDeque<u8>,
102-
read_waker: Option<Waker>,
103-
}
104-
105-
impl BufferState {
106-
/// Creates a new `BufferState`.
107-
fn new() -> Self {
108-
BufferState {
109-
buffer: VecDeque::new(),
110-
read_waker: None,
28+
impl Connector {
29+
pub fn new() -> Self {
30+
#[allow(clippy::mutable_key_type)]
31+
let map = HashMap::new();
32+
Connector {
33+
inner: Arc::new(Mutex::new(map)),
11134
}
11235
}
113-
/// Writes data to the front of the deque byte buffer
114-
fn write(&mut self, buf: &[u8]) {
115-
for b in buf {
116-
self.buffer.push_front(*b)
117-
}
11836

119-
// If somebody is waiting on this data, wake them up.
120-
if let Some(waker) = self.read_waker.take() {
121-
waker.wake();
37+
pub fn insert(&self, uri: Uri, stream: DuplexStreamWrapper) -> Result<(), Error> {
38+
match self.inner.lock() {
39+
Ok(mut map) => {
40+
map.insert(uri, stream);
41+
Ok(())
42+
}
43+
Err(_) => Err("mutex was poisoned".into()),
12244
}
12345
}
12446

125-
/// Read data from the end of the deque byte buffer
126-
fn read(&mut self, to_buf: &mut ReadBuf<'_>) -> usize {
127-
// Read no more bytes than we have available, and no more bytes than we were asked for
128-
let bytes_to_read = min(to_buf.remaining(), self.buffer.len());
129-
for _ in 0..bytes_to_read {
130-
to_buf.put_slice(&[self.buffer.pop_back().unwrap()]);
47+
pub fn with(uri: Uri, stream: DuplexStreamWrapper) -> Result<Self, Error> {
48+
let connector = Connector::new();
49+
match connector.insert(uri, stream) {
50+
Ok(_) => Ok(connector),
51+
Err(e) => Err(e),
13152
}
132-
133-
bytes_to_read
13453
}
13554
}
13655

137-
/// An AsyncWrite implementation that uses a VecDeque of bytes as a buffer. The WriteHalf will
138-
/// add new bytes to the front of the deque using push_front.
139-
///
140-
/// Intended for use with ReadHalf to read from the VecDeque
141-
#[derive(Debug, Clone)]
142-
pub struct WriteHalf {
143-
buffer: Arc<Mutex<BufferState>>,
144-
}
145-
146-
impl AsyncWrite for WriteHalf {
147-
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
148-
// Acquire the lock for the buffer
149-
let mut write_to = self
150-
.buffer
151-
.lock()
152-
.expect("Lock was poisoned when acquiring buffer lock for WriteHalf");
153-
154-
// write the bytes
155-
write_to.write(buf);
156-
157-
// This operation completes immediately
158-
Poll::Ready(Ok(buf.len()))
159-
}
56+
impl hyper::service::Service<Uri> for Connector {
57+
type Response = DuplexStreamWrapper;
58+
type Error = crate::Error;
59+
#[allow(clippy::type_complexity)]
60+
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
16061

161-
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
62+
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
16263
Poll::Ready(Ok(()))
16364
}
16465

165-
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
166-
Poll::Ready(Ok(()))
66+
fn call(&mut self, uri: Uri) -> Self::Future {
67+
let res = match self.inner.lock() {
68+
Ok(mut map) if map.contains_key(&uri) => Ok(map.remove(&uri).unwrap()),
69+
Ok(_) => Err(format!("Uri {} is not in map", uri).into()),
70+
Err(_) => Err("mutex was poisoned".into()),
71+
};
72+
Box::pin(async move { res })
16773
}
16874
}
16975

170-
#[derive(Debug, Clone)]
171-
pub struct ReadHalf {
172-
buffer: Arc<Mutex<BufferState>>,
76+
impl Connection for DuplexStreamWrapper {
77+
fn connected(&self) -> hyper::client::connect::Connected {
78+
hyper::client::connect::Connected::new()
79+
}
17380
}
17481

175-
impl AsyncRead for ReadHalf {
176-
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<IoResult<()>> {
177-
// Acquire the lock for the buffer
178-
let mut read_from = self
179-
.buffer
180-
.lock()
181-
.expect("Lock was poisoned when acquiring buffer lock for ReadHalf");
182-
183-
let bytes_read = read_from.read(buf);
184-
185-
// Returning Poll::Ready(Ok(0)) would indicate that there is nothing more to read, which
186-
// means that someone trying to read from a VecDeque that hasn't been written to yet
187-
// would get an Eof error (as I learned the hard way). Instead we should return Poll:Pending
188-
// to indicate that there could be more to read in the future.
189-
if bytes_read == 0 {
190-
read_from.read_waker = Some(cx.waker().clone());
191-
Poll::Pending
192-
} else {
193-
//read_from.read_waker = Some(cx.waker().clone());
194-
Poll::Ready(Ok(()))
195-
}
82+
impl AsyncRead for DuplexStreamWrapper {
83+
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<IoResult<()>> {
84+
Pin::new(&mut self.0).poll_read(cx, buf)
19685
}
19786
}
19887

199-
#[cfg(test)]
200-
mod tests {
201-
use super::chan;
202-
use tokio::io::{AsyncReadExt, AsyncWriteExt};
203-
204-
#[tokio::test]
205-
async fn ends_should_talk_to_each_other() {
206-
let (mut client, mut server) = chan();
207-
// Write ping to the side 1
208-
client.write_all(b"Ping").await.expect("Write should succeed");
209-
210-
// Verify we can read it from side 2
211-
let mut read_on_server = [0_u8; 4];
212-
server
213-
.read_exact(&mut read_on_server)
214-
.await
215-
.expect("Read should succeed");
216-
assert_eq!(&read_on_server, b"Ping");
88+
impl AsyncWrite for DuplexStreamWrapper {
89+
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
90+
Pin::new(&mut self.0).poll_write(cx, buf)
91+
}
21792

218-
// Write "Pong" to side 2
219-
server.write_all(b"Pong").await.expect("Write should succeed");
93+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
94+
Pin::new(&mut self.0).poll_flush(cx)
95+
}
22096

221-
// Verify we can read it from side 1
222-
let mut read_on_client = [0_u8; 4];
223-
client
224-
.read_exact(&mut read_on_client)
225-
.await
226-
.expect("Read should succeed");
227-
assert_eq!(&read_on_client, b"Pong");
97+
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
98+
Pin::new(&mut self.0).poll_shutdown(cx)
22899
}
229100
}

0 commit comments

Comments
 (0)