Skip to content

Expose RDMA support through Python APIs #462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions monarch_hyperactor/src/actor_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ use crate::shape::PyShape;
module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
)]
pub struct PythonActorMesh {
pub(super) inner: SharedCell<RootActorMesh<'static, PythonActor>>,
pub(super) client: PyMailbox,
pub inner: SharedCell<RootActorMesh<'static, PythonActor>>,
pub client: PyMailbox,
pub(super) _keepalive: Keepalive,
}

Expand Down Expand Up @@ -62,7 +62,7 @@ impl PythonActorMesh {
}

#[getter]
fn client(&self) -> PyMailbox {
pub fn client(&self) -> PyMailbox {
self.client.clone()
}

Expand Down
4 changes: 2 additions & 2 deletions monarch_hyperactor/src/mailbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ use crate::shape::PyShape;
name = "Mailbox",
module = "monarch._rust_bindings.monarch_hyperactor.mailbox"
)]
pub(super) struct PyMailbox {
pub(super) inner: Mailbox,
pub struct PyMailbox {
pub inner: Mailbox,
}

#[pymethods]
Expand Down
22 changes: 22 additions & 0 deletions monarch_rdma/extension/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# @generated by autocargo from //monarch/monarch_rdma/extension:monarch_rdma_extension

[package]
name = "monarch_rdma_extension"
version = "0.0.0"
authors = ["Meta"]
edition = "2021"
license = "BSD-3-Clause"

[lib]
path = "lib.rs"
test = false
doctest = false

[dependencies]
hyperactor = { version = "0.0.0", path = "../../hyperactor" }
monarch_hyperactor = { version = "0.0.0", path = "../../monarch_hyperactor" }
monarch_rdma = { version = "0.0.0", path = ".." }
pyo3 = { version = "0.24", features = ["anyhow", "multiple-pymethods"] }
pyo3-async-runtimes = { version = "0.24", features = ["attributes", "tokio-runtime"] }
serde = { version = "1.0.185", features = ["derive", "rc"] }
serde_json = { version = "1.0.140", features = ["alloc", "float_roundtrip", "unbounded_depth"] }
290 changes: 290 additions & 0 deletions monarch_rdma/extension/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#![allow(unsafe_op_in_unsafe_fn)]
use hyperactor::ActorId;
use hyperactor::ActorRef;
use hyperactor::Named;
use hyperactor::ProcId;
use monarch_hyperactor::mailbox::PyMailbox;
use monarch_hyperactor::runtime::signal_safe_block_on;
use monarch_rdma::RdmaBuffer;
use monarch_rdma::RdmaManagerActor;
use monarch_rdma::RdmaManagerMessageClient;
use monarch_rdma::ibverbs_supported;
use pyo3::BoundObject;
use pyo3::exceptions::PyException;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyTuple;
use pyo3::types::PyType;
use serde::Deserialize;
use serde::Serialize;

macro_rules! setup_rdma_context {
($self:ident, $local_proc_id:expr) => {{
let proc_id: ProcId = $local_proc_id.parse().unwrap();
let local_owner_id = ActorId(proc_id, "rdma_manager".to_string(), 0);
let local_owner_ref: ActorRef<RdmaManagerActor> = ActorRef::attest(local_owner_id);
let buffer = $self.buffer.clone();
(local_owner_ref, buffer)
}};
}

#[pyclass(name = "_RdmaBuffer", module = "monarch._rust_bindings.rdma")]
#[derive(Clone, Serialize, Deserialize, Named)]
struct PyRdmaBuffer {
buffer: RdmaBuffer,
owner_ref: ActorRef<RdmaManagerActor>,
}

async fn create_rdma_buffer(
addr: usize,
size: usize,
proc_id: String,
client: PyMailbox,
) -> PyResult<PyRdmaBuffer> {
// Get the owning RdmaManagerActor's ActorRef
let proc_id: ProcId = proc_id.parse().unwrap();
let owner_id = ActorId(proc_id, "rdma_manager".to_string(), 0);
let owner_ref: ActorRef<RdmaManagerActor> = ActorRef::attest(owner_id);

// Create the RdmaBuffer
let buffer = owner_ref.request_buffer(&client.inner, addr, size).await?;
Ok(PyRdmaBuffer { buffer, owner_ref })
}

#[pymethods]
impl PyRdmaBuffer {
#[classmethod]
fn create_rdma_buffer_blocking<'py>(
_cls: &Bound<'_, PyType>,
py: Python<'py>,
addr: usize,
size: usize,
proc_id: String,
client: PyMailbox,
) -> PyResult<PyRdmaBuffer> {
if !ibverbs_supported() {
return Err(PyException::new_err(
"ibverbs is not supported on this system",
));
}
signal_safe_block_on(py, create_rdma_buffer(addr, size, proc_id, client))?
}

#[classmethod]
fn create_rdma_buffer_nonblocking<'py>(
_cls: &Bound<'_, PyType>,
py: Python<'py>,
addr: usize,
size: usize,
proc_id: String,
client: PyMailbox,
) -> PyResult<Bound<'py, PyAny>> {
if !ibverbs_supported() {
return Err(PyException::new_err(
"ibverbs is not supported on this system",
));
}
pyo3_async_runtimes::tokio::future_into_py(
py,
create_rdma_buffer(addr, size, proc_id, client),
)
}

