Skip to content

Commit c4fb77b

Browse files
committed
Clean up of monarch_rdma/extension (#528)
Summary: Pull Request resolved: #528 Follow up to D76937776 - Reverts some objects back to `pub(super)`, exposing the relevant needed APIs - Cleanups of monarch_rdma/extension/lib.rs Reviewed By: colin2328, vidhyav Differential Revision: D78276671
1 parent 7a98d6e commit c4fb77b

File tree

2 files changed

+40
-34
lines changed

2 files changed

+40
-34
lines changed

monarch_hyperactor/src/mailbox.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,13 @@ use crate::shape::PyShape;
5353
module = "monarch._rust_bindings.monarch_hyperactor.mailbox"
5454
)]
5555
pub struct PyMailbox {
56-
pub inner: Mailbox,
56+
pub(super) inner: Mailbox,
57+
}
58+
59+
impl PyMailbox {
60+
pub fn get_inner(&self) -> &Mailbox {
61+
&self.inner
62+
}
5763
}
5864

5965
#[pymethods]

monarch_rdma/extension/lib.rs

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@ use pyo3::types::PyType;
2929
use serde::Deserialize;
3030
use serde::Serialize;
3131

32-
macro_rules! setup_rdma_context {
33-
($self:ident, $local_proc_id:expr) => {{
34-
let proc_id: ProcId = $local_proc_id.parse().unwrap();
35-
let local_owner_id = ActorId(proc_id, "rdma_manager".to_string(), 0);
36-
let local_owner_ref: ActorRef<RdmaManagerActor> = ActorRef::attest(local_owner_id);
37-
let buffer = $self.buffer.clone();
38-
(local_owner_ref, buffer)
39-
}};
32+
fn setup_rdma_context(
33+
rdma_buffer: &PyRdmaBuffer,
34+
local_proc_id: String,
35+
) -> (ActorRef<RdmaManagerActor>, RdmaBuffer) {
36+
let proc_id: ProcId = local_proc_id.parse().unwrap();
37+
let local_owner_id = ActorId(proc_id, "rdma_manager".to_string(), 0);
38+
let local_owner_ref: ActorRef<RdmaManagerActor> = ActorRef::attest(local_owner_id);
39+
let buffer = rdma_buffer.buffer.clone();
40+
(local_owner_ref, buffer)
4041
}
4142

4243
#[pyclass(name = "_RdmaBuffer", module = "monarch._rust_bindings.rdma")]
@@ -49,16 +50,16 @@ struct PyRdmaBuffer {
4950
async fn create_rdma_buffer(
5051
addr: usize,
5152
size: usize,
52-
proc_id: String,
53+
proc_id: ProcId,
5354
client: PyMailbox,
5455
) -> PyResult<PyRdmaBuffer> {
5556
// Get the owning RdmaManagerActor's ActorRef
56-
let proc_id: ProcId = proc_id.parse().unwrap();
5757
let owner_id = ActorId(proc_id, "rdma_manager".to_string(), 0);
5858
let owner_ref: ActorRef<RdmaManagerActor> = ActorRef::attest(owner_id);
5959

60+
let caps = client.get_inner();
6061
// Create the RdmaBuffer
61-
let buffer = owner_ref.request_buffer(&client.inner, addr, size).await?;
62+
let buffer = owner_ref.request_buffer(caps, addr, size).await?;
6263
Ok(PyRdmaBuffer { buffer, owner_ref })
6364
}
6465

@@ -78,7 +79,10 @@ impl PyRdmaBuffer {
7879
"ibverbs is not supported on this system",
7980
));
8081
}
81-
signal_safe_block_on(py, create_rdma_buffer(addr, size, proc_id, client))?
82+
signal_safe_block_on(
83+
py,
84+
create_rdma_buffer(addr, size, proc_id.parse().unwrap(), client),
85+
)?
8286
}
8387

8488
#[classmethod]
@@ -97,7 +101,7 @@ impl PyRdmaBuffer {
97101
}
98102
pyo3_async_runtimes::tokio::future_into_py(
99103
py,
100-
create_rdma_buffer(addr, size, proc_id, client),
104+
create_rdma_buffer(addr, size, proc_id.parse().unwrap(), client),
101105
)
102106
}
103107

