Skip to content

Commit bfac79b

Browse files
zdevitofacebook-github-bot
authored andcommitted
Fix python port serialization (#560)
Summary: Pull Request resolved: #560 This was regressed when we switched from serializable PortId objects into rust-bound PythonPortRef objects. Restore serialization by defining new/__reduce__ pair. ghstack-source-id: 296655361 exported-using-ghexport Reviewed By: suo Differential Revision: D78438687 fbshipit-source-id: ba1377b02a500489a0b43b1e37499f2c5661e9c0
1 parent 5e92e8a commit bfac79b

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

monarch_hyperactor/src/mailbox.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
use std::hash::DefaultHasher;
1010
use std::hash::Hash;
1111
use std::hash::Hasher;
12+
use std::ops::Deref;
1213
use std::sync::Arc;
1314

1415
use hyperactor::Mailbox;
@@ -35,11 +36,13 @@ use hyperactor::message::Bindings;
3536
use hyperactor::message::Unbind;
3637
use hyperactor_mesh::comm::multicast::set_cast_info_on_headers;
3738
use monarch_types::PickledPyObject;
39+
use pyo3::IntoPyObjectExt;
3840
use pyo3::exceptions::PyEOFError;
3941
use pyo3::exceptions::PyRuntimeError;
4042
use pyo3::exceptions::PyValueError;
4143
use pyo3::prelude::*;
4244
use pyo3::types::PyTuple;
45+
use pyo3::types::PyType;
4346
use serde::Deserialize;
4447
use serde::Serialize;
4548

@@ -326,6 +329,19 @@ pub struct PythonPortRef {
326329

327330
#[pymethods]
328331
impl PythonPortRef {
332+
#[new]
333+
fn new(port: PyPortId) -> Self {
334+
Self {
335+
inner: PortRef::attest(port.into()),
336+
}
337+
}
338+
fn __reduce__<'py>(
339+
slf: Bound<'py, PythonPortRef>,
340+
) -> PyResult<(Bound<'py, PyType>, (PyPortId,))> {
341+
let id: PyPortId = (*slf.borrow()).inner.port_id().clone().into();
342+
Ok((slf.get_type(), (id,)))
343+
}
344+
329345
fn send(&self, mailbox: &PyMailbox, message: PythonMessage) -> PyResult<()> {
330346
self.inner
331347
.send(&mailbox.inner, message)
@@ -472,6 +488,22 @@ pub struct PythonOncePortRef {
472488

473489
#[pymethods]
474490
impl PythonOncePortRef {
491+
#[new]
492+
fn new(port: Option<PyPortId>) -> Self {
493+
Self {
494+
inner: port.map(|port| PortRef::attest(port.inner).into_once()),
495+
}
496+
}
497+
fn __reduce__<'py>(
498+
slf: Bound<'py, PythonOncePortRef>,
499+
) -> PyResult<(Bound<'py, PyType>, (Option<PyPortId>,))> {
500+
let id: Option<PyPortId> = (*slf.borrow())
501+
.inner
502+
.as_ref()
503+
.map(|x| x.port_id().clone().into());
504+
Ok((slf.get_type(), (id,)))
505+
}
506+
475507
fn send(&mut self, mailbox: &PyMailbox, message: PythonMessage) -> PyResult<()> {
476508
let Some(port_ref) = self.inner.take() else {
477509
return Err(PyErr::new::<PyValueError, _>("OncePortRef is already used"));

python/tests/test_python_actors.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
import torch
2222

23+
from monarch._src.actor.actor_mesh import Port, PortTuple
24+
2325
from monarch.actor import (
2426
Accumulator,
2527
Actor,
@@ -706,6 +708,24 @@ async def test_actor_log_streaming() -> None:
706708
pass
707709

708710

711+
class SendAlot(Actor):
712+
@endpoint
713+
async def send(self, port: Port[int]):
714+
for i in range(100):
715+
port.send(i)
716+
717+
718+
def test_port_as_argument():
719+
proc_mesh = local_proc_mesh(gpus=1).get()
720+
s = proc_mesh.spawn("send_alot", SendAlot).get()
721+
send, recv = PortTuple.create(proc_mesh._mailbox, None)
722+
723+
s.send.broadcast(send)
724+
725+
for i in range(100):
726+
assert i == recv.recv().get()
727+
728+
709729
@pytest.mark.timeout(15)
710730
async def test_same_actor_twice() -> None:
711731
pm = await proc_mesh(gpus=1)

0 commit comments

Comments
 (0)