#[classmethod]
fn rdma_supported<'py>(_cls: &Bound<'_, PyType>, _py: Python<'py>) -> bool {
ibverbs_supported()
}

#[pyo3(name = "__repr__")]
fn repr(&self) -> String {
format!("<RdmaBuffer'{:?}'>", self.buffer)
}

/// Reads data from the local buffer and places it into this remote RDMA buffer.
///
/// This operation appears as "read_into" from the caller's perspective (reading from local memory
/// into the remote buffer), but internally it's implemented as a "write_from" operation on the
/// local buffer since the data flows from the local buffer to the remote one.
///
/// # Arguments
/// * `addr` - The address of the local buffer to read from
/// * `size` - The size of the data to transfer
/// * `local_proc_id` - The process ID where the local buffer resides
/// * `client` - The mailbox for communication
/// * `timeout` - Maximum time in milliseconds to wait for the operation
#[pyo3(signature = (addr, size, local_proc_id, client, timeout))]
fn read_into<'py>(
&self,
py: Python<'py>,
addr: usize,
size: usize,
local_proc_id: String,
client: PyMailbox,
timeout: u64,
) -> PyResult<Bound<'py, PyAny>> {
let (local_owner_ref, buffer) = setup_rdma_context!(self, local_proc_id);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let local_buffer = local_owner_ref
.request_buffer(&client.inner, addr, size)
.await?;
let _result_ = local_buffer
.write_from(&client.inner, buffer, timeout)
.await
.map_err(|e| PyException::new_err(format!("failed to read into buffer: {}", e)))?;
Ok(())
})
}

/// Reads data from the local buffer and places it into this remote RDMA buffer.
///
/// This operation appears as "read_into" from the caller's perspective (reading from local memory
/// into the remote buffer), but internally it's implemented as a "write_from" operation on the
/// local buffer since the data flows from the local buffer to the remote one.
///
/// This is the blocking version of `read_into`, compatible with non asyncio Python code.
///
/// # Arguments
/// * `addr` - The address of the local buffer to read from
/// * `size` - The size of the data to transfer
/// * `local_proc_id` - The process ID where the local buffer resides
/// * `client` - The mailbox for communication
/// * `timeout` - Maximum time in milliseconds to wait for the operation
#[pyo3(signature = (addr, size, local_proc_id, client, timeout))]
fn read_into_blocking<'py>(
&self,
py: Python<'py>,
addr: usize,
size: usize,
local_proc_id: String,
client: PyMailbox,
timeout: u64,
) -> PyResult<bool> {
let (local_owner_ref, buffer) = setup_rdma_context!(self, local_proc_id);
signal_safe_block_on(py, async move {
let local_buffer = local_owner_ref
.request_buffer(&client.inner, addr, size)
.await?;
local_buffer
.write_from(&client.inner, buffer, timeout)
.await
.map_err(|e| PyException::new_err(format!("failed to read into buffer: {}", e)))
})?
}

/// Writes data from this remote RDMA buffer into a local buffer.
///
/// This operation appears as "write_from" from the caller's perspective (writing from the remote
/// buffer into local memory), but internally it's implemented as a "read_into" operation on the
/// local buffer since the data flows from the remote buffer to the local one.
///
/// # Arguments
/// * `addr` - The address of the local buffer to write to
/// * `size` - The size of the data to transfer
/// * `local_proc_id` - The process ID where the local buffer resides
/// * `client` - The mailbox for communication
/// * `timeout` - Maximum time in milliseconds to wait for the operation
#[pyo3(signature = (addr, size, local_proc_id, client, timeout))]
fn write_from<'py>(
&self,
py: Python<'py>,
addr: usize,
size: usize,
local_proc_id: String,
client: PyMailbox,
timeout: u64,
) -> PyResult<Bound<'py, PyAny>> {
let (local_owner_ref, buffer) = setup_rdma_context!(self, local_proc_id);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let local_buffer = local_owner_ref
.request_buffer(&client.inner, addr, size)
.await?;
let _result_ = local_buffer
.read_into(&client.inner, buffer, timeout)
.await
.map_err(|e| PyException::new_err(format!("failed to write from buffer: {}", e)))?;
Ok(())
})
}

