Skip to content

Commit 825ade0

Browse files
pzhan9facebook-github-bot
authored andcommitted
Add PythonActorMeshRef (#531)
Summary: Pull Request resolved: #531 `PythonActorMeshRef` is used to bind Rust `struct ActorMeshRef` to the python side. Reviewed By: shayne-fletcher Differential Revision: D78284705 fbshipit-source-id: ec5dc8f7595da32ef0646fb8f2a3299d8ea5f2c4
1 parent 665853f commit 825ade0

File tree

10 files changed

+242
-35
lines changed

10 files changed

+242
-35
lines changed

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ pub trait ActorMesh: Mesh<Id = ActorMeshId> {
144144
self.name().to_string(),
145145
),
146146
self.shape().clone(),
147-
self.proc_mesh().shape().clone(),
148147
self.proc_mesh().comm_actor().clone(),
149148
)
150149
}

hyperactor_mesh/src/reference.rs

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,6 @@ pub struct ActorMeshId(pub ProcMeshId, pub String);
7575
pub struct ActorMeshRef<A: RemoteActor> {
7676
pub(crate) mesh_id: ActorMeshId,
7777
shape: Shape,
78-
/// The shape of the underlying Proc Mesh.
79-
proc_mesh_shape: Shape,
8078
/// The reference to the comm actor of the underlying Proc Mesh.
8179
comm_actor_ref: ActorRef<CommActor>,
8280
phantom: PhantomData<A>,
@@ -90,13 +88,11 @@ impl<A: RemoteActor> ActorMeshRef<A> {
9088
pub(crate) fn attest(
9189
mesh_id: ActorMeshId,
9290
shape: Shape,
93-
proc_mesh_shape: Shape,
9491
comm_actor_ref: ActorRef<CommActor>,
9592
) -> Self {
9693
Self {
9794
mesh_id,
9895
shape,
99-
proc_mesh_shape,
10096
comm_actor_ref,
10197
phantom: PhantomData,
10298
}
@@ -142,7 +138,6 @@ impl<A: RemoteActor> Clone for ActorMeshRef<A> {
142138
Self {
143139
mesh_id: self.mesh_id.clone(),
144140
shape: self.shape.clone(),
145-
proc_mesh_shape: self.proc_mesh_shape.clone(),
146141
comm_actor_ref: self.comm_actor_ref.clone(),
147142
phantom: PhantomData,
148143
}
@@ -161,12 +156,11 @@ impl<A: RemoteActor> Eq for ActorMeshRef<A> {}
161156
mod tests {
162157
use async_trait::async_trait;
163158
use hyperactor::Actor;
159+
use hyperactor::Bind;
164160
use hyperactor::Context;
165161
use hyperactor::Handler;
166162
use hyperactor::PortRef;
167-
use hyperactor::message::Bind;
168-
use hyperactor::message::Bindings;
169-
use hyperactor::message::Unbind;
163+
use hyperactor::Unbind;
170164
use hyperactor_mesh_macros::sel;
171165
use ndslice::shape;
172166

@@ -183,11 +177,11 @@ mod tests {
183177
shape! { replica = 4 }
184178
}
185179

186-
#[derive(Debug, Serialize, Deserialize, Named, Clone)]
180+
#[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
187181
struct MeshPingPongMessage(
188182
/*ttl:*/ u64,
189183
ActorMeshRef<MeshPingPongActor>,
190-
/*completed port:*/ PortRef<bool>,
184+
/*completed port:*/ #[binding(include)] PortRef<bool>,
191185
);
192186

193187
#[derive(Debug, Clone)]
@@ -203,7 +197,6 @@ mod tests {
203197
struct MeshPingPongActorParams {
204198
mesh_id: ActorMeshId,
205199
shape: Shape,
206-
proc_mesh_shape: Shape,
207200
comm_actor_ref: ActorRef<CommActor>,
208201
}
209202

@@ -213,12 +206,7 @@ mod tests {
213206

214207
async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
215208
Ok(Self {
216-
mesh_ref: ActorMeshRef::attest(
217-
params.mesh_id,
218-
params.shape,
219-
params.proc_mesh_shape,
220-
params.comm_actor_ref,
221-
),
209+
mesh_ref: ActorMeshRef::attest(params.mesh_id, params.shape, params.comm_actor_ref),
222210
})
223211
}
224212
}
@@ -240,18 +228,6 @@ mod tests {
240228
}
241229
}
242230

243-
impl Unbind for MeshPingPongMessage {
244-
fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
245-
self.2.unbind(bindings)
246-
}
247-
}
248-
249-
impl Bind for MeshPingPongMessage {
250-
fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
251-
self.2.bind(bindings)
252-
}
253-
}
254-
255231
#[tokio::test]
256232
async fn test_inter_mesh_ping_pong() {
257233
let alloc_ping = LocalAllocator
@@ -278,7 +254,6 @@ mod tests {
278254
"ping".to_string(),
279255
),
280256
shape: ping_proc_mesh.shape().clone(),
281-
proc_mesh_shape: ping_proc_mesh.shape().clone(),
282257
comm_actor_ref: ping_proc_mesh.comm_actor().clone(),
283258
},
284259
)
@@ -296,7 +271,6 @@ mod tests {
296271
"pong".to_string(),
297272
),
298273
shape: pong_proc_mesh.shape().clone(),
299-
proc_mesh_shape: pong_proc_mesh.shape().clone(),
300274
comm_actor_ref: pong_proc_mesh.comm_actor().clone(),
301275
},
302276
)

