Skip to content

Commit 55c34d4

Browse files
allenwang28facebook-github-bot
authored andcommitted
(4/N) Add Cargo support to cuda-sys and rdmacore-sys (#548)
Summary: Pull Request resolved: #548 This diff adds Cargo support for `cuda-sys` and `rdmacore-sys`, which will be necessary for the open source build. Note - I found that `arc autocargo` does not support `rust_bindgen_library` so these Cargo.tomls were created manually. Reviewed By: dstaay-fb Differential Revision: D78414608 fbshipit-source-id: 41c91978cec09ec2ce9e9a57e39281bea878b61c
1 parent 7f6c2b9 commit 55c34d4

File tree

6 files changed

+431
-195
lines changed

6 files changed

+431
-195
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
resolver = "2"
33
members = [
44
"controller",
5+
"cuda-sys",
56
"hyper",
67
"hyperactor",
78
"hyperactor_macros",
@@ -12,5 +13,6 @@ members = [
1213
"monarch_extension",
1314
"monarch_tensor_worker",
1415
"nccl-sys",
16+
"rdmacore-sys",
1517
"torch-sys",
1618
]

cuda-sys/build.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
use std::env;
10+
use std::path::Path;
11+
use std::path::PathBuf;
12+
13+
use glob::glob;
14+
use which::which;
15+
16+
// Translated from torch/utils/cpp_extension.py
17+
fn find_cuda_home() -> Option<String> {
18+
// Guess #1
19+
let mut cuda_home = env::var("CUDA_HOME")
20+
.ok()
21+
.or_else(|| env::var("CUDA_PATH").ok());
22+
23+
if cuda_home.is_none() {
24+
// Guess #2
25+
if let Ok(nvcc_path) = which("nvcc") {
26+
// Get parent directory twice
27+
if let Some(cuda_dir) = nvcc_path.parent().and_then(|p| p.parent()) {
28+
cuda_home = Some(cuda_dir.to_string_lossy().into_owned());
29+
}
30+
} else {
31+
// Guess #3
32+
if cfg!(windows) {
33+
// Windows code
34+
let pattern = r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v*.*";
35+
let cuda_homes: Vec<_> = glob(pattern).unwrap().filter_map(Result::ok).collect();
36+
if !cuda_homes.is_empty() {
37+
cuda_home = Some(cuda_homes[0].to_string_lossy().into_owned());
38+
} else {
39+
cuda_home = None;
40+
}
41+
} else {
42+
// Not Windows
43+
let cuda_candidate = "/usr/local/cuda";
44+
if Path::new(cuda_candidate).exists() {
45+
cuda_home = Some(cuda_candidate.to_string());
46+
} else {
47+
cuda_home = None;
48+
}
49+
}
50+
}
51+
}
52+
cuda_home
53+
}
54+
55+
fn main() {
56+
let cuda_home = find_cuda_home().expect("Could not find CUDA installation");
57+
58+
// Tell cargo to look for shared libraries in the CUDA directory
59+
println!("cargo:rustc-link-search={}/lib64", cuda_home);
60+
println!("cargo:rustc-link-search={}/lib", cuda_home);
61+
62+
// Link against the CUDA libraries
63+
println!("cargo:rustc-link-lib=cuda");
64+
println!("cargo:rustc-link-lib=cudart");
65+
66+
// Tell cargo to invalidate the built crate whenever the wrapper changes
67+
println!("cargo:rerun-if-changed=src/wrapper.h");
68+
69+
// Add cargo metadata
70+
println!("cargo:rustc-cfg=cargo");
71+
println!("cargo:rustc-check-cfg=cfg(cargo)");
72+
73+
// The bindgen::Builder is the main entry point to bindgen
74+
let bindings = bindgen::Builder::default()
75+
// The input header we would like to generate bindings for
76+
.header("src/wrapper.h")
77+
// Add the CUDA include directory
78+
.clang_arg(format!("-I{}/include", cuda_home))
79+
// Parse as C++
80+
.clang_arg("-x")
81+
.clang_arg("c++")
82+
.clang_arg("-std=gnu++20")
83+
// Allow the specified functions and types
84+
.allowlist_function("cu.*")
85+
.allowlist_function("CU.*")
86+
.allowlist_type("cu.*")
87+
.allowlist_type("CU.*")
88+
// Use newtype enum style
89+
.default_enum_style(bindgen::EnumVariation::NewType {
90+
is_bitfield: false,
91+
is_global: false,
92+
})
93+
// Finish the builder and generate the bindings
94+
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
95+
.generate()
96+
// Unwrap the Result and panic on failure
97+
.expect("Unable to generate bindings");
98+
99+
// Write the bindings to the $OUT_DIR/bindings.rs file
100+
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
101+
bindings
102+
.write_to_file(out_path.join("bindings.rs"))
103+
.expect("Couldn't write bindings!");
104+
}

cuda-sys/src/lib.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,49 @@
55
* This source code is licensed under the BSD-style license found in the
66
* LICENSE file in the root directory of this source tree.
77
*/
8+
9+
/*
10+
* Copyright (c) Meta Platforms, Inc. and affiliates.
11+
* All rights reserved.
12+
*
13+
* This source code is licensed under the BSD-style license found in the
14+
* LICENSE file in the root directory of this source tree.
15+
*/
16+
17+
use cxx::ExternType;
18+
use cxx::type_id;
19+
20+
/// SAFETY: bindings
21+
unsafe impl ExternType for CUstream_st {
22+
type Id = type_id!("CUstream_st");
23+
type Kind = cxx::kind::Opaque;
24+
}
25+
26+
// When building with cargo, this is actually the lib.rs file for a crate.
27+
// Include the generated bindings.rs and suppress lints.
28+
#[allow(non_camel_case_types)]
29+
#[allow(non_upper_case_globals)]
30+
#[allow(non_snake_case)]
31+
mod inner {
32+
#[cfg(cargo)]
33+
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
34+
}
35+
36+
pub use inner::*;
37+
38+
#[cfg(test)]
39+
mod tests {
40+
use std::mem::MaybeUninit;
41+
42+
use super::*;
43+
44+
#[test]
45+
fn sanity() {
46+
// SAFETY: testing bindings
47+
unsafe {
48+
let mut version = MaybeUninit::<i32>::uninit();
49+
let result = cuDriverGetVersion(version.as_mut_ptr());
50+
assert_eq!(result, cudaError_enum(0));
51+
}
52+
}
53+
}

monarch_rdma/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ license = "BSD-3-Clause"
1010
[dependencies]
1111
anyhow = "1.0.98"
1212
async-trait = "0.1.86"
13+
cuda-sys = { path = "../cuda-sys" }
1314
hyperactor = { version = "0.0.0", path = "../hyperactor" }
1415
rand = { version = "0.8", features = ["small_rng"] }
16+
rdmacore-sys = { path = "../rdmacore-sys" }
1517
serde = { version = "1.0.185", features = ["derive", "rc"] }
1618
tracing = { version = "0.1.41", features = ["attributes", "valuable"] }
1719

rdmacore-sys/build.rs

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
use std::env;
10+
use std::path::PathBuf;
11+
12+
fn main() {
13+
// Tell cargo to look for shared libraries in the specified directory
14+
println!("cargo:rustc-link-search=/usr/lib");
15+
println!("cargo:rustc-link-search=/usr/lib64");
16+
17+
// Link against the ibverbs library
18+
println!("cargo:rustc-link-lib=ibverbs");
19+
20+
// Link against the mlx5 library
21+
println!("cargo:rustc-link-lib=mlx5");
22+
23+
// Tell cargo to invalidate the built crate whenever the wrapper changes
24+
println!("cargo:rerun-if-changed=src/wrapper.h");
25+
26+
// Add cargo metadata
27+
println!("cargo:rustc-cfg=cargo");
28+
println!("cargo:rustc-check-cfg=cfg(cargo)");
29+
30+
// The bindgen::Builder is the main entry point to bindgen
31+
let bindings = bindgen::Builder::default()
32+
// The input header we would like to generate bindings for
33+
.header("src/wrapper.h")
34+
// Allow the specified functions, types, and variables
35+
.allowlist_function("ibv_.*")
36+
.allowlist_function("mlx5dv_.*")
37+
.allowlist_function("mlx5_wqe_.*")
38+
.allowlist_type("ibv_.*")
39+
.allowlist_type("mlx5dv_.*")
40+
.allowlist_type("mlx5_wqe_.*")
41+
.allowlist_var("MLX5_.*")
42+
// Block specific types that are manually defined in lib.rs
43+
.blocklist_type("ibv_wc")
44+
.blocklist_type("mlx5_wqe_ctrl_seg")
45+
// Apply the same bindgen flags as in the BUCK file
46+
.bitfield_enum("ibv_access_flags")
47+
.bitfield_enum("ibv_qp_attr_mask")
48+
.bitfield_enum("ibv_wc_flags")
49+
.bitfield_enum("ibv_send_flags")
50+
.bitfield_enum("ibv_port_cap_flags")
51+
.constified_enum_module("ibv_qp_type")
52+
.constified_enum_module("ibv_qp_state")
53+
.constified_enum_module("ibv_port_state")
54+
.constified_enum_module("ibv_wc_opcode")
55+
.constified_enum_module("ibv_wr_opcode")
56+
.constified_enum_module("ibv_wc_status")
57+
.derive_default(true)
58+
.prepend_enum_name(false)
59+
// Finish the builder and generate the bindings
60+
.generate()
61+
// Unwrap the Result and panic on failure
62+
.expect("Unable to generate bindings");
63+
64+
// Write the bindings to the $OUT_DIR/bindings.rs file
65+
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
66+
bindings
67+
.write_to_file(out_path.join("bindings.rs"))
68+
.expect("Couldn't write bindings!");
69+
}

0 commit comments

Comments
 (0)