Skip to content

Commit 280cfcd

Browse files
committed
[15/n] Pass monarch tensors to actor endpoints, part 1
Pull Request resolved: #518 This makes it possible to send a monarch tensor to an actor endpoint that is defiend over the same proc mesh as the tensor. The send is done locally so the actor can do work with the tensors. The stream actor in the tensor engine is sent a SendResultOfActorCall message, which it will forward to the actor, binding to the message the real local tensors that were passed as arguments. The stream actor that owns the tensor waits for called the actor to finish since the actor 'owns' the tensor through the duration of the call. Known limitation: this message to the actor can go out of order w.r.t to other messages sent from the owner of the tensor engine because the real message is being sent from the stream actor. The next PR will fix this limitation by sending _both_ the tensor engine and the actor a message at the same time. The actor will get a 'wait for SendResultOfActorCall' message, at which point it will stop processing any messages except for the SentResultOfActorCall message it is suppose to be waiting for. This way the correct order is preserved from the perspective of the tensor engine stream and the actor. ghstack-source-id: 295768514 @exported-using-ghexport Differential Revision: [D78196701](https://our.internmc.facebook.com/intern/diff/D78196701/)
1 parent 6dd9be7 commit 280cfcd

File tree

15 files changed

+377
-55
lines changed

15 files changed

+377
-55
lines changed

monarch_extension/src/convert.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,35 @@ impl<'a> MessageParser<'a> {
176176
fn parseWorkerMessageList(&self, name: &str) -> PyResult<Vec<WorkerMessage>> {
177177
self.attr(name)?.try_iter()?.map(|x| convert(x?)).collect()
178178
}
179+
#[allow(non_snake_case)]
180+
fn parseLocalStateList(&self, name: &str) -> PyResult<Vec<worker::LocalState>> {
181+
let local_state_list = self.attr(name)?;
182+
let mut result = Vec::new();
183+
184+
// Get the PyMailbox class for type checking
185+
let mailbox_class = self
186+
.current
187+
.py()
188+
.import("monarch._rust_bindings.monarch_hyperactor.mailbox")?
189+
.getattr("Mailbox")?;
190+
191+
for item in local_state_list.try_iter()? {
192+
let item = item?;
193+
// Check if it's a Ref by trying to extract it
194+
if let Ok(ref_val) = create_ref(item.clone()) {
195+
result.push(worker::LocalState::Ref(ref_val));
196+
} else if item.is_instance(&mailbox_class)? {
197+
// It's a PyMailbox instance
198+
result.push(worker::LocalState::Mailbox);
199+
} else {
200+
return Err(PyValueError::new_err(format!(
201+
"Expected Ref or Mailbox in local_state, got: {}",
202+
item.get_type().name()?
203+
)));
204+
}
205+
}
206+
Ok(result)
207+
}
179208
fn parse_error_reason(&self, name: &str) -> PyResult<Option<(Option<ActorId>, String)>> {
180209
let err = self.attr(name)?;
181210
if err.is_none() {
@@ -432,6 +461,19 @@ fn create_map(py: Python) -> HashMap<u64, FnType> {
432461
p.parseWorkerMessageList("commands")?,
433462
))
434463
});
464+
m.insert(key("SendResultOfActorCall"), |p| {
465+
Ok(WorkerMessage::SendResultOfActorCall(
466+
worker::ActorCallParams {
467+
seq: p.parseSeq("seq")?,
468+
actor: p.parse("actor")?,
469+
index: p.parse("actor_index")?,
470+
python_message: p.parse("python_message")?,
471+
local_state: p.parseLocalStateList("local_state")?,
472+
mutates: p.parseRefList("mutates")?,
473+
stream: p.parseStreamRef("stream")?,
474+
},
475+
))
476+
});
435477
m
436478
}
437479

monarch_extension/src/tensor_worker.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,6 +1372,7 @@ pub(crate) fn worker_message_to_py(py: Python<'_>, message: &WorkerMessage) -> P
13721372
}
13731373
WorkerMessage::SetRefUnitTestsOnly { .. } => unimplemented!(),
13741374
WorkerMessage::GetRefUnitTestsOnly { .. } => unimplemented!(),
1375+
WorkerMessage::SendResultOfActorCall { .. } => unimplemented!(),
13751376
}
13761377
}
13771378

