diff --git a/monarch_rdma/extension/Cargo.toml b/monarch_rdma/extension/Cargo.toml index cbb88be3..b43d117f 100644 --- a/monarch_rdma/extension/Cargo.toml +++ b/monarch_rdma/extension/Cargo.toml @@ -14,9 +14,11 @@ doctest = false [dependencies] hyperactor = { version = "0.0.0", path = "../../hyperactor" } +hyperactor_mesh = { version = "0.0.0", path = "../../hyperactor_mesh" } monarch_hyperactor = { version = "0.0.0", path = "../../monarch_hyperactor" } monarch_rdma = { version = "0.0.0", path = ".." } pyo3 = { version = "0.24", features = ["anyhow", "multiple-pymethods"] } pyo3-async-runtimes = { version = "0.24", features = ["attributes", "tokio-runtime"] } serde = { version = "1.0.185", features = ["derive", "rc"] } serde_json = { version = "1.0.140", features = ["alloc", "float_roundtrip", "unbounded_depth"] } +tracing = { version = "0.1.41", features = ["attributes", "valuable"] } diff --git a/monarch_rdma/extension/lib.rs b/monarch_rdma/extension/lib.rs index 70ffeeb6..fe2e35a4 100644 --- a/monarch_rdma/extension/lib.rs +++ b/monarch_rdma/extension/lib.rs @@ -11,13 +11,16 @@ use hyperactor::ActorId; use hyperactor::ActorRef; use hyperactor::Named; use hyperactor::ProcId; +use hyperactor_mesh::RootActorMesh; +use hyperactor_mesh::shared_cell::SharedCell; use monarch_hyperactor::mailbox::PyMailbox; +use monarch_hyperactor::proc_mesh::PyProcMesh; use monarch_hyperactor::runtime::signal_safe_block_on; +use monarch_rdma::IbverbsConfig; use monarch_rdma::RdmaBuffer; use monarch_rdma::RdmaManagerActor; use monarch_rdma::RdmaManagerMessageClient; use monarch_rdma::ibverbs_supported; -use pyo3::BoundObject; use pyo3::exceptions::PyException; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; @@ -284,7 +287,93 @@ impl PyRdmaBuffer { } } +#[pyclass(name = "_RdmaManager", module = "monarch._rust_bindings.rdma")] +pub struct PyRdmaManager { + inner: SharedCell>, + device: String, +} + +#[pymethods] +impl PyRdmaManager { + #[pyo3(name = "__repr__")] + fn repr(&self) -> String { + format!("", self.device) + } + + #[getter] + fn device(&self) -> &str { + &self.device + } +} + +/// Creates an RDMA manager actor on the given ProcMesh. +/// Returns the actor mesh if RDMA is supported, None otherwise. +#[pyfunction] +fn create_rdma_manager_blocking<'py>( + py: Python<'py>, + proc_mesh: &PyProcMesh, +) -> PyResult> { + if !ibverbs_supported() { + tracing::info!("rdma is not enabled on this hardware"); + return Ok(None); + } + + // TODO - make this configurable + let config = IbverbsConfig::default(); + tracing::debug!("rdma is enabled, using device {}", config.device); + + let tracked_proc_mesh = proc_mesh.try_inner()?; + let device = config.device.to_string(); + + let actor_mesh = signal_safe_block_on(py, async move { + tracked_proc_mesh + .spawn("rdma_manager", &config) + .await + .map_err(|err| PyException::new_err(err.to_string())) + })??; + + Ok(Some(PyRdmaManager { + inner: actor_mesh, + device, + })) +} + +/// Creates an RDMA manager actor on the given ProcMesh (async version). +/// Returns the actor mesh if RDMA is supported, None otherwise. +#[pyfunction] +fn create_rdma_manager_nonblocking<'py>( + py: Python<'py>, + proc_mesh: &PyProcMesh, +) -> PyResult> { + if !ibverbs_supported() { + tracing::info!("rdma is not enabled on this hardware"); + return Ok(py.None().into_bound(py)); + } + + // TODO - make this configurable + let config = IbverbsConfig::default(); + tracing::debug!("rdma is enabled, using device {}", config.device); + + let tracked_proc_mesh = proc_mesh.try_inner()?; + let device = config.device.to_string(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let actor_mesh = tracked_proc_mesh + .spawn::("rdma_manager", &config) + .await + .map_err(|err| PyException::new_err(err.to_string()))?; + + Ok(Some(PyRdmaManager { + inner: actor_mesh, + device, + })) + }) +} + pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> { module.add_class::()?; + module.add_class::()?; + module.add_function(wrap_pyfunction!(create_rdma_manager_blocking, module)?)?; + module.add_function(wrap_pyfunction!(create_rdma_manager_nonblocking, module)?)?; Ok(()) } diff --git a/python/monarch/_rust_bindings/rdma/__init__.pyi b/python/monarch/_rust_bindings/rdma/__init__.pyi index e5d0f92b..45c53fc6 100644 --- a/python/monarch/_rust_bindings/rdma/__init__.pyi +++ b/python/monarch/_rust_bindings/rdma/__init__.pyi @@ -10,6 +10,13 @@ from typing import Any, final, Optional class _RdmaMemoryRegionView: def __init__(self, addr: int, size_in_bytes: int) -> None: ... +@final +class _RdmaManager: + device: str + def __repr__(self) -> str: ... + +def create_rdma_manager_blocking(proc_mesh: Any) -> Optional[_RdmaManager]: ... +async def create_rdma_manager_nonblocking(proc_mesh: Any) -> Optional[_RdmaManager]: ... @final class _RdmaBuffer: name: str