Skip to content

Replace futures::select! with futures_concurrency throughout codebase #1572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clippy.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ disallowed-macros = [
{ path = "futures::ready", reason = "use std::task::ready" },
{ path = "tracing::enabled", reason = "https://github.com/tokio-rs/tracing/issues/2519" },
{ path = "openhcl_boot::boot_logger::debug_log", reason = "only use in local debugging, use log! if you want a production log message"},
{ path = "futures::select", reason = "use futures_concurrency instead for safer async patterns" },
{ path = "futures::select_biased", reason = "use futures_concurrency instead for safer async patterns" },
]

disallowed-methods = [
Expand Down
24 changes: 20 additions & 4 deletions petri/pipette/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
use anyhow::Context;
use futures::future::FutureExt;
use futures_concurrency::future::RaceOk;
use std::task::Poll;
use mesh_remote::PointToPointMesh;
use pal_async::DefaultDriver;
use pal_async::socket::PolledSocket;
Expand Down Expand Up @@ -75,19 +76,34 @@ impl Agent {
pub async fn run(mut self) -> anyhow::Result<()> {
let mut tasks = FuturesUnordered::new();
loop {
futures::select! {
req = self.request_recv.recv().fuse() => {
let recv_fut = std::pin::pin!(self.request_recv.recv());
let tasks_next_fut = std::pin::pin!(tasks.next());

let should_break = std::future::poll_fn(|cx| {
// Check for new requests first
if let Poll::Ready(req) = recv_fut.as_mut().poll(cx) {
match req {
Ok(req) => {
tasks.push(handle_request(&self.driver, req, self.diag_file_send.clone()));
return Poll::Ready(false); // Continue the loop
},
Err(e) => {
tracing::info!(?e, "request channel closed, shutting down");
break;
return Poll::Ready(true); // Break the loop
}
}
}
_ = tasks.next() => {}

// Check for completed tasks
if let Poll::Ready(_) = tasks_next_fut.as_mut().poll(cx) {
return Poll::Ready(false); // Continue the loop
}

Poll::Pending
}).await;

if should_break {
break;
}
}
self.watch_send.send(());
Expand Down
6 changes: 2 additions & 4 deletions support/inspect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2176,6 +2176,7 @@ mod tests {
use expect_test::Expect;
use expect_test::expect;
use futures::FutureExt;
use futures_concurrency::future::Race;
use pal_async::DefaultDriver;
use pal_async::async_test;
use pal_async::timer::Instant;
Expand All @@ -2196,10 +2197,7 @@ mod tests {
let deadline = Instant::now() + timeout;
let mut result = InspectionBuilder::new(path).depth(depth).inspect(obj);
let mut timer = PolledTimer::new(driver);
futures::select! { // race semantics
_ = result.resolve().fuse() => {}
_ = timer.sleep_until(deadline).fuse() => {}
};
let _ = (result.resolve(), timer.sleep_until(deadline)).race().await;
result.results()
}

Expand Down
25 changes: 21 additions & 4 deletions support/mesh/mesh_process/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use futures::FutureExt;
use futures::StreamExt;
use futures::executor::block_on;
use futures_concurrency::future::Race;
use std::task::Poll;
use inspect::Inspect;
use inspect::SensitivityLevel;
use mesh::MeshPayload;
Expand Down Expand Up @@ -514,10 +515,26 @@ impl MeshInner {
}

loop {
let event = futures::select! { // merge semantics
request = self.requests.select_next_some() => Event::Request(request),
n = self.waiters.select_next_some() => Event::Done(n.unwrap()),
complete => break,
let requests_fut = std::pin::pin!(self.requests.select_next_some());
let waiters_fut = std::pin::pin!(self.waiters.select_next_some());

let event = std::future::poll_fn(|cx| {
// Check for requests
if let Poll::Ready(request) = requests_fut.as_mut().poll(cx) {
return Poll::Ready(Some(Event::Request(request)));
}

// Check for completed waiters
if let Poll::Ready(n) = waiters_fut.as_mut().poll(cx) {
return Poll::Ready(Some(Event::Done(n.unwrap())));
}

Poll::Pending
}).await;

let event = match event {
Some(e) => e,
None => break, // This shouldn't happen with the current polling logic
};

match event {
Expand Down
72 changes: 62 additions & 10 deletions support/mesh/mesh_rpc/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ use futures::Stream;
use futures::StreamExt;
use futures::stream::FusedStream;
use futures_concurrency::future::TryJoin;
use futures_concurrency::future::Race;
use futures_concurrency::stream::Merge;
use std::task::Poll;
use mesh::CancelContext;
use mesh::MeshPayload;
use mesh::local_node::Port;
Expand Down Expand Up @@ -101,6 +103,12 @@ impl GenericRpc {
}
}

enum ServerEvent<T, E> {
Cancel,
TaskCompleted,
Connection(Result<T, E>),
}

impl Server {
/// Creates a new ttrpc server.
pub fn new() -> Self {
Expand All @@ -126,12 +134,34 @@ impl Server {
) -> anyhow::Result<()> {
let mut listener = PolledSocket::new(driver, listener)?;
let mut tasks = FuturesUnordered::new();
let mut cancel = cancel.fuse();
loop {
let conn = futures::select! { // merge semantics
r = listener.accept().fuse() => r,
_ = tasks.next() => continue,
_ = cancel => break,
let accept_fut = std::pin::pin!(listener.accept());
let tasks_next_fut = std::pin::pin!(tasks.next());
let cancel_fut = std::pin::pin!(cancel);

let result = std::future::poll_fn(|cx| {
// Check for cancellation first
if let Poll::Ready(_) = cancel_fut.as_mut().poll(cx) {
return Poll::Ready(ServerEvent::Cancel);
}

// Check for completed tasks
if let Poll::Ready(_) = tasks_next_fut.as_mut().poll(cx) {
return Poll::Ready(ServerEvent::TaskCompleted);
}

// Check for new connections
if let Poll::Ready(r) = accept_fut.as_mut().poll(cx) {
return Poll::Ready(ServerEvent::Connection(r));
}

Poll::Pending
}).await;

let conn = match result {
ServerEvent::Cancel => break,
ServerEvent::TaskCompleted => continue,
ServerEvent::Connection(r) => r,
};
if let Ok(conn) = conn.and_then(|(conn, _)| PolledSocket::new(driver, conn)) {
tasks.push(async {
Expand Down Expand Up @@ -342,12 +372,34 @@ mod grpc {
) -> anyhow::Result<()> {
let mut listener = PolledSocket::new(driver, listener)?;
let mut tasks = FuturesUnordered::new();
let mut cancel = cancel.fuse();
loop {
let conn = futures::select! { // merge semantics
r = listener.accept().fuse() => r,
_ = tasks.next() => continue,
_ = cancel => break,
let accept_fut = std::pin::pin!(listener.accept());
let tasks_next_fut = std::pin::pin!(tasks.next());
let cancel_fut = std::pin::pin!(cancel);

let result = std::future::poll_fn(|cx| {
// Check for cancellation first
if let Poll::Ready(_) = cancel_fut.as_mut().poll(cx) {
return Poll::Ready(ServerEvent::Cancel);
}

// Check for completed tasks
if let Poll::Ready(_) = tasks_next_fut.as_mut().poll(cx) {
return Poll::Ready(ServerEvent::TaskCompleted);
}

// Check for new connections
if let Poll::Ready(r) = accept_fut.as_mut().poll(cx) {
return Poll::Ready(ServerEvent::Connection(r));
}

Poll::Pending
}).await;

let conn = match result {
ServerEvent::Cancel => break,
ServerEvent::TaskCompleted => continue,
ServerEvent::Connection(r) => r,
};
if let Ok(conn) = conn.and_then(|(conn, _)| PolledSocket::new(driver, conn)) {
tasks.push(async {
Expand Down
40 changes: 33 additions & 7 deletions vm/devices/get/guest_emulation_device/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use core::mem::size_of;
use disk_backend::Disk;
use futures::FutureExt;
use futures::StreamExt;
use std::task::Poll;
use get_protocol::BatteryStatusFlags;
use get_protocol::BatteryStatusNotification;
use get_protocol::GspCleartextContent;
Expand Down Expand Up @@ -125,6 +126,11 @@ impl From<task_control::Cancelled> for Error {
}
}

enum PipeEvent {
Input(usize),
GuestRequest(GuestEmulationRequest),
}

/// Settings to enable in the guest.
#[derive(Debug, Clone, Inspect)]
pub struct GuestConfig {
Expand Down Expand Up @@ -474,17 +480,37 @@ impl<T: RingMem + Unpin> GedChannel<T> {
}
GedState::Ready => {
let mut message_buf = [0; get_protocol::MAX_MESSAGE_SIZE];
futures::select! { // merge semantics
pipe_input = self.channel.recv(&mut message_buf).fuse() => {
let bytes_read = pipe_input.map_err(Error::Vmbus)?;
let recv_fut = std::pin::pin!(self.channel.recv(&mut message_buf));
let mut guest_request_recv = std::pin::pin!(state.guest_request_recv.select_next_some());
let stop_fut = std::pin::pin!(stop);

let result = std::future::poll_fn(|cx| {
// Check for cancellation first
if let Poll::Ready(_) = stop_fut.as_mut().poll(cx) {
return Poll::Ready(Err(Error::Cancelled(task_control::Cancelled)));
}

// Check for pipe input
if let Poll::Ready(pipe_result) = recv_fut.as_mut().poll(cx) {
let bytes_read = pipe_result.map_err(Error::Vmbus)?;
return Poll::Ready(Ok(PipeEvent::Input(bytes_read)));
}

// Check for guest request
if let Poll::Ready(guest_request) = guest_request_recv.as_mut().poll(cx) {
return Poll::Ready(Ok(PipeEvent::GuestRequest(guest_request)));
}

Poll::Pending
}).await?;

match result {
PipeEvent::Input(bytes_read) => {
self.handle_pipe_input(&message_buf[..bytes_read], state).await?;
},
guest_request = state.guest_request_recv.select_next_some() => {
PipeEvent::GuestRequest(guest_request) => {
self.handle_guest_request_input(state, guest_request)?;
}
_ = stop.fuse() => {
return Err(Error::Cancelled(task_control::Cancelled));
}
}
}
GedState::SendingRestore { written } => {
Expand Down
Loading