Skip to content

Commit 8b0586c

Browse files
committed
[16/n] tensor engine: enum-ify PythonMessage
PythonMessage was getting a bit stringly typed. I've sectioned into 4 enum case (CallMethod, Result, Exception, Uninit) which capture when you have to pay attention to things like rank/response_port. I have to further add functionality in a follow-up to get actor ordering correct. Differential Revision: [D78295683](https://our.internmc.facebook.com/intern/diff/D78295683/) ghstack-source-id: 296136316 Pull Request resolved: #529
1 parent 12d813f commit 8b0586c

File tree

9 files changed

+264
-158
lines changed

9 files changed

+264
-158
lines changed

monarch_extension/src/mesh_controller.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ use hyperactor_mesh::actor_mesh::RootActorMesh;
3737
use hyperactor_mesh::shared_cell::SharedCell;
3838
use hyperactor_mesh::shared_cell::SharedCellRef;
3939
use monarch_hyperactor::actor::PythonMessage;
40+
use monarch_hyperactor::actor::PythonMessageKind;
4041
use monarch_hyperactor::mailbox::PyPortId;
4142
use monarch_hyperactor::ndslice::PySlice;
4243
use monarch_hyperactor::proc_mesh::PyProcMesh;
@@ -334,7 +335,7 @@ impl Invocation {
334335
Some(PortInfo { port, ranks }) => {
335336
*unreported_exception = None;
336337
for rank in ranks.iter() {
337-
let msg = exception.as_ref().clone().with_rank(rank);
338+
let msg = exception.as_ref().clone().into_rank(rank);
338339
port.send(sender, msg)?;
339340
}
340341
}
@@ -527,7 +528,7 @@ impl History {
527528
.call1((exception.backtrace, traceback, rank))
528529
.unwrap();
529530
let data: Vec<u8> = pickle.call1((exe,)).unwrap().extract().unwrap();
530-
PythonMessage::new_from_buf("exception".to_string(), data, None, Some(rank))
531+
PythonMessage::new_from_buf(PythonMessageKind::Exception { rank: Some(rank) }, data)
531532
}));
532533

533534
let mut invocation = invocation.lock().unwrap();
@@ -570,7 +571,10 @@ impl History {
570571
Some(exception) => exception.as_ref().clone(),
571572
None => {
572573
// the byte string is just a Python None
573-
PythonMessage::new("result".to_string(), b"\x80\x04N.", None, None)
574+
PythonMessage::new_from_buf(
575+
PythonMessageKind::Result { rank: None },
576+
b"\x80\x04N.".to_vec(),
577+
)
574578
}
575579
};
576580
port.send(sender, result)?;

monarch_hyperactor/src/actor.rs

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -171,41 +171,60 @@ impl PickledMessageClientActor {
171171
}
172172
}
173173

174+
#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.actor")]
175+
#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq)]
176+
pub enum PythonMessageKind {
177+
CallMethod {
178+
name: String,
179+
response_port: Option<EitherPortRef>,
180+
},
181+
Result {
182+
rank: Option<usize>,
183+
},
184+
Exception {
185+
rank: Option<usize>,
186+
},
187+
Uninit {},
188+
}
189+
190+
impl Default for PythonMessageKind {
191+
fn default() -> Self {
192+
PythonMessageKind::Uninit {}
193+
}
194+
}
195+
174196
#[pyclass(frozen, module = "monarch._rust_bindings.monarch_hyperactor.actor")]
175-
#[derive(Default, Clone, Serialize, Deserialize, Named, PartialEq)]
197+
#[derive(Clone, Serialize, Deserialize, Named, PartialEq, Default)]
176198
pub struct PythonMessage {
177-
pub method: String,
178-
pub message: ByteBuf,
179-
pub response_port: Option<EitherPortRef>,
180-
pub rank: Option<usize>,
199+
pub kind: PythonMessageKind,
200+
pub message: Vec<u8>,
181201
}
182202

