From 3ea504a3f1348e3c93ad6aeba78445f17e551860 Mon Sep 17 00:00:00 2001 From: Sam Praneis Date: Mon, 7 Jul 2025 11:41:18 -0500 Subject: [PATCH] feat(http1): customizable error messages --- src/proto/h1/conn.rs | 21 ++++++++++++- src/proto/h1/mod.rs | 10 +++++- src/proto/h1/role.rs | 66 ++++++++++++++++++++++++++++------------ src/proto/mod.rs | 13 +++++++- src/server/conn/http1.rs | 60 ++++++++++++++++++++++++++++++++++++ 5 files changed, 148 insertions(+), 22 deletions(-) diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 7c168e005e..6d7925a2f0 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -4,11 +4,15 @@ use std::future::Future; use std::io; use std::marker::{PhantomData, Unpin}; use std::pin::Pin; +#[cfg(feature = "server")] +use std::sync::Arc; use std::task::{Context, Poll}; #[cfg(feature = "server")] use std::time::{Duration, Instant}; use crate::rt::{Read, Write}; +#[cfg(feature = "server")] +use crate::server::conn::http1::Http1ErrorResponder; use bytes::{Buf, Bytes}; use futures_core::ready; use http::header::{HeaderValue, CONNECTION, TE}; @@ -61,6 +65,8 @@ where #[cfg(feature = "server")] h1_header_read_timeout: None, #[cfg(feature = "server")] + h1_error_responder: None, + #[cfg(feature = "server")] h1_header_read_timeout_fut: None, #[cfg(feature = "server")] h1_header_read_timeout_running: false, @@ -156,6 +162,11 @@ where self.state.date_header = false; } + #[cfg(feature = "server")] + pub(crate) fn set_error_responder(&mut self, val: Arc) { + self.state.h1_error_responder = Some(val); + } + pub(crate) fn into_inner(self) -> (I, Bytes) { self.io.into_inner() } @@ -810,10 +821,16 @@ where if self.has_h2_prefix() { return Err(crate::Error::new_version_h2()); } - if let Some(msg) = T::on_error(&err) { + + if let Some(msg) = T::on_error( + &err, + #[cfg(feature = "server")] + &self.state.h1_error_responder, + ) { // Drop the cached headers so as to not trigger a debug // assert in `write_head`... self.state.cached_headers.take(); + debug!("writing head"); self.write_head(msg, None); self.state.error = Some(err); return Ok(()); @@ -927,6 +944,8 @@ struct State { #[cfg(feature = "server")] h1_header_read_timeout: Option, #[cfg(feature = "server")] + h1_error_responder: Option>, + #[cfg(feature = "server")] h1_header_read_timeout_fut: Option>>, #[cfg(feature = "server")] h1_header_read_timeout_running: bool, diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs index a8f36f5fd9..3030425d26 100644 --- a/src/proto/h1/mod.rs +++ b/src/proto/h1/mod.rs @@ -5,6 +5,11 @@ use httparse::ParserConfig; use crate::body::DecodedLength; use crate::proto::{BodyLength, MessageHead}; +#[cfg(feature = "server")] +use crate::server::conn::http1::Http1ErrorResponder; +#[cfg(feature = "server")] +use std::sync::Arc; + pub(crate) use self::conn::Conn; pub(crate) use self::decode::Decoder; pub(crate) use self::dispatch::Dispatcher; @@ -35,7 +40,10 @@ pub(crate) trait Http1Transaction { fn parse(bytes: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult; fn encode(enc: Encode<'_, Self::Outgoing>, dst: &mut Vec) -> crate::Result; - fn on_error(err: &crate::Error) -> Option>; + fn on_error( + err: &crate::Error, + #[cfg(feature = "server")] responder: &Option>, + ) -> Option>; fn is_client() -> bool { !Self::is_server() diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index 1674e26bc6..d9928d5fce 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -2,6 +2,8 @@ use std::mem::MaybeUninit; #[cfg(feature = "client")] use std::fmt::{self, Write as _}; +#[cfg(feature = "server")] +use std::sync::Arc; use bytes::Bytes; use bytes::BytesMut; @@ -16,6 +18,8 @@ use smallvec::{smallvec, smallvec_inline, SmallVec}; use crate::body::DecodedLength; #[cfg(feature = "server")] use crate::common::date; +#[cfg(feature = "server")] +use crate::error::Kind; use crate::error::Parse; use crate::ext::HeaderCaseMap; #[cfg(feature = "ffi")] @@ -27,6 +31,8 @@ use crate::proto::h1::{ #[cfg(feature = "client")] use crate::proto::RequestHead; use crate::proto::{BodyLength, MessageHead, RequestLine}; +#[cfg(feature = "server")] +use crate::server::conn::http1::Http1ErrorResponder; pub(crate) const DEFAULT_MAX_HEADERS: usize = 100; const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific @@ -127,6 +133,30 @@ pub(crate) enum Client {} #[cfg(feature = "server")] pub(crate) enum Server {} +#[cfg(feature = "server")] +pub(crate) fn default_error_response(kind: &Kind) -> Option> { + use crate::error::Kind; + use crate::error::Parse; + use http::StatusCode; + let status = match kind { + Kind::Parse(Parse::Method) + | Kind::Parse(Parse::Header(_)) + | Kind::Parse(Parse::Uri) + | Kind::Parse(Parse::Version) => StatusCode::BAD_REQUEST, + Kind::Parse(Parse::TooLarge) => StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, + Kind::Parse(Parse::UriTooLong) => StatusCode::URI_TOO_LONG, + _ => return None, + }; + + debug!("building automatic response ({}) for parse error", status); + let msg = MessageHead { + subject: status, + ..Default::default() + } + .into_response(()); + Some(msg) +} + #[cfg(feature = "server")] impl Http1Transaction for Server { type Incoming = RequestLine; @@ -460,24 +490,19 @@ impl Http1Transaction for Server { ret.map(|()| encoder) } - fn on_error(err: &crate::Error) -> Option> { - use crate::error::Kind; - let status = match *err.kind() { - Kind::Parse(Parse::Method) - | Kind::Parse(Parse::Header(_)) - | Kind::Parse(Parse::Uri) - | Kind::Parse(Parse::Version) => StatusCode::BAD_REQUEST, - Kind::Parse(Parse::TooLarge) => StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, - Kind::Parse(Parse::UriTooLong) => StatusCode::URI_TOO_LONG, - _ => return None, - }; - - debug!("sending automatic response ({}) for parse error", status); - let msg = MessageHead { - subject: status, - ..Default::default() - }; - Some(msg) + fn on_error( + err: &crate::Error, + responder: &Option>, + ) -> Option> { + use crate::server::conn::http1::Http1ErrorReason; + let reason = Http1ErrorReason::from_kind(err.kind()); + responder + .as_ref() + .map_or_else( + || default_error_response(err.kind()), + |er| er.respond(&reason), + ) + .map(|rsp| MessageHead::from_response(rsp)) } fn is_server() -> bool { @@ -1216,7 +1241,10 @@ impl Http1Transaction for Client { Ok(body) } - fn on_error(_err: &crate::Error) -> Option> { + fn on_error( + _err: &crate::Error, + #[cfg(feature = "server")] _responder: &Option>, + ) -> Option> { // we can't tell the server about any errors it creates None } diff --git a/src/proto/mod.rs b/src/proto/mod.rs index fcdf2b97c0..7c9535b5ba 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -60,7 +60,7 @@ pub(crate) enum Dispatched { Upgrade(crate::upgrade::Pending), } -#[cfg(all(feature = "client", feature = "http1"))] +#[cfg(all(any(feature = "server", feature = "client"), feature = "http1"))] impl MessageHead { fn into_response(self, body: B) -> http::Response { let mut res = http::Response::new(body); @@ -70,4 +70,15 @@ impl MessageHead { *res.extensions_mut() = self.extensions; res } + + #[cfg(feature = "server")] + fn from_response(response: http::Response<()>) -> Self { + let (parts, _) = response.into_parts(); + Self { + version: parts.version, + subject: parts.status, + headers: parts.headers, + extensions: parts.extensions, + } + } } diff --git a/src/server/conn/http1.rs b/src/server/conn/http1.rs index 881c29a4be..58398e54d6 100644 --- a/src/server/conn/http1.rs +++ b/src/server/conn/http1.rs @@ -8,6 +8,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; +use crate::error::Kind; use crate::rt::{Read, Write}; use crate::upgrade::Upgraded; use bytes::Bytes; @@ -70,6 +71,7 @@ pin_project_lite::pin_project! { #[derive(Clone, Debug)] pub struct Builder { h1_parser_config: httparse::ParserConfig, + h1_error_responder: Option>, timer: Time, h1_half_close: bool, h1_keep_alive: bool, @@ -83,6 +85,49 @@ pub struct Builder { date_header: bool, } +/// Reason an error arose duing client request parsing +#[non_exhaustive] +#[derive(Debug)] +pub enum Http1ErrorReason { + /// Method in the request was invalid or malformed + InvalidMethod, + /// URI was invalid or malformed + InvalidUri, + /// Version was invalid or malformed + InvalidVersion, + /// Header line was invalid or malformed + InvalidHeader, + /// URI exceeded the server's maximum allowed length + UriTooLong, + /// Headers exceeded the server's maximum allowed size + HeadersTooLarge, + /// Internal hyper error occured during processing + InternalError, +} + +impl Http1ErrorReason { + pub(crate) fn from_kind(kind: &Kind) -> Self { + use crate::error::Kind; + use crate::error::Parse; + match kind { + Kind::Parse(Parse::Method) => Http1ErrorReason::InvalidMethod, + Kind::Parse(Parse::Header(_)) => Http1ErrorReason::InvalidHeader, + Kind::Parse(Parse::Uri) => Http1ErrorReason::InvalidUri, + Kind::Parse(Parse::Version) => Http1ErrorReason::InvalidVersion, + Kind::Parse(Parse::TooLarge) => Http1ErrorReason::HeadersTooLarge, + Kind::Parse(Parse::UriTooLong) => Http1ErrorReason::UriTooLong, + _ => Http1ErrorReason::InternalError, + } + } +} + +/// Customizable error responses for overring server defaults +/// +pub trait Http1ErrorResponder: Send + Sync + std::fmt::Debug { + /// Respond to some [`Http1ErrorReason`] + fn respond(&self, cause: &Http1ErrorReason) -> Option>; +} + /// Deconstructed parts of a `Connection`. /// /// This allows taking apart a `Connection` at a later time, in order to @@ -233,6 +278,7 @@ impl Builder { pub fn new() -> Self { Self { h1_parser_config: Default::default(), + h1_error_responder: None, timer: Time::Empty, h1_half_close: false, h1_keep_alive: true, @@ -276,6 +322,16 @@ impl Builder { self } + /// Set error responder for this connection. + /// + /// The error responder is used to generate custom error responses when the server encounters + /// an error during request processing. + /// + pub fn error_responder(&mut self, responder: Arc) -> &mut Self { + self.h1_error_responder = Some(responder); + self + } + /// Set whether HTTP/1 connections will silently ignored malformed header lines. /// /// If this is enabled and a header line does not start with a valid header @@ -471,6 +527,10 @@ impl Builder { conn.set_write_strategy_flatten(); } } + if let Some(responder) = &self.h1_error_responder { + conn.set_error_responder(responder.clone()); + } + conn.set_flush_pipeline(self.pipeline_flush); if let Some(max) = self.max_buf_size { conn.set_max_buf_size(max);