|
| 1 | +use crate::Actor; |
| 2 | +use crate::ChannelError; |
| 3 | +use crate::DynSender; |
| 4 | +use crate::LoggingReceiver; |
| 5 | +use crate::Message; |
| 6 | +use crate::MessageReceiver; |
| 7 | +use crate::RuntimeError; |
| 8 | +use crate::RuntimeRequest; |
| 9 | +use crate::Sender; |
| 10 | +use crate::Server; |
| 11 | +use async_trait::async_trait; |
| 12 | +use futures::channel::oneshot; |
| 13 | +use futures::StreamExt; |
| 14 | +use log::error; |
| 15 | +use std::fmt::Debug; |
| 16 | +use std::ops::ControlFlow; |
| 17 | + |
| 18 | +/// Wrap a request with a [Sender] to send the response to |
| 19 | +/// |
| 20 | +/// Requests are sent to server actors using such envelopes telling where to send the responses. |
| 21 | +pub struct RequestEnvelope<Request, Response> { |
| 22 | + pub request: Request, |
| 23 | + pub reply_to: Box<dyn Sender<Response>>, |
| 24 | +} |
| 25 | + |
| 26 | +impl<Request: Debug, Response> Debug for RequestEnvelope<Request, Response> { |
| 27 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 28 | + self.request.fmt(f) |
| 29 | + } |
| 30 | +} |
| 31 | + |
| 32 | +/// A message box used by a client to request a server and await the responses. |
| 33 | +pub struct ClientMessageBox<Request, Response> { |
| 34 | + sender: DynSender<RequestEnvelope<Request, Response>>, |
| 35 | +} |
| 36 | + |
| 37 | +impl<Request: Message, Response: Message> ClientMessageBox<Request, Response> { |
| 38 | + /// Send the request and await for a response |
| 39 | + pub async fn await_response(&mut self, request: Request) -> Result<Response, ChannelError> { |
| 40 | + let (sender, receiver) = oneshot::channel::<Response>(); |
| 41 | + let reply_to = Box::new(Some(sender)); |
| 42 | + self.sender |
| 43 | + .send(RequestEnvelope { request, reply_to }) |
| 44 | + .await?; |
| 45 | + let response = receiver.await; |
| 46 | + response.map_err(|_| ChannelError::ReceiveError()) |
| 47 | + } |
| 48 | +} |
| 49 | + |
| 50 | +/// A [Sender] used by a client to send requests to a server, |
| 51 | +/// redirecting the responses to another recipient. |
| 52 | +#[derive(Clone)] |
| 53 | +pub struct RequestSender<Request: 'static, Response: 'static> { |
| 54 | + sender: DynSender<RequestEnvelope<Request, Response>>, |
| 55 | + reply_to: DynSender<Response>, |
| 56 | +} |
| 57 | + |
| 58 | +#[async_trait] |
| 59 | +impl<Request: Message, Response: Message> Sender<Request> for RequestSender<Request, Response> { |
| 60 | + async fn send(&mut self, request: Request) -> Result<(), ChannelError> { |
| 61 | + let reply_to = self.reply_to.sender(); |
| 62 | + self.sender |
| 63 | + .send(RequestEnvelope { request, reply_to }) |
| 64 | + .await |
| 65 | + } |
| 66 | +} |
| 67 | + |
| 68 | +/* Adding this prevents to derive Clone for RequestSender! Why? |
| 69 | +impl<Request: Message, Response: Message> From<RequestSender<Request,Response>> for DynSender<Request> { |
| 70 | + fn from(sender: RequestSender<Request,Response>) -> Self { |
| 71 | + Box::new(sender) |
| 72 | + } |
| 73 | +}*/ |
| 74 | + |
| 75 | +/// An actor that wraps a request-response server |
| 76 | +/// |
| 77 | +/// Requests are processed in turn, leading either to a response or an error. |
| 78 | +pub struct ServerActor<S: Server> { |
| 79 | + server: S, |
| 80 | + requests: LoggingReceiver<RequestEnvelope<S::Request, S::Response>>, |
| 81 | +} |
| 82 | + |
| 83 | +#[async_trait] |
| 84 | +impl<S: Server> Actor for ServerActor<S> { |
| 85 | + fn name(&self) -> &str { |
| 86 | + self.server.name() |
| 87 | + } |
| 88 | + |
| 89 | + async fn run(mut self) -> Result<(), RuntimeError> { |
| 90 | + let server = &mut self.server; |
| 91 | + while let Some(RequestEnvelope { |
| 92 | + request, |
| 93 | + mut reply_to, |
| 94 | + }) = self.requests.recv().await |
| 95 | + { |
| 96 | + tokio::select! { |
| 97 | + response = server.handle(request) => { |
| 98 | + let _ = reply_to.send(response).await; |
| 99 | + } |
| 100 | + Some(RuntimeRequest::Shutdown) = self.requests.recv_signal() => { |
| 101 | + break; |
| 102 | + } |
| 103 | + } |
| 104 | + } |
| 105 | + Ok(()) |
| 106 | + } |
| 107 | +} |
| 108 | + |
| 109 | +/// An actor that wraps a request-response protocol |
| 110 | +/// |
| 111 | +/// Requests are processed concurrently (up to some max concurrency level). |
| 112 | +/// |
| 113 | +/// The server must be `Clone` to create a fresh server handle for each request. |
| 114 | +pub struct ConcurrentServerActor<S: Server + Clone> { |
| 115 | + server: S, |
| 116 | + messages: ConcurrentServerMessageBox<S::Request, S::Response>, |
| 117 | +} |
| 118 | + |
| 119 | +impl<S: Server + Clone> ConcurrentServerActor<S> { |
| 120 | + pub fn new(server: S, messages: ConcurrentServerMessageBox<S::Request, S::Response>) -> Self { |
| 121 | + ConcurrentServerActor { server, messages } |
| 122 | + } |
| 123 | +} |
| 124 | + |
| 125 | +#[async_trait] |
| 126 | +impl<S: Server + Clone> Actor for ConcurrentServerActor<S> { |
| 127 | + fn name(&self) -> &str { |
| 128 | + self.server.name() |
| 129 | + } |
| 130 | + |
| 131 | + async fn run(mut self) -> Result<(), RuntimeError> { |
| 132 | + while let Some(RequestEnvelope { |
| 133 | + request, |
| 134 | + mut reply_to, |
| 135 | + }) = self.messages.next_request().await |
| 136 | + { |
| 137 | + // Spawn the request |
| 138 | + let mut server = self.server.clone(); |
| 139 | + let pending_result = tokio::spawn(async move { |
| 140 | + let result = server.handle(request).await; |
| 141 | + let _ = reply_to.send(result).await; |
| 142 | + }); |
| 143 | + |
| 144 | + // Send the response back to the client |
| 145 | + self.messages.send_response_once_done(pending_result) |
| 146 | + } |
| 147 | + |
| 148 | + Ok(()) |
| 149 | + } |
| 150 | +} |
| 151 | + |
| 152 | +/// A message box for services that handles requests concurrently |
| 153 | +pub struct ConcurrentServerMessageBox<Request: Debug, Response> { |
| 154 | + /// Max concurrent requests |
| 155 | + max_concurrency: usize, |
| 156 | + |
| 157 | + /// Message box to interact with clients of this service |
| 158 | + requests: LoggingReceiver<RequestEnvelope<Request, Response>>, |
| 159 | + |
| 160 | + /// Pending responses |
| 161 | + pending_responses: futures::stream::FuturesUnordered<PendingResult>, |
| 162 | +} |
| 163 | + |
| 164 | +type PendingResult = tokio::task::JoinHandle<()>; |
| 165 | + |
| 166 | +impl<Request: Message, Response: Message> ConcurrentServerMessageBox<Request, Response> { |
| 167 | + pub(crate) fn new( |
| 168 | + max_concurrency: usize, |
| 169 | + requests: LoggingReceiver<RequestEnvelope<Request, Response>>, |
| 170 | + ) -> Self { |
| 171 | + ConcurrentServerMessageBox { |
| 172 | + max_concurrency, |
| 173 | + requests, |
| 174 | + pending_responses: futures::stream::FuturesUnordered::new(), |
| 175 | + } |
| 176 | + } |
| 177 | + |
| 178 | + async fn next_request(&mut self) -> Option<RequestEnvelope<Request, Response>> { |
| 179 | + if self.await_idle_processor().await.is_break() { |
| 180 | + return None; |
| 181 | + } |
| 182 | + |
| 183 | + loop { |
| 184 | + tokio::select! { |
| 185 | + Some(request) = self.requests.recv() => { |
| 186 | + return Some(request); |
| 187 | + } |
| 188 | + Some(result) = self.pending_responses.next() => { |
| 189 | + if let Err(err) = result { |
| 190 | + error!("Request failed with: {err}"); |
| 191 | + } |
| 192 | + } |
| 193 | + else => { |
| 194 | + return None |
| 195 | + } |
| 196 | + } |
| 197 | + } |
| 198 | + } |
| 199 | + |
| 200 | + async fn await_idle_processor(&mut self) -> ControlFlow<(), ()> { |
| 201 | + if self.pending_responses.len() < self.max_concurrency { |
| 202 | + return ControlFlow::Continue(()); |
| 203 | + } |
| 204 | + |
| 205 | + tokio::select! { |
| 206 | + Some(result) = self.pending_responses.next() => { |
| 207 | + if let Err(err) = result { |
| 208 | + error!("Request failed with: {err}"); |
| 209 | + } |
| 210 | + ControlFlow::Continue(()) |
| 211 | + }, |
| 212 | + // recv consumes the message from the channel, so we can't just use |
| 213 | + // a regular return, because then next_request wouldn't see it |
| 214 | + // |
| 215 | + // a better approach would be to do select on top-level entry point, |
| 216 | + // then we'd be sure we're able to cancel when anything happens, not |
| 217 | + // just when waiting for pending_responses. |
| 218 | + Some(RuntimeRequest::Shutdown) = self.requests.recv_signal() => { |
| 219 | + ControlFlow::Break(()) |
| 220 | + } |
| 221 | + else => ControlFlow::Break(()) |
| 222 | + } |
| 223 | + } |
| 224 | + |
| 225 | + pub fn send_response_once_done(&mut self, pending_result: PendingResult) { |
| 226 | + self.pending_responses.push(pending_result); |
| 227 | + } |
| 228 | +} |
0 commit comments