From 393a3054249385a547fd61a7e630215be7da9c06 Mon Sep 17 00:00:00 2001 From: James Date: Wed, 29 Jan 2025 15:24:23 -0500 Subject: [PATCH 01/12] feat: task management --- Cargo.toml | 2 +- src/axum.rs | 42 ++++++++++++- src/lib.rs | 5 ++ src/pubsub/shared.rs | 95 ++++++++++++++++------------- src/pubsub/trait.rs | 34 ++++++++--- src/router.rs | 15 ++++- src/routes/ctx.rs | 55 ++++++++++++++--- src/tasks.rs | 141 +++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 323 insertions(+), 66 deletions(-) create mode 100644 src/tasks.rs diff --git a/Cargo.toml b/Cargo.toml index 8ce47be..f9079dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ tokio-stream = { version = "0.1.17", optional = true } # ipc interprocess = { version = "2.2.2", features = ["async", "tokio"], optional = true } -tokio-util = { version = "0.7.13", optional = true, features = ["io"] } +tokio-util = { version = "0.7.13", optional = true, features = ["io", "rt"] } # ws tokio-tungstenite = { version = "0.26.1", features = ["rustls-tls-webpki-roots"], optional = true } diff --git a/src/axum.rs b/src/axum.rs index 115613e..7cbf305 100644 --- a/src/axum.rs +++ b/src/axum.rs @@ -1,9 +1,44 @@ -use crate::types::{InboundData, Response}; +use crate::{ + types::{InboundData, Response}, + HandlerCtx, TaskSet, +}; use axum::{extract::FromRequest, response::IntoResponse}; use bytes::Bytes; use std::{future::Future, pin::Pin}; +use tokio::runtime::Handle; -impl axum::handler::Handler for crate::Router +/// A wrapper around an [`ajj::Router`] that implements the [`axum::handler::Handler`] trait. +#[derive(Debug, Clone)] +pub struct IntoAxum { + pub(crate) router: crate::Router, + pub(crate) task_set: TaskSet, +} + +impl From> for IntoAxum { + fn from(router: crate::Router) -> Self { + Self { + router, + task_set: Default::default(), + } + } +} + +impl IntoAxum { + /// Create a new `IntoAxum` from a router and task set. + pub(crate) fn new(router: crate::Router, handle: Handle) -> Self { + Self { + router, + task_set: handle.into(), + } + } + + /// Get a new context, built from the task set. + fn ctx(&self) -> HandlerCtx { + self.task_set.clone().into() + } +} + +impl axum::handler::Handler for IntoAxum where S: Clone + Send + Sync + 'static, { @@ -21,7 +56,8 @@ where let req = InboundData::try_from(bytes).unwrap_or_default(); if let Some(response) = self - .call_batch_with_state(Default::default(), req, state) + .router + .call_batch_with_state(self.ctx(), req, state) .await { Box::::from(response).into_response() diff --git a/src/lib.rs b/src/lib.rs index 73dfb67..c445917 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -130,6 +130,8 @@ pub(crate) mod macros; #[cfg(feature = "axum")] mod axum; +#[cfg(feature = "axum")] +pub use axum::IntoAxum; mod error; pub use error::RegistrationError; @@ -152,6 +154,9 @@ pub(crate) use routes::{BoxedIntoRoute, ErasedIntoRoute, Method, Route}; mod router; pub use router::Router; +mod tasks; +pub use tasks::TaskSet; + mod types; pub use types::{ErrorPayload, ResponsePayload}; diff --git a/src/pubsub/shared.rs b/src/pubsub/shared.rs index 7315d57..8cd531c 100644 --- a/src/pubsub/shared.rs +++ b/src/pubsub/shared.rs @@ -3,13 +3,10 @@ use core::fmt; use crate::{ pubsub::{In, JsonSink, Listener, Out}, types::InboundData, + HandlerCtx, TaskSet, }; use serde_json::value::RawValue; -use tokio::{ - select, - sync::{mpsc, oneshot, watch}, - task::JoinHandle, -}; +use tokio::{select, sync::mpsc, task::JoinHandle}; use tokio_stream::StreamExt; use tracing::{debug, debug_span, error, instrument, trace, Instrument}; @@ -22,12 +19,26 @@ pub type ConnectionId = u64; /// Holds the shutdown signal for some server. #[derive(Debug)] pub struct ServerShutdown { - pub(crate) _shutdown: watch::Sender<()>, + pub(crate) task_set: TaskSet, } -impl From> for ServerShutdown { - fn from(sender: watch::Sender<()>) -> Self { - Self { _shutdown: sender } +impl From for ServerShutdown { + fn from(task_set: TaskSet) -> Self { + Self::new(task_set) + } +} + +impl ServerShutdown { + /// Create a new `ServerShutdown` with the given shutdown signal and task + /// set. + pub fn new(task_set: TaskSet) -> Self { + Self { task_set } + } +} + +impl Drop for ServerShutdown { + fn drop(&mut self) { + self.task_set.cancel(); } } @@ -67,16 +78,17 @@ where } /// Spawn the future produced by [`Self::task_future`]. - pub(crate) fn spawn(self) -> JoinHandle<()> { + pub(crate) fn spawn(self) -> JoinHandle> { + let tasks = self.manager.root_tasks.clone(); let future = self.task_future(); - tokio::spawn(future) + tasks.spawn(future) } } /// The `ConnectionManager` provides connections with IDs, and handles spawning /// the [`RouteTask`] for each connection. pub(crate) struct ConnectionManager { - pub(crate) shutdown: watch::Receiver<()>, + pub(crate) root_tasks: TaskSet, pub(crate) next_id: ConnectionId, @@ -107,19 +119,18 @@ impl ConnectionManager { ) -> (RouteTask, WriteTask) { let (tx, rx) = mpsc::channel(self.notification_buffer_per_task); - let (gone_tx, gone_rx) = oneshot::channel(); + let tasks = self.root_tasks.child(); let rt = RouteTask { router: self.router(), conn_id, write_task: tx, requests, - gone: gone_tx, + tasks: tasks.clone(), }; let wt = WriteTask { - shutdown: self.shutdown.clone(), - gone: gone_rx, + tasks, conn_id, json: rx, connection, @@ -156,8 +167,8 @@ struct RouteTask { pub(crate) write_task: mpsc::Sender>, /// Stream of requests. pub(crate) requests: In, - /// Sender to the [`WriteTask`], to notify it that this task is done. - pub(crate) gone: oneshot::Sender<()>, + /// The task set for this connection + pub(crate) tasks: TaskSet, } impl fmt::Debug for RouteTask { @@ -184,7 +195,7 @@ where router, mut requests, write_task, - gone, + tasks, .. } = self; @@ -208,7 +219,11 @@ where let span = debug_span!("pubsub request handling", reqs = reqs.len()); - let ctx = write_task.clone().into(); + let ctx = + HandlerCtx::new( + Some(write_task.clone()), + tasks.clone(), + ); let fut = router.handle_request_batch(ctx, reqs); let write_task = write_task.clone(); @@ -239,27 +254,23 @@ where } } } - // No funny business. Drop the gone signal. - drop(gone); + tasks.cancel(); } /// Spawn the future produced by [`Self::task_future`]. - pub(crate) fn spawn(self) -> tokio::task::JoinHandle<()> { + pub(crate) fn spawn(self) -> tokio::task::JoinHandle> { + let tasks = self.tasks.clone(); + let future = self.task_future(); - tokio::spawn(future) + + tasks.spawn(future) } } /// The Write Task is responsible for writing JSON to the outbound connection. struct WriteTask { - /// Shutdown signal. - /// - /// Shutdowns bubble back up to [`RouteTask`] when the write task is - /// dropped, via the closed `json` channel. - pub(crate) shutdown: watch::Receiver<()>, - - /// Signal that the connection has gone away. - pub(crate) gone: oneshot::Receiver<()>, + /// Task set + pub(crate) tasks: TaskSet, /// ID of the connection. pub(crate) conn_id: ConnectionId, @@ -284,22 +295,18 @@ impl WriteTask { #[instrument(skip(self), fields(conn_id = self.conn_id))] pub(crate) async fn task_future(self) { let WriteTask { - mut shutdown, - mut gone, + tasks, mut json, mut connection, .. } = self; - shutdown.mark_unchanged(); + loop { select! { biased; - _ = &mut gone => { - debug!("Connection has gone away"); - break; - } - _ = shutdown.changed() => { - debug!("shutdown signal received"); + + _ = tasks.cancelled() => { + debug!("Shutdown signal received"); break; } json = json.recv() => { @@ -317,7 +324,9 @@ impl WriteTask { } /// Spawn the future produced by [`Self::task_future`]. - pub(crate) fn spawn(self) -> JoinHandle<()> { - tokio::spawn(self.task_future()) + pub(crate) fn spawn(self) -> tokio::task::JoinHandle> { + let tasks = self.tasks.clone(); + let future = self.task_future(); + tasks.spawn(future) } } diff --git a/src/pubsub/trait.rs b/src/pubsub/trait.rs index db54645..087965b 100644 --- a/src/pubsub/trait.rs +++ b/src/pubsub/trait.rs @@ -1,11 +1,14 @@ -use crate::pubsub::{ - shared::{ConnectionManager, ListenerTask, DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT}, - ServerShutdown, +use crate::{ + pubsub::{ + shared::{ConnectionManager, ListenerTask, DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT}, + ServerShutdown, + }, + TaskSet, }; use bytes::Bytes; use serde_json::value::RawValue; use std::future::Future; -use tokio::sync::watch; +use tokio::runtime::Handle; use tokio_stream::Stream; /// Convenience alias for naming stream halves. @@ -67,26 +70,41 @@ pub trait Connect: Send + Sync + Sized { /// We do not recommend overriding this method. Doing so will opt out of /// the library's pubsub task system. Users overriding this method must /// manually handle connection tasks. - fn serve( + fn serve_on_handle( self, router: crate::Router<()>, + handle: Handle, ) -> impl Future> + Send { async move { + let root_tasks: TaskSet = handle.into(); let notification_buffer_per_task = self.notification_buffer_size(); - let (tx, rx) = watch::channel(()); + ListenerTask { listener: self.make_listener().await?, manager: ConnectionManager { - shutdown: rx, next_id: 0, router, notification_buffer_per_task, + root_tasks: root_tasks.clone(), }, } .spawn(); - Ok(tx.into()) + Ok(root_tasks.into()) } } + + /// Instantiate and run a task to accept connections, returning a shutdown + /// signal. + /// + /// We do not recommend overriding this method. Doing so will opt out of + /// the library's pubsub task system. Users overriding this method must + /// manually handle connection tasks. + fn serve( + self, + router: crate::Router<()>, + ) -> impl Future> + Send { + self.serve_on_handle(router, Handle::current()) + } } /// A [`Listener`] accepts incoming connections and produces [`JsonSink`] and diff --git a/src/router.rs b/src/router.rs index 490e189..ac59cfe 100644 --- a/src/router.rs +++ b/src/router.rs @@ -331,7 +331,20 @@ where /// Nest this router into a new Axum router, with the specified path. #[cfg(feature = "axum")] pub fn into_axum(self, path: &str) -> axum::Router { - axum::Router::new().route(path, axum::routing::post(self)) + axum::Router::new().route(path, axum::routing::post(crate::axum::IntoAxum::from(self))) + } + + /// Nest this router into a new Axum router, with the specified path and + /// using the specified runtime handle. + pub fn into_axum_with_handle( + self, + path: &str, + handle: tokio::runtime::Handle, + ) -> axum::Router { + axum::Router::new().route( + path, + axum::routing::post(crate::axum::IntoAxum::new(self, handle)), + ) } } diff --git a/src/routes/ctx.rs b/src/routes/ctx.rs index dcc9513..f89c0e5 100644 --- a/src/routes/ctx.rs +++ b/src/routes/ctx.rs @@ -1,6 +1,8 @@ -use crate::{types::Request, RpcSend}; +use std::future::Future; + +use crate::{types::Request, RpcSend, TaskSet}; use serde_json::value::RawValue; -use tokio::sync::mpsc; +use tokio::{runtime::Handle, sync::mpsc}; use tracing::error; /// Errors that can occur when sending notifications. @@ -23,28 +25,44 @@ pub enum NotifyError { #[derive(Debug, Clone, Default)] pub struct HandlerCtx { pub(crate) notifications: Option>>, + + /// A task set on which to spawn tasks. This is used to coordinate + pub(crate) tasks: TaskSet, } -impl From>> for HandlerCtx { - fn from(notifications: mpsc::Sender>) -> Self { +impl From for HandlerCtx { + fn from(tasks: TaskSet) -> Self { Self { - notifications: Some(notifications), + notifications: None, + tasks, } } } -impl HandlerCtx { - /// Instantiate a new handler context. - pub const fn new() -> Self { +impl From for HandlerCtx { + fn from(handle: Handle) -> Self { Self { notifications: None, + tasks: handle.into(), } } +} - /// Instantiation a new handler context with notifications enabled. - pub const fn with_notifications(notifications: mpsc::Sender>) -> Self { +impl HandlerCtx { + /// Create a new handler context. + pub fn new(notifications: Option>>, tasks: TaskSet) -> Self { + Self { + notifications, + tasks, + } + } + + /// Instantiation a new handler context with notifications enabled and a + /// default [`TaskSet`]. + pub fn notifications_only(notifications: mpsc::Sender>) -> Self { Self { notifications: Some(notifications), + tasks: Default::default(), } } @@ -73,6 +91,23 @@ impl HandlerCtx { Ok(()) } + + /// Spawn a task on the task set. + pub fn spawn(&self, f: F) + where + F: Future + Send + 'static, + { + self.tasks.spawn(f); + } + + /// Spawn a task on the task set that may block. + pub fn spawn_blocking(&self, f: F) + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.tasks.spawn_blocking(f); + } } /// Arguments passed to a handler. diff --git a/src/tasks.rs b/src/tasks.rs new file mode 100644 index 0000000..dd5c892 --- /dev/null +++ b/src/tasks.rs @@ -0,0 +1,141 @@ +use std::future::Future; + +use tokio::{runtime::Handle, task::JoinHandle}; +use tokio_util::{ + sync::{CancellationToken, WaitForCancellationFuture, WaitForCancellationFutureOwned}, + task::{task_tracker::TaskTrackerWaitFuture, TaskTracker}, +}; + +/// This is a wrapper around a [`TaskTracker`] and a [`token`]. It is used to +/// manage a set of tasks, and to token them to shut down when the set is +/// dropped. +#[derive(Debug, Clone, Default)] +pub struct TaskSet { + tasks: TaskTracker, + token: CancellationToken, + handle: Option, +} + +impl From for TaskSet { + fn from(handle: Handle) -> Self { + Self::with_handle(handle) + } +} + +impl TaskSet { + /// Create a new [`TaskSet`]. + pub fn new() -> Self { + Self { + tasks: TaskTracker::new(), + token: CancellationToken::new(), + handle: None, + } + } + + /// Create a new [`TaskSet`] with a handle. + pub fn with_handle(handle: Handle) -> Self { + Self { + tasks: TaskTracker::new(), + token: CancellationToken::new(), + handle: Some(handle), + } + } + + /// Get a handle to the runtime that the task set is running on. + pub fn handle(&self) -> Handle { + self.handle + .clone() + .unwrap_or_else(tokio::runtime::Handle::current) + } + + /// Close the task set, preventing new tasks from being added. + /// + /// See [`TaskTracker::close`] + pub fn close(&self) { + self.tasks.close(); + } + + /// Reopen the task set, allowing new tasks to be added. + /// + /// See [`TaskTracker::reopen`] + pub fn reopen(&self) { + self.tasks.reopen(); + } + + /// Cancel the token, causing all tasks to be cancelled. + pub fn cancel(&self) { + self.token.cancel(); + } + + /// Get a future that resolves when the token is fired. + pub fn cancelled(&self) -> WaitForCancellationFuture<'_> { + self.token.cancelled() + } + + /// Get a clone of the cancellation token. + pub fn cancelled_owned(&self) -> WaitForCancellationFutureOwned { + self.token.clone().cancelled_owned() + } + + /// Returns a future that resolves when the token is fired. + /// + /// See [`TaskTracker::wait`] + pub fn wait(&self) -> TaskTrackerWaitFuture<'_> { + self.tasks.wait() + } + + /// Convenience function to both fire the token and wait for all tasks to + /// complete. + pub fn shutdown(&self) -> TaskTrackerWaitFuture<'_> { + self.cancel(); + self.wait() + } + + /// Get a child [`TaskSet`]. This set will be fired when the parent + /// set is fired, or may be fired independently. + pub fn child(&self) -> Self { + Self { + tasks: TaskTracker::new(), + token: self.token.child_token(), + handle: self.handle.clone(), + } + } + + /// Prepare a future to be added to the task set, by wrapping it with a + /// cancellation token. + fn prep_fut(&self, task: F) -> impl Future> + Send + 'static + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let token = self.token.clone(); + async move { + tokio::select! { + _ = token.cancelled() => None, + result = task => Some(result), + } + } + } + + /// Spawn a future on the provided handle, and add it to the task set. + pub fn spawn(&self, task: F) -> JoinHandle> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.tasks.spawn_on(self.prep_fut(task), &self.handle()) + } + + /// Spawn a blocking future on the provided handle, and add it to the task + /// set + pub fn spawn_blocking(&self, task: F) -> JoinHandle> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let h = self.handle(); + let task = self.prep_fut(task); + self.tasks + .spawn_blocking_on(move || h.block_on(task), &self.handle()) + } +} From 0cb3220f97f6894511b6a7f1fa534c3715d6e320 Mon Sep 17 00:00:00 2001 From: James Date: Wed, 29 Jan 2025 15:25:46 -0500 Subject: [PATCH 02/12] fix: get a handle --- src/pubsub/shared.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pubsub/shared.rs b/src/pubsub/shared.rs index 8cd531c..496e648 100644 --- a/src/pubsub/shared.rs +++ b/src/pubsub/shared.rs @@ -238,7 +238,7 @@ where }; // Run the future in a new task. - tokio::spawn( + tasks.spawn( async move { // Send the response to the write task. // we don't care if the receiver has gone away, From 348b859ab1339ea475dbf9b5aee738f0a3987c3a Mon Sep 17 00:00:00 2001 From: James Date: Wed, 29 Jan 2025 16:02:42 -0500 Subject: [PATCH 03/12] cleanup: docs and pub --- src/axum.rs | 10 ++++- src/lib.rs | 4 +- src/pubsub/shared.rs | 2 +- src/pubsub/trait.rs | 28 +++++++++++++- src/router.rs | 12 +++++- src/routes/ctx.rs | 61 +++++++++++++++++++----------- src/routes/handler.rs | 72 +++++++++++++++++++++++++++-------- src/tasks.rs | 87 +++++++++++++++---------------------------- tests/common/mod.rs | 2 +- 9 files changed, 175 insertions(+), 103 deletions(-) diff --git a/src/axum.rs b/src/axum.rs index 7cbf305..766881a 100644 --- a/src/axum.rs +++ b/src/axum.rs @@ -7,9 +7,15 @@ use bytes::Bytes; use std::{future::Future, pin::Pin}; use tokio::runtime::Handle; -/// A wrapper around an [`ajj::Router`] that implements the [`axum::handler::Handler`] trait. +/// A wrapper around an [`Router`] that implements the +/// [`axum::handler::Handler`] trait. This struct is an implementation detail +/// of the [`Router::into_axum`] and [`Router::into_axum_with_handle`] methods. +/// +/// [`Router`]: crate::Router +/// [`Router::into_axum`]: crate::Router::into_axum +/// [`Router::into_axum_with_handle`]: crate::Router::into_axum_with_handle #[derive(Debug, Clone)] -pub struct IntoAxum { +pub(crate) struct IntoAxum { pub(crate) router: crate::Router, pub(crate) task_set: TaskSet, } diff --git a/src/lib.rs b/src/lib.rs index c445917..1d5ba76 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -130,8 +130,6 @@ pub(crate) mod macros; #[cfg(feature = "axum")] mod axum; -#[cfg(feature = "axum")] -pub use axum::IntoAxum; mod error; pub use error::RegistrationError; @@ -155,7 +153,7 @@ mod router; pub use router::Router; mod tasks; -pub use tasks::TaskSet; +pub(crate) use tasks::TaskSet; mod types; pub use types::{ErrorPayload, ResponsePayload}; diff --git a/src/pubsub/shared.rs b/src/pubsub/shared.rs index 496e648..0687b94 100644 --- a/src/pubsub/shared.rs +++ b/src/pubsub/shared.rs @@ -31,7 +31,7 @@ impl From for ServerShutdown { impl ServerShutdown { /// Create a new `ServerShutdown` with the given shutdown signal and task /// set. - pub fn new(task_set: TaskSet) -> Self { + pub(crate) const fn new(task_set: TaskSet) -> Self { Self { task_set } } } diff --git a/src/pubsub/trait.rs b/src/pubsub/trait.rs index 087965b..39281e0 100644 --- a/src/pubsub/trait.rs +++ b/src/pubsub/trait.rs @@ -37,6 +37,24 @@ pub type In = ::ReqStream; /// may produce arbitrary response bodies, only the server developer can /// accurately set this value. We have provided a low default. Setting it too /// high may allow resource exhaustion attacks. +/// +/// ## Task management +/// +/// When using the default impls of [`Connect::serve`] and +/// [`Connect::serve_with_handle`], the library will manage task sets for +/// inbound connections. These follow a per-connection hierarchical task model. +/// The root task set is associated with the server, and is used to spawn a +/// task that listens for inbound connections. Each connection is then given +/// a child task set, which is used to spawn tasks for that connection. +/// +/// This task set is propagated to [`Handler`]s via the [`HandlerCtx`]. They may +/// then use it to spawn tasks that are themselves associated with the +/// connection. This ensures that, for properly-implemented [`Handler`]s, all +/// tasks associated with a connection are automatically cancelled and cleaned +/// up when the connection is closed. +/// +/// [`Handler`]: crate::Handler +/// [`HandlerCtx`]: crate::HandlerCtx pub trait Connect: Send + Sync + Sized { /// The listener type produced by the connect object. type Listener: Listener; @@ -70,7 +88,9 @@ pub trait Connect: Send + Sync + Sized { /// We do not recommend overriding this method. Doing so will opt out of /// the library's pubsub task system. Users overriding this method must /// manually handle connection tasks. - fn serve_on_handle( + /// + /// The provided handle will be used to spawn tasks. + fn serve_with_handle( self, router: crate::Router<()>, handle: Handle, @@ -99,11 +119,15 @@ pub trait Connect: Send + Sync + Sized { /// We do not recommend overriding this method. Doing so will opt out of /// the library's pubsub task system. Users overriding this method must /// manually handle connection tasks. + /// + /// ## Panics + /// + /// This will panic if called outside the context of a Tokio runtime. fn serve( self, router: crate::Router<()>, ) -> impl Future> + Send { - self.serve_on_handle(router, Handle::current()) + self.serve_with_handle(router, Handle::current()) } } diff --git a/src/router.rs b/src/router.rs index ac59cfe..cba63b8 100644 --- a/src/router.rs +++ b/src/router.rs @@ -328,7 +328,10 @@ where fut } - /// Nest this router into a new Axum router, with the specified path. + /// Nest this router into a new Axum router, with the specified path and the currently-running + /// + /// Users needing specific control over the runtime should use + /// [`Router::into_axum_with_handle`] instead. #[cfg(feature = "axum")] pub fn into_axum(self, path: &str) -> axum::Router { axum::Router::new().route(path, axum::routing::post(crate::axum::IntoAxum::from(self))) @@ -336,6 +339,13 @@ where /// Nest this router into a new Axum router, with the specified path and /// using the specified runtime handle. + /// + /// This method allows users to specify a runtime handle for the router to + /// use. This runtime is accessible to all handlers invoked by the router. + /// Handlers. + /// + /// Tasks spawned by the router will be spawned on the provided runtime, + /// and automatically cancelled when the returned `axum::Router` is dropped. pub fn into_axum_with_handle( self, path: &str, diff --git a/src/routes/ctx.rs b/src/routes/ctx.rs index f89c0e5..930d332 100644 --- a/src/routes/ctx.rs +++ b/src/routes/ctx.rs @@ -2,7 +2,7 @@ use std::future::Future; use crate::{types::Request, RpcSend, TaskSet}; use serde_json::value::RawValue; -use tokio::{runtime::Handle, sync::mpsc}; +use tokio::{runtime::Handle, sync::mpsc, task::JoinHandle}; use tracing::error; /// Errors that can occur when sending notifications. @@ -17,11 +17,14 @@ pub enum NotifyError { } /// A context for handler requests that allow the handler to send notifications -/// from long-running tasks (e.g. subscriptions). +/// and spawn long-running tasks (e.g. subscriptions). /// -/// This is primarily intended to enable subscriptions over pubsub transports -/// to send notifications to clients. It is expected that JSON sent via the -/// notification channel is a valid JSON-RPC 2.0 object. +/// The handler is used for two things: +/// - Spawning long-running tasks (e.g. subscriptions) via +/// [`HandlerCtx::spawn`] or [`HandlerCtx::spawn_blocking`]. +/// - Sending notifications to pubsub clients via [`HandlerCtx::notify`]. +/// Notifcations SHOULD be valid JSON-RPC objects, but this is +/// not enforced by the type system. #[derive(Debug, Clone, Default)] pub struct HandlerCtx { pub(crate) notifications: Option>>, @@ -50,22 +53,16 @@ impl From for HandlerCtx { impl HandlerCtx { /// Create a new handler context. - pub fn new(notifications: Option>>, tasks: TaskSet) -> Self { + pub(crate) const fn new( + notifications: Option>>, + tasks: TaskSet, + ) -> Self { Self { notifications, tasks, } } - /// Instantiation a new handler context with notifications enabled and a - /// default [`TaskSet`]. - pub fn notifications_only(notifications: mpsc::Sender>) -> Self { - Self { - notifications: Some(notifications), - tasks: Default::default(), - } - } - /// Get a reference to the notification sender. This is used to /// send notifications over pubsub transports. pub const fn notifications(&self) -> Option<&mpsc::Sender>> { @@ -93,20 +90,42 @@ impl HandlerCtx { } /// Spawn a task on the task set. - pub fn spawn(&self, f: F) + pub fn spawn(&self, f: F) -> JoinHandle> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.tasks.spawn(f) + } + + /// Spawn a task on the task set with access to this context. + pub fn spawn_with_ctx(&self, f: F) -> JoinHandle> where - F: Future + Send + 'static, + F: FnOnce(HandlerCtx) -> Fut, + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, { - self.tasks.spawn(f); + self.tasks.spawn(f(self.clone())) } - /// Spawn a task on the task set that may block. - pub fn spawn_blocking(&self, f: F) + /// Spawn a task that may block on the task set. + pub fn spawn_blocking(&self, f: F) -> JoinHandle> where F: Future + Send + 'static, F::Output: Send + 'static, { - self.tasks.spawn_blocking(f); + self.tasks.spawn_blocking(f) + } + + /// Spawn a task that may block on the task set, with access to this + /// context. + pub fn spawn_blocking_with_ctx(&self, f: F) -> JoinHandle> + where + F: FnOnce(HandlerCtx) -> Fut, + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + self.tasks.spawn_blocking(f(self.clone())) } } diff --git a/src/routes/handler.rs b/src/routes/handler.rs index 6b9bb62..ca9d974 100644 --- a/src/routes/handler.rs +++ b/src/routes/handler.rs @@ -54,6 +54,62 @@ pub struct PhantomParams(PhantomData); /// }; /// ``` /// +/// ### The [`HandlerCtx`]: tasks and notifications. +/// +/// Any handler may accept the [`HandlerCtx`] as its first argument. This +/// context object is used to context associated with the client connection. It +/// can be used to spawn tasks and (when `pubsub` is enabled) send notifications +/// to the client. +/// +/// Handlers **SHOULD NOT** use [`tokio::spawn`] or +/// [`tokio::task::spawn_blocking`] directly. Instead, they should use the +/// [`HandlerCtx::spawn`] or [`HandlerCtx::spawn_blocking`] methods. These +/// methods ensure that tasks are associated with the client connection, and +/// are cleaned up promptly when the connection is closed, and that the server's +/// runtime configuration is respected. +/// +/// When the task itself requires a context, [`HandlerCtx::spawn_with_ctx`] +/// and [`HandlerCtx::spawn_blocking_with_ctx`] can be used to provide a context +/// to the spawned task. This is a thin wrapper around cloning the context and +/// moving it into the spawned future. +/// +/// ``` +/// use ajj::{Router, HandlerCtx}; +/// +/// # fn test_fn() -> Router<()> { +/// Router::new() +/// .route("good citizenship", |ctx: HandlerCtx| async move { +/// // Properly implemented task management +/// ctx.spawn(async { +/// // do something +/// }); +/// Ok::<_, ()>(()) +/// }) +/// .route("bad citizenship", |ctx: HandlerCtx| async move { +/// // Incorrect task management +/// tokio::spawn(async { +/// // do something +/// }); +/// Ok::<_, ()>(()) +/// }) +/// # } +/// ``` +/// +/// When running on pubsub, handlers can send notifications to the client. This +/// is done by calling [`HandlerCtx::notify`]. Notifications are sent as JSON +/// objects, and are queued for sending to the client. If the client is not +/// reading from the connection, the notification will be queued in a buffer. +/// If the buffer is full, the handler will be backpressured until the buffer +/// has room. +/// +/// We recommend that handler tasks `await` on the result of +/// [`HandlerCtx::notify`], to ensure that they are backpressured when the +/// notification buffer is full. If many tasks are attempting to notify the +/// same client, the buffer may fill up, and backpressure the `RouteTask` from +/// reading requests from the connection. This can lead to delays in request +/// processing. +/// +/// /// ### Handler argument type inference /// /// When the following conditions are true, the compiler may fail to infer @@ -191,22 +247,6 @@ pub struct PhantomParams(PhantomData); /// // specify the Payload on your Failure /// let handler_d = || async { ResponsePayload::<(), _>::internal_error_with_obj(4) }; /// ``` -/// -/// ## Notifications -/// -/// When running on pubsub, handlers can send notifications to the client. This -/// is done by calling [`HandlerCtx::notify`]. Notifications are sent as JSON -/// objects, and are queued for sending to the client. If the client is not -/// reading from the connection, the notification will be queued in a buffer. -/// If the buffer is full, the handler will be backpressured until the buffer -/// has room. -/// -/// We recommend that handlers `await` on the result of [`HandlerCtx::notify`], -/// to ensure that they are backpressured when the notification buffer is full. -/// If many tasks are attempting to notify the same client, the buffer may fill -/// up, and backpressure the `RouteTask` from reading requests from the -/// connection. This can lead to delays in request processing. -/// #[cfg_attr( feature = "pubsub", doc = "see the [`Listener`] documetnation for more information." diff --git a/src/tasks.rs b/src/tasks.rs index dd5c892..00ec19f 100644 --- a/src/tasks.rs +++ b/src/tasks.rs @@ -2,15 +2,18 @@ use std::future::Future; use tokio::{runtime::Handle, task::JoinHandle}; use tokio_util::{ - sync::{CancellationToken, WaitForCancellationFuture, WaitForCancellationFutureOwned}, - task::{task_tracker::TaskTrackerWaitFuture, TaskTracker}, + sync::{CancellationToken, WaitForCancellationFuture}, + task::TaskTracker, }; -/// This is a wrapper around a [`TaskTracker`] and a [`token`]. It is used to -/// manage a set of tasks, and to token them to shut down when the set is -/// dropped. +/// This is a wrapper around a [`TaskTracker`] and a [`CancellationToken`]. It +/// is used to manage a set of tasks, and to token them to shut down when the +/// set is dropped. +/// +/// When a [`Handle`] is provided, tasks are spawned on that handle. Otherwise, +/// they are spawned on the current runtime. #[derive(Debug, Clone, Default)] -pub struct TaskSet { +pub(crate) struct TaskSet { tasks: TaskTracker, token: CancellationToken, handle: Option, @@ -23,17 +26,8 @@ impl From for TaskSet { } impl TaskSet { - /// Create a new [`TaskSet`]. - pub fn new() -> Self { - Self { - tasks: TaskTracker::new(), - token: CancellationToken::new(), - handle: None, - } - } - /// Create a new [`TaskSet`] with a handle. - pub fn with_handle(handle: Handle) -> Self { + pub(crate) fn with_handle(handle: Handle) -> Self { Self { tasks: TaskTracker::new(), token: CancellationToken::new(), @@ -42,58 +36,29 @@ impl TaskSet { } /// Get a handle to the runtime that the task set is running on. - pub fn handle(&self) -> Handle { + /// + /// ## Panics + /// + /// This will panic if called outside the context of a Tokio runtime. + pub(crate) fn handle(&self) -> Handle { self.handle .clone() .unwrap_or_else(tokio::runtime::Handle::current) } - /// Close the task set, preventing new tasks from being added. - /// - /// See [`TaskTracker::close`] - pub fn close(&self) { - self.tasks.close(); - } - - /// Reopen the task set, allowing new tasks to be added. - /// - /// See [`TaskTracker::reopen`] - pub fn reopen(&self) { - self.tasks.reopen(); - } - /// Cancel the token, causing all tasks to be cancelled. - pub fn cancel(&self) { + pub(crate) fn cancel(&self) { self.token.cancel(); } /// Get a future that resolves when the token is fired. - pub fn cancelled(&self) -> WaitForCancellationFuture<'_> { + pub(crate) fn cancelled(&self) -> WaitForCancellationFuture<'_> { self.token.cancelled() } - /// Get a clone of the cancellation token. - pub fn cancelled_owned(&self) -> WaitForCancellationFutureOwned { - self.token.clone().cancelled_owned() - } - - /// Returns a future that resolves when the token is fired. - /// - /// See [`TaskTracker::wait`] - pub fn wait(&self) -> TaskTrackerWaitFuture<'_> { - self.tasks.wait() - } - - /// Convenience function to both fire the token and wait for all tasks to - /// complete. - pub fn shutdown(&self) -> TaskTrackerWaitFuture<'_> { - self.cancel(); - self.wait() - } - /// Get a child [`TaskSet`]. This set will be fired when the parent /// set is fired, or may be fired independently. - pub fn child(&self) -> Self { + pub(crate) fn child(&self) -> Self { Self { tasks: TaskTracker::new(), token: self.token.child_token(), @@ -118,7 +83,12 @@ impl TaskSet { } /// Spawn a future on the provided handle, and add it to the task set. - pub fn spawn(&self, task: F) -> JoinHandle> + /// + /// ## Panics + /// + /// This will panic if called outside the context of a Tokio runtime when + /// `self.handle` is `None`. + pub(crate) fn spawn(&self, task: F) -> JoinHandle> where F: Future + Send + 'static, F::Output: Send + 'static, @@ -127,8 +97,13 @@ impl TaskSet { } /// Spawn a blocking future on the provided handle, and add it to the task - /// set - pub fn spawn_blocking(&self, task: F) -> JoinHandle> + /// set. + /// + /// ## Panics + /// + /// This will panic if called outside the context of a Tokio runtime when + /// `self.handle` is `None`. + pub(crate) fn spawn_blocking(&self, task: F) -> JoinHandle> where F: Future + Send + 'static, F::Output: Send + 'static, diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 8c5b909..8b4f9dd 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -13,7 +13,7 @@ pub fn test_router() -> ajj::Router<()> { |params: usize| async move { Ok::<_, ()>(params * 2) }, ) .route("call_me_later", |ctx: HandlerCtx| async move { - tokio::task::spawn(async move { + ctx.spawn_with_ctx(|ctx| async move { time::sleep(time::Duration::from_millis(100)).await; let _ = ctx From d1b2f1339522e71dd16a42c67604eb4df8fc3abf Mon Sep 17 00:00:00 2001 From: James Date: Wed, 29 Jan 2025 16:17:23 -0500 Subject: [PATCH 04/12] doc: expand slightly --- src/routes/handler.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/routes/handler.rs b/src/routes/handler.rs index ca9d974..fe9b934 100644 --- a/src/routes/handler.rs +++ b/src/routes/handler.rs @@ -79,14 +79,16 @@ pub struct PhantomParams(PhantomData); /// # fn test_fn() -> Router<()> { /// Router::new() /// .route("good citizenship", |ctx: HandlerCtx| async move { -/// // Properly implemented task management +/// // Properly implemented task management. This task will +/// // automatically be cleaned up when the connection is closed. /// ctx.spawn(async { /// // do something /// }); /// Ok::<_, ()>(()) /// }) /// .route("bad citizenship", |ctx: HandlerCtx| async move { -/// // Incorrect task management +/// // Incorrect task management. Will result in the task running for +/// // some amount of time after the connection is closed. /// tokio::spawn(async { /// // do something /// }); From 9a862d421c5b58e24ff9bf9bf46d2b6c65b3b935 Mon Sep 17 00:00:00 2001 From: James Date: Wed, 29 Jan 2025 16:25:03 -0500 Subject: [PATCH 05/12] refactor: clean up all tasks before returning --- src/pubsub/shared.rs | 3 ++- src/tasks.rs | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/pubsub/shared.rs b/src/pubsub/shared.rs index 0687b94..93f3205 100644 --- a/src/pubsub/shared.rs +++ b/src/pubsub/shared.rs @@ -254,7 +254,7 @@ where } } } - tasks.cancel(); + tasks.shutdown().await; } /// Spawn the future produced by [`Self::task_future`]. @@ -321,6 +321,7 @@ impl WriteTask { } } } + tasks.shutdown().await; } /// Spawn the future produced by [`Self::task_future`]. diff --git a/src/tasks.rs b/src/tasks.rs index 00ec19f..03cb06d 100644 --- a/src/tasks.rs +++ b/src/tasks.rs @@ -51,6 +51,11 @@ impl TaskSet { self.token.cancel(); } + pub(crate) async fn shutdown(&self) { + self.cancel(); + self.tasks.wait().await + } + /// Get a future that resolves when the token is fired. pub(crate) fn cancelled(&self) -> WaitForCancellationFuture<'_> { self.token.cancelled() From b757ed7640379e2f7ca13ca7d20bcc0c8681cf80 Mon Sep 17 00:00:00 2001 From: James Date: Thu, 30 Jan 2025 08:38:54 -0500 Subject: [PATCH 06/12] refactor: use a future instead of the token --- src/pubsub/shared.rs | 29 ++++++++++++-- src/routes/ctx.rs | 94 ++++++++++++++++++++++++++++++++++++++++---- src/tasks.rs | 84 ++++++++++++++++++++++++++++++++++----- 3 files changed, 186 insertions(+), 21 deletions(-) diff --git a/src/pubsub/shared.rs b/src/pubsub/shared.rs index 93f3205..82c0dcb 100644 --- a/src/pubsub/shared.rs +++ b/src/pubsub/shared.rs @@ -8,6 +8,7 @@ use crate::{ use serde_json::value::RawValue; use tokio::{select, sync::mpsc, task::JoinHandle}; use tokio_stream::StreamExt; +use tokio_util::task::task_tracker::TaskTrackerWaitFuture; use tracing::{debug, debug_span, error, instrument, trace, Instrument}; /// Default notification buffer size per task. @@ -34,6 +35,26 @@ impl ServerShutdown { pub(crate) const fn new(task_set: TaskSet) -> Self { Self { task_set } } + + /// Wait for the tasks spawned by the server to complete. + /// + /// This future will not resolve until both of the following are true: + /// - [`ServerShutdown::close`] + pub fn wait(&self) -> TaskTrackerWaitFuture<'_> { + self.task_set.wait() + } + + /// Close the task tracker, allowing [`Self::wait`] futures to resolve. + pub fn close(&self) { + self.task_set.close(); + } + + /// Shutdown the server, and wait for all tasks to complete. + pub async fn shutdown(self) { + self.task_set.cancel(); + self.close(); + self.wait().await; + } } impl Drop for ServerShutdown { @@ -81,7 +102,7 @@ where pub(crate) fn spawn(self) -> JoinHandle> { let tasks = self.manager.root_tasks.clone(); let future = self.task_future(); - tasks.spawn(future) + tasks.spawn_cancellable(future) } } @@ -238,7 +259,7 @@ where }; // Run the future in a new task. - tasks.spawn( + tasks.spawn_cancellable( async move { // Send the response to the write task. // we don't care if the receiver has gone away, @@ -263,7 +284,7 @@ where let future = self.task_future(); - tasks.spawn(future) + tasks.spawn_cancellable(future) } } @@ -328,6 +349,6 @@ impl WriteTask { pub(crate) fn spawn(self) -> tokio::task::JoinHandle> { let tasks = self.tasks.clone(); let future = self.task_future(); - tasks.spawn(future) + tasks.spawn_cancellable(future) } } diff --git a/src/routes/ctx.rs b/src/routes/ctx.rs index 930d332..6a06dc4 100644 --- a/src/routes/ctx.rs +++ b/src/routes/ctx.rs @@ -3,6 +3,7 @@ use std::future::Future; use crate::{types::Request, RpcSend, TaskSet}; use serde_json::value::RawValue; use tokio::{runtime::Handle, sync::mpsc, task::JoinHandle}; +use tokio_util::sync::WaitForCancellationFutureOwned; use tracing::error; /// Errors that can occur when sending notifications. @@ -89,43 +90,120 @@ impl HandlerCtx { Ok(()) } - /// Spawn a task on the task set. + /// Spawn a task on the task set. This task will be cancelled if the + /// client disconnects. This is useful for long-running server tasks. + /// + /// The resulting [`JoinHandle`] will contain [`None`] if the task was + /// cancelled, and `Some` otherwise. pub fn spawn(&self, f: F) -> JoinHandle> where F: Future + Send + 'static, F::Output: Send + 'static, { - self.tasks.spawn(f) + self.tasks.spawn_cancellable(f) } - /// Spawn a task on the task set with access to this context. + /// Spawn a task on the task set with access to this context. This + /// task will be cancelled if the client disconnects. This is useful + /// for long-running tasks like subscriptions. + /// + /// The resulting [`JoinHandle`] will contain [`None`] if the task was + /// cancelled, and `Some` otherwise. pub fn spawn_with_ctx(&self, f: F) -> JoinHandle> where F: FnOnce(HandlerCtx) -> Fut, Fut: Future + Send + 'static, Fut::Output: Send + 'static, { - self.tasks.spawn(f(self.clone())) + self.tasks.spawn_cancellable(f(self.clone())) } - /// Spawn a task that may block on the task set. + /// Spawn a task that may block on the task set. This task may block, and + /// will be cancelled if the client disconnects. This is useful for + /// running expensive tasks that require blocking IO (e.g. database + /// queries). + /// + /// The resulting [`JoinHandle`] will contain [`None`] if the task was + /// cancelled, and `Some` otherwise. pub fn spawn_blocking(&self, f: F) -> JoinHandle> where F: Future + Send + 'static, F::Output: Send + 'static, { - self.tasks.spawn_blocking(f) + self.tasks.spawn_blocking_cancellable(f) } /// Spawn a task that may block on the task set, with access to this - /// context. + /// context. This task may block, and will be cancelled if the client + /// disconnects. This is useful for running expensive tasks that require + /// blocking IO (e.g. database queries). + /// + /// The resulting [`JoinHandle`] will contain [`None`] if the task was + /// cancelled, and `Some` otherwise. pub fn spawn_blocking_with_ctx(&self, f: F) -> JoinHandle> where F: FnOnce(HandlerCtx) -> Fut, Fut: Future + Send + 'static, Fut::Output: Send + 'static, { - self.tasks.spawn_blocking(f(self.clone())) + self.tasks.spawn_blocking_cancellable(f(self.clone())) + } + + /// Spawn a task on this task set. Unlike [`Self::spawn`], this task will + /// NOT be cancelled if the client disconnects. Instead, it + /// is given a future that resolves when client disconnects. This is useful + /// for tasks that need to clean up resources before completing. + pub fn spawn_graceful(&self, f: F) -> JoinHandle + where + F: FnOnce(WaitForCancellationFutureOwned) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + self.tasks.spawn_graceful(f) + } + + /// Spawn a task on this task set with access to this context. Unlike + /// [`Self::spawn`], this task will NOT be cancelled if the client + /// disconnects. Instead, it is given a future that resolves when client + /// disconnects. This is useful for tasks that need to clean up resources + /// before completing. + pub fn spawn_graceful_with_ctx(&self, f: F) -> JoinHandle + where + F: FnOnce(HandlerCtx, WaitForCancellationFutureOwned) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + let ctx = self.clone(); + self.tasks.spawn_graceful(move |token| f(ctx, token)) + } + + /// Spawn a blocking task on this task set. Unlike [`Self::spawn_blocking`], + /// this task will NOT be cancelled if the client disconnects. Instead, it + /// is given a future that resolves when client disconnects. This is useful + /// for tasks that need to clean up resources before completing. + pub fn spawn_blocking_graceful(&self, f: F) -> JoinHandle + where + F: FnOnce(WaitForCancellationFutureOwned) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + self.tasks.spawn_blocking_graceful(f) + } + + /// Spawn a blocking task on this task set with access to this context. + /// Unlike [`Self::spawn_blocking`], this task will NOT be cancelled if the + /// client disconnects. Instead, it is given a future that resolves when + /// the client disconnects. This is useful for tasks that need to clean up + /// resources before completing. + pub fn spawn_blocking_graceful_with_ctx(&self, f: F) -> JoinHandle + where + F: FnOnce(HandlerCtx, WaitForCancellationFutureOwned) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + let ctx = self.clone(); + self.tasks + .spawn_blocking_graceful(move |token| f(ctx, token)) } } diff --git a/src/tasks.rs b/src/tasks.rs index 03cb06d..6cce3bf 100644 --- a/src/tasks.rs +++ b/src/tasks.rs @@ -2,8 +2,8 @@ use std::future::Future; use tokio::{runtime::Handle, task::JoinHandle}; use tokio_util::{ - sync::{CancellationToken, WaitForCancellationFuture}, - task::TaskTracker, + sync::{CancellationToken, WaitForCancellationFuture, WaitForCancellationFutureOwned}, + task::{task_tracker::TaskTrackerWaitFuture, TaskTracker}, }; /// This is a wrapper around a [`TaskTracker`] and a [`CancellationToken`]. It @@ -49,6 +49,17 @@ impl TaskSet { /// Cancel the token, causing all tasks to be cancelled. pub(crate) fn cancel(&self) { self.token.cancel(); + self.close(); + } + + /// Close the task tracker, allowing [`Self::wait`] futures to resolve. + pub(crate) fn close(&self) { + self.tasks.close(); + } + + /// Get a future that resolves when all tasks in the set are complete. + pub(crate) fn wait(&self) -> TaskTrackerWaitFuture<'_> { + self.tasks.wait() } pub(crate) async fn shutdown(&self) { @@ -73,7 +84,10 @@ impl TaskSet { /// Prepare a future to be added to the task set, by wrapping it with a /// cancellation token. - fn prep_fut(&self, task: F) -> impl Future> + Send + 'static + fn prep_abortable_fut( + &self, + task: F, + ) -> impl Future> + Send + 'static where F: Future + Send + 'static, F::Output: Send + 'static, @@ -87,34 +101,86 @@ impl TaskSet { } } - /// Spawn a future on the provided handle, and add it to the task set. + /// Spawn a future on the provided handle, and add it to the task set. A + /// future spawned this way will be aborted when the [`TaskSet`] is + /// cancelled. + /// + /// If the future completes before the task set is cancelled, the result + /// will be returned. Otherwise, `None` will be returned. /// /// ## Panics /// /// This will panic if called outside the context of a Tokio runtime when /// `self.handle` is `None`. - pub(crate) fn spawn(&self, task: F) -> JoinHandle> + pub(crate) fn spawn_cancellable(&self, task: F) -> JoinHandle> where F: Future + Send + 'static, F::Output: Send + 'static, { - self.tasks.spawn_on(self.prep_fut(task), &self.handle()) + self.tasks + .spawn_on(self.prep_abortable_fut(task), &self.handle()) } /// Spawn a blocking future on the provided handle, and add it to the task - /// set. + /// set. A future spawned this way will be cancelled when the [`TaskSet`] + /// is cancelled. /// /// ## Panics /// /// This will panic if called outside the context of a Tokio runtime when /// `self.handle` is `None`. - pub(crate) fn spawn_blocking(&self, task: F) -> JoinHandle> + pub(crate) fn spawn_blocking_cancellable(&self, task: F) -> JoinHandle> where F: Future + Send + 'static, F::Output: Send + 'static, { let h = self.handle(); - let task = self.prep_fut(task); + let task = self.prep_abortable_fut(task); + self.tasks + .spawn_blocking_on(move || h.block_on(task), &self.handle()) + } + + /// Spawn a future on the provided handle, and add it to the task set. A + /// future spawned this way will not be aborted when the [`TaskSet`] is + /// cancelled, instead it will receive a notification via a + /// [`CancellationToken`]. This allows the future to complete gracefully. + /// This is useful for tasks that need to clean up resources before + /// completing. + /// + /// ## Panics + /// + /// This will panic if called outside the context of a Tokio runtime when + /// `self.handle` is `None`. + pub(crate) fn spawn_graceful(&self, task: F) -> JoinHandle + where + F: FnOnce(WaitForCancellationFutureOwned) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + let token = self.token.clone().cancelled_owned(); + self.tasks.spawn_on(task(token), &self.handle()) + } + + /// Spawn a blocking future on the provided handle, and add it to the task + /// set. A future spawned this way will not be cancelled when the + /// [`TaskSet`] is cancelled, instead it will receive a notification via a + /// [`CancellationToken`]. This allows the future to complete gracefully. + /// This is useful for tasks that need to clean up resources before + /// completing. + /// + /// ## Panics + /// + /// This will panic if called outside the context of a Tokio runtime when + /// `self.handle` is `None`. + pub(crate) fn spawn_blocking_graceful(&self, task: F) -> JoinHandle + where + F: FnOnce(WaitForCancellationFutureOwned) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + let h = self.handle(); + let token = self.token.clone().cancelled_owned(); + let task = task(token); self.tasks .spawn_blocking_on(move || h.block_on(task), &self.handle()) } From ae098c3e625ca091e70b8d6bd791a2b54ee5d7ad Mon Sep 17 00:00:00 2001 From: James Date: Thu, 30 Jan 2025 08:42:55 -0500 Subject: [PATCH 07/12] nit: remove newline --- src/tasks.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tasks.rs b/src/tasks.rs index 6cce3bf..a6036f9 100644 --- a/src/tasks.rs +++ b/src/tasks.rs @@ -1,5 +1,4 @@ use std::future::Future; - use tokio::{runtime::Handle, task::JoinHandle}; use tokio_util::{ sync::{CancellationToken, WaitForCancellationFuture, WaitForCancellationFutureOwned}, From 31618df0c82c9ce18cc33a6343c06685f269bb0a Mon Sep 17 00:00:00 2001 From: James Date: Thu, 30 Jan 2025 08:49:30 -0500 Subject: [PATCH 08/12] doc: readme update --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 71f9e43..dfaa536 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ ajj aims to provide simple, flexible, and ergonomic routing for JSON-RPC. - Support for pubsub-style notifications. - Built-in support for axum, and tower's middleware and service ecosystem. - Basic built-in pubsub server implementations for WS and IPC. +- Connection-oriented task management automatically cancels tasks on client + disconnect. ## Concepts From fa8c183b704ecab6c8045e999d20be87269907cf Mon Sep 17 00:00:00 2001 From: James Date: Thu, 30 Jan 2025 09:32:40 -0500 Subject: [PATCH 09/12] fix: remove unnecessary generic --- src/routes/ctx.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/routes/ctx.rs b/src/routes/ctx.rs index 6a06dc4..0eb39c2 100644 --- a/src/routes/ctx.rs +++ b/src/routes/ctx.rs @@ -125,7 +125,7 @@ impl HandlerCtx { /// /// The resulting [`JoinHandle`] will contain [`None`] if the task was /// cancelled, and `Some` otherwise. - pub fn spawn_blocking(&self, f: F) -> JoinHandle> + pub fn spawn_blocking(&self, f: F) -> JoinHandle> where F: Future + Send + 'static, F::Output: Send + 'static, From 3a9eb3fc558da37f34431649596107fcefc2e30f Mon Sep 17 00:00:00 2001 From: James Date: Fri, 31 Jan 2025 08:09:18 -0500 Subject: [PATCH 10/12] chore: more docs and bubble up more API to shutdown --- src/pubsub/mod.rs | 5 +- src/pubsub/shared.rs | 52 ++----------------- src/pubsub/shutdown.rs | 113 +++++++++++++++++++++++++++++++++++++++++ src/tasks.rs | 23 +++++++-- 4 files changed, 138 insertions(+), 55 deletions(-) create mode 100644 src/pubsub/shutdown.rs diff --git a/src/pubsub/mod.rs b/src/pubsub/mod.rs index febe8fb..abb68dd 100644 --- a/src/pubsub/mod.rs +++ b/src/pubsub/mod.rs @@ -95,7 +95,10 @@ mod ipc; pub use ipc::ReadJsonStream; mod shared; -pub use shared::{ConnectionId, ServerShutdown, DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT}; +pub use shared::{ConnectionId, DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT}; + +mod shutdown; +pub use shutdown::ServerShutdown; mod r#trait; pub use r#trait::{Connect, In, JsonReqStream, JsonSink, Listener, Out}; diff --git a/src/pubsub/shared.rs b/src/pubsub/shared.rs index 82c0dcb..60acb96 100644 --- a/src/pubsub/shared.rs +++ b/src/pubsub/shared.rs @@ -1,14 +1,12 @@ -use core::fmt; - use crate::{ pubsub::{In, JsonSink, Listener, Out}, types::InboundData, HandlerCtx, TaskSet, }; +use core::fmt; use serde_json::value::RawValue; use tokio::{select, sync::mpsc, task::JoinHandle}; use tokio_stream::StreamExt; -use tokio_util::task::task_tracker::TaskTrackerWaitFuture; use tracing::{debug, debug_span, error, instrument, trace, Instrument}; /// Default notification buffer size per task. @@ -17,52 +15,6 @@ pub const DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT: usize = 16; /// Type alias for identifying connections. pub type ConnectionId = u64; -/// Holds the shutdown signal for some server. -#[derive(Debug)] -pub struct ServerShutdown { - pub(crate) task_set: TaskSet, -} - -impl From for ServerShutdown { - fn from(task_set: TaskSet) -> Self { - Self::new(task_set) - } -} - -impl ServerShutdown { - /// Create a new `ServerShutdown` with the given shutdown signal and task - /// set. - pub(crate) const fn new(task_set: TaskSet) -> Self { - Self { task_set } - } - - /// Wait for the tasks spawned by the server to complete. - /// - /// This future will not resolve until both of the following are true: - /// - [`ServerShutdown::close`] - pub fn wait(&self) -> TaskTrackerWaitFuture<'_> { - self.task_set.wait() - } - - /// Close the task tracker, allowing [`Self::wait`] futures to resolve. - pub fn close(&self) { - self.task_set.close(); - } - - /// Shutdown the server, and wait for all tasks to complete. - pub async fn shutdown(self) { - self.task_set.cancel(); - self.close(); - self.wait().await; - } -} - -impl Drop for ServerShutdown { - fn drop(&mut self) { - self.task_set.cancel(); - } -} - /// The `ListenerTask` listens for new connections, and spawns `RouteTask`s for /// each. pub(crate) struct ListenerTask { @@ -313,6 +265,8 @@ impl WriteTask { /// channel, and acts on them. It handles JSON messages, and going away /// instructions. It also listens for the global shutdown signal from the /// [`ServerShutdown`] struct. + /// + /// [`ServerShutdown`]: crate::pubsub::ServerShutdown #[instrument(skip(self), fields(conn_id = self.conn_id))] pub(crate) async fn task_future(self) { let WriteTask { diff --git a/src/pubsub/shutdown.rs b/src/pubsub/shutdown.rs new file mode 100644 index 0000000..660726e --- /dev/null +++ b/src/pubsub/shutdown.rs @@ -0,0 +1,113 @@ +use crate::TaskSet; +use tokio_util::{sync::WaitForCancellationFuture, task::task_tracker::TaskTrackerWaitFuture}; + +/// The shutdown signal for some server. When dropped, will cancel all tasks +/// associated with the running server. This includes all running [`Handler`]s, +/// as well as `pubsub` connection management tasks (if running). +/// +/// The shutdown wraps a [`TaskTracker`] and a [`CancellationToken`], and +/// exposes methods from those APIs. Please see the documentation for those +/// types for more information. +/// +/// [`TaskTracker`]: tokio_util::task::TaskTracker +/// [`CancellationToken`]: tokio_util::sync::CancellationToken +/// [`Handler`]: crate::Handler +#[derive(Debug)] +pub struct ServerShutdown { + pub(crate) task_set: TaskSet, +} + +impl From for ServerShutdown { + fn from(task_set: TaskSet) -> Self { + Self::new(task_set) + } +} + +impl ServerShutdown { + /// Create a new [`ServerShutdown`] with the given [`TaskSet`]. + pub(crate) const fn new(task_set: TaskSet) -> Self { + Self { task_set } + } + + /// Wait for the tasks spawned by the server to complete. This is a wrapper + /// for [`TaskTracker::wait`], and allows outside code to wait for the + /// server to signal that it has completely shut down. + /// + /// This future will not resolve until both of the following are true: + /// - [`Self::close`] has been called. + /// - All tasks spawned by the server have finished running. + /// + /// [`TaskTracker::wait`]: tokio_util::task::TaskTracker::wait + pub fn wait(&self) -> TaskTrackerWaitFuture<'_> { + self.task_set.wait() + } + + /// Close the intenal [`TaskTracker`], cancelling all running tasks, and + /// allowing [`Self::wait`] futures to resolve, provided all tasks are + /// complete. + /// + /// This will not cancel running tasks, and will not prevent new tasks from + /// being spawned. + /// + /// See [`TaskTracker::close`] for more information. + /// + /// [`TaskTracker`]: tokio_util::task::TaskTracker + /// [`TaskTracker::close`]: tokio_util::task::TaskTracker::close + pub fn close(&self) { + self.task_set.close(); + } + + /// Check if the server's internal [`TaskTracker`] has been closed. This + /// does not indicate that all tasks have completed, or that the server has + /// been cancelled. See [`TaskTracker::is_closed`] for more information. + /// + /// [`TaskTracker`]: tokio_util::task::TaskTracker + /// [`TaskTracker::is_closed`]: tokio_util::task::TaskTracker::is_closed + pub fn is_closed(&self) -> bool { + self.task_set.is_closed() + } + + /// Issue a cancellation signal to all tasks spawned by the server. This + /// will immediately cancel tasks spawned with the [`HandlerCtx::spawn`] + /// family of methods, and will issue cancellation signals to tasks + /// spawned with [`HandlerCtx::spawn_graceful`] family of methods. + /// + /// This will also cause new tasks spawned by the server to be immediately + /// cancelled, or notified of cancellation. + /// + /// [`HandlerCtx::spawn`]: crate::HandlerCtx::spawn + /// [`HandlerCtx::spawn_graceful`]: crate::HandlerCtx::spawn_graceful + pub fn cancel(&self) { + self.task_set.cancel(); + } + + /// Check if the server has been cancelled. `true` indicates that the + /// server has been instructed to shut down, and all tasks have either been + /// cancelled, or have received a notification that they should shut down. + pub fn is_cancelled(&self) -> bool { + self.task_set.is_cancelled() + } + + /// Get a future that resolves when the server has been cancelled. This + /// future will resolve when [`Self::cancel`] has been called, and all + /// tasks have been issued a cancellation signal. + pub fn cancelled(&self) -> WaitForCancellationFuture<'_> { + self.task_set.cancelled() + } + + /// Shutdown the server, and wait for all tasks to complete. + /// + /// This is equivalent to calling [`Self::cancel`], [`Self::close`] and + /// then awaiting [`Self::wait`]. + pub async fn shutdown(self) { + self.task_set.cancel(); + self.close(); + self.wait().await; + } +} + +impl Drop for ServerShutdown { + fn drop(&mut self) { + self.task_set.cancel(); + } +} diff --git a/src/tasks.rs b/src/tasks.rs index a6036f9..463814c 100644 --- a/src/tasks.rs +++ b/src/tasks.rs @@ -51,26 +51,39 @@ impl TaskSet { self.close(); } + /// True if the token is cancelled. + pub(crate) fn is_cancelled(&self) -> bool { + self.token.is_cancelled() + } + + /// Get a future that resolves when the token is fired. + pub(crate) fn cancelled(&self) -> WaitForCancellationFuture<'_> { + self.token.cancelled() + } + /// Close the task tracker, allowing [`Self::wait`] futures to resolve. pub(crate) fn close(&self) { self.tasks.close(); } + /// True if the task set is closed. + pub(crate) fn is_closed(&self) -> bool { + self.tasks.is_closed() + } + /// Get a future that resolves when all tasks in the set are complete. pub(crate) fn wait(&self) -> TaskTrackerWaitFuture<'_> { self.tasks.wait() } + /// Shutdown the task set. This will cancel all tasks and wait for them to + /// complete. pub(crate) async fn shutdown(&self) { self.cancel(); + self.close(); self.tasks.wait().await } - /// Get a future that resolves when the token is fired. - pub(crate) fn cancelled(&self) -> WaitForCancellationFuture<'_> { - self.token.cancelled() - } - /// Get a child [`TaskSet`]. This set will be fired when the parent /// set is fired, or may be fired independently. pub(crate) fn child(&self) -> Self { From 3a7ad5a33034f17afd91140b5dc9b580a33355c3 Mon Sep 17 00:00:00 2001 From: James Date: Fri, 31 Jan 2025 08:18:09 -0500 Subject: [PATCH 11/12] fix: prevent deadlock of task waiting on itself --- src/pubsub/shared.rs | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/src/pubsub/shared.rs b/src/pubsub/shared.rs index 60acb96..7ec5bbd 100644 --- a/src/pubsub/shared.rs +++ b/src/pubsub/shared.rs @@ -5,8 +5,9 @@ use crate::{ }; use core::fmt; use serde_json::value::RawValue; -use tokio::{select, sync::mpsc, task::JoinHandle}; +use tokio::{pin, select, sync::mpsc, task::JoinHandle}; use tokio_stream::StreamExt; +use tokio_util::sync::WaitForCancellationFutureOwned; use tracing::{debug, debug_span, error, instrument, trace, Instrument}; /// Default notification buffer size per task. @@ -163,7 +164,7 @@ where /// to handle the request, and given a sender to the [`WriteTask`]. This /// ensures that requests can be handled concurrently. #[instrument(name = "RouteTask", skip(self), fields(conn_id = self.conn_id))] - pub async fn task_future(self) { + pub async fn task_future(self, cancel: WaitForCancellationFutureOwned) { let RouteTask { router, mut requests, @@ -172,9 +173,18 @@ where .. } = self; + // The write task is responsible for waiting for its children + let children = tasks.child(); + + pin!(cancel); + loop { select! { biased; + _ = &mut cancel => { + debug!("RouteTask cancelled"); + break; + } _ = write_task.closed() => { debug!("WriteTask has gone away"); break; @@ -195,7 +205,7 @@ where let ctx = HandlerCtx::new( Some(write_task.clone()), - tasks.clone(), + children.clone(), ); let fut = router.handle_request_batch(ctx, reqs); @@ -211,7 +221,7 @@ where }; // Run the future in a new task. - tasks.spawn_cancellable( + children.spawn_cancellable( async move { // Send the response to the write task. // we don't care if the receiver has gone away, @@ -227,16 +237,16 @@ where } } } - tasks.shutdown().await; + children.shutdown().await; } /// Spawn the future produced by [`Self::task_future`]. - pub(crate) fn spawn(self) -> tokio::task::JoinHandle> { + pub(crate) fn spawn(self) -> tokio::task::JoinHandle<()> { let tasks = self.tasks.clone(); - let future = self.task_future(); + let future = move |cancel| self.task_future(cancel); - tasks.spawn_cancellable(future) + tasks.spawn_graceful(future) } } @@ -296,7 +306,6 @@ impl WriteTask { } } } - tasks.shutdown().await; } /// Spawn the future produced by [`Self::task_future`]. From 53608ec6381020ac35bd965c4c135006cf30ccd2 Mon Sep 17 00:00:00 2001 From: James Date: Fri, 31 Jan 2025 10:46:06 -0500 Subject: [PATCH 12/12] chore: bump version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index f9079dd..332a5f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ description = "Simple, modern, ergonomic JSON-RPC 2.0 router built with tower an keywords = ["json-rpc", "jsonrpc", "json"] categories = ["web-programming::http-server", "web-programming::websocket"] -version = "0.2.0" +version = "0.3.0" edition = "2021" rust-version = "1.81" authors = ["init4", "James Prestwich"]