Skip to content

Commit 12d813f

Browse files
committed
[15/n] Pass monarch tensors to actor endpoints, part 1
Pull Request resolved: #518 This makes it possible to send a monarch tensor to an actor endpoint that is defiend over the same proc mesh as the tensor. The send is done locally so the actor can do work with the tensors. The stream actor in the tensor engine is sent a SendResultOfActorCall message, which it will forward to the actor, binding to the message the real local tensors that were passed as arguments. The stream actor that owns the tensor waits for called the actor to finish since the actor 'owns' the tensor through the duration of the call. Known limitation: this message to the actor can go out of order w.r.t to other messages sent from the owner of the tensor engine because the real message is being sent from the stream actor. The next PR will fix this limitation by sending _both_ the tensor engine and the actor a message at the same time. The actor will get a 'wait for SendResultOfActorCall' message, at which point it will stop processing any messages except for the SentResultOfActorCall message it is suppose to be waiting for. This way the correct order is preserved from the perspective of the tensor engine stream and the actor. ghstack-source-id: 296136314 @exported-using-ghexport Differential Revision: [D78196701](https://our.internmc.facebook.com/intern/diff/D78196701/)
1 parent 771a3d0 commit 12d813f

File tree

15 files changed

+385
-59
lines changed

15 files changed

+385
-59
lines changed

monarch_extension/src/convert.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,35 @@ impl<'a> MessageParser<'a> {
176176
fn parseWorkerMessageList(&self, name: &str) -> PyResult<Vec<WorkerMessage>> {
177177
self.attr(name)?.try_iter()?.map(|x| convert(x?)).collect()
178178
}
179+
#[allow(non_snake_case)]
180+
fn parseLocalStateList(&self, name: &str) -> PyResult<Vec<worker::LocalState>> {
181+
let local_state_list = self.attr(name)?;
182+
let mut result = Vec::new();
183+
184+
// Get the PyMailbox class for type checking
185+
let mailbox_class = self
186+
.current
187+
.py()
188+
.import("monarch._rust_bindings.monarch_hyperactor.mailbox")?
189+
.getattr("Mailbox")?;
190+
191+
for item in local_state_list.try_iter()? {
192+
let item = item?;
193+
// Check if it's a Ref by trying to extract it
194+
if let Ok(ref_val) = create_ref(item.clone()) {
195+
result.push(worker::LocalState::Ref(ref_val));
196+
} else if item.is_instance(&mailbox_class)? {
197+
// It's a PyMailbox instance
198+
result.push(worker::LocalState::Mailbox);
199+
} else {
200+
return Err(PyValueError::new_err(format!(
201+
"Expected Ref or Mailbox in local_state, got: {}",
202+
item.get_type().name()?
203+
)));
204+
}
205+
}
206+
Ok(result)
207+
}
179208
fn parse_error_reason(&self, name: &str) -> PyResult<Option<(Option<ActorId>, String)>> {
180209
let err = self.attr(name)?;
181210
if err.is_none() {
@@ -432,6 +461,20 @@ fn create_map(py: Python) -> HashMap<u64, FnType> {
432461
p.parseWorkerMessageList("commands")?,
433462
))
434463
});
464+
m.insert(key("SendResultOfActorCall"), |p| {
465+
Ok(WorkerMessage::SendResultOfActorCall(
466+
worker::ActorCallParams {
467+
seq: p.parseSeq("seq")?,
468+
actor: p.parse("actor")?,
469+
index: p.parse("actor_index")?,
470+
method: p.parse("method")?,
471+
args_kwargs_tuple: p.parse("args_kwargs_tuple")?,
472+
local_state: p.parseLocalStateList("local_state")?,
473+
mutates: p.parseRefList("mutates")?,
474+
stream: p.parseStreamRef("stream")?,
475+
},
476+
))
477+
});
435478
m
436479
}
437480

monarch_extension/src/tensor_worker.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,6 +1372,7 @@ pub(crate) fn worker_message_to_py(py: Python<'_>, message: &WorkerMessage) -> P
13721372
}
13731373
WorkerMessage::SetRefUnitTestsOnly { .. } => unimplemented!(),
13741374
WorkerMessage::GetRefUnitTestsOnly { .. } => unimplemented!(),
1375+
WorkerMessage::SendResultOfActorCall { .. } => unimplemented!(),
13751376
}
13761377
}
13771378

