Skip to content

Commit 7ee3afb

Browse files
pzhan9facebook-github-bot
authored andcommitted
Add selection parameter to PythonActorMesh.cast (#431)
Summary: Pull Request resolved: #431 we want to allow python side to pass a selection parameter, rather than always `select all`. For example, the most common use cases is `def call_one`, which we want to pass `select any`. Reviewed By: mariusae Differential Revision: D77747932 fbshipit-source-id: a7e07f266e319e564c9517384cb46d32da1c5fdc
1 parent ab275d4 commit 7ee3afb

File tree

4 files changed

+17
-9
lines changed

4 files changed

+17
-9
lines changed

monarch_hyperactor/src/actor_mesh.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use crate::actor::PythonMessage;
2121
use crate::mailbox::PyMailbox;
2222
use crate::proc::PyActorId;
2323
use crate::proc_mesh::Keepalive;
24+
use crate::selection::PySelection;
2425
use crate::shape::PyShape;
2526

2627
#[pyclass(
@@ -43,10 +44,9 @@ impl PythonActorMesh {
4344

4445
#[pymethods]
4546
impl PythonActorMesh {
46-
fn cast(&self, message: &PythonMessage) -> PyResult<()> {
47-
use ndslice::selection::dsl::*;
47+
fn cast(&self, selection: &PySelection, message: &PythonMessage) -> PyResult<()> {
4848
self.try_inner()?
49-
.cast(all(true_()), message.clone())
49+
.cast(selection.inner().clone(), message.clone())
5050
.map_err(|err| PyException::new_err(err.to_string()))?;
5151
Ok(())
5252
}

monarch_hyperactor/src/selection.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ pub struct PySelection {
2020
inner: Selection,
2121
}
2222

23+
impl PySelection {
24+
pub(crate) fn inner(&self) -> &Selection {
25+
&self.inner
26+
}
27+
}
28+
2329
impl From<Selection> for PySelection {
2430
fn from(inner: Selection) -> Self {
2531
Self { inner }

python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,15 @@ from typing import final
1010

1111
from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
1212
from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
13-
1413
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
15-
14+
from monarch._rust_bindings.monarch_hyperactor.selection import Selection
1615
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
1716

1817
@final
1918
class PythonActorMesh:
20-
def cast(self, message: PythonMessage) -> None:
19+
def cast(self, selection: Selection, message: PythonMessage) -> None:
2120
"""
22-
Cast a message to this mesh.
21+
Cast a message to the selected actors in the mesh.
2322
"""
2423

2524
def get(self, rank: int) -> ActorId | None:

python/tests/_monarch/test_mailbox.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
PortRef,
2626
)
2727
from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh
28+
from monarch._rust_bindings.monarch_hyperactor.selection import Selection
2829
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
2930

30-
3131
S = TypeVar("S")
3232
U = TypeVar("U")
3333

@@ -157,7 +157,10 @@ def my_reduce(state: str, update: str) -> str:
157157
handle, receiver = proc_mesh.client.open_accum_port(accumulator)
158158
port_ref = handle.bind()
159159

160-
actor_mesh.cast(PythonMessage("echo", pickle.dumps("start"), port_ref, None))
160+
actor_mesh.cast(
161+
Selection.from_string("*"),
162+
PythonMessage("echo", pickle.dumps("start"), port_ref, None),
163+
)
161164

162165
messge = await asyncio.wait_for(receiver.recv(), timeout=5)
163166
value = cast(str, pickle.loads(messge.message))

0 commit comments

Comments
 (0)