Skip to content

Commit a10e443

Browse files
davidpdrsncratelyn
authored andcommitted
Add channel body
1 parent 29642e8 commit a10e443

File tree

3 files changed

+173
-0
lines changed

3 files changed

+173
-0
lines changed

http-body-util/Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,21 @@ keywords = ["http"]
2626
categories = ["web-programming"]
2727
rust-version = "1.61"
2828

29+
[features]
30+
default = []
31+
channel = ["dep:tokio"]
32+
full = ["channel"]
33+
2934
[dependencies]
3035
bytes = "1"
3136
futures-core = { version = "0.3", default-features = false }
3237
http = "1"
3338
http-body = { version = "1", path = "../http-body" }
3439
pin-project-lite = "0.2"
3540

41+
# optional dependencies
42+
tokio = { version = "1", features = ["sync"], optional = true }
43+
3644
[dev-dependencies]
3745
futures-util = { version = "0.3", default-features = false }
3846
tokio = { version = "1", features = ["macros", "rt", "sync", "rt-multi-thread"] }

http-body-util/src/channel.rs

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
//! A body backed by a channel.
2+
3+
use std::{
4+
fmt::Display,
5+
pin::Pin,
6+
task::{Context, Poll},
7+
};
8+
9+
use bytes::Buf;
10+
use http::HeaderMap;
11+
use http_body::{Body, Frame};
12+
use tokio::sync::mpsc;
13+
14+
/// A body backed by a channel.
15+
pub struct Channel<D, E = std::convert::Infallible> {
16+
rx_frame: mpsc::Receiver<Frame<D>>,
17+
rx_error: mpsc::Receiver<E>,
18+
}
19+
20+
impl<D, E> Channel<D, E> {
21+
/// Create a new channel body.
22+
///
23+
/// The channel will buffer up to the provided number of messages. Once the buffer is full,
24+
/// attempts to send new messages will wait until a message is received from the channel. The
25+
/// provided buffer capacity must be at least 1.
26+
pub fn new(buffer: usize) -> (Sender<D, E>, Self) {
27+
let (tx_frame, rx_frame) = mpsc::channel(buffer);
28+
let (tx_error, rx_error) = mpsc::channel(1);
29+
(Sender { tx_frame, tx_error }, Self { rx_frame, rx_error })
30+
}
31+
}
32+
33+
impl<D, E> Body for Channel<D, E>
34+
where
35+
D: Buf,
36+
{
37+
type Data = D;
38+
type Error = E;
39+
40+
fn poll_frame(
41+
mut self: Pin<&mut Self>,
42+
cx: &mut Context<'_>,
43+
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
44+
match self.rx_frame.poll_recv(cx) {
45+
Poll::Ready(frame) => return Poll::Ready(frame.map(Ok)),
46+
Poll::Pending => {}
47+
}
48+
49+
match self.rx_error.poll_recv(cx) {
50+
Poll::Ready(err) => return Poll::Ready(err.map(Err)),
51+
Poll::Pending => {}
52+
}
53+
54+
Poll::Pending
55+
}
56+
}
57+
58+
impl<D, E> std::fmt::Debug for Channel<D, E> {
59+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60+
f.debug_struct("Channel")
61+
.field("rx_frame", &self.rx_frame)
62+
.field("rx_error", &self.rx_error)
63+
.finish()
64+
}
65+
}
66+
67+
/// A sender half created through [`Channel::new`].
68+
pub struct Sender<D, E = std::convert::Infallible> {
69+
tx_frame: mpsc::Sender<Frame<D>>,
70+
tx_error: mpsc::Sender<E>,
71+
}
72+
73+
impl<D, E> Sender<D, E> {
74+
/// Send a frame on the channel.
75+
pub async fn send(&self, frame: Frame<D>) -> Result<(), SendError> {
76+
self.tx_frame.send(frame).await.map_err(|_| SendError)
77+
}
78+
79+
/// Send data on data channel.
80+
pub async fn send_data(&self, buf: D) -> Result<(), SendError> {
81+
self.send(Frame::data(buf)).await
82+
}
83+
84+
/// Send trailers on trailers channel.
85+
pub async fn send_trailers(&self, trailers: HeaderMap) -> Result<(), SendError> {
86+
self.send(Frame::trailers(trailers)).await
87+
}
88+
89+
/// Aborts the body in an abnormal fashion.
90+
pub fn abort(self, error: E) {
91+
match self.tx_error.try_send(error) {
92+
Ok(_) => {}
93+
Err(err) => {
94+
match err {
95+
mpsc::error::TrySendError::Full(_) => {
96+
// Channel::new creates the error channel with space for 1 message and we
97+
// only send once because this method consumes `self`. So the receiver
98+
// can't be full.
99+
unreachable!("error receiver should never be full")
100+
}
101+
mpsc::error::TrySendError::Closed(_) => {}
102+
}
103+
}
104+
}
105+
}
106+
}
107+
108+
impl<D, E> std::fmt::Debug for Sender<D, E> {
109+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110+
f.debug_struct("Sender")
111+
.field("tx_frame", &self.tx_frame)
112+
.field("tx_error", &self.tx_error)
113+
.finish()
114+
}
115+
}
116+
117+
/// The error returned if [`Sender`] fails to send because the receiver is closed.
118+
#[derive(Debug)]
119+
#[non_exhaustive]
120+
pub struct SendError;
121+
122+
impl Display for SendError {
123+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124+
write!(f, "failed to send frame")
125+
}
126+
}
127+
128+
impl std::error::Error for SendError {}
129+
130+
#[cfg(test)]
131+
mod tests {
132+
use bytes::Bytes;
133+
use http::{HeaderName, HeaderValue};
134+
135+
use crate::BodyExt;
136+
137+
use super::*;
138+
139+
#[tokio::test]
140+
async fn works() {
141+
let (tx, body) = Channel::<Bytes>::new(1024);
142+
143+
tokio::spawn(async move {
144+
tx.send_data(Bytes::from("Hel")).await.unwrap();
145+
tx.send_data(Bytes::from("lo!")).await.unwrap();
146+
147+
let mut trailers = HeaderMap::new();
148+
trailers.insert(
149+
HeaderName::from_static("foo"),
150+
HeaderValue::from_static("bar"),
151+
);
152+
tx.send_trailers(trailers).await.unwrap();
153+
});
154+
155+
let collected = body.collect().await.unwrap();
156+
assert_eq!(collected.trailers().unwrap()["foo"], "bar");
157+
assert_eq!(collected.to_bytes(), "Hello!");
158+
}
159+
}

http-body-util/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ mod full;
1515
mod limited;
1616
mod stream;
1717

18+
#[cfg(feature = "channel")]
19+
pub mod channel;
20+
1821
mod util;
1922

2023
use self::combinators::{BoxBody, MapErr, MapFrame, UnsyncBoxBody};
@@ -26,6 +29,9 @@ pub use self::full::Full;
2629
pub use self::limited::{LengthLimitError, Limited};
2730
pub use self::stream::{BodyDataStream, BodyStream, StreamBody};
2831

32+
#[cfg(feature = "channel")]
33+
pub use self::channel::Channel;
34+
2935
/// An extension trait for [`http_body::Body`] adding various combinators and adapters
3036
pub trait BodyExt: http_body::Body {
3137
/// Returns a future that resolves to the next [`Frame`], if any.

0 commit comments

Comments
 (0)