183203
impl PythonMessage {
184-
pub fn with_rank(self, rank: usize) -> PythonMessage {
185-
PythonMessage {
186-
rank: Some(rank),
187-
..self
188-
}
204+
pub fn new_from_buf(kind: PythonMessageKind, message: Vec<u8>) -> Self {
205+
Self { kind, message }
189206
}
190-
pub fn new_from_buf(
191-
method: String,
192-
message: Vec<u8>,
193-
response_port: Option<EitherPortRef>,
194-
rank: Option<usize>,
195-
) -> Self {
196-
Self {
197-
method,
198-
message: message.into(),
199-
response_port,
200-
rank,
207+
208+
pub fn into_rank(self, rank: usize) -> Self {
209+
let rank = Some(rank);
210+
match self.kind {
211+
PythonMessageKind::Result { .. } => PythonMessage {
212+
kind: PythonMessageKind::Result { rank },
213+
message: self.message,
214+
},
215+
PythonMessageKind::Exception { .. } => PythonMessage {
216+
kind: PythonMessageKind::Exception { rank },
217+
message: self.message,
218+
},
219+
_ => panic!("PythonMessage is not a response but {:?}", self),
201220
}
202221
}
203222
}
204223

205224
impl std::fmt::Debug for PythonMessage {
206225
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207226
f.debug_struct("PythonMessage")
208-
.field("method", &self.method)
227+
.field("kind", &self.kind)
209228
.field(
210229
"message",
211230
&hyperactor::data::HexFmt(self.message.as_slice()).to_string(),
@@ -216,48 +235,39 @@ impl std::fmt::Debug for PythonMessage {
216235

217236
impl Unbind for PythonMessage {
218237
fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
219-
self.response_port.unbind(bindings)
238+
match &self.kind {
239+
PythonMessageKind::CallMethod { response_port, .. } => response_port.unbind(bindings),
240+
_ => Ok(()),
241+
}
220242
}
221243
}
222244

223245
impl Bind for PythonMessage {
224246
fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
225-
self.response_port.bind(bindings)
247+
match &mut self.kind {
248+
PythonMessageKind::CallMethod { response_port, .. } => response_port.bind(bindings),
249+
_ => Ok(()),
250+
}
226251
}
227252
}
228253

229254
#[pymethods]
230255
impl PythonMessage {
231256
#[new]
232-
#[pyo3(signature = (method, message, response_port, rank))]
233-
pub fn new(
234-
method: String,
235-
message: &[u8],
236-
response_port: Option<EitherPortRef>,
237-
rank: Option<usize>,
238-
) -> Self {
239-
Self::new_from_buf(method, message.into(), response_port, rank)
257+
#[pyo3(signature = (kind, message))]
258+
pub fn new(kind: PythonMessageKind, message: &[u8]) -> Self {
259+
PythonMessage::new_from_buf(kind, message.to_vec())
240260
}
241261

242262
#[getter]
243-
fn method(&self) -> &String {
244-
&self.method
263+
fn kind(&self) -> PythonMessageKind {
264+
self.kind.clone()
245265
}
246266

247267
#[getter]
248268
fn message<'a>(&self, py: Python<'a>) -> Bound<'a, PyBytes> {
249269
PyBytes::new(py, self.message.as_ref())
250270
}
251-
252-
#[getter]
253-
fn response_port(&self) -> Option<EitherPortRef> {
254-
self.response_port.clone()
255-
}
256-
257-
#[getter]
258-
fn rank(&self) -> Option<usize> {
259-
self.rank
260-
}
261271
}
262272

263273
#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.actor")]
@@ -583,6 +593,7 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
583593
hyperactor_mod.add_class::<PickledMessageClientActor>()?;
584594
hyperactor_mod.add_class::<PythonActorHandle>()?;
585595
hyperactor_mod.add_class::<PythonMessage>()?;
596+
hyperactor_mod.add_class::<PythonMessageKind>()?;
586597
hyperactor_mod.add_class::<PanicFlag>()?;
587598
Ok(())
588599
}

monarch_hyperactor/src/mailbox.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ use serde::Deserialize;
4444
use serde::Serialize;
4545

4646
use crate::actor::PythonMessage;
47+
use crate::actor::PythonMessageKind;
4748
use crate::proc::PyActorId;
4849
use crate::runtime::signal_safe_block_on;
4950
use crate::shape::PyShape;
@@ -528,7 +529,8 @@ impl PythonOncePortReceiver {
528529
Named,
529530
PartialEq,
530531
FromPyObject,
531-
IntoPyObject
532+
IntoPyObject,
533+
Debug
532534
)]
533535
pub enum EitherPortRef {
534536
Unbounded(PythonPortRef),
@@ -610,7 +612,7 @@ impl Accumulator for PythonAccumulator {
610612
fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()> {
611613
Python::with_gil(|py: Python<'_>| {
612614
// Initialize state if it is empty.
613-
if state.message.is_empty() && state.method.is_empty() {
615+
if matches!(state.kind, PythonMessageKind::Uninit {}) {
614616
*state = self
615617
.accumulator
616618
.getattr(py, "initial_state")?

monarch_tensor_worker/src/stream.rs

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ use hyperactor::proc::Proc;
3939
use monarch_hyperactor::actor::LocalPythonMessage;
4040
use monarch_hyperactor::actor::PythonActor;
4141
use monarch_hyperactor::actor::PythonMessage;
42+
use monarch_hyperactor::actor::PythonMessageKind;
4243
use monarch_hyperactor::mailbox::EitherPortRef;
4344
use monarch_messages::controller::ControllerMessageClient;
4445
use monarch_messages::controller::Seq;
@@ -1009,10 +1010,10 @@ impl StreamActor {
10091010
.extract()
10101011
.unwrap();
10111012
Ok(PythonMessage::new_from_buf(
1012-
"result".to_string(),
1013+
PythonMessageKind::Result {
1014+
rank: Some(worker_actor_id.rank()),
1015+
},
10131016
data,
1014-
None,
1015-
Some(worker_actor_id.rank()),
10161017
))
10171018
})
10181019
});
@@ -1686,12 +1687,13 @@ impl StreamMessageHandler for StreamActor {
16861687
let (send, recv) = cx.open_once_port();
16871688
let send = send.bind();
16881689
let send = EitherPortRef::Once(send.into());
1689-
let message = PythonMessage {
1690-
method: params.method,
1691-
message: params.args_kwargs_tuple.into(),
1692-
response_port: Some(send),
1693-
rank: None,
1694-
};
1690+
let message = PythonMessage::new_from_buf(
1691+
PythonMessageKind::CallMethod {
1692+
name: params.method,
1693+
response_port: Some(send),
1694+
},
1695+
params.args_kwargs_tuple.into(),
1696+
);
16951697
let local_state: Result<Vec<monarch_hyperactor::actor::LocalState>> =
16961698
Python::with_gil(|py| {
16971699
params
@@ -1713,31 +1715,44 @@ impl StreamMessageHandler for StreamActor {
17131715
local_state: Some(local_state?),
17141716
};
17151717
actor_handle.send(message)?;
1716-
let result = recv.recv().await?.with_rank(worker_actor_id.rank());
1717-
if result.method == "exception" {
1718-
// If result has "exception" as its kind, then
1719-
// we need to unpickle and turn it into a WorkerError
1720-
// and call remote_function_failed otherwise the
1721-
// controller assumes the object is correct and doesn't handle
1722-
// dependency tracking correctly.
1723-
let err = Python::with_gil(|py| -> Result<WorkerError, SerializablePyErr> {
1724-
let err = py
1725-
.import("pickle")
1726-
.unwrap()
1727-
.call_method1("loads", (result.message.into_vec(),))?;
1728-
Ok(WorkerError {
1729-
worker_actor_id,
1730-
backtrace: err.to_string(),
1731-
})
1732-
})?;
1733-
self.controller_actor
1734-
.remote_function_failed(cx, params.seq, err)
1735-
.await?;
1736-
} else {
1737-
let result = Serialized::serialize(&result).unwrap();
1738-
self.controller_actor
1739-
.fetch_result(cx, params.seq, Ok(result))
1740-
.await?;
1718+
let result = recv.recv().await?;
1719+
match result.kind {
1720+
PythonMessageKind::Exception { .. } => {
1721+
// If result has "exception" as its kind, then
1722+
// we need to unpickle and turn it into a WorkerError
1723+
// and call remote_function_failed otherwise the
1724+
// controller assumes the object is correct and doesn't handle
1725+
// dependency tracking correctly.
1726+
let err = Python::with_gil(|py| -> Result<WorkerError, SerializablePyErr> {
1727+
let err = py
1728+
.import("pickle")
1729+
.unwrap()
1730+
.call_method1("loads", (result.message,))?;
1731+
Ok(WorkerError {
1732+
worker_actor_id,
1733+
backtrace: err.to_string(),
1734+
})
1735+
})?;
1736+
self.controller_actor
1737+
.remote_function_failed(cx, params.seq, err)
1738+
.await?;
1739+
}
1740+
PythonMessageKind::Result { .. } => {
1741+
let result = PythonMessage::new_from_buf(
1742+
PythonMessageKind::Result {
1743+
rank: Some(worker_actor_id.rank()),
1744+
},
1745+
result.message,
1746+
);
1747+
let result = Serialized::serialize(&result).unwrap();
1748+
self.controller_actor
1749+
.fetch_result(cx, params.seq, Ok(result))
1750+
.await?;
1751+
}
1752+
_ => panic!(
1753+
"Unexpected response kind from PythonActor: {:?}",
1754+
result.kind
1755+
),
17411756
}
17421757
Ok(())
17431758
}

0 commit comments

Comments
 (0)