monarch_hyperactor/src/actor.rs

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,10 @@ impl PickledMessageClientActor {
175175
#[pyclass(frozen, module = "monarch._rust_bindings.monarch_hyperactor.actor")]
176176
#[derive(Default, Clone, Serialize, Deserialize, Named, PartialEq)]
177177
pub struct PythonMessage {
178-
pub(crate) method: String,
179-
pub(crate) message: ByteBuf,
180-
response_port: Option<EitherPortRef>,
181-
rank: Option<usize>,
178+
pub method: String,
179+
pub message: ByteBuf,
180+
pub response_port: Option<EitherPortRef>,
181+
pub rank: Option<usize>,
182182
}
183183

184184
impl PythonMessage {
@@ -289,7 +289,7 @@ impl PythonActorHandle {
289289
PythonMessage { cast = true },
290290
],
291291
)]
292-
pub(super) struct PythonActor {
292+
pub struct PythonActor {
293293
/// The Python object that we delegate message handling to. An instance of
294294
/// `monarch.actor_mesh._Actor`.
295295
pub(super) actor: PyObject,
@@ -399,16 +399,30 @@ impl PanicFlag {
399399
}
400400

401401
#[async_trait]
402-
impl Handler<PythonMessage> for PythonActor {
403-
async fn handle(&mut self, cx: &Context<Self>, message: PythonMessage) -> anyhow::Result<()> {
404-
let mailbox = PyMailbox {
402+
impl Handler<LocalPythonMessage> for PythonActor {
403+
async fn handle(
404+
&mut self,
405+
cx: &Context<Self>,
406+
message: LocalPythonMessage,
407+
) -> anyhow::Result<()> {
408+
let mailbox: PyMailbox = PyMailbox {
405409
inner: cx.mailbox_for_py().clone(),
406410
};
407411
// Create a channel for signaling panics in async endpoints.
408412
// See [Panics in async endpoints].
409413
let (sender, receiver) = oneshot::channel();
410414

411415
let future = Python::with_gil(|py| -> Result<_, SerializablePyErr> {
416+
let mailbox = mailbox.into_bound_py_any(py).unwrap();
417+
let local_state: Option<Vec<Bound<'_, PyAny>>> = message.local_state.map(|state| {
418+
state
419+
.into_iter()
420+
.map(|e| match e {
421+
LocalState::Mailbox => mailbox.clone(),
422+
LocalState::PyObject(obj) => obj.into_bound_py_any(py).unwrap(),
423+
})
424+
.collect()
425+
});
412426
let (rank, shape) = cx.cast_info();
413427
let awaitable = self.actor.call_method(
414428
py,
@@ -417,10 +431,11 @@ impl Handler<PythonMessage> for PythonActor {
417431
mailbox,
418432
rank,
419433
PyShape::from(shape),
420-
message,
434+
message.message,
421435
PanicFlag {
422436
sender: Some(sender),
423437
},
438+
local_state,
424439
),
425440
None,
426441
)?;
@@ -439,6 +454,20 @@ impl Handler<PythonMessage> for PythonActor {
439454
}
440455
}
441456

457+
#[async_trait]
458+
impl Handler<PythonMessage> for PythonActor {
459+
async fn handle(&mut self, cx: &Context<Self>, message: PythonMessage) -> anyhow::Result<()> {
460+
self.handle(
461+
cx,
462+
LocalPythonMessage {
463+
message,
464+
local_state: None,
465+
},
466+
)
467+
.await
468+
}
469+
}
470+
442471
/// Helper struct to make a Python future passable in an actor message.
443472
///
444473
/// Also so that we don't have to write this massive type signature everywhere
@@ -538,6 +567,17 @@ impl AsyncEndpointInvocationHandler for AsyncEndpointTask {
538567
Ok(())
539568
}
540569
}
570+
#[derive(Debug)]
571+
pub enum LocalState {
572+
Mailbox,
573+
PyObject(PyObject),
574+
}
575+
576+
#[derive(Debug)]
577+
pub struct LocalPythonMessage {
578+
pub message: PythonMessage,
579+
pub local_state: Option<Vec<LocalState>>,
580+
}
541581

542582
pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
543583
hyperactor_mod.add_class::<PickledMessage>()?;

monarch_hyperactor/src/mailbox.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ impl PythonUndeliverablePortHandle {
313313
name = "PortRef",
314314
module = "monarch._rust_bindings.monarch_hyperactor.mailbox"
315315
)]
316-
pub(super) struct PythonPortRef {
316+
pub struct PythonPortRef {
317317
pub(crate) inner: PortRef<PythonMessage>,
318318
}
319319

@@ -453,7 +453,7 @@ impl PythonOncePortHandle {
453453
name = "OncePortRef",
454454
module = "monarch._rust_bindings.monarch_hyperactor.mailbox"
455455
)]
456-
pub(crate) struct PythonOncePortRef {
456+
pub struct PythonOncePortRef {
457457
pub(crate) inner: Option<OncePortRef<PythonMessage>>,
458458
}
459459

monarch_messages/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ anyhow = "1.0.98"
1212
derive_more = { version = "1.0.0", features = ["full"] }
1313
enum-as-inner = "0.6.0"
1414
hyperactor = { version = "0.0.0", path = "../hyperactor" }
15+
monarch_hyperactor = { version = "0.0.0", path = "../monarch_hyperactor" }
1516
monarch_types = { version = "0.0.0", path = "../monarch_types" }
1617
ndslice = { version = "0.0.0", path = "../ndslice" }
1718
pyo3 = { version = "0.24", features = ["anyhow", "multiple-pymethods"] }

monarch_messages/src/worker.rs

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use hyperactor::Named;
2727
use hyperactor::RefClient;
2828
use hyperactor::Unbind;
2929
use hyperactor::reference::ActorId;
30+
use monarch_hyperactor::actor::PythonMessage;
3031
use monarch_types::SerializablePyErr;
3132
use ndslice::Slice;
3233
use pyo3::exceptions::PyValueError;
@@ -402,6 +403,33 @@ pub struct CallFunctionParams {
402403
pub remote_process_groups: Vec<Ref>,
403404
}
404405

406+
/// The local state that has to be restored into the python_message during
407+
/// its unpickling.
408+
#[derive(Serialize, Deserialize, Debug, Clone)]
409+
pub enum LocalState {
410+
Ref(Ref),
411+
Mailbox,
412+
}
413+
414+
#[derive(Serialize, Deserialize, Debug, Clone)]
415+
pub struct ActorCallParams {
416+
pub seq: Seq,
417+
/// The actor to call is (proc_id of stream worker, 'actor', 'index')
418+
pub actor: String,
419+
pub index: usize,
420+
/// The PythonMessage object to send to the actor.
421+
pub python_message: PythonMessage,
422+
/// Referenceable objects to pass to the actor,
423+
/// these will be put into the PythonMessage
424+
/// during its unpickling. Because unpickling also needs to
425+
/// restore mailboxes, we also have to keep track of which
426+
/// members should just be the mailbox object.
427+
pub local_state: Vec<LocalState>,
428+
429+
/// Tensors that will be mutated by the call.
430+
pub mutates: Vec<Ref>,
431+
pub stream: StreamRef,
432+
}
405433
/// Type of reduction for [`WorkerMessage::Reduce`].
406434
#[derive(Debug, Clone, Serialize, Deserialize)]
407435
pub enum Reduction {
@@ -642,21 +670,30 @@ pub enum WorkerMessage {
642670

643671
/// First use of the borrow on the receiving stream. This is a marker for
644672
/// synchronization.
645-
BorrowFirstUse { borrow: u64 },
673+
BorrowFirstUse {
674+
borrow: u64,
675+
},
646676

647677
/// Last use of the borrow on the receiving stream. This is a marker for
648678
/// synchronization.
649-
BorrowLastUse { borrow: u64 },
679+
BorrowLastUse {
680+
borrow: u64,
681+
},
650682

651683
/// Drop the borrow and free the resources associated with it.
652-
BorrowDrop { borrow: u64 },
684+
BorrowDrop {
685+
borrow: u64,
686+
},
653687

654688
/// Delete these refs from the worker state.
655689
DeleteRefs(Vec<Ref>),
656690

657691
/// A [`ControllerMessage::Status`] will be send to the controller
658692
/// when all streams have processed all the message sent before this one.
659-
RequestStatus { seq: Seq, controller: bool },
693+
RequestStatus {
694+
seq: Seq,
695+
controller: bool,
696+
},
660697

661698
/// Perform a reduction operation, using an efficient communication backend.
662699
/// Only NCCL is supported for now.
@@ -758,6 +795,7 @@ pub enum WorkerMessage {
758795
stream: StreamRef,
759796
},
760797

798+
SendResultOfActorCall(ActorCallParams),
761799
PipeRecv {
762800
seq: Seq,
763801
/// Result refs.

monarch_simulator/src/worker.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,14 @@ impl WorkerMessageHandler for WorkerActor {
312312
Ok(())
313313
}
314314

315+
async fn send_result_of_actor_call(
316+
&mut self,
317+
cx: &hyperactor::Context<Self>,
318+
params: ActorCallParams,
319+
) -> Result<()> {
320+
bail!("unimplemented: send_result_of_actor_call");
321+
}
322+
315323
async fn command_group(
316324
&mut self,
317325
cx: &hyperactor::Context<Self>,

monarch_tensor_worker/src/lib.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ use monarch_messages::controller::ControllerActor;
7272
use monarch_messages::controller::ControllerMessageClient;
7373
use monarch_messages::controller::Seq;
7474
use monarch_messages::wire_value::WireValue;
75+
use monarch_messages::worker::ActorCallParams;
7576
use monarch_messages::worker::CallFunctionError;
7677
use monarch_messages::worker::CallFunctionParams;
7778
use monarch_messages::worker::Factory;
@@ -860,6 +861,17 @@ impl WorkerMessageHandler for WorkerActor {
860861
.await
861862
}
862863

864+
async fn send_result_of_actor_call(
865+
&mut self,
866+
cx: &hyperactor::Context<Self>,
867+
params: ActorCallParams,
868+
) -> Result<()> {
869+
let stream = self.try_get_stream(params.stream)?;
870+
stream
871+
.send_result_of_actor_call(cx, cx.self_id().clone(), params)
872+
.await?;
873+
Ok(())
874+
}
863875
async fn split_comm(
864876
&mut self,
865877
cx: &hyperactor::Context<Self>,

0 commit comments

Comments
 (0)