Skip to content

feat: task management #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 31, 2025
Merged
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ tokio-stream = { version = "0.1.17", optional = true }

# ipc
interprocess = { version = "2.2.2", features = ["async", "tokio"], optional = true }
tokio-util = { version = "0.7.13", optional = true, features = ["io"] }
tokio-util = { version = "0.7.13", optional = true, features = ["io", "rt"] }

# ws
tokio-tungstenite = { version = "0.26.1", features = ["rustls-tls-webpki-roots"], optional = true }
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ ajj aims to provide simple, flexible, and ergonomic routing for JSON-RPC.
- Support for pubsub-style notifications.
- Built-in support for axum, and tower's middleware and service ecosystem.
- Basic built-in pubsub server implementations for WS and IPC.
- Connection-oriented task management automatically cancels tasks on client
disconnect.

## Concepts

Expand Down
48 changes: 45 additions & 3 deletions src/axum.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,50 @@
use crate::types::{InboundData, Response};
use crate::{
types::{InboundData, Response},
HandlerCtx, TaskSet,
};
use axum::{extract::FromRequest, response::IntoResponse};
use bytes::Bytes;
use std::{future::Future, pin::Pin};
use tokio::runtime::Handle;

impl<S> axum::handler::Handler<Bytes, S> for crate::Router<S>
/// A wrapper around an [`Router`] that implements the
/// [`axum::handler::Handler`] trait. This struct is an implementation detail
/// of the [`Router::into_axum`] and [`Router::into_axum_with_handle`] methods.
///
/// [`Router`]: crate::Router
/// [`Router::into_axum`]: crate::Router::into_axum
/// [`Router::into_axum_with_handle`]: crate::Router::into_axum_with_handle
#[derive(Debug, Clone)]
pub(crate) struct IntoAxum<S> {
pub(crate) router: crate::Router<S>,
pub(crate) task_set: TaskSet,
}

impl<S> From<crate::Router<S>> for IntoAxum<S> {
fn from(router: crate::Router<S>) -> Self {
Self {
router,
task_set: Default::default(),
}
}
}

impl<S> IntoAxum<S> {
/// Create a new `IntoAxum` from a router and task set.
pub(crate) fn new(router: crate::Router<S>, handle: Handle) -> Self {
Self {
router,
task_set: handle.into(),
}
}

/// Get a new context, built from the task set.
fn ctx(&self) -> HandlerCtx {
self.task_set.clone().into()
}
}

