Skip to content

Commit 8ccf12e

Browse files
andrewjcgfacebook-github-bot
authored andcommitted
Add SharedCell helper for wrapping types used in Python (#391)
Summary: Pull Request resolved: #391 Adds a helper class, implemented as a thin wrapper around `tokio::sync::RwLock`, to help manage Rust structs that get shared via Python wrappers. Access to the wrapped struct is supported through `.borrow()` and the object can be consumed via `.take()` when there are no outstanding borrows, leaving the cell in an unusable state. Reviewed By: mariusae, shayne-fletcher Differential Revision: D77244960 fbshipit-source-id: a7479b2df6189a328722c25be3d97d30887c05cf
1 parent 1f1a5a2 commit 8ccf12e

File tree

3 files changed

+265
-0
lines changed

3 files changed

+265
-0
lines changed

hyperactor_mesh/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ libc = "0.2.139"
4040
mockall = "0.13.1"
4141
ndslice = { version = "0.0.0", path = "../ndslice" }
4242
nix = { version = "0.29.0", features = ["dir", "event", "hostname", "inotify", "ioctl", "mman", "mount", "net", "poll", "ptrace", "reboot", "resource", "sched", "signal", "term", "time", "user", "zerocopy"] }
43+
preempt_rwlock = { version = "0.0.0", path = "../preempt_rwlock" }
4344
rand = { version = "0.8", features = ["small_rng"] }
4445
serde = { version = "1.0.185", features = ["derive", "rc"] }
4546
serde_bytes = "0.11"

hyperactor_mesh/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ pub mod mesh_selection;
2424
mod metrics;
2525
pub mod proc_mesh;
2626
pub mod reference;
27+
pub mod shared_cell;
2728
pub mod shortuuid;
2829
pub mod test_utils;
2930

hyperactor_mesh/src/shared_cell.rs

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
use std::fmt::Debug;
10+
use std::ops::Deref;
11+
use std::sync::Arc;
12+
use std::sync::atomic::AtomicUsize;
13+
use std::sync::atomic::Ordering;
14+
15+
use async_trait::async_trait;
16+
use dashmap::DashMap;
17+
use futures::future::try_join_all;
18+
use preempt_rwlock::OwnedPreemptibleRwLockReadGuard;
19+
use preempt_rwlock::PreemptibleRwLock;
20+
use tokio::sync::TryLockError;
21+
22+
#[derive(thiserror::Error, Debug)]
23+
pub struct EmptyCellError {}
24+
25+
impl From<TryLockError> for EmptyCellError {
26+
fn from(_err: TryLockError) -> Self {
27+
Self {}
28+
}
29+
}
30+
31+
impl std::fmt::Display for EmptyCellError {
32+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
33+
write!(f, "already taken")
34+
}
35+
}
36+
37+
#[derive(thiserror::Error, Debug)]
38+
pub enum TryTakeError {
39+
#[error("already taken")]
40+
Empty,
41+
#[error("cannot lock: {0}")]
42+
TryLockError(#[from] TryLockError),
43+
}
44+
45+
struct PoolRef {
46+
map: Arc<DashMap<usize, Arc<dyn SharedCellDiscard + Send + Sync>>>,
47+
key: usize,
48+
}
49+
50+
impl Debug for PoolRef {
51+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52+
f.debug_struct("PoolRef").field("key", &self.key).finish()
53+
}
54+
}
55+
56+
#[derive(Debug)]
57+
struct Inner<T> {
58+
value: Option<T>,
59+
pool: Option<PoolRef>,
60+
}
61+
62+
impl<T> Drop for Inner<T> {
63+
fn drop(&mut self) {
64+
if let Some(pool) = &self.pool {
65+
pool.map.remove(&pool.key);
66+
}
67+
}
68+
}
69+
70+
/// A wrapper class that facilitates sharing an item across different users, supporting:
71+
/// - Ability grab a reference-counted reference to the item
72+
/// - Ability to consume the item, leaving the cell in an unusable state
73+
#[derive(Debug)]
74+
pub struct SharedCell<T> {
75+
inner: Arc<PreemptibleRwLock<Inner<T>>>,
76+
}
77+
78+
impl<T> Clone for SharedCell<T> {
79+
fn clone(&self) -> Self {
80+
Self {
81+
inner: self.inner.clone(),
82+
}
83+
}
84+
}
85+
86+
impl<T> From<T> for SharedCell<T> {
87+
fn from(value: T) -> Self {
88+
Self {
89+
inner: Arc::new(PreemptibleRwLock::new(Inner {
90+
value: Some(value),
91+
pool: None,
92+
})),
93+
}
94+
}
95+
}
96+
97+
impl<T> SharedCell<T> {
98+
fn with_pool(value: T, pool: PoolRef) -> Self {
99+
Self {
100+
inner: Arc::new(PreemptibleRwLock::new(Inner {
101+
value: Some(value),
102+
pool: Some(pool),
103+
})),
104+
}
105+
}
106+
}
107+
108+
pub struct SharedCellRef<T, U = T> {
109+
guard: OwnedPreemptibleRwLockReadGuard<Inner<T>, U>,
110+
}
111+
112+
impl<T> SharedCellRef<T> {
113+
fn from(guard: OwnedPreemptibleRwLockReadGuard<Inner<T>>) -> Result<Self, EmptyCellError> {
114+
if guard.value.is_none() {
115+
return Err(EmptyCellError {});
116+
}
117+
Ok(Self {
118+
guard: OwnedPreemptibleRwLockReadGuard::map(guard, |guard| {
119+
guard.value.as_ref().unwrap()
120+
}),
121+
})
122+
}
123+
124+
pub fn map<F, U>(self, f: F) -> SharedCellRef<T, U>
125+
where
126+
F: FnOnce(&T) -> &U,
127+
{
128+
SharedCellRef {
129+
guard: OwnedPreemptibleRwLockReadGuard::map(self.guard, f),
130+
}
131+
}
132+
133+
pub async fn preempted(&self) {
134+
self.guard.preempted().await
135+
}
136+
}
137+
138+
impl<T, U: Debug> Debug for SharedCellRef<T, U> {
139+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140+
Debug::fmt(&**self, f)
141+
}
142+
}
143+
144+
impl<T, U> Deref for SharedCellRef<T, U> {
145+
type Target = U;
146+
147+
fn deref(&self) -> &Self::Target {
148+
&self.guard
149+
}
150+
}
151+
152+
impl<T> SharedCell<T> {
153+
/// Borrow the cell, returning a reference to the item. If the cell is empty, returns an error.
154+
/// While references are held, the cell cannot be taken below.
155+
pub fn borrow(&self) -> Result<SharedCellRef<T>, EmptyCellError> {
156+
SharedCellRef::from(self.inner.clone().try_read_owned()?)
157+
}
158+
159+
/// Take the item out of the cell, leaving it in an unusable state.
160+
pub async fn take(&self) -> Result<T, EmptyCellError> {
161+
let mut inner = self.inner.write(true).await;
162+
inner.value.take().ok_or(EmptyCellError {})
163+
}
164+
165+
pub fn blocking_take(&self) -> Result<T, EmptyCellError> {
166+
let mut inner = self.inner.blocking_write(true);
167+
inner.value.take().ok_or(EmptyCellError {})
168+
}
169+
170+
pub fn try_take(&self) -> Result<T, TryTakeError> {
171+
let mut inner = self.inner.try_write(true)?;
172+
inner.value.take().ok_or(TryTakeError::Empty)
173+
}
174+
}
175+
176+
/// A pool of `SharedCell`s which can be used to mass `take()` and discard them all at once.
177+
pub struct SharedCellPool {
178+
map: Arc<DashMap<usize, Arc<dyn SharedCellDiscard + Send + Sync>>>,
179+
token: AtomicUsize,
180+
}
181+
182+
impl SharedCellPool {
183+
pub fn new() -> Self {
184+
Self {
185+
map: Arc::new(DashMap::new()),
186+
token: AtomicUsize::new(0),
187+
}
188+
}
189+
190+
pub fn insert<T>(&self, value: T) -> SharedCell<T>
191+
where
192+
T: Send + Sync + 'static,
193+
{
194+
let map = self.map.clone();
195+
let key = self.token.fetch_add(1, Ordering::Relaxed);
196+
let pool = PoolRef { map, key };
197+
let value: SharedCell<_> = SharedCell::with_pool(value, pool);
198+
self.map.entry(key).insert(Arc::new(value.clone()));
199+
value
200+
}
201+
202+
/// Run `take` on all cells in the pool and immediately drop them.
203+
pub async fn discard_all(self) -> Result<(), EmptyCellError> {
204+
try_join_all(
205+
self.map
206+
.iter()
207+
.map(|r| async move { r.value().discard().await }),
208+
)
209+
.await?;
210+
Ok(())
211+
}
212+
}
213+
214+
/// Trait to facilitate storing `SharedCell`s of different types in a single pool.
215+
#[async_trait]
216+
pub trait SharedCellDiscard {
217+
async fn discard(&self) -> Result<(), EmptyCellError>;
218+
fn blocking_discard(&self) -> Result<(), EmptyCellError>;
219+
fn try_discard(&self) -> Result<(), TryTakeError>;
220+
}
221+
222+
#[async_trait]
223+
impl<T: Send + Sync> SharedCellDiscard for SharedCell<T> {
224+
fn try_discard(&self) -> Result<(), TryTakeError> {
225+
self.try_take()?;
226+
Ok(())
227+
}
228+
229+
async fn discard(&self) -> Result<(), EmptyCellError> {
230+
self.take().await?;
231+
Ok(())
232+
}
233+
234+
fn blocking_discard(&self) -> Result<(), EmptyCellError> {
235+
self.blocking_take()?;
236+
Ok(())
237+
}
238+
}
239+
240+
#[cfg(test)]
241+
mod tests {
242+
use anyhow::Result;
243+
244+
use super::*;
245+
246+
#[tokio::test]
247+
async fn borrow_after_take() -> Result<()> {
248+
let cell = SharedCell::from(0);
249+
let _ = cell.take().await;
250+
assert!(cell.borrow().is_err());
251+
Ok(())
252+
}
253+
254+
#[tokio::test]
255+
async fn take_after_borrow() -> Result<()> {
256+
let cell = SharedCell::from(0);
257+
let b = cell.borrow()?;
258+
assert!(cell.try_take().is_err());
259+
std::mem::drop(b);
260+
cell.try_take()?;
261+
Ok(())
262+
}
263+
}

0 commit comments

Comments
 (0)