Skip to content

Commit 3bd98c8

Browse files
andrewjcgfacebook-github-bot
authored andcommitted
Make SharedSpawnable support any non-borrow deref (#393)
Summary: Pull Request resolved: #393 This makes `SharedSpawnable` more generic and implements it for any type that can deref to a `ProcMesh`. The only side-effect here is that it requires the caller to do the `Arc::clone()`, rather than implicitly doing it in the impl. Reviewed By: mariusae Differential Revision: D77378974 fbshipit-source-id: 7db36cf6b83c29e317b80ad49ca34747aa60bc73
1 parent c2d1088 commit 3bd98c8

File tree

5 files changed

+20
-17
lines changed

5 files changed

+20
-17
lines changed

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#![allow(dead_code)] // until used publically
1010

1111
use std::ops::Deref;
12-
use std::sync::Arc;
1312

1413
use async_trait::async_trait;
1514
use hyperactor::Actor;
@@ -164,8 +163,8 @@ pub trait ActorMesh: Mesh {
164163
/// Given a shared ProcMesh, we can obtain a [`ActorMesh<'static, _>`]
165164
/// for it, useful when lifetime must be managed dynamically.
166165
enum ProcMeshRef<'a> {
167-
/// The reference is shared with an [`Arc`].
168-
Shared(Arc<ProcMesh>),
166+
/// The reference is shared without requiring a reference.
167+
Shared(Box<dyn Deref<Target = ProcMesh> + Sync + Send>),
169168
/// The reference is borrowed with a parameterized
170169
/// lifetime.
171170
Borrowed(&'a ProcMesh),
@@ -206,14 +205,14 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> {
206205
}
207206
}
208207

209-
pub(crate) fn new_shared(
210-
proc_mesh: Arc<ProcMesh>,
208+
pub(crate) fn new_shared<D: Deref<Target = ProcMesh> + Send + Sync + 'static>(
209+
proc_mesh: D,
211210
name: String,
212211
actor_supervision_rx: mpsc::UnboundedReceiver<ActorSupervisionEvent>,
213212
ranks: Vec<ActorRef<A>>,
214213
) -> Self {
215214
Self {
216-
proc_mesh: ProcMeshRef::Shared(proc_mesh),
215+
proc_mesh: ProcMeshRef::Shared(Box::new(proc_mesh)),
217216
name,
218217
ranks,
219218
actor_supervision_rx,
@@ -499,6 +498,7 @@ pub(crate) mod test_util {
499498

500499
#[cfg(test)]
501500
mod tests {
501+
use std::sync::Arc;
502502

503503
use hyperactor::ActorId;
504504
use hyperactor::PortRef;

hyperactor_mesh/src/comm.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,7 @@ mod tests {
738738
forward_port: tx.bind(),
739739
};
740740
let actor_mesh = proc_mesh
741+
.clone()
741742
.spawn::<TestActor>(dest_actor_name, &params)
742743
.await
743744
.unwrap();

hyperactor_mesh/src/proc_mesh.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
use std::collections::HashMap;
1010
use std::collections::HashSet;
1111
use std::fmt;
12+
use std::ops::Deref;
1213
use std::sync::Arc;
1314

1415
use async_trait::async_trait;
@@ -532,7 +533,7 @@ impl ProcEvents {
532533
#[async_trait]
533534
pub trait SharedSpawnable {
534535
async fn spawn<A: Actor + RemoteActor>(
535-
&self,
536+
self,
536537
actor_name: &str,
537538
params: &A::Params,
538539
) -> Result<RootActorMesh<'static, A>, anyhow::Error>
@@ -541,9 +542,9 @@ pub trait SharedSpawnable {
541542
}
542543

543544
#[async_trait]
544-
impl SharedSpawnable for Arc<ProcMesh> {
545+
impl<D: Deref<Target = ProcMesh> + Send + Sync + 'static> SharedSpawnable for D {
545546
async fn spawn<A: Actor + RemoteActor>(
546-
&self,
547+
self,
547548
actor_name: &str,
548549
params: &A::Params,
549550
) -> Result<RootActorMesh<'static, A>, anyhow::Error>
@@ -555,11 +556,13 @@ impl SharedSpawnable for Arc<ProcMesh> {
555556
// Instantiate supervision routing BEFORE spawning the actor mesh.
556557
self.actor_event_router.insert(actor_name.to_string(), tx);
557558
}
559+
let ranks =
560+
ProcMesh::spawn_on_procs::<A>(&self.client, self.agents(), actor_name, params).await?;
558561
Ok(RootActorMesh::new_shared(
559-
Arc::clone(self),
562+
self,
560563
actor_name.to_string(),
561564
rx,
562-
ProcMesh::spawn_on_procs::<A>(&self.client, self.agents(), actor_name, params).await?,
565+
ranks,
563566
))
564567
}
565568
}

monarch_extension/src/mesh_controller.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ impl _Controller {
184184
let workers: anyhow::Result<SharedCell<RootActorMesh<'_, WorkerActor>>> =
185185
signal_safe_block_on(py, async move {
186186
let workers = py_proc_mesh
187+
.clone()
187188
.spawn(&format!("tensor_engine_workers_{}", id), &param)
188189
.await?;
189190
//workers.cast(ndslice::Selection::True, )?;

monarch_hyperactor/src/proc_mesh.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,11 @@ impl PyProcMesh {
228228
let proc_mesh = Arc::clone(&self.inner);
229229
let keepalive = self.keepalive.clone();
230230
pyo3_async_runtimes::tokio::future_into_py(py, async move {
231+
let mailbox = proc_mesh.client().clone();
231232
let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?;
232233
let python_actor_mesh = PythonActorMesh {
233234
inner: actor_mesh,
234-
client: PyMailbox {
235-
inner: proc_mesh.client().clone(),
236-
},
235+
client: PyMailbox { inner: mailbox },
237236
_keepalive: keepalive,
238237
};
239238
Python::with_gil(|py| python_actor_mesh.into_py_any(py))
@@ -250,12 +249,11 @@ impl PyProcMesh {
250249
let proc_mesh = Arc::clone(&self.inner);
251250
let keepalive = self.keepalive.clone();
252251
signal_safe_block_on(py, async move {
252+
let mailbox = proc_mesh.client().clone();
253253
let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?;
254254
let python_actor_mesh = PythonActorMesh {
255255
inner: actor_mesh,
256-
client: PyMailbox {
257-
inner: proc_mesh.client().clone(),
258-
},
256+
client: PyMailbox { inner: mailbox },
259257
_keepalive: keepalive,
260258
};
261259
Python::with_gil(|py| python_actor_mesh.into_py_any(py))

0 commit comments

Comments
 (0)