/// Writes data from this remote RDMA buffer into a local buffer.
///
/// This operation appears as "write_from" from the caller's perspective (writing from the remote
/// buffer into local memory), but internally it's implemented as a "read_into" operation on the
/// local buffer since the data flows from the remote buffer to the local one.
///
/// This is the blocking version of `write_from`, compatible with non asyncio Python code.
///
/// # Arguments
/// * `addr` - The address of the local buffer to write to
/// * `size` - The size of the data to transfer
/// * `local_proc_id` - The process ID where the local buffer resides
/// * `client` - The mailbox for communication
/// * `timeout` - Maximum time in milliseconds to wait for the operation
#[pyo3(signature = (addr, size, local_proc_id, client, timeout))]
fn write_from_blocking<'py>(
&self,
py: Python<'py>,
addr: usize,
size: usize,
local_proc_id: String,
client: PyMailbox,
timeout: u64,
) -> PyResult<bool> {
let (local_owner_ref, buffer) = setup_rdma_context!(self, local_proc_id);
signal_safe_block_on(py, async move {
let local_buffer = local_owner_ref
.request_buffer(&client.inner, addr, size)
.await?;
local_buffer
.read_into(&client.inner, buffer, timeout)
.await
.map_err(|e| PyException::new_err(format!("failed to write from buffer: {}", e)))
})?
}

fn __reduce__(&self) -> PyResult<(PyObject, PyObject)> {
Python::with_gil(|py| {
let ctor = py.get_type::<PyRdmaBuffer>().to_object(py);
let json = serde_json::to_string(self).map_err(|e| {
PyErr::new::<PyValueError, _>(format!("Serialization failed: {}", e))
})?;

let args = PyTuple::new_bound(py, [json]).into_py(py);
Ok((ctor, args))
})
}

#[new]
fn new_from_json(json: &str) -> PyResult<Self> {
let deserialized: PyRdmaBuffer = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyValueError, _>(format!("Deserialization failed: {}", e)))?;
Ok(deserialized)
}

fn drop<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
// no op with CPUs, currently a stub.
// TODO - replace with correct GPU behavior.
pyo3_async_runtimes::tokio::future_into_py(py, async move { Ok(()) })
}

fn drop_blocking<'py>(&self, py: Python<'py>) -> PyResult<()> {
signal_safe_block_on(py, async move {
// no op with CPUs, currently a stub.
// TODO - replace with correct GPU behavior.
Ok(())
})?
}
}

pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
module.add_class::<PyRdmaBuffer>()?;
Ok(())
}
3 changes: 1 addition & 2 deletions monarch_rdma/src/ibverbs_primitives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,8 @@ pub fn ibverbs_supported() -> bool {
let device_list = rdmacore_sys::ibv_get_device_list(&mut num_devices);
if !device_list.is_null() {
rdmacore_sys::ibv_free_device_list(device_list);
return true;
}
false
num_devices > 0
}
}

Expand Down
6 changes: 3 additions & 3 deletions monarch_rdma/src/rdma_components.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl RdmaBuffer {
/// # Returns
/// `Ok(bool)` indicating if the operation completed successfully.
pub async fn read_into(
&mut self,
&self,
client: &Mailbox,
remote: RdmaBuffer,
timeout: u64,
Expand Down Expand Up @@ -124,7 +124,7 @@ impl RdmaBuffer {
/// # Returns
/// `Ok(bool)` indicating if the operation completed successfully.
pub async fn write_from(
&mut self,
&self,
client: &Mailbox,
remote: RdmaBuffer,
timeout: u64,
Expand Down Expand Up @@ -477,7 +477,7 @@ impl RdmaQueuePair {
pd: *mut rdmacore_sys::ibv_pd,
config: IbverbsConfig,
) -> Result<Self, anyhow::Error> {
tracing::info!("creating an RdmaQueuePair from config {}", config);
tracing::debug!("creating an RdmaQueuePair from config {}", config);
// SAFETY:
// This code uses unsafe rdmacore_sys calls to interact with the RDMA device, but is safe because:
// - All pointers are properly initialized and checked for null before use
Expand Down
Loading