Skip to content

Add PythonActorMeshRef #531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion hyperactor_mesh/src/actor_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ pub trait ActorMesh: Mesh<Id = ActorMeshId> {
self.name().to_string(),
),
self.shape().clone(),
self.proc_mesh().shape().clone(),
self.proc_mesh().comm_actor().clone(),
)
}
Expand Down
36 changes: 5 additions & 31 deletions hyperactor_mesh/src/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ pub struct ActorMeshId(pub ProcMeshId, pub String);
pub struct ActorMeshRef<A: RemoteActor> {
pub(crate) mesh_id: ActorMeshId,
shape: Shape,
/// The shape of the underlying Proc Mesh.
proc_mesh_shape: Shape,
/// The reference to the comm actor of the underlying Proc Mesh.
comm_actor_ref: ActorRef<CommActor>,
phantom: PhantomData<A>,
Expand All @@ -90,13 +88,11 @@ impl<A: RemoteActor> ActorMeshRef<A> {
pub(crate) fn attest(
mesh_id: ActorMeshId,
shape: Shape,
proc_mesh_shape: Shape,
comm_actor_ref: ActorRef<CommActor>,
) -> Self {
Self {
mesh_id,
shape,
proc_mesh_shape,
comm_actor_ref,
phantom: PhantomData,
}
Expand Down Expand Up @@ -142,7 +138,6 @@ impl<A: RemoteActor> Clone for ActorMeshRef<A> {
Self {
mesh_id: self.mesh_id.clone(),
shape: self.shape.clone(),
proc_mesh_shape: self.proc_mesh_shape.clone(),
comm_actor_ref: self.comm_actor_ref.clone(),
phantom: PhantomData,
}
Expand All @@ -161,12 +156,11 @@ impl<A: RemoteActor> Eq for ActorMeshRef<A> {}
mod tests {
use async_trait::async_trait;
use hyperactor::Actor;
use hyperactor::Bind;
use hyperactor::Context;
use hyperactor::Handler;
use hyperactor::PortRef;
use hyperactor::message::Bind;
use hyperactor::message::Bindings;
use hyperactor::message::Unbind;
use hyperactor::Unbind;
use hyperactor_mesh_macros::sel;
use ndslice::shape;

Expand All @@ -183,11 +177,11 @@ mod tests {
shape! { replica = 4 }
}

#[derive(Debug, Serialize, Deserialize, Named, Clone)]
#[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
struct MeshPingPongMessage(
/*ttl:*/ u64,
ActorMeshRef<MeshPingPongActor>,
/*completed port:*/ PortRef<bool>,
/*completed port:*/ #[binding(include)] PortRef<bool>,
);

#[derive(Debug, Clone)]
Expand All @@ -203,7 +197,6 @@ mod tests {
struct MeshPingPongActorParams {
mesh_id: ActorMeshId,
shape: Shape,
proc_mesh_shape: Shape,
comm_actor_ref: ActorRef<CommActor>,
}

Expand All @@ -213,12 +206,7 @@ mod tests {

async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
Ok(Self {
mesh_ref: ActorMeshRef::attest(
params.mesh_id,
params.shape,
params.proc_mesh_shape,
params.comm_actor_ref,
),
mesh_ref: ActorMeshRef::attest(params.mesh_id, params.shape, params.comm_actor_ref),
})
}
}
Expand All @@ -240,18 +228,6 @@ mod tests {
}
}

impl Unbind for MeshPingPongMessage {
fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
self.2.unbind(bindings)
}
}

impl Bind for MeshPingPongMessage {
fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
self.2.bind(bindings)
}
}

#[tokio::test]
async fn test_inter_mesh_ping_pong() {
let alloc_ping = LocalAllocator
Expand All @@ -278,7 +254,6 @@ mod tests {
"ping".to_string(),
),
shape: ping_proc_mesh.shape().clone(),
proc_mesh_shape: ping_proc_mesh.shape().clone(),
comm_actor_ref: ping_proc_mesh.comm_actor().clone(),
},
)
Expand All @@ -296,7 +271,6 @@ mod tests {
"pong".to_string(),
),
shape: pong_proc_mesh.shape().clone(),
proc_mesh_shape: pong_proc_mesh.shape().clone(),
comm_actor_ref: pong_proc_mesh.comm_actor().clone(),
},
)
Expand Down
77 changes: 77 additions & 0 deletions monarch_hyperactor/src/actor_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@ use hyperactor_mesh::Mesh;
use hyperactor_mesh::RootActorMesh;
use hyperactor_mesh::actor_mesh::ActorMesh;
use hyperactor_mesh::actor_mesh::ActorSupervisionEvents;
use hyperactor_mesh::reference::ActorMeshRef;
use hyperactor_mesh::shared_cell::SharedCell;
use hyperactor_mesh::shared_cell::SharedCellRef;
use pyo3::exceptions::PyEOFError;
use pyo3::exceptions::PyException;
use pyo3::exceptions::PyNotImplementedError;
use pyo3::exceptions::PyRuntimeError;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use serde::Deserialize;
use serde::Serialize;
use tokio::sync::Mutex;

use crate::actor::PythonActor;
Expand Down Expand Up @@ -101,6 +106,14 @@ impl PythonActorMesh {
.borrow()
.map_err(|_| PyRuntimeError::new_err("`PythonActorMesh` has already been stopped"))
}

fn pickling_err(&self) -> PyErr {
PyErr::new::<PyNotImplementedError, _>(
"PythonActorMesh cannot be pickled. If applicable, use bind() \
to get a PythonActorMeshRef, and use that instead."
.to_string(),
)
}
}

