Skip to content

Commit 103d561

Browse files
authored
Merge pull request #995 from rust-lang/stdin
2 parents 4390ab1 + 7573f68 commit 103d561

30 files changed

+1414
-505
lines changed

compiler/base/orchestrator/src/coordinator.rs

Lines changed: 318 additions & 59 deletions
Large diffs are not rendered by default.

compiler/base/orchestrator/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ pub mod coordinator;
44
mod message;
55
pub mod worker;
66

7-
trait DropErrorDetailsExt<T> {
7+
pub trait DropErrorDetailsExt<T> {
88
fn drop_error_details(self) -> Result<T, tokio::sync::mpsc::error::SendError<()>>;
99
}
1010

compiler/base/orchestrator/src/message.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ pub enum CoordinatorMessage {
2626
ReadFile(ReadFileRequest),
2727
ExecuteCommand(ExecuteCommandRequest),
2828
StdinPacket(String),
29+
StdinClose,
30+
Kill,
2931
}
3032

3133
impl_narrow_to_broad!(

compiler/base/orchestrator/src/worker.rs

Lines changed: 159 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ use tokio::{
4646
sync::mpsc,
4747
task::JoinSet,
4848
};
49+
use tokio_util::sync::CancellationToken;
4950

5051
use crate::{
5152
bincode_input_closed,
@@ -57,25 +58,21 @@ use crate::{
5758
DropErrorDetailsExt,
5859
};
5960

60-
type CommandRequest = (Multiplexed<ExecuteCommandRequest>, MultiplexingSender);
61-
6261
pub async fn listen(project_dir: impl Into<PathBuf>) -> Result<(), Error> {
6362
let project_dir = project_dir.into();
6463

6564
let (coordinator_msg_tx, coordinator_msg_rx) = mpsc::channel(8);
6665
let (worker_msg_tx, worker_msg_rx) = mpsc::channel(8);
6766
let mut io_tasks = spawn_io_queue(coordinator_msg_tx, worker_msg_rx);
6867

69-
let (cmd_tx, cmd_rx) = mpsc::channel(8);
70-
let (stdin_tx, stdin_rx) = mpsc::channel(8);
71-
let process_task = tokio::spawn(manage_processes(stdin_rx, cmd_rx, project_dir.clone()));
68+
let (process_tx, process_rx) = mpsc::channel(8);
69+
let process_task = tokio::spawn(manage_processes(process_rx, project_dir.clone()));
7270

7371
let handler_task = tokio::spawn(handle_coordinator_message(
7472
coordinator_msg_rx,
7573
worker_msg_tx,
7674
project_dir,
77-
cmd_tx,
78-
stdin_tx,
75+
process_tx,
7976
));
8077

8178
select! {
@@ -122,8 +119,7 @@ async fn handle_coordinator_message(
122119
mut coordinator_msg_rx: mpsc::Receiver<Multiplexed<CoordinatorMessage>>,
123120
worker_msg_tx: mpsc::Sender<Multiplexed<WorkerMessage>>,
124121
project_dir: PathBuf,
125-
cmd_tx: mpsc::Sender<CommandRequest>,
126-
stdin_tx: mpsc::Sender<Multiplexed<String>>,
122+
process_tx: mpsc::Sender<Multiplexed<ProcessCommand>>,
127123
) -> Result<(), HandleCoordinatorMessageError> {
128124
use handle_coordinator_message_error::*;
129125

@@ -177,20 +173,36 @@ async fn handle_coordinator_message(
177173
}
178174

179175
CoordinatorMessage::ExecuteCommand(req) => {
180-
cmd_tx
181-
.send((Multiplexed(job_id, req), worker_msg_tx()))
176+
process_tx
177+
.send(Multiplexed(job_id, ProcessCommand::Start(req, worker_msg_tx())))
182178
.await
183179
.drop_error_details()
184180
.context(UnableToSendCommandExecutionRequestSnafu)?;
185181
}
186182

187183
CoordinatorMessage::StdinPacket(data) => {
188-
stdin_tx
189-
.send(Multiplexed(job_id, data))
184+
process_tx
185+
.send(Multiplexed(job_id, ProcessCommand::Stdin(data)))
190186
.await
191187
.drop_error_details()
192188
.context(UnableToSendStdinPacketSnafu)?;
193189
}
190+
191+
CoordinatorMessage::StdinClose => {
192+
process_tx
193+
.send(Multiplexed(job_id, ProcessCommand::StdinClose))
194+
.await
195+
.drop_error_details()
196+
.context(UnableToSendStdinCloseSnafu)?;
197+
}
198+
199+
CoordinatorMessage::Kill => {
200+
process_tx
201+
.send(Multiplexed(job_id, ProcessCommand::Kill))
202+
.await
203+
.drop_error_details()
204+
.context(UnableToSendKillSnafu)?;
205+
}
194206
}
195207
}
196208

@@ -221,6 +233,12 @@ pub enum HandleCoordinatorMessageError {
221233
#[snafu(display("Failed to send stdin packet to the command task"))]
222234
UnableToSendStdinPacket { source: mpsc::error::SendError<()> },
223235

236+
#[snafu(display("Failed to send stdin close request to the command task"))]
237+
UnableToSendStdinClose { source: mpsc::error::SendError<()> },
238+
239+
#[snafu(display("Failed to send kill request to the command task"))]
240+
UnableToSendKill { source: mpsc::error::SendError<()> },
241+
224242
#[snafu(display("A coordinator command handler background task panicked"))]
225243
TaskPanicked { source: tokio::task::JoinError },
226244
}
@@ -373,63 +391,144 @@ fn parse_working_dir(cwd: Option<String>, project_path: impl Into<PathBuf>) -> P
373391
final_path
374392
}
375393

394+
enum ProcessCommand {
395+
Start(ExecuteCommandRequest, MultiplexingSender),
396+
Stdin(String),
397+
StdinClose,
398+
Kill,
399+
}
400+
401+
struct ProcessState {
402+
project_path: PathBuf,
403+
processes: JoinSet<Result<(), ProcessError>>,
404+
stdin_senders: HashMap<JobId, mpsc::Sender<String>>,
405+
stdin_shutdown_tx: mpsc::Sender<JobId>,
406+
kill_tokens: HashMap<JobId, CancellationToken>,
407+
}
408+
409+
impl ProcessState {
410+
fn new(project_path: PathBuf, stdin_shutdown_tx: mpsc::Sender<JobId>) -> Self {
411+
Self {
412+
project_path,
413+
processes: Default::default(),
414+
stdin_senders: Default::default(),
415+
stdin_shutdown_tx,
416+
kill_tokens: Default::default(),
417+
}
418+
}
419+
420+
async fn start(
421+
&mut self,
422+
job_id: JobId,
423+
req: ExecuteCommandRequest,
424+
worker_msg_tx: MultiplexingSender,
425+
) -> Result<(), ProcessError> {
426+
use process_error::*;
427+
428+
let token = CancellationToken::new();
429+
430+
let RunningChild {
431+
child,
432+
stdin_rx,
433+
stdin,
434+
stdout,
435+
stderr,
436+
} = match process_begin(req, &self.project_path, &mut self.stdin_senders, job_id) {
437+
Ok(v) => v,
438+
Err(e) => {
439+
// Should we add a message for process started
440+
// in addition to the current message which
441+
// indicates that the process has ended?
442+
worker_msg_tx
443+
.send_err(e)
444+
.await
445+
.context(UnableToSendExecuteCommandStartedResponseSnafu)?;
446+
return Ok(());
447+
}
448+
};
449+
450+
let task_set = stream_stdio(worker_msg_tx.clone(), stdin_rx, stdin, stdout, stderr);
451+
452+
self.kill_tokens.insert(job_id, token.clone());
453+
454+
self.processes.spawn({
455+
let stdin_shutdown_tx = self.stdin_shutdown_tx.clone();
456+
async move {
457+
worker_msg_tx
458+
.send(process_end(token, child, task_set, stdin_shutdown_tx, job_id).await)
459+
.await
460+
.context(UnableToSendExecuteCommandResponseSnafu)
461+
}
462+
});
463+
464+
Ok(())
465+
}
466+
467+
async fn stdin(&mut self, job_id: JobId, packet: String) -> Result<(), ProcessError> {
468+
use process_error::*;
469+
470+
if let Some(stdin_tx) = self.stdin_senders.get(&job_id) {
471+
stdin_tx
472+
.send(packet)
473+
.await
474+
.drop_error_details()
475+
.context(UnableToSendStdinDataSnafu)?;
476+
}
477+
478+
Ok(())
479+
}
480+
481+
fn stdin_close(&mut self, job_id: JobId) {
482+
self.stdin_senders.remove(&job_id);
483+
// Should we care if we remove a sender that's already removed?
484+
}
485+
486+
async fn join_process(&mut self) -> Option<Result<(), ProcessError>> {
487+
use process_error::*;
488+
489+
let process = self.processes.join_next().await?;
490+
Some(process.context(ProcessTaskPanickedSnafu).and_then(|e| e))
491+
}
492+
493+
fn kill(&mut self, job_id: JobId) {
494+
if let Some(token) = self.kill_tokens.get(&job_id) {
495+
token.cancel();
496+
}
497+
}
498+
}
499+
376500
async fn manage_processes(
377-
mut stdin_rx: mpsc::Receiver<Multiplexed<String>>,
378-
mut cmd_rx: mpsc::Receiver<CommandRequest>,
501+
mut rx: mpsc::Receiver<Multiplexed<ProcessCommand>>,
379502
project_path: PathBuf,
380503
) -> Result<(), ProcessError> {
381504
use process_error::*;
382505

383-
let mut processes = JoinSet::new();
384-
let mut stdin_senders = HashMap::new();
385506
let (stdin_shutdown_tx, mut stdin_shutdown_rx) = mpsc::channel(8);
507+
let mut state = ProcessState::new(project_path, stdin_shutdown_tx);
386508

387509
loop {
388510
select! {
389-
cmd_req = cmd_rx.recv() => {
390-
let Some((Multiplexed(job_id, req), worker_msg_tx)) = cmd_req else { break };
391-
392-
let RunningChild { child, stdin_rx, stdin, stdout, stderr } = match process_begin(req, &project_path, &mut stdin_senders, job_id) {
393-
Ok(v) => v,
394-
Err(e) => {
395-
// Should we add a message for process started
396-
// in addition to the current message which
397-
// indicates that the process has ended?
398-
worker_msg_tx.send_err(e).await.context(UnableToSendExecuteCommandStartedResponseSnafu)?;
399-
continue;
400-
}
401-
};
511+
cmd = rx.recv() => {
512+
let Some(Multiplexed(job_id, cmd)) = cmd else { break };
402513

403-
let task_set = stream_stdio(worker_msg_tx.clone(), stdin_rx, stdin, stdout, stderr);
514+
match cmd {
515+
ProcessCommand::Start(req, worker_msg_tx) => state.start(job_id, req, worker_msg_tx).await?,
404516

405-
processes.spawn({
406-
let stdin_shutdown_tx = stdin_shutdown_tx.clone();
407-
async move {
408-
worker_msg_tx
409-
.send(process_end(child, task_set, stdin_shutdown_tx, job_id).await)
410-
.await
411-
.context(UnableToSendExecuteCommandResponseSnafu)
412-
}
413-
});
414-
}
517+
ProcessCommand::Stdin(packet) => state.stdin(job_id, packet).await?,
415518

416-
stdin_packet = stdin_rx.recv() => {
417-
// Dispatch stdin packet to different child by attached command id.
418-
let Some(Multiplexed(job_id, packet)) = stdin_packet else { break };
519+
ProcessCommand::StdinClose => state.stdin_close(job_id),
419520

420-
if let Some(stdin_tx) = stdin_senders.get(&job_id) {
421-
stdin_tx.send(packet).await.drop_error_details().context(UnableToSendStdinDataSnafu)?;
521+
ProcessCommand::Kill => state.kill(job_id),
422522
}
423523
}
424524

425525
job_id = stdin_shutdown_rx.recv() => {
426526
let job_id = job_id.context(StdinShutdownReceiverEndedSnafu)?;
427-
stdin_senders.remove(&job_id);
428-
// Should we care if we remove a sender that's already removed?
527+
state.stdin_close(job_id);
429528
}
430529

431-
Some(process) = processes.join_next() => {
432-
process.context(ProcessTaskPanickedSnafu)??;
530+
Some(process) = state.join_process() => {
531+
process?;
433532
}
434533
}
435534
}
@@ -488,13 +587,19 @@ fn process_begin(
488587
}
489588

490589
async fn process_end(
590+
token: CancellationToken,
491591
mut child: Child,
492592
mut task_set: JoinSet<Result<(), StdioError>>,
493593
stdin_shutdown_tx: mpsc::Sender<JobId>,
494594
job_id: JobId,
495595
) -> Result<ExecuteCommandResponse, ProcessError> {
496596
use process_error::*;
497597

598+
select! {
599+
() = token.cancelled() => child.kill().await.context(KillChildSnafu)?,
600+
_ = child.wait() => {},
601+
};
602+
498603
let status = child.wait().await.context(WaitChildSnafu)?;
499604

500605
stdin_shutdown_tx
@@ -634,6 +739,9 @@ pub enum ProcessError {
634739
#[snafu(display("Failed to send stdin data"))]
635740
UnableToSendStdinData { source: mpsc::error::SendError<()> },
636741

742+
#[snafu(display("Failed to kill the child process"))]
743+
KillChild { source: std::io::Error },
744+
637745
#[snafu(display("Failed to wait for child process exiting"))]
638746
WaitChild { source: std::io::Error },
639747

@@ -671,10 +779,7 @@ fn stream_stdio(
671779
let mut set = JoinSet::new();
672780

673781
set.spawn(async move {
674-
loop {
675-
let Some(data) = stdin_rx.recv().await else {
676-
break;
677-
};
782+
while let Some(data) = stdin_rx.recv().await {
678783
stdin
679784
.write_all(data.as_bytes())
680785
.await

0 commit comments

Comments
 (0)