diff --git a/clippy.toml b/clippy.toml index 95ea56a5f1..7bd691c232 100644 --- a/clippy.toml +++ b/clippy.toml @@ -14,6 +14,8 @@ disallowed-macros = [ { path = "futures::ready", reason = "use std::task::ready" }, { path = "tracing::enabled", reason = "https://github.com/tokio-rs/tracing/issues/2519" }, { path = "openhcl_boot::boot_logger::debug_log", reason = "only use in local debugging, use log! if you want a production log message"}, + { path = "futures::select", reason = "use futures_concurrency instead for safer async patterns" }, + { path = "futures::select_biased", reason = "use futures_concurrency instead for safer async patterns" }, ] disallowed-methods = [ diff --git a/petri/pipette/src/agent.rs b/petri/pipette/src/agent.rs index cb39e9e8b0..1d89382f2b 100644 --- a/petri/pipette/src/agent.rs +++ b/petri/pipette/src/agent.rs @@ -8,6 +8,7 @@ use anyhow::Context; use futures::future::FutureExt; use futures_concurrency::future::RaceOk; +use std::task::Poll; use mesh_remote::PointToPointMesh; use pal_async::DefaultDriver; use pal_async::socket::PolledSocket; @@ -75,19 +76,34 @@ impl Agent { pub async fn run(mut self) -> anyhow::Result<()> { let mut tasks = FuturesUnordered::new(); loop { - futures::select! { - req = self.request_recv.recv().fuse() => { + let recv_fut = std::pin::pin!(self.request_recv.recv()); + let tasks_next_fut = std::pin::pin!(tasks.next()); + + let should_break = std::future::poll_fn(|cx| { + // Check for new requests first + if let Poll::Ready(req) = recv_fut.as_mut().poll(cx) { match req { Ok(req) => { tasks.push(handle_request(&self.driver, req, self.diag_file_send.clone())); + return Poll::Ready(false); // Continue the loop }, Err(e) => { tracing::info!(?e, "request channel closed, shutting down"); - break; + return Poll::Ready(true); // Break the loop } } } - _ = tasks.next() => {} + + // Check for completed tasks + if let Poll::Ready(_) = tasks_next_fut.as_mut().poll(cx) { + return Poll::Ready(false); // Continue the loop + } + + Poll::Pending + }).await; + + if should_break { + break; } } self.watch_send.send(()); diff --git a/support/inspect/src/lib.rs b/support/inspect/src/lib.rs index cd1621eef3..0435d174ab 100644 --- a/support/inspect/src/lib.rs +++ b/support/inspect/src/lib.rs @@ -2176,6 +2176,7 @@ mod tests { use expect_test::Expect; use expect_test::expect; use futures::FutureExt; + use futures_concurrency::future::Race; use pal_async::DefaultDriver; use pal_async::async_test; use pal_async::timer::Instant; @@ -2196,10 +2197,7 @@ mod tests { let deadline = Instant::now() + timeout; let mut result = InspectionBuilder::new(path).depth(depth).inspect(obj); let mut timer = PolledTimer::new(driver); - futures::select! { // race semantics - _ = result.resolve().fuse() => {} - _ = timer.sleep_until(deadline).fuse() => {} - }; + let _ = (result.resolve(), timer.sleep_until(deadline)).race().await; result.results() } diff --git a/support/mesh/mesh_process/src/lib.rs b/support/mesh/mesh_process/src/lib.rs index c05bfe82e2..43f23889f1 100644 --- a/support/mesh/mesh_process/src/lib.rs +++ b/support/mesh/mesh_process/src/lib.rs @@ -14,6 +14,7 @@ use futures::FutureExt; use futures::StreamExt; use futures::executor::block_on; use futures_concurrency::future::Race; +use std::task::Poll; use inspect::Inspect; use inspect::SensitivityLevel; use mesh::MeshPayload; @@ -514,10 +515,26 @@ impl MeshInner { } loop { - let event = futures::select! { // merge semantics - request = self.requests.select_next_some() => Event::Request(request), - n = self.waiters.select_next_some() => Event::Done(n.unwrap()), - complete => break, + let requests_fut = std::pin::pin!(self.requests.select_next_some()); + let waiters_fut = std::pin::pin!(self.waiters.select_next_some()); + + let event = std::future::poll_fn(|cx| { + // Check for requests + if let Poll::Ready(request) = requests_fut.as_mut().poll(cx) { + return Poll::Ready(Some(Event::Request(request))); + } + + // Check for completed waiters + if let Poll::Ready(n) = waiters_fut.as_mut().poll(cx) { + return Poll::Ready(Some(Event::Done(n.unwrap()))); + } + + Poll::Pending + }).await; + + let event = match event { + Some(e) => e, + None => break, // This shouldn't happen with the current polling logic }; match event { diff --git a/support/mesh/mesh_rpc/src/server.rs b/support/mesh/mesh_rpc/src/server.rs index fe795a9012..92fbd922f2 100644 --- a/support/mesh/mesh_rpc/src/server.rs +++ b/support/mesh/mesh_rpc/src/server.rs @@ -24,7 +24,9 @@ use futures::Stream; use futures::StreamExt; use futures::stream::FusedStream; use futures_concurrency::future::TryJoin; +use futures_concurrency::future::Race; use futures_concurrency::stream::Merge; +use std::task::Poll; use mesh::CancelContext; use mesh::MeshPayload; use mesh::local_node::Port; @@ -101,6 +103,12 @@ impl GenericRpc { } } +enum ServerEvent { + Cancel, + TaskCompleted, + Connection(Result), +} + impl Server { /// Creates a new ttrpc server. pub fn new() -> Self { @@ -126,12 +134,34 @@ impl Server { ) -> anyhow::Result<()> { let mut listener = PolledSocket::new(driver, listener)?; let mut tasks = FuturesUnordered::new(); - let mut cancel = cancel.fuse(); loop { - let conn = futures::select! { // merge semantics - r = listener.accept().fuse() => r, - _ = tasks.next() => continue, - _ = cancel => break, + let accept_fut = std::pin::pin!(listener.accept()); + let tasks_next_fut = std::pin::pin!(tasks.next()); + let cancel_fut = std::pin::pin!(cancel); + + let result = std::future::poll_fn(|cx| { + // Check for cancellation first + if let Poll::Ready(_) = cancel_fut.as_mut().poll(cx) { + return Poll::Ready(ServerEvent::Cancel); + } + + // Check for completed tasks + if let Poll::Ready(_) = tasks_next_fut.as_mut().poll(cx) { + return Poll::Ready(ServerEvent::TaskCompleted); + } + + // Check for new connections + if let Poll::Ready(r) = accept_fut.as_mut().poll(cx) { + return Poll::Ready(ServerEvent::Connection(r)); + } + + Poll::Pending + }).await; + + let conn = match result { + ServerEvent::Cancel => break, + ServerEvent::TaskCompleted => continue, + ServerEvent::Connection(r) => r, }; if let Ok(conn) = conn.and_then(|(conn, _)| PolledSocket::new(driver, conn)) { tasks.push(async { @@ -342,12 +372,34 @@ mod grpc { ) -> anyhow::Result<()> { let mut listener = PolledSocket::new(driver, listener)?; let mut tasks = FuturesUnordered::new(); - let mut cancel = cancel.fuse(); loop { - let conn = futures::select! { // merge semantics - r = listener.accept().fuse() => r, - _ = tasks.next() => continue, - _ = cancel => break, + let accept_fut = std::pin::pin!(listener.accept()); + let tasks_next_fut = std::pin::pin!(tasks.next()); + let cancel_fut = std::pin::pin!(cancel); + + let result = std::future::poll_fn(|cx| { + // Check for cancellation first + if let Poll::Ready(_) = cancel_fut.as_mut().poll(cx) { + return Poll::Ready(ServerEvent::Cancel); + } + + // Check for completed tasks + if let Poll::Ready(_) = tasks_next_fut.as_mut().poll(cx) { + return Poll::Ready(ServerEvent::TaskCompleted); + } + + // Check for new connections + if let Poll::Ready(r) = accept_fut.as_mut().poll(cx) { + return Poll::Ready(ServerEvent::Connection(r)); + } + + Poll::Pending + }).await; + + let conn = match result { + ServerEvent::Cancel => break, + ServerEvent::TaskCompleted => continue, + ServerEvent::Connection(r) => r, }; if let Ok(conn) = conn.and_then(|(conn, _)| PolledSocket::new(driver, conn)) { tasks.push(async { diff --git a/vm/devices/get/guest_emulation_device/src/lib.rs b/vm/devices/get/guest_emulation_device/src/lib.rs index bd06568e7a..2c0fee30fe 100644 --- a/vm/devices/get/guest_emulation_device/src/lib.rs +++ b/vm/devices/get/guest_emulation_device/src/lib.rs @@ -21,6 +21,7 @@ use core::mem::size_of; use disk_backend::Disk; use futures::FutureExt; use futures::StreamExt; +use std::task::Poll; use get_protocol::BatteryStatusFlags; use get_protocol::BatteryStatusNotification; use get_protocol::GspCleartextContent; @@ -125,6 +126,11 @@ impl From for Error { } } +enum PipeEvent { + Input(usize), + GuestRequest(GuestEmulationRequest), +} + /// Settings to enable in the guest. #[derive(Debug, Clone, Inspect)] pub struct GuestConfig { @@ -474,17 +480,37 @@ impl GedChannel { } GedState::Ready => { let mut message_buf = [0; get_protocol::MAX_MESSAGE_SIZE]; - futures::select! { // merge semantics - pipe_input = self.channel.recv(&mut message_buf).fuse() => { - let bytes_read = pipe_input.map_err(Error::Vmbus)?; + let recv_fut = std::pin::pin!(self.channel.recv(&mut message_buf)); + let mut guest_request_recv = std::pin::pin!(state.guest_request_recv.select_next_some()); + let stop_fut = std::pin::pin!(stop); + + let result = std::future::poll_fn(|cx| { + // Check for cancellation first + if let Poll::Ready(_) = stop_fut.as_mut().poll(cx) { + return Poll::Ready(Err(Error::Cancelled(task_control::Cancelled))); + } + + // Check for pipe input + if let Poll::Ready(pipe_result) = recv_fut.as_mut().poll(cx) { + let bytes_read = pipe_result.map_err(Error::Vmbus)?; + return Poll::Ready(Ok(PipeEvent::Input(bytes_read))); + } + + // Check for guest request + if let Poll::Ready(guest_request) = guest_request_recv.as_mut().poll(cx) { + return Poll::Ready(Ok(PipeEvent::GuestRequest(guest_request))); + } + + Poll::Pending + }).await?; + + match result { + PipeEvent::Input(bytes_read) => { self.handle_pipe_input(&message_buf[..bytes_read], state).await?; }, - guest_request = state.guest_request_recv.select_next_some() => { + PipeEvent::GuestRequest(guest_request) => { self.handle_guest_request_input(state, guest_request)?; } - _ = stop.fuse() => { - return Err(Error::Cancelled(task_control::Cancelled)); - } } } GedState::SendingRestore { written } => { diff --git a/vm/devices/net/netvsp/src/test.rs b/vm/devices/net/netvsp/src/test.rs index f829955d7e..3d0f7d0bf8 100644 --- a/vm/devices/net/netvsp/src/test.rs +++ b/vm/devices/net/netvsp/src/test.rs @@ -16,6 +16,7 @@ use futures::Future; use futures::FutureExt; use futures::StreamExt; use futures::TryFutureExt; +use futures_concurrency::future::Race; use guestmem::MemoryRead; use guestmem::MemoryWrite; use guestmem::ranges::PagedRanges; @@ -710,49 +711,55 @@ impl TestNicDevice { .with_timeout(Duration::from_millis(1000)) .until_cancelled(async { let restore = std::pin::pin!(self.channel.restore(buffer)); - let mut restore = restore.fuse(); - loop { - futures::select! { - result = restore => break result, - request = self.offer_input.server_request_recv.select_next_some() => { - match request { - vmbus_channel::bus::ChannelServerRequest::Restore(rpc) => { - let gpadls = gpadl_map_contents.iter().map(|(gpadl_id, pages)| { - let pages = pages.clone(); - vmbus_channel::bus::RestoredGpadl { - request: GpadlRequest { - id: *gpadl_id, - count: 1, - buf: pages.into_buffer(), + let mut request_stream = std::pin::pin!(self.offer_input.server_request_recv); + + std::future::poll_fn(|cx| { + // First check if restore is ready + if let Poll::Ready(result) = restore.as_mut().poll(cx) { + return Poll::Ready(result); + } + + // Then check for incoming requests + while let Poll::Ready(Some(request)) = request_stream.as_mut().poll_next(cx) { + match request { + vmbus_channel::bus::ChannelServerRequest::Restore(rpc) => { + let gpadls = gpadl_map_contents.iter().map(|(gpadl_id, pages)| { + let pages = pages.clone(); + vmbus_channel::bus::RestoredGpadl { + request: GpadlRequest { + id: *gpadl_id, + count: 1, + buf: pages.into_buffer(), + }, + accepted: true, + } + }).collect::>(); + rpc.handle_sync(|open| { + guest_to_host_interrupt = open.map(|open| open.guest_to_host_interrupt); + Ok(vmbus_channel::bus::RestoreResult { + open_request: Some(OpenRequest { + open_data: OpenData { + target_vp: 0, + ring_offset: 2, + ring_gpadl_id, + event_flag: 1, + connection_id: 1, + user_data: UserDefinedData::new_zeroed(), }, - accepted: true, - } - }).collect::>(); - rpc.handle_sync(|open| { - guest_to_host_interrupt = open.map(|open| open.guest_to_host_interrupt); - Ok(vmbus_channel::bus::RestoreResult { - open_request: Some(OpenRequest { - open_data: OpenData { - target_vp: 0, - ring_offset: 2, - ring_gpadl_id, - event_flag: 1, - connection_id: 1, - user_data: UserDefinedData::new_zeroed(), - }, - interrupt: host_to_guest_interrupt.clone(), - use_confidential_external_memory: false, - use_confidential_ring: false, - }), - gpadls, - }) + interrupt: host_to_guest_interrupt.clone(), + use_confidential_external_memory: false, + use_confidential_ring: false, + }), + gpadls, }) - } - vmbus_channel::bus::ChannelServerRequest::Revoke(_) => (), + }) } + vmbus_channel::bus::ChannelServerRequest::Revoke(_) => (), } } - } + + Poll::Pending + }).await }) .await .unwrap()?; @@ -3123,17 +3130,15 @@ async fn remove_vf_with_async_messages( .expect("completion message"); }; - let eject_vf = std::pin::pin!(eject_vf); - let mut fused_eject_vf = eject_vf.fuse(); - // Remove VF - let update_id = std::pin::pin!(test_vf_state.update_id(None, None)); - let mut fused_update_id = update_id.fuse(); - // futures_concurrency::future::try_join seems promising, but unable to get it to work here - loop { - futures::select! { - _ = fused_eject_vf => {} - result = fused_update_id => result?, - complete => break, + // Run both futures and wait for update_id to complete + match (eject_vf, test_vf_state.update_id(None, None)).race().await { + futures_concurrency::future::RaceResult::First(_) => { + // eject_vf completed first, wait for update_id + test_vf_state.update_id(None, None).await?; + } + futures_concurrency::future::RaceResult::Second(result) => { + // update_id completed, return its result + result?; } } Ok(()) diff --git a/vm/devices/storage/storvsp/fuzz/fuzz_storvsp.rs b/vm/devices/storage/storvsp/fuzz/fuzz_storvsp.rs index dd611e9b47..08ab84d0f0 100644 --- a/vm/devices/storage/storvsp/fuzz/fuzz_storvsp.rs +++ b/vm/devices/storage/storvsp/fuzz/fuzz_storvsp.rs @@ -7,7 +7,7 @@ use arbitrary::Arbitrary; use arbitrary::Unstructured; use futures::FutureExt; -use futures::select; +use futures_concurrency::future::Race; use guestmem::GuestMemory; use guestmem::ranges::PagedRange; use pal_async::DefaultPool; @@ -227,12 +227,15 @@ fn do_fuzz(u: &mut Unstructured<'_>) -> Result<(), anyhow::Error> { transaction_id: 0, }; - let mut fuzz_loop = pin!(do_fuzz_loop(u, &mut guest).fuse()); - let mut teardown = pin!(test_worker.teardown_ignore().fuse()); - - select! { - _r1 = fuzz_loop => xtask_fuzz::fuzz_eprintln!("test case exhausted arbitrary data"), - _r2 = teardown => xtask_fuzz::fuzz_eprintln!("test worker completed"), + match ( + do_fuzz_loop(u, &mut guest), + test_worker.teardown_ignore(), + ) + .race() + .await + { + futures_concurrency::future::RaceResult::First(_r1) => xtask_fuzz::fuzz_eprintln!("test case exhausted arbitrary data"), + futures_concurrency::future::RaceResult::Second(_r2) => xtask_fuzz::fuzz_eprintln!("test worker completed"), } Ok::<(), anyhow::Error>(()) diff --git a/vm/devices/storage/storvsp/src/lib.rs b/vm/devices/storage/storvsp/src/lib.rs index 87004a9686..7bf78b658f 100644 --- a/vm/devices/storage/storvsp/src/lib.rs +++ b/vm/devices/storage/storvsp/src/lib.rs @@ -23,7 +23,7 @@ use async_trait::async_trait; use fast_select::FastSelect; use futures::FutureExt; use futures::StreamExt; -use futures::select_biased; +use futures_concurrency::future::Race; use guestmem::AccessError; use guestmem::GuestMemory; use guestmem::MemoryRead; @@ -873,9 +873,15 @@ impl Worker { match current_state { ProtocolState::Ready { version, .. } => { break loop { - select_biased! { - r = self.inner.process_ready(&mut self.queue, version).fuse() => break r, - _ = self.fast_select.select((self.rescan_notification.select_next_some(),)).fuse() => { + match ( + self.inner.process_ready(&mut self.queue, version), + self.fast_select.select((self.rescan_notification.select_next_some(),)), + ) + .race() + .await + { + futures_concurrency::future::RaceResult::First(r) => break r, + futures_concurrency::future::RaceResult::Second(_) => { if version >= Version::Win7 { self.inner.send_packet(&mut self.queue.split().1, storvsp_protocol::Operation::ENUMERATE_BUS, storvsp_protocol::NtStatus::SUCCESS, &())?; diff --git a/vm/devices/virtio/virtio/src/common.rs b/vm/devices/virtio/virtio/src/common.rs index 97271fc162..315bd44de5 100644 --- a/vm/devices/virtio/virtio/src/common.rs +++ b/vm/devices/virtio/virtio/src/common.rs @@ -9,6 +9,7 @@ use async_trait::async_trait; use futures::FutureExt; use futures::Stream; use futures::StreamExt; +use futures_concurrency::future::Race; use guestmem::DoorbellRegistration; use guestmem::GuestMemory; use guestmem::GuestMemoryError; @@ -400,11 +401,15 @@ impl VirtioQueueWorker { } } VirtioQueueStateInner::Running { queue, exit_event } => { - let mut exit = exit_event.fuse(); - let mut queue_ready = queue.next().fuse(); - let work = futures::select_biased! { - _ = exit => return false, - work = queue_ready => work.expect("queue will never complete").map_err(anyhow::Error::from), + let work = match ( + exit_event, + queue.next().map(|work| work.expect("queue will never complete").map_err(anyhow::Error::from)), + ) + .race() + .await + { + futures_concurrency::future::RaceResult::First(_) => return false, + futures_concurrency::future::RaceResult::Second(work) => work, }; self.context.process_work(work).await } diff --git a/vm/devices/vmbus/vmbus_client/src/lib.rs b/vm/devices/vmbus/vmbus_client/src/lib.rs index 4eaa47d751..06d64a6e76 100644 --- a/vm/devices/vmbus/vmbus_client/src/lib.rs +++ b/vm/devices/vmbus/vmbus_client/src/lib.rs @@ -15,6 +15,7 @@ use futures::StreamExt; use futures::future::OptionFuture; use futures::stream::SelectAll; use futures_concurrency::future::Race; +use std::task::Poll; use guid::Guid; use inspect::Inspect; use mesh::rpc::FailableRpc; diff --git a/vm/vmcore/src/vmtime.rs b/vm/vmcore/src/vmtime.rs index 464e8d53ff..869e84fc37 100644 --- a/vm/vmcore/src/vmtime.rs +++ b/vm/vmcore/src/vmtime.rs @@ -1000,17 +1000,29 @@ mod tests { // Test long timeout. access.set_timeout(access.now().wrapping_add(Duration::from_secs(1000))); let mut timer = PolledTimer::new(&driver); - futures::select! { - _ = timer.sleep(Duration::from_millis(50)).fuse() => {} - _ = poll_fn(|cx| access.poll_timeout(cx)).fuse() => panic!("unexpected wait completion"), + match ( + timer.sleep(Duration::from_millis(50)), + poll_fn(|cx| access.poll_timeout(cx)), + ) + .race() + .await + { + futures_concurrency::future::RaceResult::First(_) => {} + futures_concurrency::future::RaceResult::Second(_) => panic!("unexpected wait completion"), } // Test short timeout. let deadline = access.now().wrapping_add(Duration::from_millis(10)); access.set_timeout(deadline); - futures::select! { - _ = timer.sleep(Duration::from_millis(1000)).fuse() => panic!("unexpected timeout"), - now = poll_fn(|cx| access.poll_timeout(cx)).fuse() => { + match ( + timer.sleep(Duration::from_millis(1000)), + poll_fn(|cx| access.poll_timeout(cx)), + ) + .race() + .await + { + futures_concurrency::future::RaceResult::First(_) => panic!("unexpected timeout"), + futures_concurrency::future::RaceResult::Second(now) => { assert!(now.is_after(deadline)); } } @@ -1025,18 +1037,30 @@ mod tests { let now = access.now(); let deadline = now.wrapping_add(Duration::from_millis(2000)); access.set_timeout(deadline); - futures::select! { - _ = timer.sleep(Duration::from_millis(30)).fuse() => { + match ( + timer.sleep(Duration::from_millis(30)), + poll_fn(|cx| access.poll_timeout(cx)), + ) + .race() + .await + { + futures_concurrency::future::RaceResult::First(_) => { let deadline = now.wrapping_add(Duration::from_millis(50)); access.set_timeout(deadline); - futures::select! { - _ = timer.sleep(Duration::from_millis(1000)).fuse() => panic!("unexpected timeout"), - now = poll_fn(|cx| access.poll_timeout(cx)).fuse() => { + match ( + timer.sleep(Duration::from_millis(1000)), + poll_fn(|cx| access.poll_timeout(cx)), + ) + .race() + .await + { + futures_concurrency::future::RaceResult::First(_) => panic!("unexpected timeout"), + futures_concurrency::future::RaceResult::Second(now) => { assert!(now.is_after(deadline)); } } } - _ = poll_fn(|cx| access.poll_timeout(cx)).fuse() => panic!("unexpected wait completion"), + futures_concurrency::future::RaceResult::Second(_) => panic!("unexpected wait completion"), } keeper.stop().await; } diff --git a/workers/debug_worker/src/lib.rs b/workers/debug_worker/src/lib.rs index acc4965461..4a88262a6a 100644 --- a/workers/debug_worker/src/lib.rs +++ b/workers/debug_worker/src/lib.rs @@ -17,6 +17,8 @@ use debug_worker_defs::DEBUGGER_WORKER; use debug_worker_defs::DebuggerParameters; use futures::AsyncReadExt; use futures::FutureExt; +use futures_concurrency::future::Race; +use std::task::Poll; use gdb::VmProxy; use gdb::targets::TargetArch; use gdb::targets::VmTarget; @@ -59,6 +61,11 @@ enum State { Invalid, } +enum ServerEvent { + ServerCompleted(Result<(), anyhow::Error>), + RpcReceived(Result), +} + trait GdbListener: 'static + Send + Listener + Sized + MeshField { const ID: WorkerId>; } @@ -113,49 +120,67 @@ where }; loop { - let r = futures::select! { // merge semantics - r = rpc_recv.recv().fuse() => r, - r = server.process(&driver).fuse() => { - r?; - return Ok(()) - }, - }; + let recv_fut = std::pin::pin!(rpc_recv.recv()); + let server_fut = std::pin::pin!(server.process(&driver)); + + let r = std::future::poll_fn(|cx| { + // Check if server process completes first + if let Poll::Ready(server_result) = server_fut.as_mut().poll(cx) { + return Poll::Ready(ServerEvent::ServerCompleted(server_result)); + } + + // Check for RPC messages + if let Poll::Ready(rpc_result) = recv_fut.as_mut().poll(cx) { + return Poll::Ready(ServerEvent::RpcReceived(rpc_result)); + } + + Poll::Pending + }).await; + match r { - Ok(message) => match message { - WorkerRpc::Stop => return Ok(()), - WorkerRpc::Inspect(deferred) => deferred.inspect(&mut server), - WorkerRpc::Restart(rpc) => { - let vm_proxy = match server.state { - State::Listening { vm_proxy } => vm_proxy, - State::Connected { task, abort, .. } => { - drop(abort); - task.await - } - State::Invalid => unreachable!(), - }; - - let state = { - let (req_chan, vp_count) = vm_proxy.into_params(); - DebuggerParameters { - listener: server.listener.into_inner(), - req_chan, - vp_count, - target_arch: match server.architecture { - Architecture::X86_64 => { - debug_worker_defs::TargetArch::X86_64 + ServerEvent::ServerCompleted(server_result) => { + server_result?; + return Ok(()); + } + ServerEvent::RpcReceived(rpc_result) => { + match rpc_result { + Ok(message) => match message { + WorkerRpc::Stop => return Ok(()), + WorkerRpc::Inspect(deferred) => deferred.inspect(&mut server), + WorkerRpc::Restart(rpc) => { + let vm_proxy = match server.state { + State::Listening { vm_proxy } => vm_proxy, + State::Connected { task, abort, .. } => { + drop(abort); + task.await } - Architecture::I8086 => debug_worker_defs::TargetArch::I8086, - Architecture::Aarch64 => { - debug_worker_defs::TargetArch::Aarch64 + State::Invalid => unreachable!(), + }; + + let state = { + let (req_chan, vp_count) = vm_proxy.into_params(); + DebuggerParameters { + listener: server.listener.into_inner(), + req_chan, + vp_count, + target_arch: match server.architecture { + Architecture::X86_64 => { + debug_worker_defs::TargetArch::X86_64 + } + Architecture::I8086 => debug_worker_defs::TargetArch::I8086, + Architecture::Aarch64 => { + debug_worker_defs::TargetArch::Aarch64 + } + }, } - }, + }; + rpc.complete(Ok(state)); + return Ok(()); } - }; - rpc.complete(Ok(state)); - return Ok(()); + } } - }, - Err(_) => return Ok(()), + Err(_) => return Ok(()), + } } } }) @@ -232,9 +257,9 @@ where } }; - let res = futures::select! { // race semantics - gdb_res = state_machine_fut.fuse() => Some(gdb_res), - _ = abort_recv.fuse() => None, + let res = match (state_machine_fut, abort_recv).race().await { + futures_concurrency::future::RaceResult::First(gdb_res) => Some(gdb_res), + futures_concurrency::future::RaceResult::Second(_) => None, }; match res { @@ -336,12 +361,21 @@ async fn run_state_machine( let mut b = [0]; let incoming_data = gdb.borrow_conn().0.read_exact(&mut b); - let event = futures::select! { // race semantics - r = stop_chan.fuse() => { - let reason = r.map_err(|e| GdbStubError::TargetError(e.into()))?; - Event::HaltReason(reason) + let event = match ( + async { + let reason = stop_chan.await.map_err(|e| GdbStubError::TargetError(e.into()))?; + Ok(Event::HaltReason(reason)) }, - _ = incoming_data.fuse() => Event::IncomingData(b[0]), + async { + incoming_data.await; + Ok(Event::IncomingData(b[0])) + }, + ) + .race() + .await + { + futures_concurrency::future::RaceResult::First(result) => result?, + futures_concurrency::future::RaceResult::Second(result) => result?, }; match event { diff --git a/workers/vnc_worker/src/lib.rs b/workers/vnc_worker/src/lib.rs index cf41383816..07bfcfdf64 100644 --- a/workers/vnc_worker/src/lib.rs +++ b/workers/vnc_worker/src/lib.rs @@ -8,6 +8,8 @@ use anyhow::Context; use anyhow::anyhow; use futures::FutureExt; +use futures_concurrency::future::Race; +use std::task::Poll; use input_core::InputData; use input_core::KeyboardData; use input_core::MouseData; @@ -47,6 +49,11 @@ enum State { Invalid, } +enum VncServerEvent { + ServerCompleted(Result<(), anyhow::Error>), + RpcReceived(Result), +} + impl Worker for VncWorker { type Parameters = VncParameters; type State = VncParameters; @@ -119,17 +126,37 @@ impl VncWorker { }; let rpc = loop { - let r = futures::select! { // merge semantics - r = rpc_recv.recv().fuse() => r, - r = server.process(&driver).fuse() => break r.map(|_| None)?, - }; + let recv_fut = std::pin::pin!(rpc_recv.recv()); + let server_fut = std::pin::pin!(server.process(&driver)); + + let r = std::future::poll_fn(|cx| { + // Check if server process completes first + if let Poll::Ready(server_result) = server_fut.as_mut().poll(cx) { + return Poll::Ready(VncServerEvent::ServerCompleted(server_result)); + } + + // Check for RPC messages + if let Poll::Ready(rpc_result) = recv_fut.as_mut().poll(cx) { + return Poll::Ready(VncServerEvent::RpcReceived(rpc_result)); + } + + Poll::Pending + }).await; + match r { - Ok(message) => match message { - WorkerRpc::Stop => break None, - WorkerRpc::Inspect(deferred) => deferred.inspect(&server), - WorkerRpc::Restart(response) => break Some(response), - }, - Err(_) => break None, + VncServerEvent::ServerCompleted(server_result) => { + break server_result.map(|_| None)?; + } + VncServerEvent::RpcReceived(rpc_result) => { + match rpc_result { + Ok(message) => match message { + WorkerRpc::Stop => break None, + WorkerRpc::Inspect(deferred) => deferred.inspect(&server), + WorkerRpc::Restart(response) => break Some(response), + }, + Err(_) => break None, + } + } } }; if let Some(rpc) = rpc { @@ -196,10 +223,17 @@ impl Server { updater.update(); } }; - let r = futures::select! { // race semantics - r = vncserver.run().fuse() => r.context("VNC error"), - _ = abort_recv.fuse() => Err(anyhow!("VNC connection aborted")), - _ = update_task.fuse() => unreachable!(), + let r = match ( + vncserver.run().map_err(|e| anyhow::Error::from(e).context("VNC error")), + abort_recv.map(|_| Err(anyhow!("VNC connection aborted"))), + update_task.map(|_| Err(anyhow!("unreachable update_task completed"))), + ) + .race() + .await + { + futures_concurrency::future::RaceResult::First(result) => result, + futures_concurrency::future::RaceResult::Second(result) => result, + futures_concurrency::future::RaceResult::Third(result) => result, }; match r { Ok(_) => { diff --git a/workers/vnc_worker/vnc/src/lib.rs b/workers/vnc_worker/vnc/src/lib.rs index 0a4ced7cb8..a86015be1b 100644 --- a/workers/vnc_worker/vnc/src/lib.rs +++ b/workers/vnc_worker/vnc/src/lib.rs @@ -12,6 +12,7 @@ use futures::AsyncReadExt; use futures::AsyncWriteExt; use futures::FutureExt; use futures::StreamExt; +use std::task::Poll; use futures::channel::mpsc; use futures::future::OptionFuture; use pal_async::socket::PolledSocket; @@ -173,16 +174,34 @@ impl Server { let mut update_ready = false; let mut message_type = 0u8; let update_recv = &mut self.update_recv; - let mut update: OptionFuture<_> = ready_for_update + let update: OptionFuture<_> = ready_for_update .then(|| update_recv.select_next_some()) .into(); - futures::select! { // merge semantics - _ = update => update_ready = true, - r = socket.read(message_type.as_mut_bytes()).fuse() => { - if r? == 0 { - return Ok(()) + let update_fut = std::pin::pin!(update); + let socket_read_fut = std::pin::pin!(socket.read(message_type.as_mut_bytes())); + + let read_result = std::future::poll_fn(|cx| { + // Check for update first if we're ready for it + if ready_for_update { + if let Poll::Ready(_) = update_fut.as_mut().poll(cx) { + update_ready = true; + return Poll::Ready(None); // No read result } + } + + // Check for socket read + if let Poll::Ready(r) = socket_read_fut.as_mut().poll(cx) { socket_ready = true; + return Poll::Ready(Some(r)); + } + + Poll::Pending + }).await; + + // Handle socket read result if needed + if let Some(r) = read_result { + if r? == 0 { + return Ok(()); } }