Skip to content

Commit cfb1946

Browse files
pzhan9facebook-github-bot
authored andcommitted
Add PythonActorMeshRef
Differential Revision: D78284705
1 parent 2d369e1 commit cfb1946

File tree

10 files changed

+244
-35
lines changed

10 files changed

+244
-35
lines changed

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ pub trait ActorMesh: Mesh<Id = ActorMeshId> {
145145
self.name().to_string(),
146146
),
147147
self.shape().clone(),
148-
self.proc_mesh().shape().clone(),
149148
self.proc_mesh().comm_actor().clone(),
150149
)
151150
}

hyperactor_mesh/src/reference.rs

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,6 @@ pub struct ActorMeshId(pub ProcMeshId, pub String);
7575
pub struct ActorMeshRef<A: RemoteActor> {
7676
pub(crate) mesh_id: ActorMeshId,
7777
shape: Shape,
78-
/// The shape of the underlying Proc Mesh.
79-
proc_mesh_shape: Shape,
8078
/// The reference to the comm actor of the underlying Proc Mesh.
8179
comm_actor_ref: ActorRef<CommActor>,
8280
phantom: PhantomData<A>,
@@ -90,13 +88,11 @@ impl<A: RemoteActor> ActorMeshRef<A> {
9088
pub(crate) fn attest(
9189
mesh_id: ActorMeshId,
9290
shape: Shape,
93-
proc_mesh_shape: Shape,
9491
comm_actor_ref: ActorRef<CommActor>,
9592
) -> Self {
9693
Self {
9794
mesh_id,
9895
shape,
99-
proc_mesh_shape,
10096
comm_actor_ref,
10197
phantom: PhantomData,
10298
}
@@ -142,7 +138,6 @@ impl<A: RemoteActor> Clone for ActorMeshRef<A> {
142138
Self {
143139
mesh_id: self.mesh_id.clone(),
144140
shape: self.shape.clone(),
145-
proc_mesh_shape: self.proc_mesh_shape.clone(),
146141
comm_actor_ref: self.comm_actor_ref.clone(),
147142
phantom: PhantomData,
148143
}
@@ -161,12 +156,11 @@ impl<A: RemoteActor> Eq for ActorMeshRef<A> {}
161156
mod tests {
162157
use async_trait::async_trait;
163158
use hyperactor::Actor;
159+
use hyperactor::Bind;
164160
use hyperactor::Context;
165161
use hyperactor::Handler;
166162
use hyperactor::PortRef;
167-
use hyperactor::message::Bind;
168-
use hyperactor::message::Bindings;
169-
use hyperactor::message::Unbind;
163+
use hyperactor::Unbind;
170164
use hyperactor_mesh_macros::sel;
171165
use ndslice::shape;
172166

@@ -183,11 +177,11 @@ mod tests {
183177
shape! { replica = 4 }
184178
}
185179

186-
#[derive(Debug, Serialize, Deserialize, Named, Clone)]
180+
#[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
187181
struct MeshPingPongMessage(
188182
/*ttl:*/ u64,
189183
ActorMeshRef<MeshPingPongActor>,
190-
/*completed port:*/ PortRef<bool>,
184+
/*completed port:*/ #[binding(include)] PortRef<bool>,
191185
);
192186

193187
#[derive(Debug, Clone)]
@@ -203,7 +197,6 @@ mod tests {
203197
struct MeshPingPongActorParams {
204198
mesh_id: ActorMeshId,
205199
shape: Shape,
206-
proc_mesh_shape: Shape,
207200
comm_actor_ref: ActorRef<CommActor>,
208201
}
209202

@@ -213,12 +206,7 @@ mod tests {
213206

214207
async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
215208
Ok(Self {
216-
mesh_ref: ActorMeshRef::attest(
217-
params.mesh_id,
218-
params.shape,
219-
params.proc_mesh_shape,
220-
params.comm_actor_ref,
221-
),
209+
mesh_ref: ActorMeshRef::attest(params.mesh_id, params.shape, params.comm_actor_ref),
222210
})
223211
}
224212
}
@@ -240,18 +228,6 @@ mod tests {
240228
}
241229
}
242230

243-
impl Unbind for MeshPingPongMessage {
244-
fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
245-
self.2.unbind(bindings)
246-
}
247-
}
248-
249-
impl Bind for MeshPingPongMessage {
250-
fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
251-
self.2.bind(bindings)
252-
}
253-
}
254-
255231
#[tokio::test]
256232
async fn test_inter_mesh_ping_pong() {
257233
let alloc_ping = LocalAllocator
@@ -278,7 +254,6 @@ mod tests {
278254
"ping".to_string(),
279255
),
280256
shape: ping_proc_mesh.shape().clone(),
281-
proc_mesh_shape: ping_proc_mesh.shape().clone(),
282257
comm_actor_ref: ping_proc_mesh.comm_actor().clone(),
283258
},
284259
)
@@ -296,7 +271,6 @@ mod tests {
296271
"pong".to_string(),
297272
),
298273
shape: pong_proc_mesh.shape().clone(),
299-
proc_mesh_shape: pong_proc_mesh.shape().clone(),
300274
comm_actor_ref: pong_proc_mesh.comm_actor().clone(),
301275
},
302276
)

