@@ -46,6 +46,7 @@ use tokio::{
46
46
sync:: mpsc,
47
47
task:: JoinSet ,
48
48
} ;
49
+ use tokio_util:: sync:: CancellationToken ;
49
50
50
51
use crate :: {
51
52
bincode_input_closed,
@@ -57,25 +58,21 @@ use crate::{
57
58
DropErrorDetailsExt ,
58
59
} ;
59
60
60
- type CommandRequest = ( Multiplexed < ExecuteCommandRequest > , MultiplexingSender ) ;
61
-
62
61
pub async fn listen ( project_dir : impl Into < PathBuf > ) -> Result < ( ) , Error > {
63
62
let project_dir = project_dir. into ( ) ;
64
63
65
64
let ( coordinator_msg_tx, coordinator_msg_rx) = mpsc:: channel ( 8 ) ;
66
65
let ( worker_msg_tx, worker_msg_rx) = mpsc:: channel ( 8 ) ;
67
66
let mut io_tasks = spawn_io_queue ( coordinator_msg_tx, worker_msg_rx) ;
68
67
69
- let ( cmd_tx, cmd_rx) = mpsc:: channel ( 8 ) ;
70
- let ( stdin_tx, stdin_rx) = mpsc:: channel ( 8 ) ;
71
- let process_task = tokio:: spawn ( manage_processes ( stdin_rx, cmd_rx, project_dir. clone ( ) ) ) ;
68
+ let ( process_tx, process_rx) = mpsc:: channel ( 8 ) ;
69
+ let process_task = tokio:: spawn ( manage_processes ( process_rx, project_dir. clone ( ) ) ) ;
72
70
73
71
let handler_task = tokio:: spawn ( handle_coordinator_message (
74
72
coordinator_msg_rx,
75
73
worker_msg_tx,
76
74
project_dir,
77
- cmd_tx,
78
- stdin_tx,
75
+ process_tx,
79
76
) ) ;
80
77
81
78
select ! {
@@ -122,8 +119,7 @@ async fn handle_coordinator_message(
122
119
mut coordinator_msg_rx : mpsc:: Receiver < Multiplexed < CoordinatorMessage > > ,
123
120
worker_msg_tx : mpsc:: Sender < Multiplexed < WorkerMessage > > ,
124
121
project_dir : PathBuf ,
125
- cmd_tx : mpsc:: Sender < CommandRequest > ,
126
- stdin_tx : mpsc:: Sender < Multiplexed < String > > ,
122
+ process_tx : mpsc:: Sender < Multiplexed < ProcessCommand > > ,
127
123
) -> Result < ( ) , HandleCoordinatorMessageError > {
128
124
use handle_coordinator_message_error:: * ;
129
125
@@ -177,20 +173,36 @@ async fn handle_coordinator_message(
177
173
}
178
174
179
175
CoordinatorMessage :: ExecuteCommand ( req) => {
180
- cmd_tx
181
- . send( ( Multiplexed ( job_id, req) , worker_msg_tx( ) ) )
176
+ process_tx
177
+ . send( Multiplexed ( job_id, ProcessCommand :: Start ( req, worker_msg_tx( ) ) ) )
182
178
. await
183
179
. drop_error_details( )
184
180
. context( UnableToSendCommandExecutionRequestSnafu ) ?;
185
181
}
186
182
187
183
CoordinatorMessage :: StdinPacket ( data) => {
188
- stdin_tx
189
- . send( Multiplexed ( job_id, data) )
184
+ process_tx
185
+ . send( Multiplexed ( job_id, ProcessCommand :: Stdin ( data) ) )
190
186
. await
191
187
. drop_error_details( )
192
188
. context( UnableToSendStdinPacketSnafu ) ?;
193
189
}
190
+
191
+ CoordinatorMessage :: StdinClose => {
192
+ process_tx
193
+ . send( Multiplexed ( job_id, ProcessCommand :: StdinClose ) )
194
+ . await
195
+ . drop_error_details( )
196
+ . context( UnableToSendStdinCloseSnafu ) ?;
197
+ }
198
+
199
+ CoordinatorMessage :: Kill => {
200
+ process_tx
201
+ . send( Multiplexed ( job_id, ProcessCommand :: Kill ) )
202
+ . await
203
+ . drop_error_details( )
204
+ . context( UnableToSendKillSnafu ) ?;
205
+ }
194
206
}
195
207
}
196
208
@@ -221,6 +233,12 @@ pub enum HandleCoordinatorMessageError {
221
233
#[ snafu( display( "Failed to send stdin packet to the command task" ) ) ]
222
234
UnableToSendStdinPacket { source : mpsc:: error:: SendError < ( ) > } ,
223
235
236
+ #[ snafu( display( "Failed to send stdin close request to the command task" ) ) ]
237
+ UnableToSendStdinClose { source : mpsc:: error:: SendError < ( ) > } ,
238
+
239
+ #[ snafu( display( "Failed to send kill request to the command task" ) ) ]
240
+ UnableToSendKill { source : mpsc:: error:: SendError < ( ) > } ,
241
+
224
242
#[ snafu( display( "A coordinator command handler background task panicked" ) ) ]
225
243
TaskPanicked { source : tokio:: task:: JoinError } ,
226
244
}
@@ -373,63 +391,144 @@ fn parse_working_dir(cwd: Option<String>, project_path: impl Into<PathBuf>) -> P
373
391
final_path
374
392
}
375
393
394
+ enum ProcessCommand {
395
+ Start ( ExecuteCommandRequest , MultiplexingSender ) ,
396
+ Stdin ( String ) ,
397
+ StdinClose ,
398
+ Kill ,
399
+ }
400
+
401
+ struct ProcessState {
402
+ project_path : PathBuf ,
403
+ processes : JoinSet < Result < ( ) , ProcessError > > ,
404
+ stdin_senders : HashMap < JobId , mpsc:: Sender < String > > ,
405
+ stdin_shutdown_tx : mpsc:: Sender < JobId > ,
406
+ kill_tokens : HashMap < JobId , CancellationToken > ,
407
+ }
408
+
409
+ impl ProcessState {
410
+ fn new ( project_path : PathBuf , stdin_shutdown_tx : mpsc:: Sender < JobId > ) -> Self {
411
+ Self {
412
+ project_path,
413
+ processes : Default :: default ( ) ,
414
+ stdin_senders : Default :: default ( ) ,
415
+ stdin_shutdown_tx,
416
+ kill_tokens : Default :: default ( ) ,
417
+ }
418
+ }
419
+
420
+ async fn start (
421
+ & mut self ,
422
+ job_id : JobId ,
423
+ req : ExecuteCommandRequest ,
424
+ worker_msg_tx : MultiplexingSender ,
425
+ ) -> Result < ( ) , ProcessError > {
426
+ use process_error:: * ;
427
+
428
+ let token = CancellationToken :: new ( ) ;
429
+
430
+ let RunningChild {
431
+ child,
432
+ stdin_rx,
433
+ stdin,
434
+ stdout,
435
+ stderr,
436
+ } = match process_begin ( req, & self . project_path , & mut self . stdin_senders , job_id) {
437
+ Ok ( v) => v,
438
+ Err ( e) => {
439
+ // Should we add a message for process started
440
+ // in addition to the current message which
441
+ // indicates that the process has ended?
442
+ worker_msg_tx
443
+ . send_err ( e)
444
+ . await
445
+ . context ( UnableToSendExecuteCommandStartedResponseSnafu ) ?;
446
+ return Ok ( ( ) ) ;
447
+ }
448
+ } ;
449
+
450
+ let task_set = stream_stdio ( worker_msg_tx. clone ( ) , stdin_rx, stdin, stdout, stderr) ;
451
+
452
+ self . kill_tokens . insert ( job_id, token. clone ( ) ) ;
453
+
454
+ self . processes . spawn ( {
455
+ let stdin_shutdown_tx = self . stdin_shutdown_tx . clone ( ) ;
456
+ async move {
457
+ worker_msg_tx
458
+ . send ( process_end ( token, child, task_set, stdin_shutdown_tx, job_id) . await )
459
+ . await
460
+ . context ( UnableToSendExecuteCommandResponseSnafu )
461
+ }
462
+ } ) ;
463
+
464
+ Ok ( ( ) )
465
+ }
466
+
467
+ async fn stdin ( & mut self , job_id : JobId , packet : String ) -> Result < ( ) , ProcessError > {
468
+ use process_error:: * ;
469
+
470
+ if let Some ( stdin_tx) = self . stdin_senders . get ( & job_id) {
471
+ stdin_tx
472
+ . send ( packet)
473
+ . await
474
+ . drop_error_details ( )
475
+ . context ( UnableToSendStdinDataSnafu ) ?;
476
+ }
477
+
478
+ Ok ( ( ) )
479
+ }
480
+
481
+ fn stdin_close ( & mut self , job_id : JobId ) {
482
+ self . stdin_senders . remove ( & job_id) ;
483
+ // Should we care if we remove a sender that's already removed?
484
+ }
485
+
486
+ async fn join_process ( & mut self ) -> Option < Result < ( ) , ProcessError > > {
487
+ use process_error:: * ;
488
+
489
+ let process = self . processes . join_next ( ) . await ?;
490
+ Some ( process. context ( ProcessTaskPanickedSnafu ) . and_then ( |e| e) )
491
+ }
492
+
493
+ fn kill ( & mut self , job_id : JobId ) {
494
+ if let Some ( token) = self . kill_tokens . get ( & job_id) {
495
+ token. cancel ( ) ;
496
+ }
497
+ }
498
+ }
499
+
376
500
async fn manage_processes (
377
- mut stdin_rx : mpsc:: Receiver < Multiplexed < String > > ,
378
- mut cmd_rx : mpsc:: Receiver < CommandRequest > ,
501
+ mut rx : mpsc:: Receiver < Multiplexed < ProcessCommand > > ,
379
502
project_path : PathBuf ,
380
503
) -> Result < ( ) , ProcessError > {
381
504
use process_error:: * ;
382
505
383
- let mut processes = JoinSet :: new ( ) ;
384
- let mut stdin_senders = HashMap :: new ( ) ;
385
506
let ( stdin_shutdown_tx, mut stdin_shutdown_rx) = mpsc:: channel ( 8 ) ;
507
+ let mut state = ProcessState :: new ( project_path, stdin_shutdown_tx) ;
386
508
387
509
loop {
388
510
select ! {
389
- cmd_req = cmd_rx. recv( ) => {
390
- let Some ( ( Multiplexed ( job_id, req) , worker_msg_tx) ) = cmd_req else { break } ;
391
-
392
- let RunningChild { child, stdin_rx, stdin, stdout, stderr } = match process_begin( req, & project_path, & mut stdin_senders, job_id) {
393
- Ok ( v) => v,
394
- Err ( e) => {
395
- // Should we add a message for process started
396
- // in addition to the current message which
397
- // indicates that the process has ended?
398
- worker_msg_tx. send_err( e) . await . context( UnableToSendExecuteCommandStartedResponseSnafu ) ?;
399
- continue ;
400
- }
401
- } ;
511
+ cmd = rx. recv( ) => {
512
+ let Some ( Multiplexed ( job_id, cmd) ) = cmd else { break } ;
402
513
403
- let task_set = stream_stdio( worker_msg_tx. clone( ) , stdin_rx, stdin, stdout, stderr) ;
514
+ match cmd {
515
+ ProcessCommand :: Start ( req, worker_msg_tx) => state. start( job_id, req, worker_msg_tx) . await ?,
404
516
405
- processes. spawn( {
406
- let stdin_shutdown_tx = stdin_shutdown_tx. clone( ) ;
407
- async move {
408
- worker_msg_tx
409
- . send( process_end( child, task_set, stdin_shutdown_tx, job_id) . await )
410
- . await
411
- . context( UnableToSendExecuteCommandResponseSnafu )
412
- }
413
- } ) ;
414
- }
517
+ ProcessCommand :: Stdin ( packet) => state. stdin( job_id, packet) . await ?,
415
518
416
- stdin_packet = stdin_rx. recv( ) => {
417
- // Dispatch stdin packet to different child by attached command id.
418
- let Some ( Multiplexed ( job_id, packet) ) = stdin_packet else { break } ;
519
+ ProcessCommand :: StdinClose => state. stdin_close( job_id) ,
419
520
420
- if let Some ( stdin_tx) = stdin_senders. get( & job_id) {
421
- stdin_tx. send( packet) . await . drop_error_details( ) . context( UnableToSendStdinDataSnafu ) ?;
521
+ ProcessCommand :: Kill => state. kill( job_id) ,
422
522
}
423
523
}
424
524
425
525
job_id = stdin_shutdown_rx. recv( ) => {
426
526
let job_id = job_id. context( StdinShutdownReceiverEndedSnafu ) ?;
427
- stdin_senders. remove( & job_id) ;
428
- // Should we care if we remove a sender that's already removed?
527
+ state. stdin_close( job_id) ;
429
528
}
430
529
431
- Some ( process) = processes . join_next ( ) => {
432
- process. context ( ProcessTaskPanickedSnafu ) ? ?;
530
+ Some ( process) = state . join_process ( ) => {
531
+ process?;
433
532
}
434
533
}
435
534
}
@@ -488,13 +587,19 @@ fn process_begin(
488
587
}
489
588
490
589
async fn process_end (
590
+ token : CancellationToken ,
491
591
mut child : Child ,
492
592
mut task_set : JoinSet < Result < ( ) , StdioError > > ,
493
593
stdin_shutdown_tx : mpsc:: Sender < JobId > ,
494
594
job_id : JobId ,
495
595
) -> Result < ExecuteCommandResponse , ProcessError > {
496
596
use process_error:: * ;
497
597
598
+ select ! {
599
+ ( ) = token. cancelled( ) => child. kill( ) . await . context( KillChildSnafu ) ?,
600
+ _ = child. wait( ) => { } ,
601
+ } ;
602
+
498
603
let status = child. wait ( ) . await . context ( WaitChildSnafu ) ?;
499
604
500
605
stdin_shutdown_tx
@@ -634,6 +739,9 @@ pub enum ProcessError {
634
739
#[ snafu( display( "Failed to send stdin data" ) ) ]
635
740
UnableToSendStdinData { source : mpsc:: error:: SendError < ( ) > } ,
636
741
742
+ #[ snafu( display( "Failed to kill the child process" ) ) ]
743
+ KillChild { source : std:: io:: Error } ,
744
+
637
745
#[ snafu( display( "Failed to wait for child process exiting" ) ) ]
638
746
WaitChild { source : std:: io:: Error } ,
639
747
@@ -671,10 +779,7 @@ fn stream_stdio(
671
779
let mut set = JoinSet :: new ( ) ;
672
780
673
781
set. spawn ( async move {
674
- loop {
675
- let Some ( data) = stdin_rx. recv ( ) . await else {
676
- break ;
677
- } ;
782
+ while let Some ( data) = stdin_rx. recv ( ) . await {
678
783
stdin
679
784
. write_all ( data. as_bytes ( ) )
680
785
. await
0 commit comments