monarch_hyperactor/src/actor_mesh.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,18 @@ use hyperactor_mesh::Mesh;
1717
use hyperactor_mesh::RootActorMesh;
1818
use hyperactor_mesh::actor_mesh::ActorMesh;
1919
use hyperactor_mesh::actor_mesh::ActorSupervisionEvents;
20+
use hyperactor_mesh::reference::ActorMeshRef;
2021
use hyperactor_mesh::shared_cell::SharedCell;
2122
use hyperactor_mesh::shared_cell::SharedCellRef;
2223
use pyo3::exceptions::PyEOFError;
2324
use pyo3::exceptions::PyException;
25+
use pyo3::exceptions::PyNotImplementedError;
2426
use pyo3::exceptions::PyRuntimeError;
2527
use pyo3::exceptions::PyValueError;
2628
use pyo3::prelude::*;
29+
use pyo3::types::PyBytes;
30+
use serde::Deserialize;
31+
use serde::Serialize;
2732
use tokio::sync::Mutex;
2833

2934
use crate::actor::PythonActor;
@@ -101,6 +106,14 @@ impl PythonActorMesh {
101106
.borrow()
102107
.map_err(|_| PyRuntimeError::new_err("`PythonActorMesh` has already been stopped"))
103108
}
109+
110+
fn pickling_err(&self) -> PyErr {
111+
PyErr::new::<PyNotImplementedError, _>(
112+
"PythonActorMesh cannot be pickled. If applicable, use bind() \
113+
to get a PythonActorMeshRef, and use that instead."
114+
.to_string(),
115+
)
116+
}
104117
}
105118

106119
#[pymethods]
@@ -123,6 +136,11 @@ impl PythonActorMesh {
123136
Ok(())
124137
}
125138

139+
fn bind(&self) -> PyResult<PythonActorMeshRef> {
140+
let mesh = self.try_inner()?;
141+
Ok(PythonActorMeshRef { inner: mesh.bind() })
142+
}
143+
126144
fn get_supervision_event(&self) -> PyResult<Option<PyActorSupervisionEvent>> {
127145
let unhealthy_event = self
128146
.unhealthy_event
@@ -169,6 +187,64 @@ impl PythonActorMesh {
169187
fn shape(&self) -> PyResult<PyShape> {
170188
Ok(PyShape::from(self.try_inner()?.shape().clone()))
171189
}
190+
191+
// Override the pickling methods to provide a meaningful error message.
192+
fn __reduce__(&self) -> PyResult<()> {
193+
Err(self.pickling_err())
194+
}
195+
196+
fn __reduce_ex__(&self, _proto: u8) -> PyResult<()> {
197+
Err(self.pickling_err())
198+
}
199+
}
200+
201+
#[pyclass(
202+
frozen,
203+
name = "PythonActorMeshRef",
204+
module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
205+
)]
206+
#[derive(Debug, Serialize, Deserialize)]
207+
pub(super) struct PythonActorMeshRef {
208+
inner: ActorMeshRef<PythonActor>,
209+
}
210+
211+
#[pymethods]
212+
impl PythonActorMeshRef {
213+
fn cast(
214+
&self,
215+
client: &PyMailbox,
216+
selection: &PySelection,
217+
message: &PythonMessage,
218+
) -> PyResult<()> {
219+
self.inner
220+
.cast(&client.inner, selection.inner().clone(), message.clone())
221+
.map_err(|err| PyException::new_err(err.to_string()))?;
222+
Ok(())
223+
}
224+
225+
#[getter]
226+
fn shape(&self) -> PyShape {
227+
PyShape::from(self.inner.shape().clone())
228+
}
229+
230+
#[staticmethod]
231+
fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
232+
bincode::deserialize(bytes.as_bytes())
233+
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))
234+
}
235+
236+
fn __reduce__<'py>(
237+
slf: &Bound<'py, Self>,
238+
) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> {
239+
let bytes = bincode::serialize(&*slf.borrow())
240+
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
241+
let py_bytes = PyBytes::new(slf.py(), &bytes);
242+
Ok((slf.as_any().getattr("from_bytes")?, (py_bytes,)))
243+
}
244+
245+
fn __repr__(&self) -> String {
246+
format!("{:?}", self)
247+
}
172248
}
173249

