Skip to content

Commit 1e9a395

Browse files
committed
send all requests as a whole to worker
1 parent 232050c commit 1e9a395

File tree

3 files changed

+102
-133
lines changed

3 files changed

+102
-133
lines changed

ui/src/coordinator.rs

Lines changed: 54 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@ use std::collections::HashMap;
22
use std::convert::{TryFrom, TryInto};
33

44
use std::str::from_utf8;
5+
use std::time::SystemTime;
56

6-
use tokio::sync::{broadcast, mpsc};
7-
use worker_message::{CoordinatorMessage, ExecuteOutput};
7+
8+
use worker_message::{ExecuteOutput, Job};
89

910
use crate::sandbox::{
1011
self, Channel, CompileRequest, CrateType, Edition, ExecuteRequest, FormatRequest, Mode,
1112
};
1213
use crate::{parse_channel, parse_crate_type, parse_edition, parse_mode, parse_target, Error};
1314

1415
struct JobBatch {
15-
jobs: Vec<worker_message::Request>,
16+
job: Job,
1617
extra: serde_json::Value
1718
}
1819

@@ -21,14 +22,14 @@ fn split_request(req: WSRequest) -> JobBatch {
2122
WSRequest::Compile(req) => {
2223
let (req, extra) = req.try_into().unwrap();
2324
JobBatch {
24-
jobs: compile_request_to_batch(req),
25+
job: worker_message::Job { reqs: compile_request_to_batch(req) },
2526
extra
2627
}
2728
}
2829
WSRequest::Format(req) => {
2930
let (req, extra) = req.try_into().unwrap();
3031
JobBatch {
31-
jobs: format_request_to_batch(req),
32+
job: worker_message::Job { reqs: format_request_to_batch(req) },
3233
extra
3334
}
3435
}
@@ -356,17 +357,16 @@ fn worker_response_to_websocket_response(resp: BatchResponse) -> WSResponse {
356357
// }
357358
// response
358359
// }
359-
360-
async fn work(
361-
req: worker_message::Request,
362-
sender: &mpsc::Sender<CoordinatorMessage>,
363-
receiver: &mut broadcast::Receiver<worker_message::Response>,
364-
) -> worker_message::Response {
365-
let coordinator_msg = CoordinatorMessage::Request(0, req);
366-
sender.send(coordinator_msg).await.unwrap();
367-
receiver.recv().await.unwrap()
360+
//
361+
// It's safe unless a leap second happens.
362+
fn now_as_uid() -> u128 {
363+
SystemTime::now()
364+
.duration_since(SystemTime::UNIX_EPOCH)
365+
.unwrap()
366+
.as_nanos()
368367
}
369368

369+
370370
#[derive(serde::Deserialize)]
371371
#[serde(rename_all = "camelCase")]
372372
struct WSExecuteRequest {
@@ -566,6 +566,7 @@ enum WSResponse {
566566
#[cfg(test)]
567567
mod tests {
568568

569+
use std::collections::HashMap;
569570
use std::process::Stdio;
570571

571572
use std::sync::Arc;
@@ -574,17 +575,17 @@ mod tests {
574575

575576
const WORKER_FILEPATH: &str = "../worker-message/target/debug/worker";
576577
use crate::coordinator::{
577-
split_request, work, worker_response_to_websocket_response, BatchResponse,
578-
PlaygroundMessage, ResponseKind, WSMessage, WSRequest, WSResponse,
578+
split_request, worker_response_to_websocket_response, BatchResponse,
579+
PlaygroundMessage, ResponseKind, WSMessage, WSRequest, WSResponse, now_as_uid,
579580
};
580581

581582
use serde_json::json;
582583
use tokio::io::{AsyncReadExt, AsyncWriteExt};
583584
use tokio::process::{ChildStdin, ChildStdout, Command};
584-
use tokio::sync::{broadcast, mpsc};
585+
use tokio::sync::{mpsc, oneshot};
585586
use tokio::task::JoinHandle;
586587
use tokio::time::error::Elapsed;
587-
use worker_message::{CoordinatorMessage, WorkerMessage};
588+
use worker_message::{CoordinatorMessage, WorkerMessage, JobReport};
588589

589590
use super::{WSCompileRequest, WSFormatRequest};
590591

@@ -717,8 +718,9 @@ mod tests {
717718
worker_sender: mpsc::Sender<CoordinatorMessage>,
718719
mut worker_receiver: mpsc::Receiver<WorkerMessage>,
719720
) {
720-
let mut current_task: Option<JoinHandle<()>> = None;
721-
let (worker_response_tx, _worker_response_rx) = broadcast::channel(32);
721+
let mut worker_response_senders = HashMap::new();
722+
let (worker_response_sender_tx, mut worker_response_sender_rx) = mpsc::channel(32);
723+
let mut current_job: Option<JoinHandle<()>> = None;
722724
loop {
723725
tokio::select! {
724726
ws_request = ws_receiver.recv() => {
@@ -730,54 +732,60 @@ mod tests {
730732
let ws_msg = serde_json::from_str(&txt).expect("Failed to deserialize websocket message from json string");
731733
match ws_msg {
732734
WSMessage::Request(ws_request) => {
733-
// Abort current task.
735+
// Abort current job.
734736
// Lower request into low-level operations.
735737
// Execute operations in sequence until a failure is reached.
736-
if let Some(task) = current_task {
737-
task.abort();
738+
if let Some(job) = current_job {
739+
job.abort();
738740
}
739-
let mut worker_response_rx = worker_response_tx.subscribe();
740741
let worker_sender = worker_sender.clone();
741742
let ws_sender = ws_sender.clone();
742-
let kind = ResponseKind::kind(&ws_request);
743-
let batch = split_request(ws_request);
744-
let jobs = batch.jobs;
745-
let extra = batch.extra;
746-
current_task = Some(tokio::spawn(async move {
747-
assert!(jobs.len() != 0);
748-
let mut responses = Vec::with_capacity(jobs.len());
749-
for job in jobs {
750-
let resp = work(job, &worker_sender, &mut worker_response_rx).await;
751-
if !resp.is_ok() {
752-
responses.push(resp);
753-
break;
754-
}
755-
responses.push(resp);
756-
}
743+
let worker_response_sender_tx = worker_response_sender_tx.clone();
744+
current_job = Some(tokio::spawn(async move {
745+
let kind = ResponseKind::kind(&ws_request);
746+
let batch = split_request(ws_request);
747+
let job = batch.job;
748+
let extra = batch.extra;
749+
let id = now_as_uid();
750+
let (resp_tx, resp_rx) = oneshot::channel();
751+
worker_response_sender_tx.send((id, resp_tx)).await.unwrap();
752+
let coordinator_msg = CoordinatorMessage::Request(id, job);
753+
worker_sender.send(coordinator_msg).await.unwrap();
754+
let job_report: JobReport = resp_rx.await.unwrap();
755+
let worker_response = job_report.resps;
757756
let batch_response = BatchResponse {
758757
kind,
759-
responses,
758+
responses: worker_response,
760759
extra
761760
};
762761
let ws_response = worker_response_to_websocket_response(batch_response);
763762
ws_sender.send(serde_json::to_string(&PlaygroundMessage::Response(ws_response)).expect("Failed to serialize websocket response")).await.expect("WebSocket failed to send response to user");
764763
}));
764+
765765
}
766766
WSMessage::StdinPacket(packet) => {}
767767
}
768768
}
769769
}
770770
},
771+
worker_response_sender = worker_response_sender_rx.recv() => {
772+
if let Some((uid, resp_tx)) = worker_response_sender {
773+
worker_response_senders.insert(uid, resp_tx);
774+
} else {
775+
break;
776+
}
777+
},
771778
worker_msg = worker_receiver.recv() => {
772779
match worker_msg {
773780
None => {
774781
break;
775782
}
776783
Some(msg) => {
777784
match msg {
778-
WorkerMessage::Response(uid, resp) => {
779-
// TODO: drop response that belongs to aborted job.
780-
worker_response_tx.send(resp).unwrap();
785+
WorkerMessage::Response(id, resp) => {
786+
if let Some(tx) = worker_response_senders.remove(&id) {
787+
tx.send(resp).unwrap();
788+
}
781789
}
782790
WorkerMessage::StdoutPacket(pid, packet) => {}
783791
WorkerMessage::StderrPacket(pid, packet) => {}
@@ -791,12 +799,12 @@ mod tests {
791799

792800
async fn setup_coordinator() -> (mpsc::Sender<String>, mpsc::Receiver<String>) {
793801
let (worker_sender, worker_receiver) = make_worker_channel().await;
794-
let (tx, ws_receiver) = mpsc::channel(32);
795-
let (ws_sender, rx) = mpsc::channel(32);
802+
let (client_tx, ws_receiver) = mpsc::channel(32);
803+
let (ws_sender, client_rx) = mpsc::channel(32);
796804
tokio::spawn(async move {
797805
pair_websocket_worker(ws_sender, ws_receiver, worker_sender, worker_receiver).await;
798806
});
799-
(tx, rx)
807+
(client_tx, client_rx)
800808
}
801809

802810
async fn check_websocket_request_response(

worker-message/src/message.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,32 @@ use serde::{Deserialize, Serialize};
22
use std::{collections::HashMap, fmt};
33

44
pub type Pid = u32;
5-
pub type RequestId = u64;
5+
pub type JobId = u128;
66
pub type Path = String;
77
pub type ResponseError = String;
88
pub type Result<T> = std::result::Result<T, ResponseError>;
99

10+
11+
#[derive(Debug, Serialize, Deserialize)]
12+
pub struct Job {
13+
pub reqs: Vec<Request>
14+
}
15+
16+
#[derive(Debug, Serialize, Deserialize)]
17+
pub struct JobReport {
18+
pub resps: Vec<Response>
19+
}
20+
21+
1022
#[derive(Debug, Serialize, Deserialize)]
1123
pub enum CoordinatorMessage {
12-
Request(RequestId, Request),
24+
Request(JobId, Job),
1325
StdinPacket(Pid, Vec<u8>),
1426
}
1527

1628
#[derive(Debug, Serialize, Deserialize)]
1729
pub enum WorkerMessage {
18-
Response(RequestId, Response),
30+
Response(JobId, JobReport),
1931
StdoutPacket(Pid, Vec<u8>),
2032
StderrPacket(Pid, Vec<u8>),
2133
}

0 commit comments

Comments
 (0)