monarch_hyperactor/src/actor_mesh.rs

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,17 @@ use hyperactor::ActorRef;
1010
use hyperactor_mesh::Mesh;
1111
use hyperactor_mesh::RootActorMesh;
1212
use hyperactor_mesh::actor_mesh::ActorMesh;
13+
use hyperactor_mesh::reference::ActorMeshRef;
1314
use hyperactor_mesh::shared_cell::SharedCell;
1415
use hyperactor_mesh::shared_cell::SharedCellRef;
1516
use pyo3::exceptions::PyException;
17+
use pyo3::exceptions::PyNotImplementedError;
1618
use pyo3::exceptions::PyRuntimeError;
19+
use pyo3::exceptions::PyValueError;
1720
use pyo3::prelude::*;
21+
use pyo3::types::PyBytes;
22+
use serde::Deserialize;
23+
use serde::Serialize;
1824

1925
use crate::actor::PythonActor;
2026
use crate::actor::PythonMessage;
@@ -40,6 +46,14 @@ impl PythonActorMesh {
4046
.borrow()
4147
.map_err(|_| PyRuntimeError::new_err("`PythonActorMesh` has already been stopped"))
4248
}
49+
50+
fn pickling_err(&self) -> PyErr {
51+
PyErr::new::<PyNotImplementedError, _>(
52+
"PythonActorMesh cannot be pickled. If applicable, use bind() \
53+
to get a PythonActorMeshRef, and use that instead."
54+
.to_string(),
55+
)
56+
}
4357
}
4458

4559
#[pymethods]
@@ -51,6 +65,11 @@ impl PythonActorMesh {
5165
Ok(())
5266
}
5367

68+
fn bind(&self) -> PyResult<PythonActorMeshRef> {
69+
let mesh = self.try_inner()?;
70+
Ok(PythonActorMeshRef { inner: mesh.bind() })
71+
}
72+
5473
// Consider defining a "PythonActorRef", which carries specifically
5574
// a reference to python message actors.
5675
fn get(&self, rank: usize) -> PyResult<Option<PyActorId>> {
@@ -70,8 +89,68 @@ impl PythonActorMesh {
7089
fn shape(&self) -> PyResult<PyShape> {
7190
Ok(PyShape::from(self.try_inner()?.shape().clone()))
7291
}
92+
93+
// Override the pickling methods to provide a meaningful error message.
94+
fn __reduce__(&self) -> PyResult<()> {
95+
Err(self.pickling_err())
96+
}
97+
98+
fn __reduce_ex__(&self, _proto: u8) -> PyResult<()> {
99+
Err(self.pickling_err())
100+
}
101+
}
102+
103+
#[pyclass(
104+
frozen,
105+
name = "PythonActorMeshRef",
106+
module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
107+
)]
108+
#[derive(Debug, Serialize, Deserialize)]
109+
pub(super) struct PythonActorMeshRef {
110+
inner: ActorMeshRef<PythonActor>,
111+
}
112+
113+
#[pymethods]
114+
impl PythonActorMeshRef {
115+
fn cast(
116+
&self,
117+
client: &PyMailbox,
118+
selection: &PySelection,
119+
message: &PythonMessage,
120+
) -> PyResult<()> {
121+
self.inner
122+
.cast(&client.inner, selection.inner().clone(), message.clone())
123+
.map_err(|err| PyException::new_err(err.to_string()))?;
124+
Ok(())
125+
}
126+
127+
#[getter]
128+
fn shape(&self) -> PyShape {
129+
PyShape::from(self.inner.shape().clone())
130+
}
131+
132+
#[staticmethod]
133+
fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
134+
bincode::deserialize(bytes.as_bytes())
135+
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))
136+
}
137+
138+
fn __reduce__<'py>(
139+
slf: &Bound<'py, Self>,
140+
) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> {
141+
let bytes = bincode::serialize(&*slf.borrow())
142+
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
143+
let py_bytes = PyBytes::new(slf.py(), &bytes);
144+
Ok((slf.as_any().getattr("from_bytes")?, (py_bytes,)))
145+
}
146+
147+
fn __repr__(&self) -> String {
148+
format!("{:?}", self)
149+
}
73150
}
151+
74152
pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
75153
hyperactor_mod.add_class::<PythonActorMesh>()?;
154+
hyperactor_mod.add_class::<PythonActorMeshRef>()?;
76155
Ok(())
77156
}