monarch_hyperactor/src/actor.rs

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,10 @@ impl PickledMessageClientActor {
174174
#[pyclass(frozen, module = "monarch._rust_bindings.monarch_hyperactor.actor")]
175175
#[derive(Default, Clone, Serialize, Deserialize, Named, PartialEq)]
176176
pub struct PythonMessage {
177-
pub(crate) method: String,
178-
pub(crate) message: ByteBuf,
179-
response_port: Option<EitherPortRef>,
180-
rank: Option<usize>,
177+
pub method: String,
178+
pub message: ByteBuf,
179+
pub response_port: Option<EitherPortRef>,
180+
pub rank: Option<usize>,
181181
}
182182

183183
impl PythonMessage {
@@ -288,7 +288,7 @@ impl PythonActorHandle {
288288
PythonMessage { cast = true },
289289
],
290290
)]
291-
pub(super) struct PythonActor {
291+
pub struct PythonActor {
292292
/// The Python object that we delegate message handling to. An instance of
293293
/// `monarch.actor_mesh._Actor`.
294294
pub(super) actor: PyObject,
@@ -398,16 +398,30 @@ impl PanicFlag {
398398
}
399399

400400
#[async_trait]
401-
impl Handler<PythonMessage> for PythonActor {
402-
async fn handle(&mut self, cx: &Context<Self>, message: PythonMessage) -> anyhow::Result<()> {
403-
let mailbox = PyMailbox {
401+
impl Handler<LocalPythonMessage> for PythonActor {
402+
async fn handle(
403+
&mut self,
404+
cx: &Context<Self>,
405+
message: LocalPythonMessage,
406+
) -> anyhow::Result<()> {
407+
let mailbox: PyMailbox = PyMailbox {
404408
inner: cx.mailbox_for_py().clone(),
405409
};
406410
// Create a channel for signaling panics in async endpoints.
407411
// See [Panics in async endpoints].
408412
let (sender, receiver) = oneshot::channel();
409413

410414
let future = Python::with_gil(|py| -> Result<_, SerializablePyErr> {
415+
let mailbox = mailbox.into_bound_py_any(py).unwrap();
416+
let local_state: Option<Vec<Bound<'_, PyAny>>> = message.local_state.map(|state| {
417+
state
418+
.into_iter()
419+
.map(|e| match e {
420+
LocalState::Mailbox => mailbox.clone(),
421+
LocalState::PyObject(obj) => obj.into_bound_py_any(py).unwrap(),
422+
})
423+
.collect()
424+
});
411425
let (rank, shape) = cx.cast_info();
412426
let awaitable = self.actor.call_method(
413427
py,
@@ -416,10 +430,11 @@ impl Handler<PythonMessage> for PythonActor {
416430
mailbox,
417431
rank,
418432
PyShape::from(shape),
419-
message,
433+
message.message,
420434
PanicFlag {
421435
sender: Some(sender),
422436
},
437+
local_state,
423438
),
424439
None,
425440
)?;
@@ -438,6 +453,20 @@ impl Handler<PythonMessage> for PythonActor {
438453
}
439454
}
440455

456+
#[async_trait]
457+
impl Handler<PythonMessage> for PythonActor {
458+
async fn handle(&mut self, cx: &Context<Self>, message: PythonMessage) -> anyhow::Result<()> {
459+
self.handle(
460+
cx,
461+
LocalPythonMessage {
462+
message,
463+
local_state: None,
464+
},
465+
)
466+
.await
467+
}
468+
}
469+
441470
/// Helper struct to make a Python future passable in an actor message.
442471
///
443472
/// Also so that we don't have to write this massive type signature everywhere
@@ -537,6 +566,17 @@ impl AsyncEndpointInvocationHandler for AsyncEndpointTask {
537566
Ok(())
538567
}
539568
}
569+
#[derive(Debug)]
570+
pub enum LocalState {
571+
Mailbox,
572+
PyObject(PyObject),
573+
}
574+
575+
#[derive(Debug)]
576+
pub struct LocalPythonMessage {
577+
pub message: PythonMessage,
578+
pub local_state: Option<Vec<LocalState>>,
579+
}
540580

541581
pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
542582
hyperactor_mod.add_class::<PickledMessage>()?;

monarch_messages/src/worker.rs

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,36 @@ pub struct CallFunctionParams {
402402
pub remote_process_groups: Vec<Ref>,
403403
}
404404

405+
/// The local state that has to be restored into the python_message during
406+
/// its unpickling.
407+
#[derive(Serialize, Deserialize, Debug, Clone)]
408+
pub enum LocalState {
409+
Ref(Ref),
410+
Mailbox,
411+
}
412+
413+
#[derive(Serialize, Deserialize, Debug, Clone)]
414+
pub struct ActorCallParams {
415+
pub seq: Seq,
416+
/// The actor to call is (proc_id of stream worker, 'actor', 'index')
417+
pub actor: String,
418+
pub index: usize,
419+
420+
// method name to call
421+
pub method: String,
422+
// pickled arguments, that will need to be patched with local state
423+
pub args_kwargs_tuple: Vec<u8>,
424+
/// Referenceable objects to pass to the actor,
425+
/// these will be put into the PythonMessage
426+
/// during its unpickling. Because unpickling also needs to
427+
/// restore mailboxes, we also have to keep track of which
428+
/// members should just be the mailbox object.
429+
pub local_state: Vec<LocalState>,
430+
431+
/// Tensors that will be mutated by the call.
432+
pub mutates: Vec<Ref>,
433+
pub stream: StreamRef,
434+
}
405435
/// Type of reduction for [`WorkerMessage::Reduce`].
406436
#[derive(Debug, Clone, Serialize, Deserialize)]
407437
pub enum Reduction {
@@ -642,21 +672,30 @@ pub enum WorkerMessage {
642672

643673
/// First use of the borrow on the receiving stream. This is a marker for
644674
/// synchronization.
645-
BorrowFirstUse { borrow: u64 },
675+
BorrowFirstUse {
676+
borrow: u64,
677+
},
646678

647679
/// Last use of the borrow on the receiving stream. This is a marker for
648680
/// synchronization.
649-
BorrowLastUse { borrow: u64 },
681+
BorrowLastUse {
682+
borrow: u64,
683+
},
650684

651685
/// Drop the borrow and free the resources associated with it.
652-
BorrowDrop { borrow: u64 },
686+
BorrowDrop {
687+
borrow: u64,
688+
},
653689

654690
/// Delete these refs from the worker state.
655691
DeleteRefs(Vec<Ref>),
656692

657693
/// A [`ControllerMessage::Status`] will be send to the controller
658694
/// when all streams have processed all the message sent before this one.
659-
RequestStatus { seq: Seq, controller: bool },
695+
RequestStatus {
696+
seq: Seq,
697+
controller: bool,
698+
},
660699

661700
/// Perform a reduction operation, using an efficient communication backend.
662701
/// Only NCCL is supported for now.
@@ -758,6 +797,7 @@ pub enum WorkerMessage {
758797
stream: StreamRef,
759798
},
760799

800+
SendResultOfActorCall(ActorCallParams),
761801
PipeRecv {
762802
seq: Seq,
763803
/// Result refs.

monarch_simulator/src/worker.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,14 @@ impl WorkerMessageHandler for WorkerActor {
311311
Ok(())
312312
}
313313

314+
async fn send_result_of_actor_call(
315+
&mut self,
316+
cx: &hyperactor::Context<Self>,
317+
params: ActorCallParams,
318+
) -> Result<()> {
319+
bail!("unimplemented: send_result_of_actor_call");
320+
}
321+
314322
async fn command_group(
315323
&mut self,
316324
cx: &hyperactor::Context<Self>,

monarch_tensor_worker/src/lib.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ use monarch_messages::controller::ControllerActor;
7171
use monarch_messages::controller::ControllerMessageClient;
7272
use monarch_messages::controller::Seq;
7373
use monarch_messages::wire_value::WireValue;
74+
use monarch_messages::worker::ActorCallParams;
7475
use monarch_messages::worker::CallFunctionError;
7576
use monarch_messages::worker::CallFunctionParams;
7677
use monarch_messages::worker::Factory;
@@ -859,6 +860,17 @@ impl WorkerMessageHandler for WorkerActor {
859860
.await
860861
}
861862

863+
async fn send_result_of_actor_call(
864+
&mut self,
865+
cx: &hyperactor::Context<Self>,
866+
params: ActorCallParams,
867+
) -> Result<()> {
868+
let stream = self.try_get_stream(params.stream)?;
869+
stream
870+
.send_result_of_actor_call(cx, cx.self_id().clone(), params)
871+
.await?;
872+
Ok(())
873+
}
862874
async fn split_comm(
863875
&mut self,
864876
cx: &hyperactor::Context<Self>,

monarch_tensor_worker/src/stream.rs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,17 @@ use hyperactor::mailbox::Mailbox;
3636
use hyperactor::mailbox::OncePortHandle;
3737
use hyperactor::mailbox::PortReceiver;
3838
use hyperactor::proc::Proc;
39+
use monarch_hyperactor::actor::LocalPythonMessage;
40+
use monarch_hyperactor::actor::PythonActor;
3941
use monarch_hyperactor::actor::PythonMessage;
42+
use monarch_hyperactor::mailbox::EitherPortRef;
4043
use monarch_messages::controller::ControllerMessageClient;
4144
use monarch_messages::controller::Seq;
4245
use monarch_messages::controller::WorkerError;
46+
use monarch_messages::worker::ActorCallParams;
4347
use monarch_messages::worker::CallFunctionError;
4448
use monarch_messages::worker::CallFunctionParams;
49+
use monarch_messages::worker::LocalState;
4550
use monarch_messages::worker::StreamRef;
4651
use monarch_types::PyTree;
4752
use monarch_types::SerializablePyErr;
@@ -233,6 +238,8 @@ pub enum StreamMessage {
233238
),
234239

235240
GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle<Option<TensorCellResult>>),
241+
242+
SendResultOfActorCall(ActorId, ActorCallParams),
236243
}
237244

238245
impl StreamMessage {
@@ -1666,6 +1673,75 @@ impl StreamMessageHandler for StreamActor {
16661673
Ok(())
16671674
}
16681675

1676+
async fn send_result_of_actor_call(
1677+
&mut self,
1678+
cx: &Context<Self>,
1679+
worker_actor_id: ActorId,
1680+
params: ActorCallParams,
1681+
) -> anyhow::Result<()> {
1682+
// TODO: handle mutates
1683+
let actor_id = ActorId(cx.proc().proc_id().clone(), params.actor, params.index);
1684+
let actor_ref: ActorRef<PythonActor> = ActorRef::attest(actor_id);
1685+
let actor_handle = actor_ref.downcast_handle(cx).unwrap();
1686+
let (send, recv) = cx.open_once_port();
1687+
let send = send.bind();
1688+
let send = EitherPortRef::Once(send.into());
1689+
let message = PythonMessage {
1690+
method: params.method,
1691+
message: params.args_kwargs_tuple.into(),
1692+
response_port: Some(send),
1693+
rank: None,
1694+
};
1695+
let local_state: Result<Vec<monarch_hyperactor::actor::LocalState>> =
1696+
Python::with_gil(|py| {
1697+
params
1698+
.local_state
1699+
.into_iter()
1700+
.map(|elem| match elem {
1701+
LocalState::Mailbox => Ok(monarch_hyperactor::actor::LocalState::Mailbox),
1702+
// SAFETY: python is gonna make unsafe copies of this stuff anyway
1703+
LocalState::Ref(r) => unsafe {
1704+
let x = self.ref_to_rvalue(&r)?.try_to_object_unsafe(py)?.into();
1705+
Ok(monarch_hyperactor::actor::LocalState::PyObject(x))
1706+
},
1707+
})
1708+
.collect()
1709+
});
1710+
// including making the PyMailbox
1711+
let message = LocalPythonMessage {
1712+
message,
1713+
local_state: Some(local_state?),
1714+
};
1715+
actor_handle.send(message)?;
1716+
let result = recv.recv().await?.with_rank(worker_actor_id.rank());
1717+
if result.method == "exception" {
1718+
// If result has "exception" as its kind, then
1719+
// we need to unpickle and turn it into a WorkerError
1720+
// and call remote_function_failed otherwise the
1721+
// controller assumes the object is correct and doesn't handle
1722+
// dependency tracking correctly.
1723+
let err = Python::with_gil(|py| -> Result<WorkerError, SerializablePyErr> {
1724+
let err = py
1725+
.import("pickle")
1726+
.unwrap()
1727+
.call_method1("loads", (result.message.into_vec(),))?;
1728+
Ok(WorkerError {
1729+
worker_actor_id,
1730+
backtrace: err.to_string(),
1731+
})
1732+
})?;
1733+
self.controller_actor
1734+
.remote_function_failed(cx, params.seq, err)
1735+
.await?;
1736+
} else {
1737+
let result = Serialized::serialize(&result).unwrap();
1738+
self.controller_actor
1739+
.fetch_result(cx, params.seq, Ok(result))
1740+
.await?;
1741+
}
1742+
Ok(())
1743+
}
1744+
16691745
async fn set_value(
16701746
&mut self,
16711747
cx: &Context<Self>,

0 commit comments

Comments
 (0)