174250
impl Drop for PythonActorMesh {
@@ -379,6 +455,7 @@ impl From<ActorSupervisionEvent> for PyActorSupervisionEvent {
379455

380456
pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
381457
hyperactor_mod.add_class::<PythonActorMesh>()?;
458+
hyperactor_mod.add_class::<PythonActorMeshRef>()?;
382459
hyperactor_mod.add_class::<PyActorMeshMonitor>()?;
383460
hyperactor_mod.add_class::<MonitoredPythonPortReceiver>()?;
384461
hyperactor_mod.add_class::<MonitoredPythonOncePortReceiver>()?;

monarch_hyperactor/src/shape.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use crate::ndslice::PySlice;
2121
module = "monarch._rust_bindings.monarch_hyperactor.shape",
2222
frozen
2323
)]
24+
#[derive(Clone)]
2425
pub struct PyShape {
2526
pub(super) inner: Shape,
2627
}
@@ -112,6 +113,14 @@ impl PyShape {
112113
self.inner.slice().len()
113114
}
114115

116+
fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
117+
if let Ok(other) = other.extract::<PyShape>() {
118+
Ok(self.inner == other.inner)
119+
} else {
120+
Ok(false)
121+
}
122+
}
123+
115124
#[staticmethod]
116125
fn unity() -> PyShape {
117126
Shape::unity().into()

python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
from collections.abc import Mapping
910
from typing import AsyncIterator, final
1011

1112
from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
@@ -18,8 +19,36 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
1819
from monarch._rust_bindings.monarch_hyperactor.selection import Selection
1920
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
2021

22+
@final
23+
class PythonActorMeshRef:
24+
"""
25+
A reference to a remote actor mesh over which PythonMessages can be sent.
26+
"""
27+
28+
def cast(
29+
self, mailbox: Mailbox, selection: Selection, message: PythonMessage
30+
) -> None:
31+
"""Cast a message to the selected actors in the mesh."""
32+
...
33+
34+
@property
35+
def shape(self) -> Shape:
36+
"""
37+
The Shape object that describes how the rank of an actor
38+
retrieved with get corresponds to coordinates in the
39+
mesh.
40+
"""
41+
...
42+
2143
@final
2244
class PythonActorMesh:
45+
def bind(self) -> PythonActorMeshRef:
46+
"""
47+
Bind this actor mesh. The returned mesh ref can be used to reach the
48+
mesh remotely.
49+
"""
50+
...
51+
2352
def cast(self, selection: Selection, message: PythonMessage) -> None:
2453
"""
2554
Cast a message to the selected actors in the mesh.

python/monarch/_rust_bindings/monarch_hyperactor/bootstrap.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
def bootstrap_main() -> None: ...

python/monarch/_rust_bindings/monarch_hyperactor/selection.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
from typing import final
810

911
@final

python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
import collections.abc
8-
from typing import Dict, final, Iterator, List, overload, Sequence
10+
from typing import Any, Dict, final, Iterator, List, overload, Sequence
911

1012
@final
1113
class Slice:
@@ -71,11 +73,11 @@ class Slice:
7173

7274
def __eq__(self, value: object) -> bool: ...
7375
def __hash__(self) -> int: ...
74-
def __getnewargs_ex__(self) -> tuple[tuple, dict]: ...
76+
def __getnewargs_ex__(self) -> tuple[tuple[Any], dict[Any, Any]]: ...
7577
@overload
7678
def __getitem__(self, i: int) -> int: ...
7779
@overload
78-
def __getitem__(self, i: slice) -> tuple[int, ...]: ...
80+
def __getitem__(self, i: slice[Any, Any, Any]) -> tuple[int, ...]: ...
7981
def __len__(self) -> int:
8082
"""Returns the complete size of the slice."""
8183
...
@@ -136,6 +138,7 @@ class Shape:
136138
"""
137139
...
138140
def __len__(self) -> int: ...
141+
def __eq__(self, value: object) -> bool: ...
139142
@staticmethod
140143
def unity() -> "Shape": ...
141144

python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
import logging
810

911
def forward_to_tracing(record: logging.LogRecord) -> None:

0 commit comments

Comments
 (0)