diff --git a/crates/base/src/rt_worker/worker_ctx.rs b/crates/base/src/rt_worker/worker_ctx.rs index bc58fa2fb..9a2f7d06f 100644 --- a/crates/base/src/rt_worker/worker_ctx.rs +++ b/crates/base/src/rt_worker/worker_ctx.rs @@ -1,6 +1,6 @@ use crate::deno_runtime::DenoRuntime; use crate::inspector_server::Inspector; -use crate::timeout; +use crate::timeout::{self, CancelOnWriteTimeout, ReadTimeoutStream}; use crate::utils::send_event_if_event_worker_available; use crate::utils::units::bytes_to_display; @@ -36,6 +36,7 @@ use tokio::io::{self, copy_bidirectional}; use tokio::net::TcpStream; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::{mpsc, oneshot, Mutex}; +use tokio::time::sleep; use tokio_rustls::server::TlsStream; use tokio_util::sync::CancellationToken; use uuid::Uuid; @@ -96,6 +97,7 @@ async fn handle_request( worker_kind: WorkerKind, duplex_stream_tx: mpsc::UnboundedSender, msg: WorkerRequestMsg, + maybe_request_idle_timeout: Option, ) -> Result<(), Error> { let (ours, theirs) = io::duplex(1024); let WorkerRequestMsg { @@ -135,6 +137,7 @@ async fn handle_request( tokio::spawn(relay_upgraded_request_and_response( req_upgrade, parts, + maybe_request_idle_timeout, )); return; @@ -152,7 +155,23 @@ async fn handle_request( tokio::task::yield_now().await; - let res = request_sender.send_request(req).await; + let maybe_cancel_fut = async move { + if let Some(timeout_ms) = maybe_request_idle_timeout { + sleep(Duration::from_millis(timeout_ms)).await; + } else { + pending::<()>().await; + unreachable!() + } + }; + + let res = tokio::select! { + resp = request_sender.send_request(req) => resp, + _ = maybe_cancel_fut => { + // XXX(Nyannyacha): Should we add a more detailed message? + Ok(emit_status_code(http::StatusCode::REQUEST_TIMEOUT, None, true)) + } + }; + let Ok(res) = res else { drop(res_tx.send(res)); return Ok(()); @@ -165,12 +184,29 @@ async fn handle_request( match res_upgrade_type { Some(accepted) if accepted == requested => {} _ => { - drop(res_tx.send(Ok(emit_status_code(StatusCode::BAD_GATEWAY)))); + drop(res_tx.send(Ok(emit_status_code(StatusCode::BAD_GATEWAY, None, true)))); return Ok(()); } } } + if let Some(timeout_ms) = maybe_request_idle_timeout { + let headers = res.headers(); + let is_streamed_response = !headers.contains_key(http::header::CONTENT_LENGTH); + + if is_streamed_response { + let duration = Duration::from_millis(timeout_ms); + let (parts, body) = res.into_parts(); + + drop(res_tx.send(Ok(Response::from_parts( + parts, + Body::wrap_stream(CancelOnWriteTimeout::new(body, duration)), + )))); + + return Ok(()); + } + } + drop(res_tx.send(Ok(res))); Ok(()) } @@ -178,13 +214,21 @@ async fn handle_request( async fn relay_upgraded_request_and_response( downstream: OnUpgrade, parts: http1::Parts, + maybe_idle_timeout: Option, ) { - let mut upstream = Upgraded2::new(parts.io, parts.read_buf); + let upstream = Upgraded2::new(parts.io, parts.read_buf); + let mut upstream = if let Some(timeout_ms) = maybe_idle_timeout { + ReadTimeoutStream::with_timeout(upstream, Duration::from_millis(timeout_ms)) + } else { + ReadTimeoutStream::with_bypass(upstream) + }; + let mut downstream = downstream.await.expect("failed to upgrade request"); match copy_bidirectional(&mut upstream, &mut downstream).await { Ok(_) => {} - Err(err) if err.kind() == ErrorKind::UnexpectedEof => { + Err(err) if matches!(err.kind(), ErrorKind::TimedOut | ErrorKind::BrokenPipe) => {} + Err(err) if matches!(err.kind(), ErrorKind::UnexpectedEof) => { let Ok(_) = downstream.downcast::>>() else { // TODO(Nyannyacha): It would be better if we send // `close_notify` before shutdown an upstream if downstream is a @@ -196,8 +240,8 @@ async fn relay_upgraded_request_and_response( }; } - _ => { - unreachable!("coping between upgraded connections failed"); + value => { + unreachable!("coping between upgraded connections failed: {:?}", value); } } @@ -512,6 +556,7 @@ impl CreateWorkerArgs { pub async fn create_worker>( init_opts: Opt, inspector: Option, + maybe_request_idle_timeout: Option, ) -> Result<(MetricSource, mpsc::UnboundedSender), Error> { let (duplex_stream_tx, duplex_stream_rx) = mpsc::unbounded_channel::(); let (worker_boot_result_tx, worker_boot_result_rx) = @@ -553,8 +598,13 @@ pub async fn create_worker>( tokio::task::spawn({ let stream_tx_inner = stream_tx.clone(); async move { - if let Err(err) = - handle_request(worker_kind, stream_tx_inner, msg).await + if let Err(err) = handle_request( + worker_kind, + stream_tx_inner, + msg, + maybe_request_idle_timeout, + ) + .await { error!("worker failed to handle request: {:?}", err); } @@ -666,6 +716,7 @@ pub async fn create_main_worker( termination_token, ), inspector, + None, ) .await .map_err(|err| anyhow!("main worker boot error: {}", err))?; @@ -713,6 +764,7 @@ pub async fn create_events_worker( termination_token, ), None, + None, ) .await .map_err(|err| anyhow!("events worker boot error: {}", err))?; @@ -726,6 +778,7 @@ pub async fn create_user_worker_pool( termination_token: Option, static_patterns: Vec, inspector: Option, + request_idle_timeout: Option, ) -> Result<(SharedMetricSource, mpsc::UnboundedSender), Error> { let metric_src = SharedMetricSource::default(); let (user_worker_msgs_tx, mut user_worker_msgs_rx) = @@ -744,6 +797,7 @@ pub async fn create_user_worker_pool( worker_event_sender, user_worker_msgs_tx_clone, inspector, + request_idle_timeout, ); // Note: Keep this loop non-blocking. Spawn a task to run blocking calls. diff --git a/crates/base/src/rt_worker/worker_pool.rs b/crates/base/src/rt_worker/worker_pool.rs index ecb317651..579efcd9c 100644 --- a/crates/base/src/rt_worker/worker_pool.rs +++ b/crates/base/src/rt_worker/worker_pool.rs @@ -1,5 +1,6 @@ use crate::inspector_server::Inspector; use crate::rt_worker::worker_ctx::{create_worker, send_user_worker_request}; +use crate::server::ServerFlags; use anyhow::{anyhow, bail, Context, Error}; use enum_as_inner::EnumAsInner; use event_worker::events::WorkerEventWithMetadata; @@ -88,15 +89,15 @@ impl WorkerPoolPolicy { pub fn new( supervisor: impl Into>, max_parallelism: impl Into>, - request_wait_timeout_ms: impl Into>, + server_flags: ServerFlags, ) -> Self { let default = Self::default(); Self { supervisor_policy: supervisor.into().unwrap_or(default.supervisor_policy), max_parallelism: max_parallelism.into().unwrap_or(default.max_parallelism), - request_wait_timeout_ms: request_wait_timeout_ms - .into() + request_wait_timeout_ms: server_flags + .request_wait_timeout_ms .unwrap_or(default.request_wait_timeout_ms), } } @@ -211,6 +212,7 @@ pub struct WorkerPool { pub active_workers: HashMap, pub worker_pool_msgs_tx: mpsc::UnboundedSender, pub maybe_inspector: Option, + pub maybe_request_idle_timeout: Option, // TODO: refactor this out of worker pool pub worker_event_sender: Option>, @@ -223,6 +225,7 @@ impl WorkerPool { worker_event_sender: Option>, worker_pool_msgs_tx: mpsc::UnboundedSender, inspector: Option, + request_idle_timeout: Option, ) -> Self { Self { policy, @@ -231,6 +234,7 @@ impl WorkerPool { user_workers: HashMap::new(), active_workers: HashMap::new(), maybe_inspector: inspector, + maybe_request_idle_timeout: request_idle_timeout, worker_pool_msgs_tx, } } @@ -249,6 +253,7 @@ impl WorkerPool { let is_oneshot_policy = self.policy.supervisor_policy.is_oneshot(); let inspector = self.maybe_inspector.clone(); + let request_idle_timeout = self.maybe_request_idle_timeout; let force_create = worker_options .conf @@ -418,6 +423,7 @@ impl WorkerPool { match create_worker( (worker_options, supervisor_policy, termination_token.clone()), inspector, + request_idle_timeout, ) .await { diff --git a/crates/base/src/server.rs b/crates/base/src/server.rs index 885e6ff57..ab40a429a 100644 --- a/crates/base/src/server.rs +++ b/crates/base/src/server.rs @@ -244,6 +244,8 @@ pub struct ServerFlags { pub tcp_nodelay: bool, pub graceful_exit_deadline_sec: u64, pub graceful_exit_keepalive_deadline_ms: Option, + pub request_wait_timeout_ms: Option, + pub request_idle_timeout_ms: Option, pub request_read_timeout_ms: Option, } @@ -379,6 +381,7 @@ impl Server { Some(termination_tokens.pool.clone()), static_patterns, inspector.clone(), + flags.request_idle_timeout_ms, ) .await?; diff --git a/crates/base/src/timeout.rs b/crates/base/src/timeout.rs index 4fac82bff..96a075735 100644 --- a/crates/base/src/timeout.rs +++ b/crates/base/src/timeout.rs @@ -7,6 +7,7 @@ use std::{ time::Duration, }; +use enum_as_inner::EnumAsInner; use futures_util::Future; use pin_project::pin_project; use tokio::{ @@ -278,3 +279,164 @@ where self.inner.size_hint() } } + +pub(crate) struct CancelOnWriteTimeout { + inner: S, + duration: Duration, + sleep: Pin>, +} + +impl futures_util::Stream for CancelOnWriteTimeout { + type Item = S::Item; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + match Pin::new(&mut self.inner).poll_next(cx) { + Poll::Ready(v) => { + let deadline = Instant::now() + self.duration; + + self.sleep.as_mut().reset(deadline); + + Poll::Ready(v) + } + + Poll::Pending => { + if let Poll::Ready(()) = self.sleep.as_mut().poll(cx) { + return Poll::Ready(None); + } + + Poll::Pending + } + } + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +impl CancelOnWriteTimeout { + pub(crate) fn new(inner: S, duration: Duration) -> Self { + Self { + inner, + duration, + sleep: Box::pin(sleep(duration)), + } + } +} + +#[derive(EnumAsInner)] +pub(crate) enum ReadTimeoutOp { + UseTimeout { + duration: Duration, + sleep: Pin>, + }, + + Bypass, +} + +pub(crate) struct ReadTimeoutStream { + inner: S, + op: ReadTimeoutOp, +} + +impl AsyncRead for ReadTimeoutStream +where + S: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + match Pin::new(&mut self.inner).poll_read(cx, buf) { + Poll::Ready(v) => { + if self.op.is_bypass() { + return Poll::Ready(v); + } + + let (duration, sleep) = self.op.as_use_timeout_mut().unwrap(); + + let deadline = Instant::now() + *duration; + + sleep.as_mut().reset(deadline); + + Poll::Ready(v) + } + + Poll::Pending => { + if let Some((_, sleep)) = self.op.as_use_timeout_mut() { + if let Poll::Ready(()) = sleep.as_mut().poll(cx) { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "socket timed out", + ))); + } + } + + Poll::Pending + } + } + } +} + +impl AsyncWrite for ReadTimeoutStream +where + S: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +} + +impl ReadTimeoutStream { + pub(crate) fn new(inner: S, kind: ReadTimeoutOp) -> Self { + Self { inner, op: kind } + } + + pub(crate) fn with_timeout(inner: S, duration: Duration) -> Self { + Self::new( + inner, + ReadTimeoutOp::UseTimeout { + duration, + sleep: Box::pin(sleep(duration)), + }, + ) + } + + pub(crate) fn with_bypass(inner: S) -> Self { + Self::new(inner, ReadTimeoutOp::Bypass) + } +} diff --git a/crates/base/src/utils/integration_test_helper.rs b/crates/base/src/utils/integration_test_helper.rs index ca113966e..dc11d1c21 100644 --- a/crates/base/src/utils/integration_test_helper.rs +++ b/crates/base/src/utils/integration_test_helper.rs @@ -10,9 +10,12 @@ use std::{ }; use anyhow::{bail, Context, Error}; -use base::rt_worker::{ - worker_ctx::{create_user_worker_pool, create_worker, CreateWorkerArgs, TerminationToken}, - worker_pool::{SupervisorPolicy, WorkerPoolPolicy}, +use base::{ + rt_worker::{ + worker_ctx::{create_user_worker_pool, create_worker, CreateWorkerArgs, TerminationToken}, + worker_pool::{SupervisorPolicy, WorkerPoolPolicy}, + }, + server::ServerFlags, }; use futures_util::{future::BoxFuture, Future, FutureExt}; use http::{Request, Response}; @@ -129,6 +132,7 @@ pub struct TestBedBuilder { main_service_path: PathBuf, worker_pool_policy: Option, main_worker_init_opts: Option, + request_idle_timeout: Option, } impl TestBedBuilder { @@ -140,6 +144,7 @@ impl TestBedBuilder { main_service_path: main_service_path.into(), worker_pool_policy: None, main_worker_init_opts: None, + request_idle_timeout: None, } } @@ -152,7 +157,10 @@ impl TestBedBuilder { self.worker_pool_policy = Some(WorkerPoolPolicy::new( SupervisorPolicy::oneshot(), 1, - Some(request_wait_timeout_ms), + ServerFlags { + request_wait_timeout_ms: Some(request_wait_timeout_ms), + ..Default::default() + }, )); self @@ -162,7 +170,10 @@ impl TestBedBuilder { self.worker_pool_policy = Some(WorkerPoolPolicy::new( SupervisorPolicy::PerWorker, 1, - Some(request_wait_timeout_ms), + ServerFlags { + request_wait_timeout_ms: Some(request_wait_timeout_ms), + ..Default::default() + }, )); self @@ -172,7 +183,10 @@ impl TestBedBuilder { self.worker_pool_policy = Some(WorkerPoolPolicy::new( SupervisorPolicy::PerRequest { oneshot: false }, 1, - Some(request_wait_timeout_ms), + ServerFlags { + request_wait_timeout_ms: Some(request_wait_timeout_ms), + ..Default::default() + }, )); self @@ -186,6 +200,11 @@ impl TestBedBuilder { self } + pub fn with_request_idle_timeout(mut self, request_idle_timeout: u64) -> Self { + self.request_idle_timeout = Some(request_idle_timeout); + self + } + pub async fn build(self) -> TestBed { let ((_, worker_pool_tx), pool_termination_token) = { let token = TerminationToken::new(); @@ -197,6 +216,7 @@ impl TestBedBuilder { Some(token.clone()), vec![], None, + self.request_idle_timeout, ) .await .unwrap(), @@ -227,6 +247,7 @@ impl TestBedBuilder { let (_, main_worker_msg_tx) = create_worker( (main_worker_init_opts, main_termination_token.clone()), None, + None, ) .await .unwrap(); @@ -308,6 +329,7 @@ pub async fn create_test_user_worker>( opts.with_policy(policy) .with_termination_token(termination_token.clone()), None, + None, ) .await?; @@ -325,7 +347,14 @@ pub async fn create_test_user_worker>( } pub fn test_user_worker_pool_policy() -> WorkerPoolPolicy { - WorkerPoolPolicy::new(SupervisorPolicy::oneshot(), 1, 4 * 1000 * 3600) + WorkerPoolPolicy::new( + SupervisorPolicy::oneshot(), + 1, + ServerFlags { + request_wait_timeout_ms: Some(4 * 1000 * 3600), + ..Default::default() + }, + ) } pub fn test_user_runtime_opts() -> UserWorkerRuntimeOpts { diff --git a/crates/base/test_cases/chunked-char-first-6000ms/index.ts b/crates/base/test_cases/chunked-char-first-6000ms/index.ts new file mode 100644 index 000000000..9125612e7 --- /dev/null +++ b/crates/base/test_cases/chunked-char-first-6000ms/index.ts @@ -0,0 +1,32 @@ +async function sleep(ms: number) { + return new Promise(res => { + setTimeout(() => { + res(void 0); + }, ms) + }); +} + +Deno.serve(() => { + const encoder = new TextEncoder(); + const stream = new ReadableStream({ + async start(controller) { + const input: [string, number][] = [ + ["m", 6000], + ["e", 100], + ]; + + for (const [char, wait] of input) { + await sleep(wait); + controller.enqueue(encoder.encode(char)); + } + + controller.close(); + }, + }); + + return new Response(stream, { + headers: { + "Content-Type": "text/plain" + } + }); +}); \ No newline at end of file diff --git a/crates/base/test_cases/chunked-char-variable-delay-max-6000ms/index.ts b/crates/base/test_cases/chunked-char-variable-delay-max-6000ms/index.ts new file mode 100644 index 000000000..cd40e7157 --- /dev/null +++ b/crates/base/test_cases/chunked-char-variable-delay-max-6000ms/index.ts @@ -0,0 +1,38 @@ +async function sleep(ms: number) { + return new Promise(res => { + setTimeout(() => { + res(void 0); + }, ms) + }); +} + +Deno.serve(() => { + const encoder = new TextEncoder(); + const stream = new ReadableStream({ + async start(controller) { + const input: [string, number][] = [ + ["m", 500], + ["e", 1000], + ["o", 1000], + ["w", 6000], + ["m", 100], + ["e", 100], + ["o", 100], + ["w", 600], + ]; + + for (const [char, wait] of input) { + await sleep(wait); + controller.enqueue(encoder.encode(char)); + } + + controller.close(); + }, + }); + + return new Response(stream, { + headers: { + "Content-Type": "text/plain" + } + }); +}); \ No newline at end of file diff --git a/crates/base/test_cases/sleep-5000ms/index.ts b/crates/base/test_cases/sleep-5000ms/index.ts new file mode 100644 index 000000000..5569f3e6f --- /dev/null +++ b/crates/base/test_cases/sleep-5000ms/index.ts @@ -0,0 +1,12 @@ +async function sleep(ms: number) { + return new Promise(res => { + setTimeout(() => { + res(void 0); + }, ms) + }); +} + +Deno.serve(async () => { + await sleep(5000); + return new Response("meow"); +}); diff --git a/crates/base/test_cases/websocket-upgrade-no-send-node/index.ts b/crates/base/test_cases/websocket-upgrade-no-send-node/index.ts new file mode 100644 index 000000000..4db2937ed --- /dev/null +++ b/crates/base/test_cases/websocket-upgrade-no-send-node/index.ts @@ -0,0 +1,19 @@ +import { createServer } from "node:http"; +import { WebSocketServer } from "npm:ws"; + +const server = createServer(); +const wss = new WebSocketServer({ noServer: true }); + +wss.on("connection", ws => { + ws.on("message", data => { + ws.send(data.toString()); + }); +}); + +server.on("upgrade", (req, socket, head) => { + wss.handleUpgrade(req, socket, head, ws => { + wss.emit("connection", ws, req); + }); +}); + +server.listen(8080); \ No newline at end of file diff --git a/crates/base/test_cases/websocket-upgrade-no-send/index.ts b/crates/base/test_cases/websocket-upgrade-no-send/index.ts new file mode 100644 index 000000000..fc268500e --- /dev/null +++ b/crates/base/test_cases/websocket-upgrade-no-send/index.ts @@ -0,0 +1,9 @@ +Deno.serve(async (req: Request) => { + const { socket, response } = Deno.upgradeWebSocket(req); + + socket.onmessage = ev => { + socket.send(ev.data); + }; + + return response; +}); diff --git a/crates/base/tests/integration_tests.rs b/crates/base/tests/integration_tests.rs index e24f82fbb..0127edc2c 100644 --- a/crates/base/tests/integration_tests.rs +++ b/crates/base/tests/integration_tests.rs @@ -14,7 +14,7 @@ use std::{ use anyhow::Context; use async_tungstenite::WebSocketStream; use base::{ - integration_test, integration_test_listen_fut, + integration_test, integration_test_listen_fut, integration_test_with_server_flag, rt_worker::worker_ctx::{create_user_worker_pool, create_worker, TerminationToken}, server::{ServerEvent, ServerFlags, ServerHealth, Tls}, DecoratorType, @@ -190,6 +190,7 @@ async fn test_not_trigger_pku_sigsegv_due_to_jit_compilation_non_cli() { Some(pool_termination_token.clone()), vec![], None, + None, ) .await .unwrap(); @@ -213,7 +214,7 @@ async fn test_not_trigger_pku_sigsegv_due_to_jit_compilation_non_cli() { static_patterns: vec![], }; - let (_, worker_req_tx) = create_worker((opts, main_termination_token.clone()), None) + let (_, worker_req_tx) = create_worker((opts, main_termination_token.clone()), None, None) .await .unwrap(); @@ -346,6 +347,7 @@ async fn test_main_worker_boot_error() { Some(pool_termination_token.clone()), vec![], None, + None, ) .await .unwrap(); @@ -369,7 +371,7 @@ async fn test_main_worker_boot_error() { static_patterns: vec![], }; - let result = create_worker((opts, main_termination_token.clone()), None).await; + let result = create_worker((opts, main_termination_token.clone()), None, None).await; assert!(result.is_err()); assert!(result @@ -1636,6 +1638,291 @@ async fn test_slowloris_slow_header_timedout_secure_inverted() { test_slowloris_slow_header_timedout(new_localhost_tls(true), true).await; } +async fn test_request_idle_timeout_no_streamed_response(maybe_tls: Option) { + let client = maybe_tls.client(); + let req = client + .request( + Method::GET, + format!( + "{}://localhost:{}/sleep-5000ms", + maybe_tls.schema(), + maybe_tls.port(), + ), + ) + .build() + .unwrap(); + + let original = RequestBuilder::from_parts(client, req); + let request_builder = Some(original); + + integration_test_with_server_flag!( + ServerFlags { + request_idle_timeout_ms: Some(1000), + ..Default::default() + }, + "./test_cases/main", + NON_SECURE_PORT, + "", + None, + None, + request_builder, + maybe_tls, + (|resp| async { + assert_eq!(resp.unwrap().status().as_u16(), StatusCode::REQUEST_TIMEOUT); + }), + TerminationToken::new() + ); +} + +#[tokio::test] +#[serial] +async fn test_request_idle_timeout_no_streamed_response_non_secure() { + test_request_idle_timeout_no_streamed_response(new_localhost_tls(false)).await; +} + +#[tokio::test] +#[serial] +async fn test_request_idle_timeout_no_streamed_response_secure() { + test_request_idle_timeout_no_streamed_response(new_localhost_tls(true)).await; +} + +async fn test_request_idle_timeout_streamed_response(maybe_tls: Option) { + let client = maybe_tls.client(); + let req = client + .request( + Method::GET, + format!( + "{}://localhost:{}/chunked-char-variable-delay-max-6000ms", + maybe_tls.schema(), + maybe_tls.port(), + ), + ) + .build() + .unwrap(); + + let original = RequestBuilder::from_parts(client, req); + let request_builder = Some(original); + + integration_test_with_server_flag!( + ServerFlags { + request_idle_timeout_ms: Some(2000), + ..Default::default() + }, + "./test_cases/main", + NON_SECURE_PORT, + "", + None, + None, + request_builder, + maybe_tls, + (|resp| async { + let resp = resp.unwrap(); + + assert_eq!(resp.status().as_u16(), StatusCode::OK); + assert!(resp.content_length().is_none()); + + let mut buf = Vec::::new(); + let mut bytes_stream = resp.bytes_stream(); + + loop { + match bytes_stream.next().await { + Some(Ok(v)) => { + buf.extend(v); + } + + Some(Err(_)) => { + break; + } + + None => { + break; + } + } + } + + assert_eq!(&buf, b"meo"); + }), + TerminationToken::new() + ); +} + +#[tokio::test] +#[serial] +async fn test_request_idle_timeout_streamed_response_non_secure() { + test_request_idle_timeout_streamed_response(new_localhost_tls(false)).await; +} + +#[tokio::test] +#[serial] +async fn test_request_idle_timeout_streamed_response_secure() { + test_request_idle_timeout_streamed_response(new_localhost_tls(true)).await; +} + +async fn test_request_idle_timeout_streamed_response_first_chunk_timeout(maybe_tls: Option) { + let client = maybe_tls.client(); + let req = client + .request( + Method::GET, + format!( + "{}://localhost:{}/chunked-char-first-6000ms", + maybe_tls.schema(), + maybe_tls.port(), + ), + ) + .build() + .unwrap(); + + let original = RequestBuilder::from_parts(client, req); + let request_builder = Some(original); + + integration_test_with_server_flag!( + ServerFlags { + request_idle_timeout_ms: Some(1000), + ..Default::default() + }, + "./test_cases/main", + NON_SECURE_PORT, + "", + None, + None, + request_builder, + maybe_tls, + (|resp| async { + let resp = resp.unwrap(); + + assert_eq!(resp.status().as_u16(), StatusCode::OK); + assert!(resp.content_length().is_none()); + + let mut buf = Vec::::new(); + let mut bytes_stream = resp.bytes_stream(); + + loop { + match bytes_stream.next().await { + Some(Ok(v)) => { + buf.extend(v); + } + + Some(Err(_)) => { + break; + } + + None => { + break; + } + } + } + + assert_eq!(&buf, b""); + }), + TerminationToken::new() + ); +} + +#[tokio::test] +#[serial] +async fn test_request_idle_timeout_streamed_response_first_chunk_timeout_non_secure() { + test_request_idle_timeout_streamed_response_first_chunk_timeout(new_localhost_tls(false)).await; +} + +#[tokio::test] +#[serial] +async fn test_request_idle_timeout_streamed_response_first_chunk_timeout_secure() { + test_request_idle_timeout_streamed_response_first_chunk_timeout(new_localhost_tls(true)).await; +} + +async fn test_request_idle_timeout_websocket_deno(maybe_tls: Option, use_node_ws: bool) { + let nonce = tungstenite::handshake::client::generate_key(); + let client = maybe_tls.client(); + let req = client + .request( + Method::GET, + format!( + "{}://localhost:{}/websocket-upgrade-no-send{}", + maybe_tls.schema(), + maybe_tls.port(), + if use_node_ws { "-node" } else { "" } + ), + ) + .header(header::CONNECTION, "upgrade") + .header(header::UPGRADE, "websocket") + .header(header::SEC_WEBSOCKET_KEY, &nonce) + .header(header::SEC_WEBSOCKET_VERSION, "13") + .build() + .unwrap(); + + let original = RequestBuilder::from_parts(client, req); + let request_builder = Some(original); + + integration_test_with_server_flag!( + ServerFlags { + request_idle_timeout_ms: Some(1000), + ..Default::default() + }, + "./test_cases/main", + NON_SECURE_PORT, + "", + None, + None, + request_builder, + maybe_tls, + (|resp| async { + let res = resp.unwrap(); + let accepted = get_upgrade_type(res.headers()); + + assert!(res.status().as_u16() == 101); + assert!(accepted.is_some()); + assert_eq!(accepted.as_ref().unwrap(), "websocket"); + + let upgraded = res.upgrade().await.unwrap(); + let mut ws = WebSocketStream::from_raw_socket( + upgraded.compat(), + tungstenite::protocol::Role::Client, + None, + ) + .await; + + sleep(Duration::from_secs(3)).await; + + ws.send(Message::Text("meow!!".into())).await.unwrap(); + + let err = ws.next().await.unwrap().unwrap_err(); + + use tungstenite::error::ProtocolError; + use tungstenite::Error; + + assert!(matches!( + err, + Error::Protocol(ProtocolError::ResetWithoutClosingHandshake) + )); + }), + TerminationToken::new() + ); +} + +#[tokio::test] +#[serial] +async fn test_request_idle_timeout_websocket_deno_non_secure() { + test_request_idle_timeout_websocket_deno(new_localhost_tls(false), false).await; +} + +#[tokio::test] +#[serial] +async fn test_request_idle_timeout_websocket_deno_secure() { + test_request_idle_timeout_websocket_deno(new_localhost_tls(true), false).await; +} + +#[tokio::test] +#[serial] +async fn test_request_idle_timeout_websocket_node_non_secure() { + test_request_idle_timeout_websocket_deno(new_localhost_tls(false), true).await; +} + +#[tokio::test] +#[serial] +async fn test_request_idle_timeout_websocket_node_secure() { + test_request_idle_timeout_websocket_deno(new_localhost_tls(true), true).await; +} + trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Unpin {} impl AsyncReadWrite for T where T: AsyncRead + AsyncWrite + Send + Unpin {} diff --git a/crates/cli/src/flags.rs b/crates/cli/src/flags.rs index 298bad2b2..9e193dd75 100644 --- a/crates/cli/src/flags.rs +++ b/crates/cli/src/flags.rs @@ -140,6 +140,11 @@ fn get_start_command() -> Command { .help("Maximum time in milliseconds that can wait to establish a connection with a worker") .value_parser(value_parser!(u64)), ) + .arg( + arg!(--"request-idle-timeout" ) + .help("Maximum time in milliseconds that can be waited from when a worker takes over the request") + .value_parser(value_parser!(u64)), + ) .arg( arg!(--"request-read-timeout" ) .help("Maximum time in milliseconds that can be waited from when the connection is accepted until the request body is fully read") diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index e44687577..c5dc958d2 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -122,6 +122,8 @@ fn main() -> Result<(), anyhow::Error> { sub_matches.get_one::("max-parallelism").cloned(); let maybe_request_wait_timeout = sub_matches.get_one::("request-wait-timeout").cloned(); + let maybe_request_idle_timeout = + sub_matches.get_one::("request-idle-timeout").cloned(); let maybe_request_read_timeout = sub_matches.get_one::("request-read-timeout").cloned(); let static_patterns = @@ -157,6 +159,16 @@ fn main() -> Result<(), anyhow::Error> { }; let tcp_nodelay = sub_matches.get_one::("tcp-nodelay").copied().unwrap(); + let flags = ServerFlags { + no_module_cache, + allow_main_inspector, + tcp_nodelay, + graceful_exit_deadline_sec, + graceful_exit_keepalive_deadline_ms, + request_wait_timeout_ms: maybe_request_wait_timeout, + request_idle_timeout_ms: maybe_request_idle_timeout, + request_read_timeout_ms: maybe_request_read_timeout, + }; start_server( ip.as_str(), @@ -187,17 +199,10 @@ fn main() -> Result<(), anyhow::Error> { } else { maybe_max_parallelism }, - maybe_request_wait_timeout, + flags, )), import_map_path, - ServerFlags { - no_module_cache, - allow_main_inspector, - tcp_nodelay, - graceful_exit_deadline_sec, - graceful_exit_keepalive_deadline_ms, - request_read_timeout_ms: maybe_request_read_timeout, - }, + flags, None, WorkerEntrypoints { main: maybe_main_entrypoint, diff --git a/crates/http_utils/src/utils.rs b/crates/http_utils/src/utils.rs index e15bf3a36..fb22b2596 100644 --- a/crates/http_utils/src/utils.rs +++ b/crates/http_utils/src/utils.rs @@ -1,4 +1,4 @@ -use http::{header, response, HeaderMap, Response, StatusCode}; +use http::{header, response, HeaderMap, HeaderValue, Response, StatusCode}; use hyper::Body; pub fn get_upgrade_type(headers: &HeaderMap) -> Option { @@ -21,9 +21,25 @@ pub fn get_upgrade_type(headers: &HeaderMap) -> Option { None } -pub fn emit_status_code(status: StatusCode) -> Response { - response::Builder::new() - .status(status) - .body(Body::empty()) - .unwrap() +pub fn emit_status_code( + status: StatusCode, + body: Option, + connection_close: bool, +) -> Response { + let builder = response::Builder::new().status(status); + + let builder = if connection_close { + builder.header(header::CONNECTION, HeaderValue::from_static("close")) + } else { + builder + }; + + if let Some(body) = body { + builder.body(body) + } else { + builder + .header(http::header::CONTENT_LENGTH, 0) + .body(Body::empty()) + } + .unwrap() }