Skip to content

[17/n] Pass monarch tensors to actor endpoints, part 2, actor messages sent in order #539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: gh/zdevito/40/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 2 additions & 34 deletions monarch_extension/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,35 +176,6 @@ impl<'a> MessageParser<'a> {
fn parseWorkerMessageList(&self, name: &str) -> PyResult<Vec<WorkerMessage>> {
self.attr(name)?.try_iter()?.map(|x| convert(x?)).collect()
}
#[allow(non_snake_case)]
fn parseLocalStateList(&self, name: &str) -> PyResult<Vec<worker::LocalState>> {
let local_state_list = self.attr(name)?;
let mut result = Vec::new();

// Get the PyMailbox class for type checking
let mailbox_class = self
.current
.py()
.import("monarch._rust_bindings.monarch_hyperactor.mailbox")?
.getattr("Mailbox")?;

for item in local_state_list.try_iter()? {
let item = item?;
// Check if it's a Ref by trying to extract it
if let Ok(ref_val) = create_ref(item.clone()) {
result.push(worker::LocalState::Ref(ref_val));
} else if item.is_instance(&mailbox_class)? {
// It's a PyMailbox instance
result.push(worker::LocalState::Mailbox);
} else {
return Err(PyValueError::new_err(format!(
"Expected Ref or Mailbox in local_state, got: {}",
item.get_type().name()?
)));
}
}
Ok(result)
}
fn parse_error_reason(&self, name: &str) -> PyResult<Option<(Option<ActorId>, String)>> {
let err = self.attr(name)?;
if err.is_none() {
Expand Down Expand Up @@ -465,11 +436,8 @@ fn create_map(py: Python) -> HashMap<u64, FnType> {
Ok(WorkerMessage::SendResultOfActorCall(
worker::ActorCallParams {
seq: p.parseSeq("seq")?,
actor: p.parse("actor")?,
index: p.parse("actor_index")?,
method: p.parse("method")?,
args_kwargs_tuple: p.parse("args_kwargs_tuple")?,
local_state: p.parseLocalStateList("local_state")?,
broker_id: p.parse("broker_id")?,
local_state: p.parseRefList("local_state")?,
mutates: p.parseRefList("mutates")?,
stream: p.parseStreamRef("stream")?,
},
Expand Down
16 changes: 16 additions & 0 deletions monarch_extension/src/mesh_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use hyperactor_mesh::shared_cell::SharedCell;
use hyperactor_mesh::shared_cell::SharedCellRef;
use monarch_hyperactor::actor::PythonMessage;
use monarch_hyperactor::actor::PythonMessageKind;
use monarch_hyperactor::local_state_broker::LocalStateBrokerActor;
use monarch_hyperactor::mailbox::PyPortId;
use monarch_hyperactor::ndslice::PySlice;
use monarch_hyperactor::proc_mesh::PyProcMesh;
Expand Down Expand Up @@ -78,6 +79,7 @@ pub(crate) fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult
struct _Controller {
controller_handle: Arc<Mutex<ActorHandle<MeshControllerActor>>>,
all_ranks: Slice,
broker_id: (String, usize),
}

static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
Expand Down Expand Up @@ -123,9 +125,17 @@ impl _Controller {
Ok(Self {
controller_handle,
all_ranks,
// note that 0 is the _pid_ of the broker, which will be 0 for
// top-level spawned actors.
broker_id: (format!("tensor_engine_brokers_{}", id), 0),
})
}

#[getter]
fn broker_id(&self) -> (String, usize) {
self.broker_id.clone()
}

#[pyo3(signature = (seq, defs, uses, response_port, tracebacks))]
fn node<'py>(
&mut self,
Expand Down Expand Up @@ -626,6 +636,7 @@ enum ClientToControllerMessage {
struct MeshControllerActor {
proc_mesh: SharedCell<TrackedProcMesh>,
workers: Option<SharedCell<RootActorMesh<'static, WorkerActor>>>,
brokers: Option<SharedCell<RootActorMesh<'static, LocalStateBrokerActor>>>,
history: History,
id: usize,
debugger_active: Option<ActorRef<DebuggerActor>>,
Expand Down Expand Up @@ -730,6 +741,7 @@ impl Actor for MeshControllerActor {
Ok(MeshControllerActor {
proc_mesh: proc_mesh.clone(),
workers: None,
brokers: None,
history: History::new(world_size),
id,
debugger_active: None,
Expand Down Expand Up @@ -758,6 +770,10 @@ impl Actor for MeshControllerActor {
.cast(selection::dsl::true_(), AssignRankMessage::AssignRank())?;

self.workers = Some(workers);
let brokers = proc_mesh
.spawn(&format!("tensor_engine_brokers_{}", self.id), &())
.await?;
self.brokers = Some(brokers);
Ok(())
}
}
Expand Down
125 changes: 80 additions & 45 deletions monarch_hyperactor/src/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use hyperactor::Context;
use hyperactor::HandleClient;
use hyperactor::Handler;
use hyperactor::Named;
use hyperactor::cap::CanSend;
use hyperactor::forward;
use hyperactor::message::Bind;
use hyperactor::message::Bindings;
Expand All @@ -43,6 +44,8 @@ use tokio::sync::Mutex;
use tokio::sync::oneshot;

use crate::config::SHARED_ASYNCIO_RUNTIME;
use crate::local_state_broker::BrokerId;
use crate::local_state_broker::LocalStateBrokerMessage;
use crate::mailbox::EitherPortRef;
use crate::mailbox::PyMailbox;
use crate::proc::InstanceWrapper;
Expand Down Expand Up @@ -171,6 +174,13 @@ impl PickledMessageClientActor {
}
}

#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.actor")]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum UnflattenArg {
Mailbox,
PyObject,
}

#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.actor")]
#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq)]
pub enum PythonMessageKind {
Expand All @@ -185,6 +195,14 @@ pub enum PythonMessageKind {
rank: Option<usize>,
},
Uninit {},
CallMethodIndirect {
name: String,
local_state_broker: (String, usize),
id: usize,
// specify whether the argument to unflatten the local mailbox,
// or the next argument of the local state.
unflatten_args: Vec<UnflattenArg>,
},
}

impl Default for PythonMessageKind {
Expand All @@ -193,6 +211,11 @@ impl Default for PythonMessageKind {
}
}

fn mailbox<'py, T: Actor>(py: Python<'py>, cx: &Context<'_, T>) -> Bound<'py, PyAny> {
let mailbox: PyMailbox = cx.mailbox_for_py().clone().into();
mailbox.into_bound_py_any(py).unwrap()
}

#[pyclass(frozen, module = "monarch._rust_bindings.monarch_hyperactor.actor")]
#[derive(Clone, Serialize, Deserialize, Named, PartialEq, Default)]
pub struct PythonMessage {
Expand All @@ -219,6 +242,56 @@ impl PythonMessage {
_ => panic!("PythonMessage is not a response but {:?}", self),
}
}

pub async fn resolve_indirect_call<T: Actor>(
mut self,
cx: &Context<'_, T>,
) -> anyhow::Result<(Self, PyObject)> {
let local_state: PyObject;
match self.kind {
PythonMessageKind::CallMethodIndirect {
name,
local_state_broker,
id,
unflatten_args,
} => {
let broker = BrokerId::new(local_state_broker).resolve(cx).unwrap();
let (send, recv) = cx.open_once_port();
broker.send(LocalStateBrokerMessage::Get(id, send))?;
let state = recv.recv().await?;
let mut state_it = state.state.into_iter();
local_state = Python::with_gil(|py| {
let mailbox = mailbox(py, cx);
PyList::new(
py,
unflatten_args.into_iter().map(|x| -> Bound<'_, PyAny> {
match x {
UnflattenArg::Mailbox => mailbox.clone(),
UnflattenArg::PyObject => state_it.next().unwrap().into_bound(py),
}
}),
)
.unwrap()
.into()
});
self.kind = PythonMessageKind::CallMethod {
name,
response_port: Some(state.response_port),
}
}
_ => {
local_state = Python::with_gil(|py| {
let mailbox = mailbox(py, cx);
py.import("itertools")
.unwrap()
.call_method1("repeat", (mailbox.clone(),))
.unwrap()
.unbind()
});
}
};
Ok((self, local_state))
}
}

impl std::fmt::Debug for PythonMessage {
Expand Down Expand Up @@ -408,30 +481,16 @@ impl PanicFlag {
}

#[async_trait]
impl Handler<LocalPythonMessage> for PythonActor {
async fn handle(
&mut self,
cx: &Context<Self>,
message: LocalPythonMessage,
) -> anyhow::Result<()> {
let mailbox: PyMailbox = PyMailbox {
inner: cx.mailbox_for_py().clone(),
};
impl Handler<PythonMessage> for PythonActor {
async fn handle(&mut self, cx: &Context<Self>, message: PythonMessage) -> anyhow::Result<()> {
let (message, local_state) = message.resolve_indirect_call(cx).await?;

// Create a channel for signaling panics in async endpoints.
// See [Panics in async endpoints].
let (sender, receiver) = oneshot::channel();

let future = Python::with_gil(|py| -> Result<_, SerializablePyErr> {
let mailbox = mailbox.into_bound_py_any(py).unwrap();
let local_state: Option<Vec<Bound<'_, PyAny>>> = message.local_state.map(|state| {
state
.into_iter()
.map(|e| match e {
LocalState::Mailbox => mailbox.clone(),
LocalState::PyObject(obj) => obj.into_bound_py_any(py).unwrap(),
})
.collect()
});
let mailbox = mailbox(py, cx);
let (rank, shape) = cx.cast_info();
let awaitable = self.actor.call_method(
py,
Expand All @@ -440,7 +499,7 @@ impl Handler<LocalPythonMessage> for PythonActor {
mailbox,
rank,
PyShape::from(shape),
message.message,
message,
PanicFlag {
sender: Some(sender),
},
Expand All @@ -463,20 +522,6 @@ impl Handler<LocalPythonMessage> for PythonActor {
}
}

#[async_trait]
impl Handler<PythonMessage> for PythonActor {
async fn handle(&mut self, cx: &Context<Self>, message: PythonMessage) -> anyhow::Result<()> {
self.handle(
cx,
LocalPythonMessage {
message,
local_state: None,
},
)
.await
}
}

/// Helper struct to make a Python future passable in an actor message.
///
/// Also so that we don't have to write this massive type signature everywhere
Expand Down Expand Up @@ -576,24 +621,14 @@ impl AsyncEndpointInvocationHandler for AsyncEndpointTask {
Ok(())
}
}
#[derive(Debug)]
pub enum LocalState {
Mailbox,
PyObject(PyObject),
}

#[derive(Debug)]
pub struct LocalPythonMessage {
pub message: PythonMessage,
pub local_state: Option<Vec<LocalState>>,
}

pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
hyperactor_mod.add_class::<PickledMessage>()?;
hyperactor_mod.add_class::<PickledMessageClientActor>()?;
hyperactor_mod.add_class::<PythonActorHandle>()?;
hyperactor_mod.add_class::<PythonMessage>()?;
hyperactor_mod.add_class::<PythonMessageKind>()?;
hyperactor_mod.add_class::<UnflattenArg>()?;
hyperactor_mod.add_class::<PanicFlag>()?;
Ok(())
}
Expand Down
1 change: 1 addition & 0 deletions monarch_hyperactor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub mod alloc;
pub mod bootstrap;
pub mod channel;
pub mod config;
pub mod local_state_broker;
pub mod mailbox;
pub mod ndslice;
pub mod proc;
Expand Down
Loading
Loading