Skip to content

Commit 7573f68

Browse files
committed
Allow killing long-running processes
1 parent b930fc2 commit 7573f68

File tree

9 files changed

+199
-18
lines changed

9 files changed

+199
-18
lines changed

compiler/base/orchestrator/src/coordinator.rs

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -456,14 +456,15 @@ where
456456

457457
pub async fn begin_execute(
458458
&self,
459+
token: CancellationToken,
459460
request: ExecuteRequest,
460461
) -> Result<ActiveExecution, ExecuteError> {
461462
use execute_error::*;
462463

463464
self.select_channel(request.channel)
464465
.await
465466
.context(CouldNotStartContainerSnafu)?
466-
.begin_execute(request)
467+
.begin_execute(token, request)
467468
.await
468469
}
469470

@@ -482,14 +483,15 @@ where
482483

483484
pub async fn begin_compile(
484485
&self,
486+
token: CancellationToken,
485487
request: CompileRequest,
486488
) -> Result<ActiveCompilation, CompileError> {
487489
use compile_error::*;
488490

489491
self.select_channel(request.channel)
490492
.await
491493
.context(CouldNotStartContainerSnafu)?
492-
.begin_compile(request)
494+
.begin_compile(token, request)
493495
.await
494496
}
495497

@@ -603,12 +605,14 @@ impl Container {
603605
&self,
604606
request: ExecuteRequest,
605607
) -> Result<WithOutput<ExecuteResponse>, ExecuteError> {
608+
let token = Default::default();
609+
606610
let ActiveExecution {
607611
task,
608612
stdin_tx,
609613
stdout_rx,
610614
stderr_rx,
611-
} = self.begin_execute(request).await?;
615+
} = self.begin_execute(token, request).await?;
612616

