Skip to content

Commit 6056d13

Browse files
committed
[17/n] Pass monarch tensors to actor endpoints, part 2, actor messages sent in order
When you send a monarch.Tensor to an actor, what actually happens is that you instruct the stream managing the tensor to send the tensor to the actor. In part 1, that stream directly told the actor what method to call, and where to send the result. The issue with that is ordering: from the perspective of the original caller it is as if we are sending a message to the actor directly. If the stream sends the message it is possible that it arrives out of order with respect to a message that goes directly to the actor (e.g. a message that contains no monarch.Tensors). This resolves ordering issue: now the client sends a message to the actor 'CallMethodIndirect' that provides the pickled args/kwargs and method to invoke. The local python values of the torch.Tensors still need to come from Stream actor, so CallMethodIndirect looks them up by messaging the (new) LocalStateBrokerActor which can transfer PyObjects between local actors without serialization. Ordering is guarenteed for both tensors and actors because a message is sent to both the receiving actor _and_ the stream actor at the same time. Differential Revision: [D78314012](https://our.internmc.facebook.com/intern/diff/D78314012/) ghstack-source-id: 296356330 Pull Request resolved: #539
1 parent 21a70d3 commit 6056d13

File tree

17 files changed

+320
-168
lines changed

17 files changed

+320
-168
lines changed

monarch_extension/src/convert.rs

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -176,35 +176,6 @@ impl<'a> MessageParser<'a> {
176176
fn parseWorkerMessageList(&self, name: &str) -> PyResult<Vec<WorkerMessage>> {
177177
self.attr(name)?.try_iter()?.map(|x| convert(x?)).collect()
178178
}
179-
#[allow(non_snake_case)]
180-
fn parseLocalStateList(&self, name: &str) -> PyResult<Vec<worker::LocalState>> {
181-
let local_state_list = self.attr(name)?;
182-
let mut result = Vec::new();
183-
184-
// Get the PyMailbox class for type checking
185-
let mailbox_class = self
186-
.current
187-
.py()
188-
.import("monarch._rust_bindings.monarch_hyperactor.mailbox")?
189-
.getattr("Mailbox")?;
190-
191-
for item in local_state_list.try_iter()? {
192-
let item = item?;
193-
// Check if it's a Ref by trying to extract it
194-
if let Ok(ref_val) = create_ref(item.clone()) {
195-
result.push(worker::LocalState::Ref(ref_val));
196-
} else if item.is_instance(&mailbox_class)? {
197-
// It's a PyMailbox instance
198-
result.push(worker::LocalState::Mailbox);
199-
} else {
200-
return Err(PyValueError::new_err(format!(
201-
"Expected Ref or Mailbox in local_state, got: {}",
202-
item.get_type().name()?
203-
)));
204-
}
205-
}
206-
Ok(result)
207-
}
208179
fn parse_error_reason(&self, name: &str) -> PyResult<Option<(Option<ActorId>, String)>> {
209180
let err = self.attr(name)?;
210181
if err.is_none() {
@@ -465,11 +436,8 @@ fn create_map(py: Python) -> HashMap<u64, FnType> {
465436
Ok(WorkerMessage::SendResultOfActorCall(
466437
worker::ActorCallParams {
467438
seq: p.parseSeq("seq")?,
468-
actor: p.parse("actor")?,
469-
index: p.parse("actor_index")?,
470-
method: p.parse("method")?,
471-
args_kwargs_tuple: p.parse("args_kwargs_tuple")?,
472-
local_state: p.parseLocalStateList("local_state")?,
439+
broker_id: p.parse("broker_id")?,
440+
local_state: p.parseRefList("local_state")?,
473441
mutates: p.parseRefList("mutates")?,
474442
stream: p.parseStreamRef("stream")?,
475443
},

monarch_extension/src/mesh_controller.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ use hyperactor_mesh::shared_cell::SharedCell;
3838
use hyperactor_mesh::shared_cell::SharedCellRef;
3939
use monarch_hyperactor::actor::PythonMessage;
4040
use monarch_hyperactor::actor::PythonMessageKind;
41+
use monarch_hyperactor::local_state_broker::LocalStateBrokerActor;
4142
use monarch_hyperactor::mailbox::PyPortId;
4243
use monarch_hyperactor::ndslice::PySlice;
4344
use monarch_hyperactor::proc_mesh::PyProcMesh;
@@ -78,6 +79,7 @@ pub(crate) fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult
7879
struct _Controller {
7980
controller_handle: Arc<Mutex<ActorHandle<MeshControllerActor>>>,
8081
all_ranks: Slice,
82+
broker_id: (String, usize),
8183
}
8284

8385
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
@@ -123,9 +125,17 @@ impl _Controller {
123125
Ok(Self {
124126
controller_handle,
125127
all_ranks,
128+
// note that 0 is the _pid_ of the broker, which will be 0 for
129+
// top-level spawned actors.
130+
broker_id: (format!("tensor_engine_brokers_{}", id), 0),
126131
})
127132
}
128133

134+
#[getter]
135+
fn broker_id(&self) -> (String, usize) {
136+
self.broker_id.clone()
137+
}
138+
129139
#[pyo3(signature = (seq, defs, uses, response_port, tracebacks))]
130140
fn node<'py>(
131141
&mut self,
@@ -626,6 +636,7 @@ enum ClientToControllerMessage {
626636
struct MeshControllerActor {
627637
proc_mesh: SharedCell<TrackedProcMesh>,
628638
workers: Option<SharedCell<RootActorMesh<'static, WorkerActor>>>,
639+
brokers: Option<SharedCell<RootActorMesh<'static, LocalStateBrokerActor>>>,
629640
history: History,
630641
id: usize,
631642
debugger_active: Option<ActorRef<DebuggerActor>>,
@@ -730,6 +741,7 @@ impl Actor for MeshControllerActor {
730741
Ok(MeshControllerActor {
731742
proc_mesh: proc_mesh.clone(),
732743
workers: None,
744+
brokers: None,
733745
history: History::new(world_size),
734746
id,
735747
debugger_active: None,
@@ -758,6 +770,10 @@ impl Actor for MeshControllerActor {
758770
.cast(selection::dsl::true_(), AssignRankMessage::AssignRank())?;
759771

760772
self.workers = Some(workers);
773+
let brokers = proc_mesh
774+
.spawn(&format!("tensor_engine_brokers_{}", self.id), &())
775+
.await?;
776+
self.brokers = Some(brokers);
761777
Ok(())
762778
}
763779
}

monarch_hyperactor/src/actor.rs

Lines changed: 80 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use hyperactor::Context;
2121
use hyperactor::HandleClient;
2222
use hyperactor::Handler;
2323
use hyperactor::Named;
24+
use hyperactor::cap::CanSend;
2425
use hyperactor::forward;
2526
use hyperactor::message::Bind;
2627
use hyperactor::message::Bindings;
@@ -43,6 +44,8 @@ use tokio::sync::Mutex;
4344
use tokio::sync::oneshot;
4445

4546
use crate::config::SHARED_ASYNCIO_RUNTIME;
47+
use crate::local_state_broker::BrokerId;
48+
use crate::local_state_broker::LocalStateBrokerMessage;
4649
use crate::mailbox::EitherPortRef;
4750
use crate::mailbox::PyMailbox;
4851
use crate::proc::InstanceWrapper;
@@ -171,6 +174,13 @@ impl PickledMessageClientActor {
171174
}
172175
}
173176

177+
#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.actor")]
178+
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
179+
pub enum UnflattenArg {
180+
Mailbox,
181+
PyObject,
182+
}
183+
174184
#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.actor")]
175185
#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq)]
176186
pub enum PythonMessageKind {
@@ -185,6 +195,14 @@ pub enum PythonMessageKind {
185195
rank: Option<usize>,
186196
},
187197
Uninit {},
198+
CallMethodIndirect {
199+
name: String,
200+
local_state_broker: (String, usize),
201+
id: usize,
202+
// specify whether the argument to unflatten the local mailbox,
203+
// or the next argument of the local state.
204+
unflatten_args: Vec<UnflattenArg>,
205+
},
188206
}
189207

190208
impl Default for PythonMessageKind {
@@ -193,6 +211,11 @@ impl Default for PythonMessageKind {
193211
}
194212
}
195213

214+
fn mailbox<'py, T: Actor>(py: Python<'py>, cx: &Context<'_, T>) -> Bound<'py, PyAny> {
215+
let mailbox: PyMailbox = cx.mailbox_for_py().clone().into();
216+
mailbox.into_bound_py_any(py).unwrap()
217+
}
218+
196219
#[pyclass(frozen, module = "monarch._rust_bindings.monarch_hyperactor.actor")]
197220
#[derive(Clone, Serialize, Deserialize, Named, PartialEq, Default)]
198221
pub struct PythonMessage {
@@ -219,6 +242,56 @@ impl PythonMessage {
219242
_ => panic!("PythonMessage is not a response but {:?}", self),
220243
}
221244
}
245+
246+
pub async fn resolve_indirect_call<T: Actor>(
247+
mut self,
248+
cx: &Context<'_, T>,
249+
) -> anyhow::Result<(Self, PyObject)> {
250+
let local_state: PyObject;
251+
match self.kind {
252+
PythonMessageKind::CallMethodIndirect {
253+
name,
254+
local_state_broker,
255+
id,
256+
unflatten_args,
257+
} => {
258+
let broker = BrokerId::new(local_state_broker).resolve(cx).unwrap();
259+
let (send, recv) = cx.open_once_port();
260+
broker.send(LocalStateBrokerMessage::Get(id, send))?;
261+
let state = recv.recv().await?;
262+
let mut state_it = state.state.into_iter();
263+
local_state = Python::with_gil(|py| {
264+
let mailbox = mailbox(py, cx);
265+
PyList::new(
266+
py,
267+
unflatten_args.into_iter().map(|x| -> Bound<'_, PyAny> {
268+
match x {
269+
UnflattenArg::Mailbox => mailbox.clone(),
270+
UnflattenArg::PyObject => state_it.next().unwrap().into_bound(py),
271+
}
272+
}),
273+
)
274+
.unwrap()
275+
.into()
276+
});
277+
self.kind = PythonMessageKind::CallMethod {
278+
name,
279+
response_port: Some(state.response_port),
280+
}
281+
}
282+
_ => {
283+
local_state = Python::with_gil(|py| {
284+
let mailbox = mailbox(py, cx);
285+
py.import("itertools")
286+
.unwrap()
287+
.call_method1("repeat", (mailbox.clone(),))
288+
.unwrap()
289+
.unbind()
290+
});
291+
}
292+
};
293+
Ok((self, local_state))
294+
}
222295
}
223296

224297
impl std::fmt::Debug for PythonMessage {
@@ -408,30 +481,16 @@ impl PanicFlag {
408481
}
409482

410483
#[async_trait]
411-
impl Handler<LocalPythonMessage> for PythonActor {
412-
async fn handle(
413-
&mut self,
414-
cx: &Context<Self>,
415-
message: LocalPythonMessage,
416-
) -> anyhow::Result<()> {
417-
let mailbox: PyMailbox = PyMailbox {
418-
inner: cx.mailbox_for_py().clone(),
419-
};
484+
impl Handler<PythonMessage> for PythonActor {
485+
async fn handle(&mut self, cx: &Context<Self>, message: PythonMessage) -> anyhow::Result<()> {
486+
let (message, local_state) = message.resolve_indirect_call(cx).await?;
487+
420488
// Create a channel for signaling panics in async endpoints.
421489
// See [Panics in async endpoints].
422490
let (sender, receiver) = oneshot::channel();
423491

424492
let future = Python::with_gil(|py| -> Result<_, SerializablePyErr> {
425-
let mailbox = mailbox.into_bound_py_any(py).unwrap();
426-
let local_state: Option<Vec<Bound<'_, PyAny>>> = message.local_state.map(|state| {
427-
state
428-
.into_iter()
429-
.map(|e| match e {
430-
LocalState::Mailbox => mailbox.clone(),
431-
LocalState::PyObject(obj) => obj.into_bound_py_any(py).unwrap(),
432-
})
433-
.collect()
434-
});
493+
let mailbox = mailbox(py, cx);
435494
let (rank, shape) = cx.cast_info();
436495
let awaitable = self.actor.call_method(
437496
py,
@@ -440,7 +499,7 @@ impl Handler<LocalPythonMessage> for PythonActor {
440499
mailbox,
441500
rank,
442501
PyShape::from(shape),
443-
message.message,
502+
message,
444503
PanicFlag {
445504
sender: Some(sender),
446505
},
@@ -463,20 +522,6 @@ impl Handler<LocalPythonMessage> for PythonActor {
463522
}
464523
}
465524

466-
#[async_trait]
467-
impl Handler<PythonMessage> for PythonActor {
468-
async fn handle(&mut self, cx: &Context<Self>, message: PythonMessage) -> anyhow::Result<()> {
469-
self.handle(
470-
cx,
471-
LocalPythonMessage {
472-
message,
473-
local_state: None,
474-
},
475-
)
476-
.await
477-
}
478-
}
479-
480525
/// Helper struct to make a Python future passable in an actor message.
481526
///
482527
/// Also so that we don't have to write this massive type signature everywhere
@@ -576,24 +621,14 @@ impl AsyncEndpointInvocationHandler for AsyncEndpointTask {
576621
Ok(())
577622
}
578623
}
579-
#[derive(Debug)]
580-
pub enum LocalState {
581-
Mailbox,
582-
PyObject(PyObject),
583-
}
584-
585-
#[derive(Debug)]
586-
pub struct LocalPythonMessage {
587-
pub message: PythonMessage,
588-
pub local_state: Option<Vec<LocalState>>,
589-
}
590624

591625
pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
592626
hyperactor_mod.add_class::<PickledMessage>()?;
593627
hyperactor_mod.add_class::<PickledMessageClientActor>()?;
594628
hyperactor_mod.add_class::<PythonActorHandle>()?;
595629
hyperactor_mod.add_class::<PythonMessage>()?;
596630
hyperactor_mod.add_class::<PythonMessageKind>()?;
631+
hyperactor_mod.add_class::<UnflattenArg>()?;
597632
hyperactor_mod.add_class::<PanicFlag>()?;
598633
Ok(())
599634
}

monarch_hyperactor/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ pub mod alloc;
1414
pub mod bootstrap;
1515
pub mod channel;
1616
pub mod config;
17+
pub mod local_state_broker;
1718
pub mod mailbox;
1819
pub mod ndslice;
1920
pub mod proc;

0 commit comments

Comments
 (0)