Skip to content

Commit 1f1a5a2

Browse files
andrewjcgfacebook-github-bot
authored andcommitted
Add a preemptible RWLock utility (#409)
Summary: Pull Request resolved: #409 This adds a `tokio::sync::RWLock` wrapper struct which allows waiting writers to "preempt" and interrupt readers that hold the lock. This is supported by a `async fn preempt(&self)` helper on the read guard that readers can use in a `tokio::select!` to poll for interruptions and vacate their lock guard. Reviewed By: suo Differential Revision: D77635672 fbshipit-source-id: 5c1f195b9984f4ee33717019627bec98d082a0c0
1 parent 79b07c2 commit 1f1a5a2

File tree

2 files changed

+251
-0
lines changed

2 files changed

+251
-0
lines changed

preempt_rwlock/Cargo.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# @generated by autocargo from //monarch/preempt_rwlock:preempt_rwlock
2+
3+
[package]
4+
name = "preempt_rwlock"
5+
version = "0.0.0"
6+
authors = ["Meta"]
7+
edition = "2021"
8+
license = "BSD-3-Clause"
9+
10+
[dependencies]
11+
tokio = { version = "1.45.0", features = ["full", "test-util", "tracing"] }
12+
13+
[dev-dependencies]
14+
anyhow = "1.0.98"

preempt_rwlock/src/lib.rs

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#![feature(future_join)]
10+
11+
use std::sync::Arc;
12+
13+
use tokio::sync::OwnedRwLockReadGuard;
14+
use tokio::sync::RwLock;
15+
use tokio::sync::RwLockReadGuard;
16+
use tokio::sync::RwLockWriteGuard;
17+
use tokio::sync::TryLockError;
18+
use tokio::sync::watch;
19+
20+
pub struct PreemptibleRwLockReadGuard<'a, T: Sized> {
21+
preemptor: &'a watch::Receiver<usize>,
22+
guard: RwLockReadGuard<'a, T>,
23+
}
24+
25+
impl<'a, T: Sized> PreemptibleRwLockReadGuard<'a, T> {
26+
pub async fn preempted(&self) {
27+
let mut preemptor = self.preemptor.clone();
28+
// Wait for pending writers.
29+
preemptor.wait_for(|&v| v > 0).await.expect("wait_for fail");
30+
}
31+
}
32+
33+
impl<T: Sized> std::ops::Deref for PreemptibleRwLockReadGuard<'_, T> {
34+
type Target = T;
35+
fn deref(&self) -> &T {
36+
self.guard.deref()
37+
}
38+
}
39+
40+
pub struct OwnedPreemptibleRwLockReadGuard<T: ?Sized, U: ?Sized = T> {
41+
preemptor: watch::Receiver<usize>,
42+
guard: OwnedRwLockReadGuard<T, U>,
43+
}
44+
45+
impl<T: ?Sized, U: ?Sized> OwnedPreemptibleRwLockReadGuard<T, U> {
46+
pub async fn preempted(&self) {
47+
let mut preemptor = self.preemptor.clone();
48+
// Wait for pending writers.
49+
preemptor.wait_for(|&v| v > 0).await.expect("wait_for fail");
50+
}
51+
52+
/// Maps the data guarded by this lock with a function.
53+
///
54+
/// This is similar to the `map` method on `OwnedRwLockReadGuard`, but preserves
55+
/// the preemption capability.
56+
pub fn map<F, V>(self, f: F) -> OwnedPreemptibleRwLockReadGuard<T, V>
57+
where
58+
F: FnOnce(&U) -> &V,
59+
{
60+
OwnedPreemptibleRwLockReadGuard {
61+
preemptor: self.preemptor,
62+
guard: OwnedRwLockReadGuard::map(self.guard, f),
63+
}
64+
}
65+
}
66+
67+
impl<T: ?Sized, U: ?Sized> std::ops::Deref for OwnedPreemptibleRwLockReadGuard<T, U> {
68+
type Target = U;
69+
fn deref(&self) -> &U {
70+
self.guard.deref()
71+
}
72+
}
73+
74+
pub struct PreemptibleRwLockWriteGuard<'a, T: Sized> {
75+
preemptor: &'a watch::Sender<usize>,
76+
guard: RwLockWriteGuard<'a, T>,
77+
preempt_readers: bool,
78+
}
79+
80+
impl<'a, T: Sized> Drop for PreemptibleRwLockWriteGuard<'a, T> {
81+
fn drop(&mut self) {
82+
if self.preempt_readers {
83+
self.preemptor.send_if_modified(|v| {
84+
*v -= 1;
85+
// No need to send a change event when decrementing.
86+
false
87+
});
88+
}
89+
}
90+
}
91+
92+
impl<T: Sized> std::ops::Deref for PreemptibleRwLockWriteGuard<'_, T> {
93+
type Target = T;
94+
fn deref(&self) -> &T {
95+
self.guard.deref()
96+
}
97+
}
98+
99+
impl<T: Sized> std::ops::DerefMut for PreemptibleRwLockWriteGuard<'_, T> {
100+
fn deref_mut(&mut self) -> &mut Self::Target {
101+
self.guard.deref_mut()
102+
}
103+
}
104+
105+
/// A RW-lock which also supports a way for pending writers to request that
106+
/// readers get preempted, via `preempted()` method on the read guard that
107+
/// readers can `tokio::select!` on.
108+
#[derive(Debug)]
109+
pub struct PreemptibleRwLock<T: Sized> {
110+
lock: Arc<RwLock<T>>,
111+
preemptor_lock: RwLock<()>,
112+
// Used to track the number of writers waiting to acquire the lock and
113+
// allows readers to `await` on updates to this value to support preemption.
114+
preemptor: (watch::Sender<usize>, watch::Receiver<usize>),
115+
}
116+
117+
impl<T: Sized> PreemptibleRwLock<T> {
118+
pub fn new(item: T) -> Self {
119+
PreemptibleRwLock {
120+
lock: Arc::new(RwLock::new(item)),
121+
preemptor_lock: RwLock::new(()),
122+
preemptor: watch::channel(0),
123+
}
124+
}
125+
126+
pub async fn read<'a>(&'a self) -> PreemptibleRwLockReadGuard<'a, T> {
127+
let _guard = self.preemptor_lock.read().await;
128+
PreemptibleRwLockReadGuard {
129+
preemptor: &self.preemptor.1,
130+
guard: self.lock.read().await,
131+
}
132+
}
133+
134+
pub async fn read_owned(self: Arc<Self>) -> OwnedPreemptibleRwLockReadGuard<T> {
135+
let _guard = self.preemptor_lock.read().await;
136+
OwnedPreemptibleRwLockReadGuard {
137+
preemptor: self.preemptor.1.clone(),
138+
guard: self.lock.clone().read_owned().await,
139+
}
140+
}
141+
142+
pub fn try_read_owned(
143+
self: Arc<Self>,
144+
) -> Result<OwnedPreemptibleRwLockReadGuard<T>, TryLockError> {
145+
let _guard = self.preemptor_lock.try_read()?;
146+
Ok(OwnedPreemptibleRwLockReadGuard {
147+
preemptor: self.preemptor.1.clone(),
148+
guard: self.lock.clone().try_read_owned()?,
149+
})
150+
}
151+
152+
pub async fn write<'a>(&'a self, preempt_readers: bool) -> PreemptibleRwLockWriteGuard<'a, T> {
153+
let _guard = self.preemptor_lock.write().await;
154+
if preempt_readers {
155+
self.preemptor.0.send_if_modified(|v| {
156+
*v += 1;
157+
// Only send a change event if we're the first pending writer.
158+
*v == 1
159+
});
160+
}
161+
PreemptibleRwLockWriteGuard {
162+
preemptor: &self.preemptor.0,
163+
guard: self.lock.write().await,
164+
preempt_readers,
165+
}
166+
}
167+
pub fn try_write<'a>(
168+
&'a self,
169+
preempt_readers: bool,
170+
) -> Result<PreemptibleRwLockWriteGuard<'a, T>, TryLockError> {
171+
let _guard = self.preemptor_lock.try_write()?;
172+
if preempt_readers {
173+
self.preemptor.0.send_if_modified(|v| {
174+
*v += 1;
175+
// Only send a change event if we're the first pending writer.
176+
*v == 1
177+
});
178+
}
179+
Ok(PreemptibleRwLockWriteGuard {
180+
preemptor: &self.preemptor.0,
181+
guard: self.lock.try_write()?,
182+
preempt_readers,
183+
})
184+
}
185+
186+
pub fn blocking_write<'a>(
187+
&'a self,
188+
preempt_readers: bool,
189+
) -> PreemptibleRwLockWriteGuard<'a, T> {
190+
let _guard = self.preemptor_lock.blocking_write();
191+
if preempt_readers {
192+
self.preemptor.0.send_if_modified(|v| {
193+
*v += 1;
194+
// Only send a change event if we're the first pending writer.
195+
*v == 1
196+
});
197+
}
198+
PreemptibleRwLockWriteGuard {
199+
preemptor: &self.preemptor.0,
200+
guard: self.lock.blocking_write(),
201+
preempt_readers,
202+
}
203+
}
204+
}
205+
206+
#[cfg(test)]
207+
mod tests {
208+
use std::future::join;
209+
use std::time::Duration;
210+
211+
use anyhow::Result;
212+
213+
use super::*;
214+
215+
#[tokio::test]
216+
#[allow(clippy::disallowed_methods)]
217+
async fn test_preempt_reader() -> Result<()> {
218+
let lock = PreemptibleRwLock::new(());
219+
220+
let reader = lock.read().await;
221+
222+
join!(
223+
async move {
224+
loop {
225+
tokio::select!(
226+
_ = reader.preempted() => break,
227+
_ = tokio::time::sleep(Duration::from_secs(100)) => (),
228+
)
229+
}
230+
},
231+
lock.write(/* preempt_readers */ true),
232+
)
233+
.await;
234+
235+
Ok(())
236+
}
237+
}

0 commit comments

Comments
 (0)