613617
drop(stdin_tx);
614618
WithOutput::try_absorb(task, stdout_rx, stderr_rx).await
@@ -617,6 +621,7 @@ impl Container {
617621
#[instrument(skip_all)]
618622
async fn begin_execute(
619623
&self,
624+
token: CancellationToken,
620625
request: ExecuteRequest,
621626
) -> Result<ActiveExecution, ExecuteError> {
622627
use execute_error::*;
@@ -642,7 +647,7 @@ impl Container {
642647
stdout_rx,
643648
stderr_rx,
644649
} = self
645-
.spawn_cargo_task(execute_cargo)
650+
.spawn_cargo_task(token, execute_cargo)
646651
.await
647652
.context(CouldNotStartCargoSnafu)?;
648653

@@ -673,18 +678,21 @@ impl Container {
673678
&self,
674679
request: CompileRequest,
675680
) -> Result<WithOutput<CompileResponse>, CompileError> {
681+
let token = Default::default();
682+
676683
let ActiveCompilation {
677684
task,
678685
stdout_rx,
679686
stderr_rx,
680-
} = self.begin_compile(request).await?;
687+
} = self.begin_compile(token, request).await?;
681688

682689
WithOutput::try_absorb(task, stdout_rx, stderr_rx).await
683690
}
684691

685692
#[instrument(skip_all)]
686693
async fn begin_compile(
687694
&self,
695+
token: CancellationToken,
688696
request: CompileRequest,
689697
) -> Result<ActiveCompilation, CompileError> {
690698
use compile_error::*;
@@ -715,7 +723,7 @@ impl Container {
715723
stdout_rx,
716724
stderr_rx,
717725
} = self
718-
.spawn_cargo_task(execute_cargo)
726+
.spawn_cargo_task(token, execute_cargo)
719727
.await
720728
.context(CouldNotStartCargoSnafu)?;
721729

@@ -761,6 +769,7 @@ impl Container {
761769

762770
async fn spawn_cargo_task(
763771
&self,
772+
token: CancellationToken,
764773
execute_cargo: ExecuteCommandRequest,
765774
) -> Result<SpawnCargo, SpawnCargoError> {
766775
use spawn_cargo_error::*;
@@ -777,10 +786,19 @@ impl Container {
777786

778787
let task = tokio::spawn({
779788
async move {
789+
let mut already_cancelled = false;
780790
let mut stdin_open = true;
781791

782792
loop {
783793
select! {
794+
() = token.cancelled(), if !already_cancelled => {
795+
already_cancelled = true;
796+
797+
let msg = CoordinatorMessage::Kill;
798+
trace!("processing {msg:?}");
799+
to_worker_tx.send(msg).await.context(KillSnafu)?;
800+
},
801+
784802
stdin = stdin_rx.recv(), if stdin_open => {
785803
let msg = match stdin {
786804
Some(stdin) => {
@@ -952,6 +970,9 @@ pub enum SpawnCargoError {
952970

953971
#[snafu(display("Unable to send stdin message"))]
954972
Stdin { source: MultiplexedSenderError },
973+
974+
#[snafu(display("Unable to send kill message"))]
975+
Kill { source: MultiplexedSenderError },
955976
}
956977

957978
#[derive(Debug, Clone)]
@@ -1787,12 +1808,13 @@ mod tests {
17871808
..ARBITRARY_EXECUTE_REQUEST
17881809
};
17891810

1811+
let token = Default::default();
17901812
let ActiveExecution {
17911813
task,
17921814
stdin_tx,
17931815
stdout_rx,
17941816
stderr_rx,
1795-
} = coordinator.begin_execute(request).await.unwrap();
1817+
} = coordinator.begin_execute(token, request).await.unwrap();
17961818

17971819
stdin_tx.send("this is stdin\n".into()).await.unwrap();
17981820
// Purposefully not dropping stdin_tx early -- a user might forget
@@ -1836,12 +1858,13 @@ mod tests {
18361858
..ARBITRARY_EXECUTE_REQUEST
18371859
};
18381860

1861+
let token = Default::default();
18391862
let ActiveExecution {
18401863
task,
18411864
stdin_tx,
18421865
stdout_rx,
18431866
stderr_rx,
1844-
} = coordinator.begin_execute(request).await.unwrap();
1867+
} = coordinator.begin_execute(token, request).await.unwrap();
18451868

18461869
for i in 0..3 {
18471870
stdin_tx.send(format!("line {i}\n")).await.unwrap();
@@ -1870,6 +1893,62 @@ mod tests {
18701893
Ok(())
18711894
}
18721895

1896+
#[tokio::test]
1897+
#[snafu::report]
1898+
async fn execute_kill() -> Result<()> {
1899+
let coordinator = new_coordinator().await;
1900+
1901+
let request = ExecuteRequest {
1902+
code: r#"
1903+
fn main() {
1904+
println!("Before");
1905+
loop {
1906+
std::thread::sleep(std::time::Duration::from_secs(1));
1907+
}
1908+
println!("After");
1909+
}
1910+
"#
1911+
.into(),
1912+
..ARBITRARY_EXECUTE_REQUEST
1913+
};
1914+
1915+
let token = CancellationToken::new();
1916+
let ActiveExecution {
1917+
task,
1918+
stdin_tx: _,
1919+
mut stdout_rx,
1920+
stderr_rx,
1921+
} = coordinator
1922+
.begin_execute(token.clone(), request)
1923+
.await
1924+
.unwrap();
1925+
1926+
// Wait for some output before killing
1927+
let early_stdout = stdout_rx.recv().await.unwrap();
1928+
1929+
token.cancel();
1930+
1931+
let WithOutput {
1932+
response,
1933+
stdout,
1934+
stderr,
1935+
} = WithOutput::try_absorb(task, stdout_rx, stderr_rx)
1936+
.with_timeout()
1937+
.await
1938+
.unwrap();
1939+
1940+
assert!(!response.success, "{stderr}");
1941+
assert_contains!(response.exit_detail, "kill");
1942+
1943+
assert_contains!(early_stdout, "Before");
1944+
assert_not_contains!(stdout, "Before");
1945+
assert_not_contains!(stdout, "After");
1946+
1947+
coordinator.shutdown().await?;
1948+
1949+
Ok(())
1950+
}
1951+
18731952
const HELLO_WORLD_CODE: &str = r#"fn main() { println!("Hello World!"); }"#;
18741953

18751954
const ARBITRARY_COMPILE_REQUEST: CompileRequest = CompileRequest {
@@ -1914,11 +1993,12 @@ mod tests {
19141993
..ARBITRARY_COMPILE_REQUEST
19151994
};
19161995

1996+
let token = Default::default();
19171997
let ActiveCompilation {
19181998
task,
19191999
stdout_rx,
19202000
stderr_rx,
1921-
} = coordinator.begin_compile(req).await.unwrap();
2001+
} = coordinator.begin_compile(token, req).await.unwrap();
19222002

19232003
let WithOutput {
19242004
response,

compiler/base/orchestrator/src/message.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pub enum CoordinatorMessage {
2727
ExecuteCommand(ExecuteCommandRequest),
2828
StdinPacket(String),
2929
StdinClose,
30+
Kill,
3031
}
3132

3233
impl_narrow_to_broad!(

compiler/base/orchestrator/src/worker.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ use tokio::{
4646
sync::mpsc,
4747
task::JoinSet,
4848
};
49+
use tokio_util::sync::CancellationToken;
4950

5051
use crate::{
5152
bincode_input_closed,
@@ -194,6 +195,14 @@ async fn handle_coordinator_message(
194195
.drop_error_details()
195196
.context(UnableToSendStdinCloseSnafu)?;
196197
}
198+
199+
CoordinatorMessage::Kill => {
200+
process_tx
201+
.send(Multiplexed(job_id, ProcessCommand::Kill))
202+
.await
203+
.drop_error_details()
204+
.context(UnableToSendKillSnafu)?;
205+
}
197206
}
198207
}
199208

@@ -227,6 +236,9 @@ pub enum HandleCoordinatorMessageError {
227236
#[snafu(display("Failed to send stdin close request to the command task"))]
228237
UnableToSendStdinClose { source: mpsc::error::SendError<()> },
229238

239+
#[snafu(display("Failed to send kill request to the command task"))]
240+
UnableToSendKill { source: mpsc::error::SendError<()> },
241+
230242
#[snafu(display("A coordinator command handler background task panicked"))]
231243
TaskPanicked { source: tokio::task::JoinError },
232244
}
@@ -383,13 +395,15 @@ enum ProcessCommand {
383395
Start(ExecuteCommandRequest, MultiplexingSender),
384396
Stdin(String),
385397
StdinClose,
398+
Kill,
386399
}
387400

388401
struct ProcessState {
389402
project_path: PathBuf,
390403
processes: JoinSet<Result<(), ProcessError>>,
391404
stdin_senders: HashMap<JobId, mpsc::Sender<String>>,
392405
stdin_shutdown_tx: mpsc::Sender<JobId>,
406+
kill_tokens: HashMap<JobId, CancellationToken>,
393407
}
394408

395409
impl ProcessState {
@@ -399,6 +413,7 @@ impl ProcessState {
399413
processes: Default::default(),
400414
stdin_senders: Default::default(),
401415
stdin_shutdown_tx,
416+
kill_tokens: Default::default(),
402417
}
403418
}
404419

@@ -410,6 +425,8 @@ impl ProcessState {
410425
) -> Result<(), ProcessError> {
411426
use process_error::*;
412427

428+
let token = CancellationToken::new();
429+
413430
let RunningChild {
414431
child,
415432
stdin_rx,
@@ -432,11 +449,13 @@ impl ProcessState {
432449

433450
let task_set = stream_stdio(worker_msg_tx.clone(), stdin_rx, stdin, stdout, stderr);
434451

452+
self.kill_tokens.insert(job_id, token.clone());
453+
435454
self.processes.spawn({
436455
let stdin_shutdown_tx = self.stdin_shutdown_tx.clone();
437456
async move {
438457
worker_msg_tx
439-
.send(process_end(child, task_set, stdin_shutdown_tx, job_id).await)
458+
.send(process_end(token, child, task_set, stdin_shutdown_tx, job_id).await)
440459
.await
441460
.context(UnableToSendExecuteCommandResponseSnafu)
442461
}
@@ -470,6 +489,12 @@ impl ProcessState {
470489
let process = self.processes.join_next().await?;
471490
Some(process.context(ProcessTaskPanickedSnafu).and_then(|e| e))
472491
}
492+
493+
fn kill(&mut self, job_id: JobId) {
494+
if let Some(token) = self.kill_tokens.get(&job_id) {
495+
token.cancel();
496+
}
497+
}
473498
}
474499

475500
async fn manage_processes(
@@ -492,6 +517,8 @@ async fn manage_processes(
492517
ProcessCommand::Stdin(packet) => state.stdin(job_id, packet).await?,
493518

494519
ProcessCommand::StdinClose => state.stdin_close(job_id),
520+
521+
ProcessCommand::Kill => state.kill(job_id),
495522
}
496523
}
497524

@@ -560,13 +587,19 @@ fn process_begin(
560587
}
561588

562589
async fn process_end(
590+
token: CancellationToken,
563591
mut child: Child,
564592
mut task_set: JoinSet<Result<(), StdioError>>,
565593
stdin_shutdown_tx: mpsc::Sender<JobId>,
566594
job_id: JobId,
567595
) -> Result<ExecuteCommandResponse, ProcessError> {
568596
use process_error::*;
569597

598+
select! {
599+
() = token.cancelled() => child.kill().await.context(KillChildSnafu)?,
600+
_ = child.wait() => {},
601+
};
602+
570603
let status = child.wait().await.context(WaitChildSnafu)?;
571604

572605
stdin_shutdown_tx
@@ -706,6 +739,9 @@ pub enum ProcessError {
706739
#[snafu(display("Failed to send stdin data"))]
707740
UnableToSendStdinData { source: mpsc::error::SendError<()> },
708741

742+
#[snafu(display("Failed to kill the child process"))]
743+
KillChild { source: std::io::Error },
744+
709745
#[snafu(display("Failed to wait for child process exiting"))]
710746
WaitChild { source: std::io::Error },
711747

0 commit comments

Comments
 (0)