From c4fb77b39d0db6d36853f6e5ae903a3375643634 Mon Sep 17 00:00:00 2001 From: Allen Wang <9057208+allenwang28@users.noreply.github.com> Date: Tue, 15 Jul 2025 07:38:01 -0700 Subject: [PATCH] Clean up of monarch_rdma/extension (#528) Summary: Pull Request resolved: https://github.com/pytorch-labs/monarch/pull/528 Follow up to D76937776 - Reverts some objects back to `pub(super)`, exposing the relevant needed APIs - Cleanups of monarch_rdma/extension/lib.rs Reviewed By: colin2328, vidhyav Differential Revision: D78276671 --- monarch_hyperactor/src/mailbox.rs | 8 +++- monarch_rdma/extension/lib.rs | 66 +++++++++++++++---------------- 2 files changed, 40 insertions(+), 34 deletions(-) diff --git a/monarch_hyperactor/src/mailbox.rs b/monarch_hyperactor/src/mailbox.rs index e025fe1a..8740b66d 100644 --- a/monarch_hyperactor/src/mailbox.rs +++ b/monarch_hyperactor/src/mailbox.rs @@ -53,7 +53,13 @@ use crate::shape::PyShape; module = "monarch._rust_bindings.monarch_hyperactor.mailbox" )] pub struct PyMailbox { - pub inner: Mailbox, + pub(super) inner: Mailbox, +} + +impl PyMailbox { + pub fn get_inner(&self) -> &Mailbox { + &self.inner + } } #[pymethods] diff --git a/monarch_rdma/extension/lib.rs b/monarch_rdma/extension/lib.rs index fe2e35a4..1dc4da81 100644 --- a/monarch_rdma/extension/lib.rs +++ b/monarch_rdma/extension/lib.rs @@ -29,14 +29,15 @@ use pyo3::types::PyType; use serde::Deserialize; use serde::Serialize; -macro_rules! setup_rdma_context { - ($self:ident, $local_proc_id:expr) => {{ - let proc_id: ProcId = $local_proc_id.parse().unwrap(); - let local_owner_id = ActorId(proc_id, "rdma_manager".to_string(), 0); - let local_owner_ref: ActorRef = ActorRef::attest(local_owner_id); - let buffer = $self.buffer.clone(); - (local_owner_ref, buffer) - }}; +fn setup_rdma_context( + rdma_buffer: &PyRdmaBuffer, + local_proc_id: String, +) -> (ActorRef, RdmaBuffer) { + let proc_id: ProcId = local_proc_id.parse().unwrap(); + let local_owner_id = ActorId(proc_id, "rdma_manager".to_string(), 0); + let local_owner_ref: ActorRef = ActorRef::attest(local_owner_id); + let buffer = rdma_buffer.buffer.clone(); + (local_owner_ref, buffer) } #[pyclass(name = "_RdmaBuffer", module = "monarch._rust_bindings.rdma")] @@ -49,16 +50,16 @@ struct PyRdmaBuffer { async fn create_rdma_buffer( addr: usize, size: usize, - proc_id: String, + proc_id: ProcId, client: PyMailbox, ) -> PyResult { // Get the owning RdmaManagerActor's ActorRef - let proc_id: ProcId = proc_id.parse().unwrap(); let owner_id = ActorId(proc_id, "rdma_manager".to_string(), 0); let owner_ref: ActorRef = ActorRef::attest(owner_id); + let caps = client.get_inner(); // Create the RdmaBuffer - let buffer = owner_ref.request_buffer(&client.inner, addr, size).await?; + let buffer = owner_ref.request_buffer(caps, addr, size).await?; Ok(PyRdmaBuffer { buffer, owner_ref }) } @@ -78,7 +79,10 @@ impl PyRdmaBuffer { "ibverbs is not supported on this system", )); } - signal_safe_block_on(py, create_rdma_buffer(addr, size, proc_id, client))? + signal_safe_block_on( + py, + create_rdma_buffer(addr, size, proc_id.parse().unwrap(), client), + )? } #[classmethod] @@ -97,7 +101,7 @@ impl PyRdmaBuffer { } pyo3_async_runtimes::tokio::future_into_py( py, - create_rdma_buffer(addr, size, proc_id, client), + create_rdma_buffer(addr, size, proc_id.parse().unwrap(), client), ) } @@ -133,13 +137,12 @@ impl PyRdmaBuffer { client: PyMailbox, timeout: u64, ) -> PyResult> { - let (local_owner_ref, buffer) = setup_rdma_context!(self, local_proc_id); + let (local_owner_ref, buffer) = setup_rdma_context(self, local_proc_id); pyo3_async_runtimes::tokio::future_into_py(py, async move { - let local_buffer = local_owner_ref - .request_buffer(&client.inner, addr, size) - .await?; + let caps = client.get_inner(); + let local_buffer = local_owner_ref.request_buffer(caps, addr, size).await?; let _result_ = local_buffer - .write_from(&client.inner, buffer, timeout) + .write_from(caps, buffer, timeout) .await .map_err(|e| PyException::new_err(format!("failed to read into buffer: {}", e)))?; Ok(()) @@ -170,13 +173,12 @@ impl PyRdmaBuffer { client: PyMailbox, timeout: u64, ) -> PyResult { - let (local_owner_ref, buffer) = setup_rdma_context!(self, local_proc_id); + let (local_owner_ref, buffer) = setup_rdma_context(self, local_proc_id); signal_safe_block_on(py, async move { - let local_buffer = local_owner_ref - .request_buffer(&client.inner, addr, size) - .await?; + let caps = client.get_inner(); + let local_buffer = local_owner_ref.request_buffer(caps, addr, size).await?; local_buffer - .write_from(&client.inner, buffer, timeout) + .write_from(caps, buffer, timeout) .await .map_err(|e| PyException::new_err(format!("failed to read into buffer: {}", e))) })? @@ -204,13 +206,12 @@ impl PyRdmaBuffer { client: PyMailbox, timeout: u64, ) -> PyResult> { - let (local_owner_ref, buffer) = setup_rdma_context!(self, local_proc_id); + let (local_owner_ref, buffer) = setup_rdma_context(self, local_proc_id); pyo3_async_runtimes::tokio::future_into_py(py, async move { - let local_buffer = local_owner_ref - .request_buffer(&client.inner, addr, size) - .await?; + let caps = client.get_inner(); + let local_buffer = local_owner_ref.request_buffer(caps, addr, size).await?; let _result_ = local_buffer - .read_into(&client.inner, buffer, timeout) + .read_into(caps, buffer, timeout) .await .map_err(|e| PyException::new_err(format!("failed to write from buffer: {}", e)))?; Ok(()) @@ -241,13 +242,12 @@ impl PyRdmaBuffer { client: PyMailbox, timeout: u64, ) -> PyResult { - let (local_owner_ref, buffer) = setup_rdma_context!(self, local_proc_id); + let (local_owner_ref, buffer) = setup_rdma_context(self, local_proc_id); signal_safe_block_on(py, async move { - let local_buffer = local_owner_ref - .request_buffer(&client.inner, addr, size) - .await?; + let caps = client.get_inner(); + let local_buffer = local_owner_ref.request_buffer(caps, addr, size).await?; local_buffer - .read_into(&client.inner, buffer, timeout) + .read_into(caps, buffer, timeout) .await .map_err(|e| PyException::new_err(format!("failed to write from buffer: {}", e))) })?