Skip to content

Commit 167620a

Browse files
committed
[14/n] tensor engine {handle_cast, handle} -> handle
Pull Request resolved: #491 We unified the message handling logic in rust, so we can keep that unification for the python bindings. We keep the handle_cast variant, rename it handle. We provide a definition for cast coordinates even when a message wasn't cast: the message is considered to be the rank of a zero dimension shape. ghstack-source-id: 295545965 @exported-using-ghexport Differential Revision: [D78061501](https://our.internmc.facebook.com/intern/diff/D78061501/)
1 parent e120236 commit 167620a

File tree

10 files changed

+34
-76
lines changed

10 files changed

+34
-76
lines changed

hyperactor_mesh/examples/dining_philosophers.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ impl Handler<PhilosopherMessage> for PhilosopherActor {
142142
cx: &Context<Self>,
143143
message: PhilosopherMessage,
144144
) -> Result<(), anyhow::Error> {
145-
let (rank, _) = cx.cast_info()?;
145+
let (rank, _) = cx.cast_info();
146146
self.rank = rank;
147147
match message {
148148
PhilosopherMessage::Start(waiter) => {

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ pub(crate) mod test_util {
441441
cx: &Context<Self>,
442442
GetRank(ok, reply): GetRank,
443443
) -> Result<(), anyhow::Error> {
444-
let (rank, _) = cx.cast_info()?;
444+
let (rank, _) = cx.cast_info();
445445
reply.send(cx, rank)?;
446446
anyhow::ensure!(ok, "intentional error!"); // If `!ok` exit with `Err()`.
447447
Ok(())

hyperactor_mesh/src/comm/multicast.rs

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -215,32 +215,20 @@ pub fn set_cast_info_on_headers(headers: &mut Attrs, rank: usize, shape: Shape,
215215
}
216216

217217
pub trait CastInfo {
218-
/// Get the cast rank and cast shape, returning an error
219-
/// if the relevant info isn't available.
220-
fn cast_info(&self) -> anyhow::Result<(usize, Shape)>;
221-
222-
/// Get the cast rank and cast shape, returning None
223-
/// if the relevant info isn't available.
224-
fn maybe_cast_info(&self) -> Option<(usize, Shape)>;
218+
/// Get the cast rank and cast shape.
219+
/// If something wasn't explicitly sent via a cast, then
220+
/// we represent it as the only member of a 0-dimensonal cast shape,
221+
/// which is the same as a singleton.
222+
fn cast_info(&self) -> (usize, Shape);
225223
}
226224

227225
impl<A: Actor> CastInfo for Context<'_, A> {
228-
fn cast_info(&self) -> anyhow::Result<(usize, Shape)> {
226+
fn cast_info(&self) -> (usize, Shape) {
229227
let headers = self.headers();
230-
let rank = headers
231-
.get(CAST_RANK)
232-
.ok_or_else(|| anyhow::anyhow!("{} not found in headers", CAST_RANK.name()))?;
233-
let shape = headers
234-
.get(CAST_SHAPE)
235-
.ok_or_else(|| anyhow::anyhow!("{} not found in headers", CAST_SHAPE.name()))?
236-
.clone();
237-
Ok((*rank, shape))
238-
}
239-
240-
fn maybe_cast_info(&self) -> Option<(usize, Shape)> {
241-
let headers = self.headers();
242-
headers
243-
.get(CAST_RANK)
244-
.map(|rank| headers.get(CAST_SHAPE).map(|shape| (*rank, shape.clone())))?
228+
match (headers.get(CAST_RANK), headers.get(CAST_SHAPE)) {
229+
(Some(rank), Some(shape)) => (*rank, shape.clone()),
230+
(None, None) => (0, Shape::unity()),
231+
_ => panic!("Expected either both rank and shape or neither"),
232+
}
245233
}
246234
}

monarch_hyperactor/src/actor.rs

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use hyperactor::message::Unbind;
2828
use hyperactor_mesh::comm::multicast::CastInfo;
2929
use monarch_types::PickledPyObject;
3030
use monarch_types::SerializablePyErr;
31+
use ndslice::Shape;
3132
use pyo3::conversion::IntoPyObjectExt;
3233
use pyo3::exceptions::PyBaseException;
3334
use pyo3::exceptions::PyRuntimeError;
@@ -408,34 +409,21 @@ impl Handler<PythonMessage> for PythonActor {
408409
let (sender, receiver) = oneshot::channel();
409410

410411
let future = Python::with_gil(|py| -> Result<_, SerializablePyErr> {
411-
let awaitable = match cx.maybe_cast_info() {
412-
Some((rank, shape)) => self.actor.call_method(
413-
py,
414-
"handle_cast",
415-
(
416-
mailbox,
417-
rank,
418-
PyShape::from(shape),
419-
message,
420-
PanicFlag {
421-
sender: Some(sender),
422-
},
423-
),
424-
None,
425-
)?,
426-
None => self.actor.call_method(
427-
py,
428-
"handle",
429-
(
430-
mailbox,
431-
message,
432-
PanicFlag {
433-
sender: Some(sender),
434-
},
435-
),
436-
None,
437-
)?,
438-
};
412+
let (rank, shape) = cx.cast_info();
413+
let awaitable = self.actor.call_method(
414+
py,
415+
"handle",
416+
(
417+
mailbox,
418+
rank,
419+
PyShape::from(shape),
420+
message,
421+
PanicFlag {
422+
sender: Some(sender),
423+
},
424+
),
425+
None,
426+
)?;
439427

440428
pyo3_async_runtimes::into_future_with_locals(
441429
self.get_task_locals(py),

monarch_rdma/examples/parameter_server.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ impl Handler<WorkerInit> for WorkerActor {
303303
cx: &Context<Self>,
304304
WorkerInit(ps_ref, rdma_managers): WorkerInit,
305305
) -> Result<(), anyhow::Error> {
306-
let (rank, _) = cx.cast_info()?;
306+
let (rank, _) = cx.cast_info();
307307

308308
println!("[worker_actor_{}] initializing", rank);
309309

@@ -336,7 +336,7 @@ impl Handler<WorkerStep> for WorkerActor {
336336
cx: &Context<Self>,
337337
WorkerStep(reply): WorkerStep,
338338
) -> Result<(), anyhow::Error> {
339-
let (rank, _) = cx.cast_info()?;
339+
let (rank, _) = cx.cast_info();
340340

341341
for (grad_value, weight) in self
342342
.local_gradients
@@ -385,7 +385,7 @@ impl Handler<WorkerUpdate> for WorkerActor {
385385
cx: &Context<Self>,
386386
WorkerUpdate(reply): WorkerUpdate,
387387
) -> Result<(), anyhow::Error> {
388-
let (rank, _) = cx.cast_info()?;
388+
let (rank, _) = cx.cast_info();
389389

390390
println!(
391391
"[worker_actor_{}] pulling new weights from parameter server (before: {:?})",
@@ -418,7 +418,7 @@ impl Handler<WorkerUpdate> for WorkerActor {
418418
impl Handler<Log> for WorkerActor {
419419
/// Logs the worker's weights
420420
async fn handle(&mut self, cx: &Context<Self>, _: Log) -> Result<(), anyhow::Error> {
421-
let (rank, _) = cx.cast_info()?;
421+
let (rank, _) = cx.cast_info();
422422
println!("[worker_actor_{}] weights: {:?}", rank, self.weights_data);
423423
Ok(())
424424
}

monarch_tensor_worker/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ impl Handler<AssignRankMessage> for WorkerActor {
266266
cx: &hyperactor::Context<Self>,
267267
_: AssignRankMessage,
268268
) -> anyhow::Result<()> {
269-
let (rank, shape) = cx.cast_info()?;
269+
let (rank, shape) = cx.cast_info();
270270
self.rank = rank;
271271
self.respond_with_python_message = true;
272272
Python::with_gil(|py| {

python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,6 @@ class PanicFlag:
179179

180180
class Actor(Protocol):
181181
async def handle(
182-
self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag
183-
) -> None: ...
184-
async def handle_cast(
185182
self,
186183
mailbox: Mailbox,
187184
rank: int,

python/monarch/_src/actor/actor_mesh.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -580,11 +580,6 @@ def __init__(self) -> None:
580580
self.instance: object | None = None
581581

582582
async def handle(
583-
self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag
584-
) -> None:
585-
return await self.handle_cast(mailbox, 0, singleton_shape, message, panic_flag)
586-
587-
async def handle_cast(
588583
self,
589584
mailbox: Mailbox,
590585
rank: int,

python/tests/_monarch/test_hyperactor.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,6 @@
2727

2828
class MyActor:
2929
async def handle(
30-
self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag
31-
) -> None:
32-
raise NotImplementedError()
33-
34-
async def handle_cast(
3530
self,
3631
mailbox: Mailbox,
3732
rank: int,

python/tests/_monarch/test_mailbox.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,6 @@ async def recv_message() -> str:
118118

119119
class MyActor:
120120
async def handle(
121-
self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag
122-
) -> None:
123-
return None
124-
125-
async def handle_cast(
126121
self,
127122
mailbox: Mailbox,
128123
rank: int,

0 commit comments

Comments
 (0)