Skip to content

feat: task management #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 31, 2025
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ tokio-stream = { version = "0.1.17", optional = true }

# ipc
interprocess = { version = "2.2.2", features = ["async", "tokio"], optional = true }
tokio-util = { version = "0.7.13", optional = true, features = ["io"] }
tokio-util = { version = "0.7.13", optional = true, features = ["io", "rt"] }

# ws
tokio-tungstenite = { version = "0.26.1", features = ["rustls-tls-webpki-roots"], optional = true }
Expand Down
48 changes: 45 additions & 3 deletions src/axum.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,50 @@
use crate::types::{InboundData, Response};
use crate::{
types::{InboundData, Response},
HandlerCtx, TaskSet,
};
use axum::{extract::FromRequest, response::IntoResponse};
use bytes::Bytes;
use std::{future::Future, pin::Pin};
use tokio::runtime::Handle;

impl<S> axum::handler::Handler<Bytes, S> for crate::Router<S>
/// A wrapper around an [`Router`] that implements the
/// [`axum::handler::Handler`] trait. This struct is an implementation detail
/// of the [`Router::into_axum`] and [`Router::into_axum_with_handle`] methods.
///
/// [`Router`]: crate::Router
/// [`Router::into_axum`]: crate::Router::into_axum
/// [`Router::into_axum_with_handle`]: crate::Router::into_axum_with_handle
#[derive(Debug, Clone)]
pub(crate) struct IntoAxum<S> {
pub(crate) router: crate::Router<S>,
pub(crate) task_set: TaskSet,
}

impl<S> From<crate::Router<S>> for IntoAxum<S> {
fn from(router: crate::Router<S>) -> Self {
Self {
router,
task_set: Default::default(),
}
}
}

impl<S> IntoAxum<S> {
/// Create a new `IntoAxum` from a router and task set.
pub(crate) fn new(router: crate::Router<S>, handle: Handle) -> Self {
Self {
router,
task_set: handle.into(),
}
}

/// Get a new context, built from the task set.
fn ctx(&self) -> HandlerCtx {
self.task_set.clone().into()
}
}

