Skip to content

Fix failing doctests #235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/cust/src/memory/device/device_box.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ impl<T: DeviceCopy> DeviceBox<T> {
/// # let _context = cust::quick_init().unwrap();
/// use cust::memory::*;
/// let x = DeviceBox::new(&5).unwrap();
/// let ptr = DeviceBox::into_device(x).as_raw_mut();
/// let x = unsafe { DeviceBox::from_raw(ptr) };
/// let ptr = DeviceBox::into_device(x).as_raw();
/// let x: DeviceBox<i32> = unsafe { DeviceBox::from_raw(ptr) };
/// ```
pub unsafe fn from_raw(ptr: driver_sys::CUdeviceptr) -> Self {
DeviceBox {
Expand Down
2 changes: 1 addition & 1 deletion crates/cust/src/memory/device/device_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl<T: DeviceCopy> DeviceSlice<T> {
/// # let _context = cust::quick_init().unwrap();
/// use cust::memory::*;
/// let a = DeviceBuffer::from_slice(&[1, 2, 3]).unwrap();
/// println!("{:p}", a.as_ptr());
/// println!("{:p}", a.as_slice().as_device_ptr());
/// ```
pub fn as_device_ptr(&self) -> DevicePointer<T> {
DevicePointer::from_raw(self as *const _ as *const () as usize as u64)
Expand Down
14 changes: 6 additions & 8 deletions crates/cust/src/memory/pointer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,8 @@ impl<T: DeviceCopy> DevicePointer<T> {
/// ```
/// # let _context = cust::quick_init().unwrap();
/// use cust::memory::*;
/// use std::ptr;
/// unsafe {
/// let null : *mut u64 = ptr::null_mut();
/// assert!(DevicePointer::wrap(null).is_null());
/// }
/// let null_ptr = DevicePointer::<u64>::null();
/// assert!(null_ptr.is_null());
/// ```
pub fn is_null(self) -> bool {
self.ptr == 0
Expand Down Expand Up @@ -245,14 +242,15 @@ impl<T: DeviceCopy> DevicePointer<T> {
///
/// # Examples
///
/// ```
/// ```no_run
/// # let _context = cust::quick_init().unwrap();
/// use cust::memory::*;
/// unsafe {
/// let mut dev_ptr = cuda_malloc::<u64>(5).unwrap();
/// let offset = dev_ptr.add(4).sub(3); // Points to the 2nd u64 in the buffer
/// let offset = dev_ptr.add(3).sub(2); // Points to the 2nd u64 in the buffer
/// cuda_free(dev_ptr); // Must free the buffer using the original pointer
/// }
/// ```
#[allow(clippy::should_implement_trait)]
pub unsafe fn sub(self, count: usize) -> Self
where
Expand Down Expand Up @@ -309,7 +307,7 @@ impl<T: DeviceCopy> DevicePointer<T> {
///
/// # Examples
///
/// ```
/// ```no_run
/// # let _context = cust::quick_init().unwrap();
/// use cust::memory::*;
/// unsafe {
Expand Down
4 changes: 2 additions & 2 deletions crates/cust/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ impl Module {
///
/// # Example
///
/// ```
/// ```no_run
/// # use cust::*;
/// # use std::error::Error;
/// # fn main() -> Result<(), Box<dyn Error>> {
Expand Down Expand Up @@ -290,7 +290,7 @@ impl Module {
/// # fn main() -> Result<(), Box<dyn Error>> {
/// # let _ctx = quick_init()?;
/// use cust::module::Module;
/// let ptx = std::fs::read("./resources/add.ptx")?;
/// let ptx = std::fs::read_to_string("./resources/add.ptx")?;
/// let module = Module::from_ptx(&ptx, &[])?;
/// # Ok(())
/// # }
Expand Down
10 changes: 5 additions & 5 deletions crates/cust_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ pub mod _hidden {
/// There are two ways to implement DeviceCopy on your type. The simplest is to use `derive`:
///
/// ```
/// use cust::DeviceCopy;
/// use cust_core::DeviceCopy;
///
/// #[derive(Clone, DeviceCopy)]
/// #[derive(Clone, Copy, DeviceCopy)]
/// struct MyStruct(u64);
///
/// # fn main () {}
Expand All @@ -33,7 +33,7 @@ pub mod _hidden {
/// be copied to the device:
///
/// ```compile_fail
/// use cust::DeviceCopy;
/// use cust_core::DeviceCopy;
///
/// #[derive(Clone, DeviceCopy)]
/// struct MyStruct(Vec<u64>);
Expand All @@ -43,9 +43,9 @@ pub mod _hidden {
/// You can also implement `DeviceCopy` unsafely:
///
/// ```
/// use cust::memory::DeviceCopy;
/// use cust_core::DeviceCopy;
///
/// #[derive(Clone)]
/// #[derive(Clone, Copy)]
/// struct MyStruct(u64);
///
/// unsafe impl DeviceCopy for MyStruct { }
Expand Down
1 change: 1 addition & 0 deletions crates/cust_raw/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ build = "build/main.rs"
bindgen = "0.71.1"
bimap = "0.6.3"
cc = "1.2.17"
doxygen-bindgen = "0.1"

[package.metadata.docs.rs]
features = [
Expand Down
117 changes: 116 additions & 1 deletion crates/cust_raw/build/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,118 @@ use std::fs;
use std::path;
use std::sync;

use bindgen::callbacks::{ItemInfo, ItemKind, MacroParsingBehavior, ParseCallbacks};
use bindgen::callbacks::{DeriveInfo, ItemInfo, ItemKind, MacroParsingBehavior, ParseCallbacks};

/// Enum to handle different callback combinations
#[derive(Debug)]
pub(crate) enum BindgenCallbacks {
/// For bindings that need function renaming (driver, runtime, cublas)
WithFunctionRenames {
function_renames: Box<FunctionRenames>,
cargo_callbacks: bindgen::CargoCallbacks,
},
/// For bindings that only need comment processing (nvptx, nvvm)
Simple {
cargo_callbacks: bindgen::CargoCallbacks,
},
}

impl BindgenCallbacks {
pub fn with_function_renames(function_renames: FunctionRenames) -> Self {
Self::WithFunctionRenames {
function_renames: Box::new(function_renames),
cargo_callbacks: bindgen::CargoCallbacks::new(),
}
}

pub fn simple() -> Self {
Self::Simple {
cargo_callbacks: bindgen::CargoCallbacks::new(),
}
}
}

impl ParseCallbacks for BindgenCallbacks {
fn process_comment(&self, comment: &str) -> Option<String> {
// First replace backslashes with @ to avoid doctest parsing issues
let cleaned = comment.replace('\\', "@");
// Then transform doxygen syntax to rustdoc
match doxygen_bindgen::transform(&cleaned) {
Ok(res) => Some(res),
Err(err) => {
println!(
"cargo:warning=Problem processing doxygen comment: {}\n{}",
comment, err
);
None
}
}
}

fn will_parse_macro(&self, name: &str) -> MacroParsingBehavior {
match self {
Self::WithFunctionRenames {
function_renames, ..
} => function_renames.will_parse_macro(name),
Self::Simple { .. } => MacroParsingBehavior::Default,
}
}

fn item_name(&self, original_item_name: &str) -> Option<String> {
match self {
Self::WithFunctionRenames {
function_renames, ..
} => function_renames.item_name(original_item_name),
Self::Simple { .. } => None,
}
}

fn add_derives(&self, info: &DeriveInfo) -> Vec<String> {
match self {
Self::WithFunctionRenames {
function_renames, ..
} => ParseCallbacks::add_derives(function_renames.as_ref(), info),
Self::Simple { .. } => vec![],
}
}

fn generated_name_override(&self, item_info: ItemInfo<'_>) -> Option<String> {
match self {
Self::WithFunctionRenames {
function_renames, ..
} => ParseCallbacks::generated_name_override(function_renames.as_ref(), item_info),
Self::Simple { .. } => None,
}
}

fn generated_link_name_override(&self, item_info: ItemInfo<'_>) -> Option<String> {
match self {
Self::WithFunctionRenames {
function_renames, ..
} => ParseCallbacks::generated_link_name_override(function_renames.as_ref(), item_info),
Self::Simple { .. } => None,
}
}

// Delegate cargo callbacks
fn include_file(&self, filename: &str) {
match self {
Self::WithFunctionRenames {
cargo_callbacks, ..
}
| Self::Simple { cargo_callbacks } => cargo_callbacks.include_file(filename),
}
}

fn read_env_var(&self, var: &str) {
match self {
Self::WithFunctionRenames {
cargo_callbacks, ..
}
| Self::Simple { cargo_callbacks } => cargo_callbacks.read_env_var(var),
}
}
}

/// Struct to handle renaming of functions through macro expansion.
#[derive(Debug)]
Expand Down Expand Up @@ -123,4 +234,8 @@ impl ParseCallbacks for FunctionRenames {
_ => None,
}
}

fn add_derives(&self, _info: &DeriveInfo) -> Vec<String> {
vec![]
}
}
51 changes: 28 additions & 23 deletions crates/cust_raw/build/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,14 @@ fn create_cuda_driver_bindings(
println!("cargo::rerun-if-changed={}", header.display());
let bindings = bindgen::Builder::default()
.header(header.to_str().expect("header should be valid UTF-8"))
.parse_callbacks(Box::new(callbacks::FunctionRenames::new(
"cu",
outdir,
header,
sdk.cuda_include_paths().to_owned(),
)))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.parse_callbacks(Box::new(
callbacks::BindgenCallbacks::with_function_renames(callbacks::FunctionRenames::new(
"cu",
outdir,
header,
sdk.cuda_include_paths().to_owned(),
)),
))
.clang_args(
sdk.cuda_include_paths()
.iter()
Expand Down Expand Up @@ -167,13 +168,14 @@ fn create_cuda_runtime_bindings(
println!("cargo::rerun-if-changed={}", header.display());
let bindings = bindgen::Builder::default()
.header(header.to_str().expect("header should be valid UTF-8"))
.parse_callbacks(Box::new(callbacks::FunctionRenames::new(
"cuda",
outdir,
header,
sdk.cuda_include_paths().to_owned(),
)))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.parse_callbacks(Box::new(
callbacks::BindgenCallbacks::with_function_renames(callbacks::FunctionRenames::new(
"cuda",
outdir,
header,
sdk.cuda_include_paths().to_owned(),
)),
))
.clang_args(
sdk.cuda_include_paths()
.iter()
Expand Down Expand Up @@ -217,13 +219,16 @@ fn create_cublas_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path, manifest
println!("cargo::rerun-if-changed={}", header.display());
let bindings = bindgen::Builder::default()
.header(header.to_str().expect("header should be valid UTF-8"))
.parse_callbacks(Box::new(callbacks::FunctionRenames::new(
pkg,
outdir,
header,
sdk.cuda_include_paths().to_owned(),
)))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.parse_callbacks(Box::new(
callbacks::BindgenCallbacks::with_function_renames(
callbacks::FunctionRenames::new(
pkg,
outdir,
header.clone(),
sdk.cuda_include_paths().to_owned(),
),
),
))
.clang_args(
sdk.cuda_include_paths()
.iter()
Expand Down Expand Up @@ -263,7 +268,7 @@ fn create_nptx_compiler_bindings(
println!("cargo::rerun-if-changed={}", header.display());
let bindings = bindgen::Builder::default()
.header(header.to_str().expect("header should be valid UTF-8"))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.parse_callbacks(Box::new(callbacks::BindgenCallbacks::simple()))
.clang_args(
sdk.cuda_include_paths()
.iter()
Expand Down Expand Up @@ -298,7 +303,7 @@ fn create_nvvm_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path, manifest_d
println!("cargo::rerun-if-changed={}", header.display());
let bindings = bindgen::Builder::default()
.header(header.to_str().expect("header should be valid UTF-8"))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.parse_callbacks(Box::new(callbacks::BindgenCallbacks::simple()))
.clang_args(
sdk.nvvm_include_paths()
.iter()
Expand Down
Loading