monarch_hyperactor/src/shape.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use crate::ndslice::PySlice;
2121
module = "monarch._rust_bindings.monarch_hyperactor.shape",
2222
frozen
2323
)]
24+
#[derive(Clone)]
2425
pub struct PyShape {
2526
pub(super) inner: Shape,
2627
}
@@ -112,6 +113,14 @@ impl PyShape {
112113
self.inner.slice().len()
113114
}
114115

116+
fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
117+
if let Ok(other) = other.extract::<PyShape>() {
118+
Ok(self.inner == other.inner)
119+
} else {
120+
Ok(false)
121+
}
122+
}
123+
115124
#[staticmethod]
116125
fn unity() -> PyShape {
117126
Shape::unity().into()

python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
from collections.abc import Mapping
910
from typing import final
1011

1112
from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
@@ -14,8 +15,36 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
1415
from monarch._rust_bindings.monarch_hyperactor.selection import Selection
1516
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
1617

18+
@final
19+
class PythonActorMeshRef:
20+
"""
21+
A reference to a remote actor mesh over which PythonMessages can be sent.
22+
"""
23+
24+
def cast(
25+
self, mailbox: Mailbox, selection: Selection, message: PythonMessage
26+
) -> None:
27+
"""Cast a message to the selected actors in the mesh."""
28+
...
29+
30+
@property
31+
def shape(self) -> Shape:
32+
"""
33+
The Shape object that describes how the rank of an actor
34+
retrieved with get corresponds to coordinates in the
35+
mesh.
36+
"""
37+
...
38+
1739
@final
1840
class PythonActorMesh:
41+
def bind(self) -> PythonActorMeshRef:
42+
"""
43+
Bind this actor mesh. The returned mesh ref can be used to reach the
44+
mesh remotely.
45+
"""
46+
...
47+
1948
def cast(self, selection: Selection, message: PythonMessage) -> None:
2049
"""
2150
Cast a message to the selected actors in the mesh.

python/monarch/_rust_bindings/monarch_hyperactor/bootstrap.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
def bootstrap_main() -> None: ...

python/monarch/_rust_bindings/monarch_hyperactor/selection.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
from typing import final
810

911
@final

python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
import collections.abc
8-
from typing import Dict, final, Iterator, List, overload, Sequence
10+
from typing import Any, Dict, final, Iterator, List, overload, Sequence
911

1012
@final
1113
class Slice:
@@ -71,11 +73,11 @@ class Slice:
7173

7274
def __eq__(self, value: object) -> bool: ...
7375
def __hash__(self) -> int: ...
74-
def __getnewargs_ex__(self) -> tuple[tuple, dict]: ...
76+
def __getnewargs_ex__(self) -> tuple[tuple[Any], dict[Any, Any]]: ...
7577
@overload
7678
def __getitem__(self, i: int) -> int: ...
7779
@overload
78-
def __getitem__(self, i: slice) -> tuple[int, ...]: ...
80+
def __getitem__(self, i: slice[Any, Any, Any]) -> tuple[int, ...]: ...
7981
def __len__(self) -> int:
8082
"""Returns the complete size of the slice."""
8183
...
@@ -136,6 +138,7 @@ class Shape:
136138
"""
137139
...
138140
def __len__(self) -> int: ...
141+
def __eq__(self, value: object) -> bool: ...
139142
@staticmethod
140143
def unity() -> "Shape": ...
141144

python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
import logging
810

911
def forward_to_tracing(record: logging.LogRecord) -> None:

0 commit comments

Comments
 (0)