Skip to content

create djls-tasks crate for easy background work #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ djls = { path = "crates/djls" }
djls-ast = { path = "crates/djls-ast" }
djls-django = { path = "crates/djls-django" }
djls-python = { path = "crates/djls-python" }
djls-worker = { path = "crates/djls-worker" }

anyhow = "1.0.94"
serde = { version = "1.0.215", features = ["derive"] }
serde_json = "1.0.133"
tokio = { version = "1.42.0", features = ["full"] }
10 changes: 10 additions & 0 deletions crates/djls-worker/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[package]
name = "djls-worker"
version = "0.0.0"
edition = "2021"

[dependencies]
anyhow = { workspace = true }
tokio = { workspace = true }

async-trait = "0.1.83"
279 changes: 279 additions & 0 deletions crates/djls-worker/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
use anyhow::Result;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};

pub trait Task: Send + 'static {
type Output: Send + 'static;
fn run(&self) -> Result<Self::Output>;
}

struct WorkerInner {
sender: mpsc::Sender<TaskMessage>,
shutdown_sender: Option<oneshot::Sender<()>>,
}

#[derive(Clone)]
pub struct Worker {
inner: Arc<WorkerInner>,
}

enum TaskMessage {
Execute(Box<dyn TaskTrait>),
WithResult(
Box<dyn TaskTrait>,
oneshot::Sender<Result<Box<dyn std::any::Any + Send>>>,
),
}

trait TaskTrait: Send {
fn run_boxed(self: Box<Self>) -> Result<Box<dyn std::any::Any + Send>>;
}

impl<T: Task> TaskTrait for T {
fn run_boxed(self: Box<Self>) -> Result<Box<dyn std::any::Any + Send>> {
self.run()
.map(|output| Box::new(output) as Box<dyn std::any::Any + Send>)
}
}

impl Worker {
pub fn new() -> Self {
let (sender, mut receiver) = mpsc::channel(32);
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();

tokio::spawn(async move {
loop {
tokio::select! {
Some(msg) = receiver.recv() => {
match msg {
TaskMessage::Execute(task) => {
let _ = task.run_boxed();
}
TaskMessage::WithResult(task, sender) => {
let result = task.run_boxed();
let _ = sender.send(result);
}
}
}
_ = &mut shutdown_rx => break,
}
}
});

Self {
inner: Arc::new(WorkerInner {
sender,
shutdown_sender: Some(shutdown_tx),
}),
}
}

pub fn execute<T>(&self, task: T) -> Result<()>
where
T: Task + 'static,
{
self.inner
.sender
.try_send(TaskMessage::Execute(Box::new(task)))
.map_err(|e| anyhow::anyhow!("Failed to execute task: {}", e))
}

pub async fn submit<T>(&self, task: T) -> Result<()>
where
T: Task + 'static,
{
self.inner
.sender
.send(TaskMessage::Execute(Box::new(task)))
.await
.map_err(|e| anyhow::anyhow!("Failed to submit task: {}", e))
}

pub async fn wait_for<T>(&self, task: T) -> Result<T::Output>
where
T: Task + 'static,
{
let (tx, rx) = oneshot::channel();

self.inner
.sender
.send(TaskMessage::WithResult(Box::new(task), tx))
.await
.map_err(|e| anyhow::anyhow!("Failed to send task: {}", e))?;

let result = rx
.await
.map_err(|e| anyhow::anyhow!("Failed to receive result: {}", e))??;

result
.downcast()
.map(|b| *b)
.map_err(|_| anyhow::anyhow!("Failed to downcast result"))
}
}

impl Default for Worker {
fn default() -> Self {
Self::new()
}
}