impl<S> axum::handler::Handler<Bytes, S> for IntoAxum<S>
where
S: Clone + Send + Sync + 'static,
{
Expand All @@ -21,7 +62,8 @@ where
let req = InboundData::try_from(bytes).unwrap_or_default();

if let Some(response) = self
.call_batch_with_state(Default::default(), req, state)
.router
.call_batch_with_state(self.ctx(), req, state)
.await
{
Box::<str>::from(response).into_response()
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ pub(crate) use routes::{BoxedIntoRoute, ErasedIntoRoute, Method, Route};
mod router;
pub use router::Router;

mod tasks;
pub(crate) use tasks::TaskSet;

mod types;
pub use types::{ErrorPayload, ResponsePayload};

Expand Down
119 changes: 75 additions & 44 deletions src/pubsub/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@ use core::fmt;
use crate::{
pubsub::{In, JsonSink, Listener, Out},
types::InboundData,
HandlerCtx, TaskSet,
};
use serde_json::value::RawValue;
use tokio::{
select,
sync::{mpsc, oneshot, watch},
task::JoinHandle,
};
use tokio::{select, sync::mpsc, task::JoinHandle};
use tokio_stream::StreamExt;
use tokio_util::task::task_tracker::TaskTrackerWaitFuture;
use tracing::{debug, debug_span, error, instrument, trace, Instrument};

/// Default notification buffer size per task.
Expand All @@ -22,12 +20,46 @@ pub type ConnectionId = u64;
/// Holds the shutdown signal for some server.
#[derive(Debug)]
pub struct ServerShutdown {
pub(crate) _shutdown: watch::Sender<()>,
pub(crate) task_set: TaskSet,
}

impl From<TaskSet> for ServerShutdown {
fn from(task_set: TaskSet) -> Self {
Self::new(task_set)
}
}

impl ServerShutdown {
/// Create a new `ServerShutdown` with the given shutdown signal and task
/// set.
pub(crate) const fn new(task_set: TaskSet) -> Self {
Self { task_set }
}

/// Wait for the tasks spawned by the server to complete.
///
/// This future will not resolve until both of the following are true:
/// - [`ServerShutdown::close`]
pub fn wait(&self) -> TaskTrackerWaitFuture<'_> {
self.task_set.wait()
}

/// Close the task tracker, allowing [`Self::wait`] futures to resolve.
pub fn close(&self) {
self.task_set.close();
}

/// Shutdown the server, and wait for all tasks to complete.
pub async fn shutdown(self) {
self.task_set.cancel();
self.close();
self.wait().await;
}
}

impl From<watch::Sender<()>> for ServerShutdown {
fn from(sender: watch::Sender<()>) -> Self {
Self { _shutdown: sender }
impl Drop for ServerShutdown {
fn drop(&mut self) {
self.task_set.cancel();
}
}

Expand Down Expand Up @@ -67,16 +99,17 @@ where
}

/// Spawn the future produced by [`Self::task_future`].
pub(crate) fn spawn(self) -> JoinHandle<()> {
pub(crate) fn spawn(self) -> JoinHandle<Option<()>> {
let tasks = self.manager.root_tasks.clone();
let future = self.task_future();
tokio::spawn(future)
tasks.spawn_cancellable(future)
}
}

/// The `ConnectionManager` provides connections with IDs, and handles spawning
/// the [`RouteTask`] for each connection.
pub(crate) struct ConnectionManager {
pub(crate) shutdown: watch::Receiver<()>,
pub(crate) root_tasks: TaskSet,

pub(crate) next_id: ConnectionId,

Expand Down Expand Up @@ -107,19 +140,18 @@ impl ConnectionManager {
) -> (RouteTask<T>, WriteTask<T>) {
let (tx, rx) = mpsc::channel(self.notification_buffer_per_task);

let (gone_tx, gone_rx) = oneshot::channel();
let tasks = self.root_tasks.child();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

connections each get their own task set, which work is spawned onto. this work is automatically cancelled when the connection goes away 😎


let rt = RouteTask {
router: self.router(),
conn_id,
write_task: tx,
requests,
gone: gone_tx,
tasks: tasks.clone(),
};

let wt = WriteTask {
shutdown: self.shutdown.clone(),
gone: gone_rx,
tasks,
conn_id,
json: rx,
connection,
Expand Down Expand Up @@ -156,8 +188,8 @@ struct RouteTask<T: crate::pubsub::Listener> {
pub(crate) write_task: mpsc::Sender<Box<RawValue>>,
/// Stream of requests.
pub(crate) requests: In<T>,
/// Sender to the [`WriteTask`], to notify it that this task is done.
pub(crate) gone: oneshot::Sender<()>,
/// The task set for this connection
pub(crate) tasks: TaskSet,
}

impl<T: crate::pubsub::Listener> fmt::Debug for RouteTask<T> {
Expand All @@ -184,7 +216,7 @@ where
router,
mut requests,
write_task,
gone,
tasks,
..
} = self;

Expand All @@ -208,7 +240,11 @@ where

let span = debug_span!("pubsub request handling", reqs = reqs.len());

let ctx = write_task.clone().into();
let ctx =
HandlerCtx::new(
Some(write_task.clone()),
tasks.clone(),
);

let fut = router.handle_request_batch(ctx, reqs);
let write_task = write_task.clone();
Expand All @@ -223,7 +259,7 @@ where
};

// Run the future in a new task.
tokio::spawn(
tasks.spawn_cancellable(
async move {
// Send the response to the write task.
// we don't care if the receiver has gone away,
Expand All @@ -239,27 +275,23 @@ where
}
}
}
// No funny business. Drop the gone signal.
drop(gone);
tasks.shutdown().await;
}

/// Spawn the future produced by [`Self::task_future`].
pub(crate) fn spawn(self) -> tokio::task::JoinHandle<()> {
pub(crate) fn spawn(self) -> tokio::task::JoinHandle<Option<()>> {
let tasks = self.tasks.clone();

let future = self.task_future();
tokio::spawn(future)

tasks.spawn_cancellable(future)
}
}

/// The Write Task is responsible for writing JSON to the outbound connection.
struct WriteTask<T: Listener> {
/// Shutdown signal.
///
/// Shutdowns bubble back up to [`RouteTask`] when the write task is
/// dropped, via the closed `json` channel.
pub(crate) shutdown: watch::Receiver<()>,

/// Signal that the connection has gone away.
pub(crate) gone: oneshot::Receiver<()>,
/// Task set
pub(crate) tasks: TaskSet,

/// ID of the connection.
pub(crate) conn_id: ConnectionId,
Expand All @@ -284,22 +316,18 @@ impl<T: Listener> WriteTask<T> {
#[instrument(skip(self), fields(conn_id = self.conn_id))]
pub(crate) async fn task_future(self) {
let WriteTask {
mut shutdown,
mut gone,
tasks,
mut json,
mut connection,
..
} = self;
shutdown.mark_unchanged();

loop {
select! {
biased;
_ = &mut gone => {
debug!("Connection has gone away");
break;
}
_ = shutdown.changed() => {
debug!("shutdown signal received");

_ = tasks.cancelled() => {
debug!("Shutdown signal received");
break;
}
json = json.recv() => {
Expand All @@ -314,10 +342,13 @@ impl<T: Listener> WriteTask<T> {
}
}
}
tasks.shutdown().await;
}

/// Spawn the future produced by [`Self::task_future`].
pub(crate) fn spawn(self) -> JoinHandle<()> {
tokio::spawn(self.task_future())
pub(crate) fn spawn(self) -> tokio::task::JoinHandle<Option<()>> {
let tasks = self.tasks.clone();
let future = self.task_future();
tasks.spawn_cancellable(future)
}
}
Loading