Skip to content

Commit 5cde139

Browse files
dstaay-fbfacebook-github-bot
authored andcommitted
basic CUDA <> CPU or CUDA <> CUDA rdma Support (#372)
Summary: Pull Request resolved: #372 RDMA support for CUDA <> CUDA and CUDA <> CPU comms Key changes - using cuda apis we can detect if a given pointer is mapped to a cuda device, or cpu. - if data pointer is cuda, the code leverages dma registration to register with NIC; we are able to avoid directly passing with cuda allocation handles using cuMemGetHandleForAddressRange. - if data pointer is cpu, we use standard ibv mr; note I transitioned to using standard registration, not entire memory space (security concern raised by mariusae) - Refactored test infra to support named NIC devices, and different compute (cuda:X or cpu) This implementation is relatively naive, and I will iterate accordingly. To Do: add unit test for cuda/cuda Reviewed By: allenwang28 Differential Revision: D77408653 fbshipit-source-id: c6118516b109a184f2fd63196bcdd5fc4e5667cc
1 parent 5a2d0ee commit 5cde139

File tree

9 files changed

+853
-347
lines changed

9 files changed

+853
-347
lines changed

cuda-sys/src/lib.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/

monarch_rdma/Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ license = "BSD-3-Clause"
1111
anyhow = "1.0.98"
1212
async-trait = "0.1.86"
1313
hyperactor = { version = "0.0.0", path = "../hyperactor" }
14-
ibverbs = "0.7.1"
1514
rand = { version = "0.8", features = ["small_rng"] }
1615
serde = { version = "1.0.185", features = ["derive", "rc"] }
1716
tracing = { version = "0.1.41", features = ["attributes", "valuable"] }
@@ -21,3 +20,7 @@ hyperactor_mesh = { version = "0.0.0", path = "../hyperactor_mesh" }
2120
ndslice = { version = "0.0.0", path = "../ndslice" }
2221
timed_test = { version = "0.0.0", path = "../timed_test" }
2322
tokio = { version = "1.45.0", features = ["full", "test-util", "tracing"] }
23+
24+
[features]
25+
cuda = []
26+
default = ["cuda"]

monarch_rdma/examples/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ path = "parameter_server.rs"
1313
[[bin]]
1414
name = "parameter_server_bootstrap"
1515
path = "bootstrap.rs"
16+
test = false
1617

1718
[[bin]]
1819
name = "parameter_server_example"
1920
path = "main.rs"
21+
test = false
2022

2123
[dependencies]
2224
anyhow = "1.0.98"

monarch_rdma/src/ibverbs_primitives.rs

Lines changed: 102 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
/*
2-
* Copyright (c) Meta Platforms, Inc. and affiliates.
2+
* Portions Copyright (c) Meta Platforms, Inc. and affiliates.
33
* All rights reserved.
44
*
55
* This source code is licensed under the BSD-style license found in the
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
/*
10+
* Sections of code adapted from
11+
* Copyright (c) 2016 Jon Gjengset under MIT License (MIT)
12+
*/
13+
914
//! This file contains primitive data structures for interacting with ibverbs.
1015
//!
1116
//! Primitives:
@@ -25,10 +30,62 @@ use std::ffi::CStr;
2530
use std::fmt;
2631

2732
use hyperactor::Named;
28-
use ibverbs::Gid;
2933
use serde::Deserialize;
3034
use serde::Serialize;
3135

36+
#[derive(
37+
Default,
38+
Copy,
39+
Clone,
40+
Debug,
41+
Eq,
42+
PartialEq,
43+
Hash,
44+
serde::Serialize,
45+
serde::Deserialize
46+
)]
47+
#[repr(transparent)]
48+
pub struct Gid {
49+
raw: [u8; 16],
50+
}
51+
52+
impl Gid {
53+
#[allow(dead_code)]
54+
fn subnet_prefix(&self) -> u64 {
55+
u64::from_be_bytes(self.raw[..8].try_into().unwrap())
56+
}
57+
58+
#[allow(dead_code)]
59+
fn interface_id(&self) -> u64 {
60+
u64::from_be_bytes(self.raw[8..].try_into().unwrap())
61+
}
62+
}
63+
impl From<rdmacore_sys::ibv_gid> for Gid {
64+
fn from(gid: rdmacore_sys::ibv_gid) -> Self {
65+
Self {
66+
raw: unsafe { gid.raw },
67+
}
68+
}
69+
}
70+
71+
impl From<Gid> for rdmacore_sys::ibv_gid {
72+
fn from(mut gid: Gid) -> Self {
73+
*gid.as_mut()
74+
}
75+
}
76+
77+
impl AsRef<rdmacore_sys::ibv_gid> for Gid {
78+
fn as_ref(&self) -> &rdmacore_sys::ibv_gid {
79+
unsafe { &*self.raw.as_ptr().cast::<rdmacore_sys::ibv_gid>() }
80+
}
81+
}
82+
83+
impl AsMut<rdmacore_sys::ibv_gid> for Gid {
84+
fn as_mut(&mut self) -> &mut rdmacore_sys::ibv_gid {
85+
unsafe { &mut *self.raw.as_mut_ptr().cast::<rdmacore_sys::ibv_gid>() }
86+
}
87+
}
88+
3289
/// Represents ibverbs specific configurations.
3390
///
3491
/// This struct holds various parameters required to establish and manage an RDMA connection.
@@ -86,7 +143,7 @@ impl Default for IbverbsConfig {
86143
max_recv_wr: 1,
87144
max_send_sge: 1,
88145
max_recv_sge: 1,
89-
path_mtu: ffi::IBV_MTU_1024,
146+
path_mtu: rdmacore_sys::IBV_MTU_1024,
90147
retry_cnt: 7,
91148
rnr_retry: 7,
92149
qp_timeout: 14, // 4.096 μs * 2^14 = ~67 ms
@@ -144,7 +201,7 @@ impl std::fmt::Display for IbverbsConfig {
144201
#[derive(Debug, Clone, Serialize, Deserialize)]
145202
pub struct RdmaDevice {
146203
/// `name` - The name of the RDMA device (e.g., "mlx5_0").
147-
name: String,
204+
pub name: String,
148205
/// `vendor_id` - The vendor ID of the device.
149206
vendor_id: u32,
150207
/// `vendor_part_id` - The vendor part ID of the device.
@@ -330,10 +387,10 @@ impl fmt::Display for RdmaPort {
330387
/// # Returns
331388
///
332389
/// A string representation of the port state.
333-
pub fn get_port_state_str(state: ffi::ibv_port_state::Type) -> String {
390+
pub fn get_port_state_str(state: rdmacore_sys::ibv_port_state::Type) -> String {
334391
// SAFETY: We are calling a C function that returns a C string.
335392
unsafe {
336-
let c_str = ffi::ibv_port_state_str(state);
393+
let c_str = rdmacore_sys::ibv_port_state_str(state);
337394
if c_str.is_null() {
338395
return "Unknown".to_string();
339396
}
@@ -428,7 +485,7 @@ pub fn get_all_devices() -> Vec<RdmaDevice> {
428485
// SAFETY: We are calling several C functions from libibverbs.
429486
unsafe {
430487
let mut num_devices = 0;
431-
let device_list = ffi::ibv_get_device_list(&mut num_devices);
488+
let device_list = rdmacore_sys::ibv_get_device_list(&mut num_devices);
432489
if device_list.is_null() || num_devices == 0 {
433490
return devices;
434491
}
@@ -439,18 +496,18 @@ pub fn get_all_devices() -> Vec<RdmaDevice> {
439496
continue;
440497
}
441498

442-
let context = ffi::ibv_open_device(device);
499+
let context = rdmacore_sys::ibv_open_device(device);
443500
if context.is_null() {
444501
continue;
445502
}
446503

447-
let device_name = CStr::from_ptr(ffi::ibv_get_device_name(device))
504+
let device_name = CStr::from_ptr(rdmacore_sys::ibv_get_device_name(device))
448505
.to_string_lossy()
449506
.into_owned();
450507

451-
let mut device_attr = ffi::ibv_device_attr::default();
452-
if ffi::ibv_query_device(context, &mut device_attr) != 0 {
453-
ffi::ibv_close_device(context);
508+
let mut device_attr = rdmacore_sys::ibv_device_attr::default();
509+
if rdmacore_sys::ibv_query_device(context, &mut device_attr) != 0 {
510+
rdmacore_sys::ibv_close_device(context);
454511
continue;
455512
}
456513

@@ -475,11 +532,11 @@ pub fn get_all_devices() -> Vec<RdmaDevice> {
475532
};
476533

477534
for port_num in 1..=device_attr.phys_port_cnt {
478-
let mut port_attr = ffi::ibv_port_attr::default();
479-
if ffi::ibv_query_port(
535+
let mut port_attr = rdmacore_sys::ibv_port_attr::default();
536+
if rdmacore_sys::ibv_query_port(
480537
context,
481538
port_num,
482-
&mut port_attr as *mut ffi::ibv_port_attr as *mut _,
539+
&mut port_attr as *mut rdmacore_sys::ibv_port_attr as *mut _,
483540
) != 0
484541
{
485542
continue;
@@ -489,8 +546,8 @@ pub fn get_all_devices() -> Vec<RdmaDevice> {
489546

490547
let link_layer = get_link_layer_str(port_attr.link_layer);
491548

492-
let mut gid = ffi::ibv_gid::default();
493-
let gid_str = if ffi::ibv_query_gid(context, port_num, 0, &mut gid) == 0 {
549+
let mut gid = rdmacore_sys::ibv_gid::default();
550+
let gid_str = if rdmacore_sys::ibv_query_gid(context, port_num, 0, &mut gid) == 0 {
494551
format_gid(&gid.raw)
495552
} else {
496553
"N/A".to_string()
@@ -513,10 +570,10 @@ pub fn get_all_devices() -> Vec<RdmaDevice> {
513570
}
514571

515572
devices.push(rdma_device);
516-
ffi::ibv_close_device(context);
573+
rdmacore_sys::ibv_close_device(context);
517574
}
518575

519-
ffi::ibv_free_device_list(device_list);
576+
rdmacore_sys::ibv_free_device_list(device_list);
520577
}
521578

522579
devices
@@ -535,9 +592,9 @@ pub fn ibverbs_supported() -> bool {
535592
// SAFETY: We are calling a C function from libibverbs.
536593
unsafe {
537594
let mut num_devices = 0;
538-
let device_list = ffi::ibv_get_device_list(&mut num_devices);
595+
let device_list = rdmacore_sys::ibv_get_device_list(&mut num_devices);
539596
if !device_list.is_null() {
540-
ffi::ibv_free_device_list(device_list);
597+
rdmacore_sys::ibv_free_device_list(device_list);
541598
return true;
542599
}
543600
false
@@ -557,6 +614,7 @@ pub fn ibverbs_supported() -> bool {
557614
/// RDMA operations are in progress.
558615
#[derive(Debug, PartialEq, Eq, std::hash::Hash, Serialize, Deserialize, Clone)]
559616
pub struct RdmaMemoryRegionView {
617+
pub id: u32,
560618
pub addr: usize,
561619
pub size: usize,
562620
pub lkey: u32,
@@ -582,8 +640,9 @@ unsafe impl Sync for RdmaMemoryRegionView {}
582640

583641
impl RdmaMemoryRegionView {
584642
/// Creates a new `RdmaMemoryRegionView` with the given address and size.
585-
pub fn new(addr: usize, size: usize, lkey: u32, rkey: u32) -> Self {
643+
pub fn new(id: u32, addr: usize, size: usize, lkey: u32, rkey: u32) -> Self {
586644
Self {
645+
id,
587646
addr,
588647
size,
589648
lkey,
@@ -612,20 +671,20 @@ pub enum RdmaOperation {
612671
Read,
613672
}
614673

615-
impl From<RdmaOperation> for ffi::ibv_wr_opcode::Type {
674+
impl From<RdmaOperation> for rdmacore_sys::ibv_wr_opcode::Type {
616675
fn from(op: RdmaOperation) -> Self {
617676
match op {
618-
RdmaOperation::Write => ffi::ibv_wr_opcode::IBV_WR_RDMA_WRITE,
619-
RdmaOperation::Read => ffi::ibv_wr_opcode::IBV_WR_RDMA_READ,
677+
RdmaOperation::Write => rdmacore_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE,
678+
RdmaOperation::Read => rdmacore_sys::ibv_wr_opcode::IBV_WR_RDMA_READ,
620679
}
621680
}
622681
}
623682

624-
impl From<ffi::ibv_wc_opcode::Type> for RdmaOperation {
625-
fn from(op: ffi::ibv_wc_opcode::Type) -> Self {
683+
impl From<rdmacore_sys::ibv_wc_opcode::Type> for RdmaOperation {
684+
fn from(op: rdmacore_sys::ibv_wc_opcode::Type) -> Self {
626685
match op {
627-
ffi::ibv_wc_opcode::IBV_WC_RDMA_WRITE => RdmaOperation::Write,
628-
ffi::ibv_wc_opcode::IBV_WC_RDMA_READ => RdmaOperation::Read,
686+
rdmacore_sys::ibv_wc_opcode::IBV_WC_RDMA_WRITE => RdmaOperation::Write,
687+
rdmacore_sys::ibv_wc_opcode::IBV_WC_RDMA_READ => RdmaOperation::Read,
629688
_ => panic!("Unsupported operation type"),
630689
}
631690
}
@@ -660,7 +719,7 @@ impl std::fmt::Debug for RdmaQpInfo {
660719

661720
/// Wrapper around ibv_wc (ibverbs work completion).
662721
///
663-
/// This exposes only the public fields of ffi::ibv_wc, allowing us to more easily
722+
/// This exposes only the public fields of rdmacore_sys::ibv_wc, allowing us to more easily
664723
/// interact with it from Rust. Work completions are used to track the status of
665724
/// RDMA operations and are generated when an operation completes.
666725
#[derive(Debug, Named, Clone, serde::Serialize, serde::Deserialize)]
@@ -672,9 +731,9 @@ pub struct IbvWc {
672731
/// `valid` - Whether the work completion is valid
673732
valid: bool,
674733
/// `error` - Error information if the operation failed
675-
error: Option<(ffi::ibv_wc_status::Type, u32)>,
734+
error: Option<(rdmacore_sys::ibv_wc_status::Type, u32)>,
676735
/// `opcode` - Type of operation that completed (read, write, etc.)
677-
opcode: ffi::ibv_wc_opcode::Type,
736+
opcode: rdmacore_sys::ibv_wc_opcode::Type,
678737
/// `bytes` - Immediate data (if any)
679738
bytes: Option<u32>,
680739
/// `qp_num` - Queue Pair Number
@@ -691,8 +750,8 @@ pub struct IbvWc {
691750
dlid_path_bits: u8,
692751
}
693752

694-
impl From<ffi::ibv_wc> for IbvWc {
695-
fn from(wc: ffi::ibv_wc) -> Self {
753+
impl From<rdmacore_sys::ibv_wc> for IbvWc {
754+
fn from(wc: rdmacore_sys::ibv_wc) -> Self {
696755
IbvWc {
697756
wr_id: wc.wr_id(),
698757
len: wc.len(),
@@ -804,21 +863,21 @@ mod tests {
804863
#[test]
805864
fn test_rdma_operation_conversion() {
806865
assert_eq!(
807-
ffi::ibv_wr_opcode::IBV_WR_RDMA_WRITE,
808-
ffi::ibv_wr_opcode::Type::from(RdmaOperation::Write)
866+
rdmacore_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE,
867+
rdmacore_sys::ibv_wr_opcode::Type::from(RdmaOperation::Write)
809868
);
810869
assert_eq!(
811-
ffi::ibv_wr_opcode::IBV_WR_RDMA_READ,
812-
ffi::ibv_wr_opcode::Type::from(RdmaOperation::Read)
870+
rdmacore_sys::ibv_wr_opcode::IBV_WR_RDMA_READ,
871+
rdmacore_sys::ibv_wr_opcode::Type::from(RdmaOperation::Read)
813872
);
814873

815874
assert_eq!(
816875
RdmaOperation::Write,
817-
RdmaOperation::from(ffi::ibv_wc_opcode::IBV_WC_RDMA_WRITE)
876+
RdmaOperation::from(rdmacore_sys::ibv_wc_opcode::IBV_WC_RDMA_WRITE)
818877
);
819878
assert_eq!(
820879
RdmaOperation::Read,
821-
RdmaOperation::from(ffi::ibv_wc_opcode::IBV_WC_RDMA_READ)
880+
RdmaOperation::from(rdmacore_sys::ibv_wc_opcode::IBV_WC_RDMA_READ)
822881
);
823882
}
824883

@@ -839,18 +898,18 @@ mod tests {
839898

840899
#[test]
841900
fn test_ibv_wc() {
842-
let mut wc = ffi::ibv_wc::default();
901+
let mut wc = rdmacore_sys::ibv_wc::default();
843902

844903
// SAFETY: modifies private fields through pointer manipulation
845904
unsafe {
846905
// Cast to pointer and modify the fields directly
847-
let wc_ptr = &mut wc as *mut ffi::ibv_wc as *mut u8;
906+
let wc_ptr = &mut wc as *mut rdmacore_sys::ibv_wc as *mut u8;
848907

849908
// Set wr_id (at offset 0, u64)
850909
*(wc_ptr as *mut u64) = 42;
851910

852911
// Set status to SUCCESS (at offset 8, u32)
853-
*(wc_ptr.add(8) as *mut i32) = ffi::ibv_wc_status::IBV_WC_SUCCESS as i32;
912+
*(wc_ptr.add(8) as *mut i32) = rdmacore_sys::ibv_wc_status::IBV_WC_SUCCESS as i32;
854913
}
855914
let ibv_wc = IbvWc::from(wc);
856915
assert_eq!(ibv_wc.wr_id(), 42);

monarch_rdma/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
// RDMA requires frequent unsafe code blocks
10+
#![allow(clippy::undocumented_unsafe_blocks)]
11+
912
mod ibverbs_primitives;
1013
mod rdma_components;
1114
mod rdma_manager_actor;
1215
mod test_utils;
1316

17+
#[macro_use]
18+
mod macros;
19+
1420
pub use ibverbs_primitives::*;
1521
pub use rdma_components::*;
1622
pub use rdma_manager_actor::*;

0 commit comments

Comments
 (0)