Skip to content

Commit f31f027

Browse files
committed
[14/n] tensor engine {handle_cast, handle} -> handle
We unified the message handling logic in rust, so we can keep that unification for the python bindings. We keep the handle_cast variant, rename it handle. We provide a definition for cast coordinates even when a message wasn't cast: the message is considered to be the rank of a zero dimension shape. Differential Revision: [D78061501](https://our.internmc.facebook.com/intern/diff/D78061501/) ghstack-source-id: 295293115 Pull Request resolved: #484
1 parent 81390a7 commit f31f027

File tree

8 files changed

+29
-71
lines changed

8 files changed

+29
-71
lines changed

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ pub(crate) mod test_util {
451451
cx: &Context<Self>,
452452
GetRank(ok, reply): GetRank,
453453
) -> Result<(), anyhow::Error> {
454-
let (rank, _) = cx.cast_info()?;
454+
let (rank, _) = cx.cast_info();
455455
reply.send(cx, rank)?;
456456
anyhow::ensure!(ok, "intentional error!"); // If `!ok` exit with `Err()`.
457457
Ok(())

hyperactor_mesh/src/comm/multicast.rs

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -196,32 +196,20 @@ pub fn set_cast_info_on_headers(headers: &mut Attrs, rank: usize, shape: Shape,
196196
}
197197

198198
pub trait CastInfo {
199-
/// Get the cast rank and cast shape, returning an error
200-
/// if the relevant info isn't available.
201-
fn cast_info(&self) -> anyhow::Result<(usize, Shape)>;
202-
203-
/// Get the cast rank and cast shape, returning None
204-
/// if the relevant info isn't available.
205-
fn maybe_cast_info(&self) -> Option<(usize, Shape)>;
199+
/// Get the cast rank and cast shape.
200+
/// If something wasn't explicitly sent via a cast, then
201+
/// we represent it as the only member of a 0-dimensonal cast shape,
202+
/// which is the same as a singleton.
203+
fn cast_info(&self) -> (usize, Shape);
206204
}
207205

208206
impl<A: Actor> CastInfo for Context<'_, A> {
209-
fn cast_info(&self) -> anyhow::Result<(usize, Shape)> {
207+
fn cast_info(&self) -> (usize, Shape) {
210208
let headers = self.headers();
211-
let rank = headers
212-
.get(CAST_RANK)
213-
.ok_or_else(|| anyhow::anyhow!("{} not found in headers", CAST_RANK.name()))?;
214-
let shape = headers
215-
.get(CAST_SHAPE)
216-
.ok_or_else(|| anyhow::anyhow!("{} not found in headers", CAST_SHAPE.name()))?
217-
.clone();
218-
Ok((*rank, shape))
219-
}
220-
221-
fn maybe_cast_info(&self) -> Option<(usize, Shape)> {
222-
let headers = self.headers();
223-
headers
224-
.get(CAST_RANK)
225-
.map(|rank| headers.get(CAST_SHAPE).map(|shape| (*rank, shape.clone())))?
209+
match (headers.get(CAST_RANK), headers.get(CAST_SHAPE)) {
210+
(Some(rank), Some(shape)) => (*rank, shape.clone()),
211+
(None, None) => (0, Shape::unity()),
212+
_ => panic!("Expected either both rank and shape or neither"),
213+
}
226214
}
227215
}

monarch_hyperactor/src/actor.rs

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use hyperactor::message::Unbind;
2828
use hyperactor_mesh::comm::multicast::CastInfo;
2929
use monarch_types::PickledPyObject;
3030
use monarch_types::SerializablePyErr;
31+
use ndslice::Shape;
3132
use pyo3::conversion::IntoPyObjectExt;
3233
use pyo3::exceptions::PyBaseException;
3334
use pyo3::exceptions::PyRuntimeError;
@@ -408,34 +409,21 @@ impl Handler<PythonMessage> for PythonActor {
408409
let (sender, receiver) = oneshot::channel();
409410

410411
let future = Python::with_gil(|py| -> Result<_, SerializablePyErr> {
411-
let awaitable = match cx.maybe_cast_info() {
412-
Some((rank, shape)) => self.actor.call_method(
413-
py,
414-
"handle_cast",
415-
(
416-
mailbox,
417-
rank,
418-
PyShape::from(shape),
419-
message,
420-
PanicFlag {
421-
sender: Some(sender),
422-
},
423-
),
424-
None,
425-
)?,
426-
None => self.actor.call_method(
427-
py,
428-
"handle",
429-
(
430-
mailbox,
431-
message,
432-
PanicFlag {
433-
sender: Some(sender),
434-
},
435-
),
436-
None,
437-
)?,
438-
};
412+
let (rank, shape) = cx.cast_info();
413+
let awaitable = self.actor.call_method(
414+
py,
415+
"handle",
416+
(
417+
mailbox,
418+
rank,
419+
PyShape::from(shape),
420+
message,
421+
PanicFlag {
422+
sender: Some(sender),
423+
},
424+
),
425+
None,
426+
)?;
439427

440428
pyo3_async_runtimes::into_future_with_locals(
441429
self.get_task_locals(py),

monarch_tensor_worker/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ impl Handler<AssignRankMessage> for WorkerActor {
266266
cx: &hyperactor::Context<Self>,
267267
_: AssignRankMessage,
268268
) -> anyhow::Result<()> {
269-
let (rank, shape) = cx.cast_info()?;
269+
let (rank, shape) = cx.cast_info();
270270
self.rank = rank;
271271
self.respond_with_python_message = true;
272272
Python::with_gil(|py| {

python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,6 @@ class PanicFlag:
179179

180180
class Actor(Protocol):
181181
async def handle(
182-
self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag
183-
) -> None: ...
184-
async def handle_cast(
185182
self,
186183
mailbox: Mailbox,
187184
rank: int,

python/monarch/_src/actor/actor_mesh.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -571,11 +571,6 @@ def __init__(self) -> None:
571571
self.instance: object | None = None
572572

573573
async def handle(
574-
self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag
575-
) -> None:
576-
return await self.handle_cast(mailbox, 0, singleton_shape, message, panic_flag)
577-
578-
async def handle_cast(
579574
self,
580575
mailbox: Mailbox,
581576
rank: int,

python/tests/_monarch/test_hyperactor.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,6 @@
2727

2828
class MyActor:
2929
async def handle(
30-
self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag
31-
) -> None:
32-
raise NotImplementedError()
33-
34-
async def handle_cast(
3530
self,
3631
mailbox: Mailbox,
3732
rank: int,

python/tests/_monarch/test_mailbox.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,6 @@ async def recv_message() -> str:
121121

122122
class MyActor:
123123
async def handle(
124-
self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag
125-
) -> None:
126-
return None
127-
128-
async def handle_cast(
129124
self,
130125
mailbox: Mailbox,
131126
rank: int,

0 commit comments

Comments
 (0)