impl<S> axum::handler::Handler<Bytes, S> for IntoAxum<S>
where
S: Clone + Send + Sync + 'static,
{
Expand All @@ -21,7 +62,8 @@ where
let req = InboundData::try_from(bytes).unwrap_or_default();

if let Some(response) = self
.call_batch_with_state(Default::default(), req, state)
.router
.call_batch_with_state(self.ctx(), req, state)
.await
{
Box::<str>::from(response).into_response()
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ pub(crate) use routes::{BoxedIntoRoute, ErasedIntoRoute, Method, Route};
mod router;
pub use router::Router;

mod tasks;
pub(crate) use tasks::TaskSet;

mod types;
pub use types::{ErrorPayload, ResponsePayload};

Expand Down
97 changes: 53 additions & 44 deletions src/pubsub/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@ use core::fmt;
use crate::{
pubsub::{In, JsonSink, Listener, Out},
types::InboundData,
HandlerCtx, TaskSet,
};
use serde_json::value::RawValue;
use tokio::{
select,
sync::{mpsc, oneshot, watch},
task::JoinHandle,
};
use tokio::{select, sync::mpsc, task::JoinHandle};
use tokio_stream::StreamExt;
use tracing::{debug, debug_span, error, instrument, trace, Instrument};

Expand All @@ -22,12 +19,26 @@ pub type ConnectionId = u64;
/// Holds the shutdown signal for some server.
#[derive(Debug)]
pub struct ServerShutdown {
pub(crate) _shutdown: watch::Sender<()>,
pub(crate) task_set: TaskSet,
}

impl From<watch::Sender<()>> for ServerShutdown {
fn from(sender: watch::Sender<()>) -> Self {
Self { _shutdown: sender }
impl From<TaskSet> for ServerShutdown {
fn from(task_set: TaskSet) -> Self {
Self::new(task_set)
}
}

impl ServerShutdown {
/// Create a new `ServerShutdown` with the given shutdown signal and task
/// set.
pub(crate) const fn new(task_set: TaskSet) -> Self {
Self { task_set }
}
}

impl Drop for ServerShutdown {
fn drop(&mut self) {
self.task_set.cancel();
}
}

Expand Down Expand Up @@ -67,16 +78,17 @@ where
}

/// Spawn the future produced by [`Self::task_future`].
pub(crate) fn spawn(self) -> JoinHandle<()> {
pub(crate) fn spawn(self) -> JoinHandle<Option<()>> {
let tasks = self.manager.root_tasks.clone();
let future = self.task_future();
tokio::spawn(future)
tasks.spawn(future)
}
}

/// The `ConnectionManager` provides connections with IDs, and handles spawning
/// the [`RouteTask`] for each connection.
pub(crate) struct ConnectionManager {
pub(crate) shutdown: watch::Receiver<()>,
pub(crate) root_tasks: TaskSet,

pub(crate) next_id: ConnectionId,

Expand Down Expand Up @@ -107,19 +119,18 @@ impl ConnectionManager {
) -> (RouteTask<T>, WriteTask<T>) {
let (tx, rx) = mpsc::channel(self.notification_buffer_per_task);

let (gone_tx, gone_rx) = oneshot::channel();
let tasks = self.root_tasks.child();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

connections each get their own task set, which work is spawned onto. this work is automatically cancelled when the connection goes away 😎


let rt = RouteTask {
router: self.router(),
conn_id,
write_task: tx,
requests,
gone: gone_tx,
tasks: tasks.clone(),
};

let wt = WriteTask {
shutdown: self.shutdown.clone(),
gone: gone_rx,
tasks,
conn_id,
json: rx,
connection,
Expand Down Expand Up @@ -156,8 +167,8 @@ struct RouteTask<T: crate::pubsub::Listener> {
pub(crate) write_task: mpsc::Sender<Box<RawValue>>,
/// Stream of requests.
pub(crate) requests: In<T>,
/// Sender to the [`WriteTask`], to notify it that this task is done.
pub(crate) gone: oneshot::Sender<()>,
/// The task set for this connection
pub(crate) tasks: TaskSet,
}

impl<T: crate::pubsub::Listener> fmt::Debug for RouteTask<T> {
Expand All @@ -184,7 +195,7 @@ where
router,
mut requests,
write_task,
gone,
tasks,
..
} = self;

Expand All @@ -208,7 +219,11 @@ where

let span = debug_span!("pubsub request handling", reqs = reqs.len());

let ctx = write_task.clone().into();
let ctx =
HandlerCtx::new(
Some(write_task.clone()),
tasks.clone(),
);

let fut = router.handle_request_batch(ctx, reqs);
let write_task = write_task.clone();
Expand All @@ -223,7 +238,7 @@ where
};

// Run the future in a new task.
tokio::spawn(
tasks.spawn(
async move {
// Send the response to the write task.
// we don't care if the receiver has gone away,
Expand All @@ -239,27 +254,23 @@ where
}
}
}
// No funny business. Drop the gone signal.
drop(gone);
tasks.cancel();
}

/// Spawn the future produced by [`Self::task_future`].
pub(crate) fn spawn(self) -> tokio::task::JoinHandle<()> {
pub(crate) fn spawn(self) -> tokio::task::JoinHandle<Option<()>> {
let tasks = self.tasks.clone();

let future = self.task_future();
tokio::spawn(future)

tasks.spawn(future)
}
}

/// The Write Task is responsible for writing JSON to the outbound connection.
struct WriteTask<T: Listener> {
/// Shutdown signal.
///
/// Shutdowns bubble back up to [`RouteTask`] when the write task is
/// dropped, via the closed `json` channel.
pub(crate) shutdown: watch::Receiver<()>,

/// Signal that the connection has gone away.
pub(crate) gone: oneshot::Receiver<()>,
/// Task set
pub(crate) tasks: TaskSet,

/// ID of the connection.
pub(crate) conn_id: ConnectionId,
Expand All @@ -284,22 +295,18 @@ impl<T: Listener> WriteTask<T> {
#[instrument(skip(self), fields(conn_id = self.conn_id))]
pub(crate) async fn task_future(self) {
let WriteTask {
mut shutdown,
mut gone,
tasks,
mut json,
mut connection,
..
} = self;
shutdown.mark_unchanged();

loop {
select! {
biased;
_ = &mut gone => {
debug!("Connection has gone away");
break;
}
_ = shutdown.changed() => {
debug!("shutdown signal received");

_ = tasks.cancelled() => {
debug!("Shutdown signal received");
break;
}
json = json.recv() => {
Expand All @@ -317,7 +324,9 @@ impl<T: Listener> WriteTask<T> {
}

/// Spawn the future produced by [`Self::task_future`].
pub(crate) fn spawn(self) -> JoinHandle<()> {
tokio::spawn(self.task_future())
pub(crate) fn spawn(self) -> tokio::task::JoinHandle<Option<()>> {
let tasks = self.tasks.clone();
let future = self.task_future();
tasks.spawn(future)
}
}
58 changes: 50 additions & 8 deletions src/pubsub/trait.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use crate::pubsub::{
shared::{ConnectionManager, ListenerTask, DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT},
ServerShutdown,
use crate::{
pubsub::{
shared::{ConnectionManager, ListenerTask, DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT},
ServerShutdown,
},
TaskSet,
};
use bytes::Bytes;
use serde_json::value::RawValue;
use std::future::Future;
use tokio::sync::watch;
use tokio::runtime::Handle;
use tokio_stream::Stream;

/// Convenience alias for naming stream halves.
Expand Down Expand Up @@ -34,6 +37,24 @@ pub type In<T> = <T as Listener>::ReqStream;
/// may produce arbitrary response bodies, only the server developer can
/// accurately set this value. We have provided a low default. Setting it too
/// high may allow resource exhaustion attacks.
///
/// ## Task management
///
/// When using the default impls of [`Connect::serve`] and
/// [`Connect::serve_with_handle`], the library will manage task sets for
/// inbound connections. These follow a per-connection hierarchical task model.
/// The root task set is associated with the server, and is used to spawn a
/// task that listens for inbound connections. Each connection is then given
/// a child task set, which is used to spawn tasks for that connection.
///
/// This task set is propagated to [`Handler`]s via the [`HandlerCtx`]. They may
/// then use it to spawn tasks that are themselves associated with the
/// connection. This ensures that, for properly-implemented [`Handler`]s, all
/// tasks associated with a connection are automatically cancelled and cleaned
/// up when the connection is closed.
///
/// [`Handler`]: crate::Handler
/// [`HandlerCtx`]: crate::HandlerCtx
pub trait Connect: Send + Sync + Sized {
/// The listener type produced by the connect object.
type Listener: Listener;
Expand Down Expand Up @@ -67,26 +88,47 @@ pub trait Connect: Send + Sync + Sized {
/// We do not recommend overriding this method. Doing so will opt out of
/// the library's pubsub task system. Users overriding this method must
/// manually handle connection tasks.
fn serve(
///
/// The provided handle will be used to spawn tasks.
fn serve_with_handle(
self,
router: crate::Router<()>,
handle: Handle,
) -> impl Future<Output = Result<ServerShutdown, Self::Error>> + Send {
async move {
let root_tasks: TaskSet = handle.into();
let notification_buffer_per_task = self.notification_buffer_size();
let (tx, rx) = watch::channel(());

ListenerTask {
listener: self.make_listener().await?,
manager: ConnectionManager {
shutdown: rx,
next_id: 0,
router,
notification_buffer_per_task,
root_tasks: root_tasks.clone(),
},
}
.spawn();
Ok(tx.into())
Ok(root_tasks.into())
}
}

/// Instantiate and run a task to accept connections, returning a shutdown
/// signal.
///
/// We do not recommend overriding this method. Doing so will opt out of
/// the library's pubsub task system. Users overriding this method must
/// manually handle connection tasks.
///
/// ## Panics
///
/// This will panic if called outside the context of a Tokio runtime.
fn serve(
self,
router: crate::Router<()>,
) -> impl Future<Output = Result<ServerShutdown, Self::Error>> + Send {
self.serve_with_handle(router, Handle::current())
}
}

/// A [`Listener`] accepts incoming connections and produces [`JsonSink`] and
Expand Down
Loading