From 0c5d1832ae443e93589d85e4ddb5c516bd7d0cec Mon Sep 17 00:00:00 2001 From: Floris Bruynooghe Date: Thu, 30 Mar 2023 10:32:08 +0200 Subject: [PATCH 1/4] ref: make Context::alloc_ongoing return a guard This guard can be awaited and will resolve when Context::stop_ongoing is called, i.e. the ongoing process is cancelled. The guard will also free the ongoing process when dropped, making it RAII and easier to not accidentally free the ongoing process. --- src/configure.rs | 2 +- src/context.rs | 152 +++++++++++++++++++++++++++++++++---------- src/imex.rs | 2 +- src/imex/transfer.rs | 13 ++-- 4 files changed, 125 insertions(+), 44 deletions(-) diff --git a/src/configure.rs b/src/configure.rs index eb69ad25cd..d6589975a5 100644 --- a/src/configure.rs +++ b/src/configure.rs @@ -70,7 +70,7 @@ impl Context { let res = self .inner_configure() - .race(cancel_channel.recv().map(|_| { + .race(cancel_channel.map(|_| { progress!(self, 0); Ok(()) })) diff --git a/src/context.rs b/src/context.rs index a9bbd01ad1..8835aeb519 100644 --- a/src/context.rs +++ b/src/context.rs @@ -2,16 +2,19 @@ use std::collections::{BTreeMap, HashMap}; use std::ffi::OsString; +use std::future::Future; use std::ops::Deref; use std::path::{Path, PathBuf}; +use std::pin::Pin; use std::sync::atomic::AtomicBool; use std::sync::Arc; +use std::task::Poll; use std::time::{Duration, Instant, SystemTime}; use anyhow::{bail, ensure, Context as _, Result}; -use async_channel::{self as channel, Receiver, Sender}; +use async_channel::Sender; use ratelimit::Ratelimit; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::{oneshot, Mutex, RwLock}; use tokio::task; use crate::chat::{get_chat_cnt, ChatId}; @@ -257,7 +260,7 @@ pub(crate) struct DebugLogging { #[derive(Debug)] enum RunningState { /// Ongoing process is allocated. - Running { cancel_sender: Sender<()> }, + Running { cancel_sender: oneshot::Sender<()> }, /// Cancel signal has been sent, waiting for ongoing process to be freed. ShallStop, @@ -511,21 +514,35 @@ impl Context { /// This is for modal operations during which no other user actions are allowed. Only /// one such operation is allowed at any given time. /// - /// The return value is a cancel token, which will release the ongoing mutex when - /// dropped. - pub(crate) async fn alloc_ongoing(&self) -> Result> { + /// The return value is a guard which does two things: + /// + /// - It is a Future which will complete when the ongoing process is cancelled using + /// [`Context::stop_ongoing`] and must stop. + /// - It will free the ongoing process, aka release the mutex, when dropped. + pub(crate) async fn alloc_ongoing(&self) -> Result { let mut s = self.running_state.write().await; ensure!( matches!(*s, RunningState::Stopped), "There is already another ongoing process running." ); - let (sender, receiver) = channel::bounded(1); + let (cancel_tx, cancel_rx) = oneshot::channel(); *s = RunningState::Running { - cancel_sender: sender, + cancel_sender: cancel_tx, }; + let (drop_tx, drop_rx) = oneshot::channel(); + let context = self.clone(); + + tokio::spawn(async move { + drop_rx.await.ok(); + let mut s = context.running_state.write().await; + *s = RunningState::Stopped; + }); - Ok(receiver) + Ok(OngoingGuard { + cancel_rx, + drop_tx: Some(drop_tx), + }) } pub(crate) async fn free_ongoing(&self) { @@ -536,21 +553,24 @@ impl Context { /// Signal an ongoing process to stop. pub async fn stop_ongoing(&self) { let mut s = self.running_state.write().await; - match &*s { - RunningState::Running { cancel_sender } => { - if let Err(err) = cancel_sender.send(()).await { - warn!(self, "could not cancel ongoing: {:#}", err); - } - info!(self, "Signaling the ongoing process to stop ASAP.",); - *s = RunningState::ShallStop; - } + + // Take out the state so we can call the oneshot sender (which takes ownership). + let current_state = std::mem::replace(&mut *s, RunningState::ShallStop); + + match current_state { + RunningState::Running { cancel_sender } => match cancel_sender.send(()) { + Ok(()) => info!(self, "Signaling the ongoing process to stop ASAP."), + Err(()) => warn!(self, "could not cancel ongoing"), + }, RunningState::ShallStop | RunningState::Stopped => { + // Put back the current state + *s = current_state; info!(self, "No ongoing process to stop.",); } } } - #[allow(unused)] + #[cfg(test)] pub(crate) async fn shall_stop_ongoing(&self) -> bool { match &*self.running_state.read().await { RunningState::Running { .. } => false, @@ -945,6 +965,54 @@ impl Context { } } +/// Guard received when calling [`Context::alloc_ongoing`]. +/// +/// While holding this guard the ongoing mutex is held, dropping this guard frees the +/// ongoing process. +/// +/// The ongoing process can also be cancelled by unrelated code calling +/// [`Context::stop_ongoing`]. This guard implements [`Future`] and the future will +/// complete when the ongoing process is cancelled and must be aborted. Freeing the ongoing +/// process works as usual in this case: when this guard is dropped. So if you need to do +/// some more work before freeing make sure to keep ownership of the guard, e.g.: +/// +/// ```no_compile +/// let mut guard = context.alloc_ongoing().await?; +/// tokio::select!{ +/// biased; +/// _ = &mut guard => (), // guard is not moved, so we keep ownership. +/// _ = do_work() => (), +/// }; +/// do_cleaup().await; +/// drop(guard); +/// ``` +pub(crate) struct OngoingGuard { + /// Receives a message when the ongoing process should be cancelled. + cancel_rx: oneshot::Receiver<()>, + /// Used by `Drop` to send a message which will free the ongoing process. + drop_tx: Option>, +} + +impl Future for OngoingGuard { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + match Pin::new(&mut self.cancel_rx).poll(cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } +} + +impl Drop for OngoingGuard { + fn drop(&mut self) { + if let Some(sender) = self.drop_tx.take() { + // TODO: Maybe this should log? But we'd need to have a context. + sender.send(()).ok(); + } + } +} + /// Returns core version as a string. pub fn get_version_str() -> &'static str { &DC_VERSION_STR @@ -1409,38 +1477,52 @@ mod tests { async fn test_ongoing() -> Result<()> { let context = TestContext::new().await; - // No ongoing process allocated. + println!("No ongoing process allocated."); assert!(context.shall_stop_ongoing().await); - let receiver = context.alloc_ongoing().await?; + let mut guard = context.alloc_ongoing().await?; - // Cannot allocate another ongoing process while the first one is running. + println!("Cannot allocate another ongoing process while the first one is running."); assert!(context.alloc_ongoing().await.is_err()); - // Stop signal is not sent yet. - assert!(receiver.try_recv().is_err()); + println!("Stop signal is not sent yet."); + assert!(matches!(futures::poll!(&mut guard), Poll::Pending)); assert!(!context.shall_stop_ongoing().await); - // Send the stop signal. + println!("Send the stop signal."); context.stop_ongoing().await; - // Receive stop signal. - receiver.recv().await?; + println!("Receive stop signal."); + (&mut guard).await; assert!(context.shall_stop_ongoing().await); - // Ongoing process is still running even though stop signal was received, - // so another one cannot be allocated. + println!("Ongoing process still running even though stop signal was received"); assert!(context.alloc_ongoing().await.is_err()); - context.free_ongoing().await; - - // No ongoing process allocated, should have been stopped already. - assert!(context.shall_stop_ongoing().await); - - // Another ongoing process can be allocated now. - let _receiver = context.alloc_ongoing().await?; + println!("free the ongoing process"); + // context.free_ongoing().await; + drop(guard); + + println!("re-acquire the ongoing process"); + // Since the drop guard needs to send a message and the receiving task must run and + // acquire a lock this needs some time so won't succeed immediately. + #[allow(clippy::async_yields_async)] + let _guard = tokio::time::timeout(Duration::from_secs(10), async { + loop { + match context.alloc_ongoing().await { + Ok(guard) => break guard, + Err(_) => { + // tokio::task::yield_now() results in a lot hotter loop, it takes a + // lot of yields. + tokio::time::sleep(Duration::from_millis(1)).await; + } + } + } + }) + .await + .expect("timeout"); Ok(()) } diff --git a/src/imex.rs b/src/imex.rs index 885187cae9..0dc519d52a 100644 --- a/src/imex.rs +++ b/src/imex.rs @@ -94,7 +94,7 @@ pub async fn imex( let _guard = context.scheduler.pause(context.clone()).await; imex_inner(context, what, path, passphrase) .race(async { - cancel.recv().await.ok(); + cancel.await; Err(format_err!("canceled")) }) .await diff --git a/src/imex/transfer.rs b/src/imex/transfer.rs index d93cf4f6d6..3ef7a16fec 100644 --- a/src/imex/transfer.rs +++ b/src/imex/transfer.rs @@ -31,7 +31,6 @@ use std::pin::Pin; use std::task::Poll; use anyhow::{anyhow, bail, ensure, format_err, Context as _, Result}; -use async_channel::Receiver; use futures_lite::StreamExt; use iroh::get::{DataStream, Options}; use iroh::progress::ProgressEmitter; @@ -47,7 +46,7 @@ use tokio_stream::wrappers::ReadDirStream; use crate::blob::BlobDirContents; use crate::chat::delete_and_reset_all_device_msgs; -use crate::context::Context; +use crate::context::{Context, OngoingGuard}; use crate::qr::Qr; use crate::{e2ee, EventType}; @@ -91,7 +90,7 @@ impl BackupProvider { .context("Private key not available, aborting backup export")?; // Acquire global "ongoing" mutex. - let cancel_token = context.alloc_ongoing().await?; + let mut cancel_token = context.alloc_ongoing().await?; let paused_guard = context.scheduler.pause(context.clone()).await; let context_dir = context .get_blobdir() @@ -114,7 +113,7 @@ impl BackupProvider { }, } }, - _ = cancel_token.recv() => Err(format_err!("cancelled")), + _ = &mut cancel_token => Err(format_err!("cancelled")), }; let (provider, ticket) = match res { Ok((provider, ticket)) => (provider, ticket), @@ -188,7 +187,7 @@ impl BackupProvider { async fn watch_provider( context: &Context, mut provider: Provider, - cancel_token: Receiver<()>, + mut cancel_token: OngoingGuard, ) -> Result<()> { // _dbfile exists so we can clean up the file once it is no longer needed let mut events = provider.subscribe(); @@ -248,7 +247,7 @@ impl BackupProvider { } } }, - _ = cancel_token.recv() => { + _ = &mut cancel_token => { provider.shutdown(); break Err(anyhow!("BackupSender cancelled")); }, @@ -381,7 +380,7 @@ pub async fn get_backup(context: &Context, qr: Qr) -> Result<()> { context.free_ongoing().await; res } - _ = cancel_token.recv() => Err(format_err!("cancelled")), + _ = cancel_token => Err(format_err!("cancelled")), }; res } From 201d05d4fa069b2b9fb7a9d1243688ae5bd01464 Mon Sep 17 00:00:00 2001 From: Floris Bruynooghe Date: Thu, 30 Mar 2023 10:52:47 +0200 Subject: [PATCH 2/4] Remove Context::free_ongoing function This is now handled better by the Drop from the OngoingGuard returned by Context::alloc_ongoing. --- src/configure.rs | 6 ++---- src/context.rs | 5 ----- src/imex.rs | 5 ++--- src/imex/transfer.rs | 33 ++++++++++----------------------- 4 files changed, 14 insertions(+), 35 deletions(-) diff --git a/src/configure.rs b/src/configure.rs index d6589975a5..504454b4a2 100644 --- a/src/configure.rs +++ b/src/configure.rs @@ -66,18 +66,16 @@ impl Context { self.sql.is_open().await, "cannot configure, database not opened." ); - let cancel_channel = self.alloc_ongoing().await?; + let ongoing_guard = self.alloc_ongoing().await?; let res = self .inner_configure() - .race(cancel_channel.map(|_| { + .race(ongoing_guard.map(|_| { progress!(self, 0); Ok(()) })) .await; - self.free_ongoing().await; - if let Err(err) = res.as_ref() { progress!( self, diff --git a/src/context.rs b/src/context.rs index 8835aeb519..59879a1797 100644 --- a/src/context.rs +++ b/src/context.rs @@ -545,11 +545,6 @@ impl Context { }) } - pub(crate) async fn free_ongoing(&self) { - let mut s = self.running_state.write().await; - *s = RunningState::Stopped; - } - /// Signal an ongoing process to stop. pub async fn stop_ongoing(&self) { let mut s = self.running_state.write().await; diff --git a/src/imex.rs b/src/imex.rs index 0dc519d52a..94f79b8f56 100644 --- a/src/imex.rs +++ b/src/imex.rs @@ -88,18 +88,17 @@ pub async fn imex( path: &Path, passphrase: Option, ) -> Result<()> { - let cancel = context.alloc_ongoing().await?; + let ongoing_guard = context.alloc_ongoing().await?; let res = { let _guard = context.scheduler.pause(context.clone()).await; imex_inner(context, what, path, passphrase) .race(async { - cancel.await; + ongoing_guard.await; Err(format_err!("canceled")) }) .await }; - context.free_ongoing().await; if let Err(err) = res.as_ref() { // We are using Anyhow's .context() and to show the inner error, too, we need the {:#}: diff --git a/src/imex/transfer.rs b/src/imex/transfer.rs index 3ef7a16fec..2a93bf2681 100644 --- a/src/imex/transfer.rs +++ b/src/imex/transfer.rs @@ -90,7 +90,7 @@ impl BackupProvider { .context("Private key not available, aborting backup export")?; // Acquire global "ongoing" mutex. - let mut cancel_token = context.alloc_ongoing().await?; + let mut ongoing_guard = context.alloc_ongoing().await?; let paused_guard = context.scheduler.pause(context.clone()).await; let context_dir = context .get_blobdir() @@ -102,7 +102,7 @@ impl BackupProvider { warn!(context, "Previous database export deleted"); } let dbfile = TempPathGuard::new(dbfile); - let res = tokio::select! { + let (provider, ticket) = tokio::select! { biased; res = Self::prepare_inner(context, &dbfile) => { match res { @@ -113,20 +113,12 @@ impl BackupProvider { }, } }, - _ = &mut cancel_token => Err(format_err!("cancelled")), - }; - let (provider, ticket) = match res { - Ok((provider, ticket)) => (provider, ticket), - Err(err) => { - context.free_ongoing().await; - return Err(err); - } - }; + _ = &mut ongoing_guard => Err(format_err!("cancelled")), + }?; let handle = { let context = context.clone(); tokio::spawn(async move { - let res = Self::watch_provider(&context, provider, cancel_token).await; - context.free_ongoing().await; + let res = Self::watch_provider(&context, provider, ongoing_guard).await; // Explicit drop to move the guards into this future drop(paused_guard); @@ -189,7 +181,6 @@ impl BackupProvider { mut provider: Provider, mut cancel_token: OngoingGuard, ) -> Result<()> { - // _dbfile exists so we can clean up the file once it is no longer needed let mut events = provider.subscribe(); let mut total_size = 0; let mut current_size = 0; @@ -373,16 +364,12 @@ pub async fn get_backup(context: &Context, qr: Qr) -> Result<()> { let _guard = context.scheduler.pause(context.clone()).await; // Acquire global "ongoing" mutex. - let cancel_token = context.alloc_ongoing().await?; - let res = tokio::select! { + let mut cancel_token = context.alloc_ongoing().await?; + tokio::select! { biased; - res = get_backup_inner(context, qr) => { - context.free_ongoing().await; - res - } - _ = cancel_token => Err(format_err!("cancelled")), - }; - res + res = get_backup_inner(context, qr) => res, + _ = &mut cancel_token => Err(format_err!("cancelled")), + } } async fn get_backup_inner(context: &Context, qr: Qr) -> Result<()> { From 61b00f99916ff086685100821a3c511bf7e396b2 Mon Sep 17 00:00:00 2001 From: Floris Bruynooghe Date: Tue, 4 Apr 2023 14:44:29 +0200 Subject: [PATCH 3/4] typo --- src/imex.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imex.rs b/src/imex.rs index 9a1d4cb065..d749072423 100644 --- a/src/imex.rs +++ b/src/imex.rs @@ -95,7 +95,7 @@ pub async fn imex( imex_inner(context, what, path, passphrase) .race(async { ongoing_guard.await; - Err(format_err!("canceled")) + Err(format_err!("cancelled")) }) .await }; From b9fd5296bb0ff0e580c15faf3d77a3c35375b752 Mon Sep 17 00:00:00 2001 From: Floris Bruynooghe Date: Sat, 16 Dec 2023 17:00:51 +0100 Subject: [PATCH 4/4] Re-add info message using elapsed stopping time --- src/context.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/context.rs b/src/context.rs index 3298742d4e..58883aefd8 100644 --- a/src/context.rs +++ b/src/context.rs @@ -526,6 +526,9 @@ impl Context { tokio::spawn(async move { drop_rx.await.ok(); let mut s = context.running_state.write().await; + if let RunningState::ShallStop { request } = *s { + info!(context, "Ongoing stopped in {:?}", request.elapsed()); + } *s = RunningState::Stopped; });