Skip to content

[15/n] Pass monarch tensors to actor endpoints, part 1 #518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: gh/zdevito/38/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions monarch_extension/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,35 @@ impl<'a> MessageParser<'a> {
fn parseWorkerMessageList(&self, name: &str) -> PyResult<Vec<WorkerMessage>> {
self.attr(name)?.try_iter()?.map(|x| convert(x?)).collect()
}
#[allow(non_snake_case)]
fn parseLocalStateList(&self, name: &str) -> PyResult<Vec<worker::LocalState>> {
let local_state_list = self.attr(name)?;
let mut result = Vec::new();

// Get the PyMailbox class for type checking
let mailbox_class = self
.current
.py()
.import("monarch._rust_bindings.monarch_hyperactor.mailbox")?
.getattr("Mailbox")?;

for item in local_state_list.try_iter()? {
let item = item?;
// Check if it's a Ref by trying to extract it
if let Ok(ref_val) = create_ref(item.clone()) {
result.push(worker::LocalState::Ref(ref_val));
} else if item.is_instance(&mailbox_class)? {
// It's a PyMailbox instance
result.push(worker::LocalState::Mailbox);
} else {
return Err(PyValueError::new_err(format!(
"Expected Ref or Mailbox in local_state, got: {}",
item.get_type().name()?
)));
}
}
Ok(result)
}
fn parse_error_reason(&self, name: &str) -> PyResult<Option<(Option<ActorId>, String)>> {
let err = self.attr(name)?;
if err.is_none() {
Expand Down Expand Up @@ -432,6 +461,19 @@ fn create_map(py: Python) -> HashMap<u64, FnType> {
p.parseWorkerMessageList("commands")?,
))
});
m.insert(key("SendResultOfActorCall"), |p| {
Ok(WorkerMessage::SendResultOfActorCall(
worker::ActorCallParams {
seq: p.parseSeq("seq")?,
actor: p.parse("actor")?,
index: p.parse("actor_index")?,
python_message: p.parse("python_message")?,
local_state: p.parseLocalStateList("local_state")?,
mutates: p.parseRefList("mutates")?,
stream: p.parseStreamRef("stream")?,
},
))
});
m
}

Expand Down
1 change: 1 addition & 0 deletions monarch_extension/src/tensor_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,7 @@ pub(crate) fn worker_message_to_py(py: Python<'_>, message: &WorkerMessage) -> P
}
WorkerMessage::SetRefUnitTestsOnly { .. } => unimplemented!(),
WorkerMessage::GetRefUnitTestsOnly { .. } => unimplemented!(),
WorkerMessage::SendResultOfActorCall { .. } => unimplemented!(),
}
}