@@ -133,13 +137,12 @@ impl PyRdmaBuffer {
133137
client: PyMailbox,
134138
timeout: u64,
135139
) -> PyResult<Bound<'py, PyAny>> {
136-
let (local_owner_ref, buffer) = setup_rdma_context!(self, local_proc_id);
140+
let (local_owner_ref, buffer) = setup_rdma_context(self, local_proc_id);
137141
pyo3_async_runtimes::tokio::future_into_py(py, async move {
138-
let local_buffer = local_owner_ref
139-
.request_buffer(&client.inner, addr, size)
140-
.await?;
142+
let caps = client.get_inner();
143+
let local_buffer = local_owner_ref.request_buffer(caps, addr, size).await?;
141144
let _result_ = local_buffer
142-
.write_from(&client.inner, buffer, timeout)
145+
.write_from(caps, buffer, timeout)
143146
.await
144147
.map_err(|e| PyException::new_err(format!("failed to read into buffer: {}", e)))?;
145148
Ok(())
@@ -170,13 +173,12 @@ impl PyRdmaBuffer {
170173
client: PyMailbox,
171174
timeout: u64,
172175
) -> PyResult<bool> {
173-
let (local_owner_ref, buffer) = setup_rdma_context!(self, local_proc_id);
176+
let (local_owner_ref, buffer) = setup_rdma_context(self, local_proc_id);
174177
signal_safe_block_on(py, async move {
175-
let local_buffer = local_owner_ref
176-
.request_buffer(&client.inner, addr, size)
177-
.await?;
178+
let caps = client.get_inner();
179+
let local_buffer = local_owner_ref.request_buffer(caps, addr, size).await?;
178180
local_buffer
179-
.write_from(&client.inner, buffer, timeout)
181+
.write_from(caps, buffer, timeout)
180182
.await
181183
.map_err(|e| PyException::new_err(format!("failed to read into buffer: {}", e)))
182184
})?
@@ -204,13 +206,12 @@ impl PyRdmaBuffer {
204206
client: PyMailbox,
205207
timeout: u64,
206208
) -> PyResult<Bound<'py, PyAny>> {
207-
let (local_owner_ref, buffer) = setup_rdma_context!(self, local_proc_id);
209+
let (local_owner_ref, buffer) = setup_rdma_context(self, local_proc_id);
208210
pyo3_async_runtimes::tokio::future_into_py(py, async move {
209-
let local_buffer = local_owner_ref
210-
.request_buffer(&client.inner, addr, size)
211-
.await?;
211+
let caps = client.get_inner();
212+
let local_buffer = local_owner_ref.request_buffer(caps, addr, size).await?;
212213
let _result_ = local_buffer
213-
.read_into(&client.inner, buffer, timeout)
214+
.read_into(caps, buffer, timeout)
214215
.await
215216
.map_err(|e| PyException::new_err(format!("failed to write from buffer: {}", e)))?;
216217
Ok(())
@@ -241,13 +242,12 @@ impl PyRdmaBuffer {
241242
client: PyMailbox,
242243
timeout: u64,
243244
) -> PyResult<bool> {
244-
let (local_owner_ref, buffer) = setup_rdma_context!(self, local_proc_id);
245+
let (local_owner_ref, buffer) = setup_rdma_context(self, local_proc_id);
245246
signal_safe_block_on(py, async move {
246-
let local_buffer = local_owner_ref
247-
.request_buffer(&client.inner, addr, size)
248-
.await?;
247+
let caps = client.get_inner();
248+
let local_buffer = local_owner_ref.request_buffer(caps, addr, size).await?;
249249
local_buffer
250-
.read_into(&client.inner, buffer, timeout)
250+
.read_into(caps, buffer, timeout)
251251
.await
252252
.map_err(|e| PyException::new_err(format!("failed to write from buffer: {}", e)))
253253
})?

0 commit comments

Comments
 (0)