Skip to content

feat(http1): customizable error messages #3911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/proto/h1/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@ use std::future::Future;
use std::io;
use std::marker::{PhantomData, Unpin};
use std::pin::Pin;
#[cfg(feature = "server")]
use std::sync::Arc;
use std::task::{Context, Poll};
#[cfg(feature = "server")]
use std::time::{Duration, Instant};

use crate::rt::{Read, Write};
#[cfg(feature = "server")]
use crate::server::conn::http1::Http1ErrorResponder;
use bytes::{Buf, Bytes};
use futures_core::ready;
use http::header::{HeaderValue, CONNECTION, TE};
Expand Down Expand Up @@ -61,6 +65,8 @@ where
#[cfg(feature = "server")]
h1_header_read_timeout: None,
#[cfg(feature = "server")]
h1_error_responder: None,
#[cfg(feature = "server")]
h1_header_read_timeout_fut: None,
#[cfg(feature = "server")]
h1_header_read_timeout_running: false,
Expand Down Expand Up @@ -156,6 +162,11 @@ where
self.state.date_header = false;
}

#[cfg(feature = "server")]
pub(crate) fn set_error_responder(&mut self, val: Arc<dyn Http1ErrorResponder>) {
self.state.h1_error_responder = Some(val);
}

pub(crate) fn into_inner(self) -> (I, Bytes) {
self.io.into_inner()
}
Expand Down Expand Up @@ -810,10 +821,16 @@ where
if self.has_h2_prefix() {
return Err(crate::Error::new_version_h2());
}
if let Some(msg) = T::on_error(&err) {

if let Some(msg) = T::on_error(
&err,
#[cfg(feature = "server")]
&self.state.h1_error_responder,
) {
// Drop the cached headers so as to not trigger a debug
// assert in `write_head`...
self.state.cached_headers.take();
debug!("writing head");
self.write_head(msg, None);
self.state.error = Some(err);
return Ok(());
Expand Down Expand Up @@ -927,6 +944,8 @@ struct State {
#[cfg(feature = "server")]
h1_header_read_timeout: Option<Duration>,
#[cfg(feature = "server")]
h1_error_responder: Option<Arc<dyn Http1ErrorResponder>>,
#[cfg(feature = "server")]
h1_header_read_timeout_fut: Option<Pin<Box<dyn Sleep>>>,
#[cfg(feature = "server")]
h1_header_read_timeout_running: bool,
Expand Down
10 changes: 9 additions & 1 deletion src/proto/h1/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ use httparse::ParserConfig;
use crate::body::DecodedLength;
use crate::proto::{BodyLength, MessageHead};

#[cfg(feature = "server")]
use crate::server::conn::http1::Http1ErrorResponder;
#[cfg(feature = "server")]
use std::sync::Arc;

pub(crate) use self::conn::Conn;
pub(crate) use self::decode::Decoder;
pub(crate) use self::dispatch::Dispatcher;
Expand Down Expand Up @@ -35,7 +40,10 @@ pub(crate) trait Http1Transaction {
fn parse(bytes: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<Self::Incoming>;
fn encode(enc: Encode<'_, Self::Outgoing>, dst: &mut Vec<u8>) -> crate::Result<Encoder>;

fn on_error(err: &crate::Error) -> Option<MessageHead<Self::Outgoing>>;
fn on_error(
err: &crate::Error,
#[cfg(feature = "server")] responder: &Option<Arc<dyn Http1ErrorResponder>>,
) -> Option<MessageHead<Self::Outgoing>>;

fn is_client() -> bool {
!Self::is_server()
Expand Down
66 changes: 47 additions & 19 deletions src/proto/h1/role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::mem::MaybeUninit;

#[cfg(feature = "client")]
use std::fmt::{self, Write as _};
#[cfg(feature = "server")]
use std::sync::Arc;

use bytes::Bytes;
use bytes::BytesMut;
Expand All @@ -16,6 +18,8 @@ use smallvec::{smallvec, smallvec_inline, SmallVec};
use crate::body::DecodedLength;
#[cfg(feature = "server")]
use crate::common::date;
#[cfg(feature = "server")]
use crate::error::Kind;
use crate::error::Parse;
use crate::ext::HeaderCaseMap;
#[cfg(feature = "ffi")]
Expand All @@ -27,6 +31,8 @@ use crate::proto::h1::{
#[cfg(feature = "client")]
use crate::proto::RequestHead;
use crate::proto::{BodyLength, MessageHead, RequestLine};
#[cfg(feature = "server")]
use crate::server::conn::http1::Http1ErrorResponder;

pub(crate) const DEFAULT_MAX_HEADERS: usize = 100;
const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific
Expand Down Expand Up @@ -127,6 +133,30 @@ pub(crate) enum Client {}
#[cfg(feature = "server")]
pub(crate) enum Server {}

#[cfg(feature = "server")]
pub(crate) fn default_error_response(kind: &Kind) -> Option<crate::Response<()>> {
use crate::error::Kind;
use crate::error::Parse;
use http::StatusCode;
let status = match kind {
Kind::Parse(Parse::Method)
| Kind::Parse(Parse::Header(_))
| Kind::Parse(Parse::Uri)
| Kind::Parse(Parse::Version) => StatusCode::BAD_REQUEST,
Kind::Parse(Parse::TooLarge) => StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
Kind::Parse(Parse::UriTooLong) => StatusCode::URI_TOO_LONG,
_ => return None,
};

debug!("building automatic response ({}) for parse error", status);
let msg = MessageHead {
subject: status,
..Default::default()
}
.into_response(());
Some(msg)
}

#[cfg(feature = "server")]
impl Http1Transaction for Server {
type Incoming = RequestLine;
Expand Down Expand Up @@ -460,24 +490,19 @@ impl Http1Transaction for Server {
ret.map(|()| encoder)
}

fn on_error(err: &crate::Error) -> Option<MessageHead<Self::Outgoing>> {
use crate::error::Kind;
let status = match *err.kind() {
Kind::Parse(Parse::Method)
| Kind::Parse(Parse::Header(_))
| Kind::Parse(Parse::Uri)
| Kind::Parse(Parse::Version) => StatusCode::BAD_REQUEST,
Kind::Parse(Parse::TooLarge) => StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
Kind::Parse(Parse::UriTooLong) => StatusCode::URI_TOO_LONG,
_ => return None,
};

debug!("sending automatic response ({}) for parse error", status);
let msg = MessageHead {
subject: status,
..Default::default()
};
Some(msg)
fn on_error(
err: &crate::Error,
responder: &Option<Arc<dyn Http1ErrorResponder>>,
) -> Option<MessageHead<Self::Outgoing>> {
use crate::server::conn::http1::Http1ErrorReason;
let reason = Http1ErrorReason::from_kind(err.kind());
responder
.as_ref()
.map_or_else(
|| default_error_response(err.kind()),
|er| er.respond(&reason),
)
.map(|rsp| MessageHead::from_response(rsp))
}

fn is_server() -> bool {
Expand Down Expand Up @@ -1216,7 +1241,10 @@ impl Http1Transaction for Client {
Ok(body)
}

fn on_error(_err: &crate::Error) -> Option<MessageHead<Self::Outgoing>> {
fn on_error(
_err: &crate::Error,
#[cfg(feature = "server")] _responder: &Option<Arc<dyn Http1ErrorResponder>>,
) -> Option<MessageHead<Self::Outgoing>> {
// we can't tell the server about any errors it creates
None
}
Expand Down
13 changes: 12 additions & 1 deletion src/proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub(crate) enum Dispatched {
Upgrade(crate::upgrade::Pending),
}

#[cfg(all(feature = "client", feature = "http1"))]
#[cfg(all(any(feature = "server", feature = "client"), feature = "http1"))]
impl MessageHead<http::StatusCode> {
fn into_response<B>(self, body: B) -> http::Response<B> {
let mut res = http::Response::new(body);
Expand All @@ -70,4 +70,15 @@ impl MessageHead<http::StatusCode> {
*res.extensions_mut() = self.extensions;
res
}

#[cfg(feature = "server")]
fn from_response(response: http::Response<()>) -> Self {
let (parts, _) = response.into_parts();
Self {
version: parts.version,
subject: parts.status,
headers: parts.headers,
extensions: parts.extensions,
}
}
}
60 changes: 60 additions & 0 deletions src/server/conn/http1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use crate::error::Kind;
use crate::rt::{Read, Write};
use crate::upgrade::Upgraded;
use bytes::Bytes;
Expand Down Expand Up @@ -70,6 +71,7 @@ pin_project_lite::pin_project! {
#[derive(Clone, Debug)]
pub struct Builder {
h1_parser_config: httparse::ParserConfig,
h1_error_responder: Option<Arc<dyn Http1ErrorResponder>>,
timer: Time,
h1_half_close: bool,
h1_keep_alive: bool,
Expand All @@ -83,6 +85,49 @@ pub struct Builder {
date_header: bool,
}

/// Reason an error arose duing client request parsing
#[non_exhaustive]
#[derive(Debug)]
pub enum Http1ErrorReason {
/// Method in the request was invalid or malformed
InvalidMethod,
/// URI was invalid or malformed
InvalidUri,
/// Version was invalid or malformed
InvalidVersion,
/// Header line was invalid or malformed
InvalidHeader,
/// URI exceeded the server's maximum allowed length
UriTooLong,
/// Headers exceeded the server's maximum allowed size
HeadersTooLarge,
/// Internal hyper error occured during processing
InternalError,
}

impl Http1ErrorReason {
pub(crate) fn from_kind(kind: &Kind) -> Self {
use crate::error::Kind;
use crate::error::Parse;
match kind {
Kind::Parse(Parse::Method) => Http1ErrorReason::InvalidMethod,
Kind::Parse(Parse::Header(_)) => Http1ErrorReason::InvalidHeader,
Kind::Parse(Parse::Uri) => Http1ErrorReason::InvalidUri,
Kind::Parse(Parse::Version) => Http1ErrorReason::InvalidVersion,
Kind::Parse(Parse::TooLarge) => Http1ErrorReason::HeadersTooLarge,
Kind::Parse(Parse::UriTooLong) => Http1ErrorReason::UriTooLong,
_ => Http1ErrorReason::InternalError,
}
}
}

/// Customizable error responses for overring server defaults
///
pub trait Http1ErrorResponder: Send + Sync + std::fmt::Debug {
/// Respond to some [`Http1ErrorReason`]
fn respond(&self, cause: &Http1ErrorReason) -> Option<crate::Response<()>>;
}

/// Deconstructed parts of a `Connection`.
///
/// This allows taking apart a `Connection` at a later time, in order to
Expand Down Expand Up @@ -233,6 +278,7 @@ impl Builder {
pub fn new() -> Self {
Self {
h1_parser_config: Default::default(),
h1_error_responder: None,
timer: Time::Empty,
h1_half_close: false,
h1_keep_alive: true,
Expand Down Expand Up @@ -276,6 +322,16 @@ impl Builder {
self
}

/// Set error responder for this connection.
///
/// The error responder is used to generate custom error responses when the server encounters
/// an error during request processing.
///
pub fn error_responder(&mut self, responder: Arc<dyn Http1ErrorResponder>) -> &mut Self {
self.h1_error_responder = Some(responder);
self
}

/// Set whether HTTP/1 connections will silently ignored malformed header lines.
///
/// If this is enabled and a header line does not start with a valid header
Expand Down Expand Up @@ -471,6 +527,10 @@ impl Builder {
conn.set_write_strategy_flatten();
}
}
if let Some(responder) = &self.h1_error_responder {
conn.set_error_responder(responder.clone());
}

conn.set_flush_pipeline(self.pipeline_flush);
if let Some(max) = self.max_buf_size {
conn.set_max_buf_size(max);
Expand Down
Loading