Expand Down
58 changes: 49 additions & 9 deletions monarch_hyperactor/src/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ impl PickledMessageClientActor {
#[pyclass(frozen, module = "monarch._rust_bindings.monarch_hyperactor.actor")]
#[derive(Default, Clone, Serialize, Deserialize, Named, PartialEq)]
pub struct PythonMessage {
pub(crate) method: String,
pub(crate) message: ByteBuf,
response_port: Option<EitherPortRef>,
rank: Option<usize>,
pub method: String,
pub message: ByteBuf,
pub response_port: Option<EitherPortRef>,
pub rank: Option<usize>,
}

impl PythonMessage {
Expand Down Expand Up @@ -289,7 +289,7 @@ impl PythonActorHandle {
PythonMessage { cast = true },
],
)]
pub(super) struct PythonActor {
pub struct PythonActor {
/// The Python object that we delegate message handling to. An instance of
/// `monarch.actor_mesh._Actor`.
pub(super) actor: PyObject,
Expand Down Expand Up @@ -399,16 +399,30 @@ impl PanicFlag {
}

#[async_trait]
impl Handler<PythonMessage> for PythonActor {
async fn handle(&mut self, cx: &Context<Self>, message: PythonMessage) -> anyhow::Result<()> {
let mailbox = PyMailbox {
impl Handler<LocalPythonMessage> for PythonActor {
async fn handle(
&mut self,
cx: &Context<Self>,
message: LocalPythonMessage,
) -> anyhow::Result<()> {
let mailbox: PyMailbox = PyMailbox {
inner: cx.mailbox_for_py().clone(),
};
// Create a channel for signaling panics in async endpoints.
// See [Panics in async endpoints].
let (sender, receiver) = oneshot::channel();

let future = Python::with_gil(|py| -> Result<_, SerializablePyErr> {
let mailbox = mailbox.into_bound_py_any(py).unwrap();
let local_state: Option<Vec<Bound<'_, PyAny>>> = message.local_state.map(|state| {
state
.into_iter()
.map(|e| match e {
LocalState::Mailbox => mailbox.clone(),
LocalState::PyObject(obj) => obj.into_bound_py_any(py).unwrap(),
})
.collect()
});
let (rank, shape) = cx.cast_info();
let awaitable = self.actor.call_method(
py,
Expand All @@ -417,10 +431,11 @@ impl Handler<PythonMessage> for PythonActor {
mailbox,
rank,
PyShape::from(shape),
message,
message.message,
PanicFlag {
sender: Some(sender),
},
local_state,
),
None,
)?;
Expand All @@ -439,6 +454,20 @@ impl Handler<PythonMessage> for PythonActor {
}
}

#[async_trait]
impl Handler<PythonMessage> for PythonActor {
async fn handle(&mut self, cx: &Context<Self>, message: PythonMessage) -> anyhow::Result<()> {
self.handle(
cx,
LocalPythonMessage {
message,
local_state: None,
},
)
.await
}
}

/// Helper struct to make a Python future passable in an actor message.
///
/// Also so that we don't have to write this massive type signature everywhere
Expand Down Expand Up @@ -538,6 +567,17 @@ impl AsyncEndpointInvocationHandler for AsyncEndpointTask {
Ok(())
}
}
#[derive(Debug)]
pub enum LocalState {
Mailbox,
PyObject(PyObject),
}

#[derive(Debug)]
pub struct LocalPythonMessage {
pub message: PythonMessage,
pub local_state: Option<Vec<LocalState>>,
}

pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
hyperactor_mod.add_class::<PickledMessage>()?;
Expand Down
4 changes: 2 additions & 2 deletions monarch_hyperactor/src/mailbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ impl PythonUndeliverablePortHandle {
name = "PortRef",
module = "monarch._rust_bindings.monarch_hyperactor.mailbox"
)]
pub(super) struct PythonPortRef {
pub struct PythonPortRef {
pub(crate) inner: PortRef<PythonMessage>,
}

Expand Down Expand Up @@ -453,7 +453,7 @@ impl PythonOncePortHandle {
name = "OncePortRef",
module = "monarch._rust_bindings.monarch_hyperactor.mailbox"
)]
pub(crate) struct PythonOncePortRef {
pub struct PythonOncePortRef {
pub(crate) inner: Option<OncePortRef<PythonMessage>>,
}

Expand Down
1 change: 1 addition & 0 deletions monarch_messages/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ anyhow = "1.0.98"
derive_more = { version = "1.0.0", features = ["full"] }
enum-as-inner = "0.6.0"
hyperactor = { version = "0.0.0", path = "../hyperactor" }
monarch_hyperactor = { version = "0.0.0", path = "../monarch_hyperactor" }
monarch_types = { version = "0.0.0", path = "../monarch_types" }
ndslice = { version = "0.0.0", path = "../ndslice" }
pyo3 = { version = "0.24", features = ["anyhow", "multiple-pymethods"] }
Expand Down
46 changes: 42 additions & 4 deletions monarch_messages/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use hyperactor::Named;
use hyperactor::RefClient;
use hyperactor::Unbind;
use hyperactor::reference::ActorId;
use monarch_hyperactor::actor::PythonMessage;
use monarch_types::SerializablePyErr;
use ndslice::Slice;
use pyo3::exceptions::PyValueError;
Expand Down Expand Up @@ -402,6 +403,33 @@ pub struct CallFunctionParams {
pub remote_process_groups: Vec<Ref>,
}

/// The local state that has to be restored into the python_message during
/// its unpickling.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum LocalState {
Ref(Ref),
Mailbox,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ActorCallParams {
pub seq: Seq,
/// The actor to call is (proc_id of stream worker, 'actor', 'index')
pub actor: String,
pub index: usize,
/// The PythonMessage object to send to the actor.
pub python_message: PythonMessage,
/// Referenceable objects to pass to the actor,
/// these will be put into the PythonMessage
/// during its unpickling. Because unpickling also needs to
/// restore mailboxes, we also have to keep track of which
/// members should just be the mailbox object.
pub local_state: Vec<LocalState>,

/// Tensors that will be mutated by the call.
pub mutates: Vec<Ref>,
pub stream: StreamRef,
}
/// Type of reduction for [`WorkerMessage::Reduce`].
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Reduction {
Expand Down Expand Up @@ -642,21 +670,30 @@ pub enum WorkerMessage {

/// First use of the borrow on the receiving stream. This is a marker for
/// synchronization.
BorrowFirstUse { borrow: u64 },
BorrowFirstUse {
borrow: u64,
},

/// Last use of the borrow on the receiving stream. This is a marker for
/// synchronization.
BorrowLastUse { borrow: u64 },
BorrowLastUse {
borrow: u64,
},

/// Drop the borrow and free the resources associated with it.
BorrowDrop { borrow: u64 },
BorrowDrop {
borrow: u64,
},

/// Delete these refs from the worker state.
DeleteRefs(Vec<Ref>),

/// A [`ControllerMessage::Status`] will be send to the controller
/// when all streams have processed all the message sent before this one.
RequestStatus { seq: Seq, controller: bool },
RequestStatus {
seq: Seq,
controller: bool,
},

/// Perform a reduction operation, using an efficient communication backend.
/// Only NCCL is supported for now.
Expand Down Expand Up @@ -758,6 +795,7 @@ pub enum WorkerMessage {
stream: StreamRef,
},

SendResultOfActorCall(ActorCallParams),
PipeRecv {
seq: Seq,
/// Result refs.
Expand Down
8 changes: 8 additions & 0 deletions monarch_simulator/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,14 @@ impl WorkerMessageHandler for WorkerActor {
Ok(())
}

async fn send_result_of_actor_call(
&mut self,
cx: &hyperactor::Context<Self>,
params: ActorCallParams,
) -> Result<()> {
bail!("unimplemented: send_result_of_actor_call");
}

async fn command_group(
&mut self,
cx: &hyperactor::Context<Self>,
Expand Down
12 changes: 12 additions & 0 deletions monarch_tensor_worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ use monarch_messages::controller::ControllerActor;
use monarch_messages::controller::ControllerMessageClient;
use monarch_messages::controller::Seq;
use monarch_messages::wire_value::WireValue;
use monarch_messages::worker::ActorCallParams;
use monarch_messages::worker::CallFunctionError;
use monarch_messages::worker::CallFunctionParams;
use monarch_messages::worker::Factory;
Expand Down Expand Up @@ -860,6 +861,17 @@ impl WorkerMessageHandler for WorkerActor {
.await
}

async fn send_result_of_actor_call(
&mut self,
cx: &hyperactor::Context<Self>,
params: ActorCallParams,
) -> Result<()> {
let stream = self.try_get_stream(params.stream)?;
stream
.send_result_of_actor_call(cx, cx.self_id().clone(), params)
.await?;
Ok(())
}
async fn split_comm(
&mut self,
cx: &hyperactor::Context<Self>,
Expand Down
Loading
Loading