#[pymethods]
Expand All @@ -123,6 +136,11 @@ impl PythonActorMesh {
Ok(())
}

fn bind(&self) -> PyResult<PythonActorMeshRef> {
let mesh = self.try_inner()?;
Ok(PythonActorMeshRef { inner: mesh.bind() })
}

fn get_supervision_event(&self) -> PyResult<Option<PyActorSupervisionEvent>> {
let unhealthy_event = self
.unhealthy_event
Expand Down Expand Up @@ -169,6 +187,64 @@ impl PythonActorMesh {
fn shape(&self) -> PyResult<PyShape> {
Ok(PyShape::from(self.try_inner()?.shape().clone()))
}

// Override the pickling methods to provide a meaningful error message.
fn __reduce__(&self) -> PyResult<()> {
Err(self.pickling_err())
}

fn __reduce_ex__(&self, _proto: u8) -> PyResult<()> {
Err(self.pickling_err())
}
}

#[pyclass(
frozen,
name = "PythonActorMeshRef",
module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
)]
#[derive(Debug, Serialize, Deserialize)]
pub(super) struct PythonActorMeshRef {
inner: ActorMeshRef<PythonActor>,
}

#[pymethods]
impl PythonActorMeshRef {
fn cast(
&self,
client: &PyMailbox,
selection: &PySelection,
message: &PythonMessage,
) -> PyResult<()> {
self.inner
.cast(&client.inner, selection.inner().clone(), message.clone())
.map_err(|err| PyException::new_err(err.to_string()))?;
Ok(())
}

#[getter]
fn shape(&self) -> PyShape {
PyShape::from(self.inner.shape().clone())
}

#[staticmethod]
fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
bincode::deserialize(bytes.as_bytes())
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))
}

fn __reduce__<'py>(
slf: &Bound<'py, Self>,
) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> {
let bytes = bincode::serialize(&*slf.borrow())
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
let py_bytes = PyBytes::new(slf.py(), &bytes);
Ok((slf.as_any().getattr("from_bytes")?, (py_bytes,)))
}

fn __repr__(&self) -> String {
format!("{:?}", self)
}
}

impl Drop for PythonActorMesh {
Expand Down Expand Up @@ -379,6 +455,7 @@ impl From<ActorSupervisionEvent> for PyActorSupervisionEvent {

pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
hyperactor_mod.add_class::<PythonActorMesh>()?;
hyperactor_mod.add_class::<PythonActorMeshRef>()?;
hyperactor_mod.add_class::<PyActorMeshMonitor>()?;
hyperactor_mod.add_class::<MonitoredPythonPortReceiver>()?;
hyperactor_mod.add_class::<MonitoredPythonOncePortReceiver>()?;
Expand Down
9 changes: 9 additions & 0 deletions monarch_hyperactor/src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::ndslice::PySlice;
module = "monarch._rust_bindings.monarch_hyperactor.shape",
frozen
)]
#[derive(Clone)]
pub struct PyShape {
pub(super) inner: Shape,
}
Expand Down Expand Up @@ -112,6 +113,14 @@ impl PyShape {
self.inner.slice().len()
}

fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
if let Ok(other) = other.extract::<PyShape>() {
Ok(self.inner == other.inner)
} else {
Ok(false)
}
}

#[staticmethod]
fn unity() -> PyShape {
Shape::unity().into()
Expand Down
29 changes: 29 additions & 0 deletions python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

from collections.abc import Mapping
from typing import AsyncIterator, final

from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
Expand All @@ -18,8 +19,36 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
from monarch._rust_bindings.monarch_hyperactor.selection import Selection
from monarch._rust_bindings.monarch_hyperactor.shape import Shape

@final
class PythonActorMeshRef:
"""
A reference to a remote actor mesh over which PythonMessages can be sent.
"""

def cast(
self, mailbox: Mailbox, selection: Selection, message: PythonMessage
) -> None:
"""Cast a message to the selected actors in the mesh."""
...

@property
def shape(self) -> Shape:
"""
The Shape object that describes how the rank of an actor
retrieved with get corresponds to coordinates in the
mesh.
"""
...

@final
class PythonActorMesh:
def bind(self) -> PythonActorMeshRef:
"""
Bind this actor mesh. The returned mesh ref can be used to reach the
mesh remotely.
"""
...

def cast(self, selection: Selection, message: PythonMessage) -> None:
"""
Cast a message to the selected actors in the mesh.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

def bootstrap_main() -> None: ...
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import final

@final
Expand Down
9 changes: 6 additions & 3 deletions python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import collections.abc
from typing import Dict, final, Iterator, List, overload, Sequence
from typing import Any, Dict, final, Iterator, List, overload, Sequence

@final
class Slice:
Expand Down Expand Up @@ -71,11 +73,11 @@ class Slice:

def __eq__(self, value: object) -> bool: ...
def __hash__(self) -> int: ...
def __getnewargs_ex__(self) -> tuple[tuple, dict]: ...
def __getnewargs_ex__(self) -> tuple[tuple[Any], dict[Any, Any]]: ...
@overload
def __getitem__(self, i: int) -> int: ...
@overload
def __getitem__(self, i: slice) -> tuple[int, ...]: ...
def __getitem__(self, i: slice[Any, Any, Any]) -> tuple[int, ...]: ...
def __len__(self) -> int:
"""Returns the complete size of the slice."""
...
Expand Down Expand Up @@ -136,6 +138,7 @@ class Shape:
"""
...
def __len__(self) -> int: ...
def __eq__(self, value: object) -> bool: ...
@staticmethod
def unity() -> "Shape": ...

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import logging

def forward_to_tracing(record: logging.LogRecord) -> None:
Expand Down
Loading
Loading