impl Drop for WorkerInner {
fn drop(&mut self) {
if let Some(sender) = self.shutdown_sender.take() {
sender.send(()).ok();
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use anyhow::anyhow;
use std::time::Duration;
use tokio::time::sleep;

struct TestTask(i32);

impl Task for TestTask {
type Output = i32;

fn run(&self) -> Result<Self::Output> {
Ok(self.0 * 2)
}
}

// Basic functionality tests
#[tokio::test]
async fn test_wait_for() {
let worker = Worker::new();
let result = worker.wait_for(TestTask(21)).await.unwrap();
assert_eq!(result, 42);
}

#[tokio::test]
async fn test_submit() {
let worker = Worker::new();
for i in 0..32 {
assert!(worker.execute(TestTask(i)).is_ok());
}
assert!(worker.execute(TestTask(33)).is_err());
assert!(worker.submit(TestTask(33)).await.is_ok());
sleep(Duration::from_millis(50)).await;
}

#[tokio::test]
async fn test_execute() {
let worker = Worker::new();
assert!(worker.execute(TestTask(21)).is_ok());
sleep(Duration::from_millis(50)).await;
}

// Test channel backpressure
#[tokio::test]
async fn test_channel_backpressure() {
let worker = Worker::new();

// Fill the channel (channel size is 32)
for i in 0..32 {
assert!(worker.execute(TestTask(i)).is_ok());
}

// Next execute should fail
assert!(worker.execute(TestTask(33)).is_err());

// But wait_for should eventually succeed
let result = worker.wait_for(TestTask(33)).await.unwrap();
assert_eq!(result, 66);
}

// Test concurrent tasks
#[tokio::test]
async fn test_concurrent_tasks() {
let worker = Worker::new();
let mut handles = Vec::new();

// Spawn multiple concurrent tasks
for i in 0..10 {
let worker = worker.clone();
let handle = tokio::spawn(async move {
let result = worker.wait_for(TestTask(i)).await.unwrap();
assert_eq!(result, i * 2);
});
handles.push(handle);
}

// Wait for all tasks to complete
for handle in handles {
handle.await.unwrap();
}
}

// Test shutdown behavior
#[tokio::test]
async fn test_shutdown() {
{
let worker = Worker::new();
worker.execute(TestTask(1)).unwrap();
worker.wait_for(TestTask(2)).await.unwrap();
// Worker will be dropped here, triggering shutdown
}
sleep(Duration::from_millis(50)).await;
}

// Test error handling
struct ErrorTask;

impl Task for ErrorTask {
type Output = (); // Unit type for error test

fn run(&self) -> Result<Self::Output> {
Err(anyhow!("Task failed"))
}
}

#[tokio::test]
async fn test_error_handling() {
let worker = Worker::new();

// Test error propagation
assert!(worker.wait_for(ErrorTask).await.is_err());

// Test that worker continues to function after error
let result = worker.wait_for(TestTask(21)).await.unwrap();
assert_eq!(result, 42);
}

#[tokio::test]
async fn test_worker_cloning() {
let worker = Worker::new();
let worker2 = worker.clone();

let (result1, result2) = tokio::join!(
worker.wait_for(TestTask(21)),
worker2.wait_for(TestTask(42))
);

assert_eq!(result1.unwrap(), 42);
assert_eq!(result2.unwrap(), 84);
}

#[tokio::test]
async fn test_multiple_workers() {
let worker = Worker::new();
let mut handles = Vec::new();

for i in 0..10 {
let worker = worker.clone();
let handle = tokio::spawn(async move {
let result = worker.wait_for(TestTask(i)).await.unwrap();
assert_eq!(result, i * 2);
});
handles.push(handle);
}

for handle in handles {
handle.await.unwrap();
}
}
}
3 changes: 2 additions & 1 deletion crates/djls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ edition = "2021"
[dependencies]
djls-django = { workspace = true }
djls-python = { workspace = true }
djls-worker = { workspace = true }

anyhow = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }

tokio = { version = "1.42.0", features = ["full"] }
tower-lsp = { version = "0.20.0", features = ["proposed"] }
lsp-types = "0.97.0"
1 change: 1 addition & 0 deletions crates/djls/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod documents;
mod notifier;
mod server;
mod tasks;

use crate::notifier::TowerLspNotifier;
use crate::server::{DjangoLanguageServer, LspNotification, LspRequest};
Expand Down
40 changes: 39 additions & 1 deletion crates/djls/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use crate::documents::Store;
use crate::notifier::Notifier;
use crate::tasks::DebugTask;
use anyhow::Result;
use djls_django::DjangoProject;
use djls_worker::Worker;
use std::sync::Arc;
use std::time::Duration;
use tower_lsp::lsp_types::*;

const SERVER_NAME: &str = "Django Language Server";
Expand All @@ -21,16 +25,20 @@ pub enum LspNotification {

pub struct DjangoLanguageServer {
django: DjangoProject,
notifier: Box<dyn Notifier>,
notifier: Arc<Box<dyn Notifier>>,
documents: Store,
worker: Worker,
}

impl DjangoLanguageServer {
pub fn new(django: DjangoProject, notifier: Box<dyn Notifier>) -> Self {
let notifier = Arc::new(notifier);

Self {
django,
notifier,
documents: Store::new(),
worker: Worker::new(),
}
}

Expand Down Expand Up @@ -66,6 +74,36 @@ impl DjangoLanguageServer {
MessageType::INFO,
&format!("Opened document: {}", params.text_document.uri),
)?;

// Execute - still sync
self.worker.execute(DebugTask::new(
"Quick task".to_string(),
Duration::from_millis(100),
self.notifier.clone(),
))?;

// Submit - spawn async task
let worker = self.worker.clone();
let task = DebugTask::new(
"Important task".to_string(),
Duration::from_secs(1),
self.notifier.clone(),
);
tokio::spawn(async move {
let _ = worker.submit(task).await;
});

// Wait for result - spawn async task
let worker = self.worker.clone();
let task = DebugTask::new(
"Task with result".to_string(),
Duration::from_secs(2),
self.notifier.clone(),
);
tokio::spawn(async move {
let _ = worker.wait_for(task).await;
});

Ok(())
}
LspNotification::DidChangeTextDocument(params) => {
Expand Down
Loading
Loading