Skip to content

Commit abdb071

Browse files
committed
Fix failing doctests
1 parent 62e4ac0 commit abdb071

File tree

8 files changed

+161
-42
lines changed

8 files changed

+161
-42
lines changed

crates/cust/src/memory/device/device_box.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,8 @@ impl<T: DeviceCopy> DeviceBox<T> {
290290
/// # let _context = cust::quick_init().unwrap();
291291
/// use cust::memory::*;
292292
/// let x = DeviceBox::new(&5).unwrap();
293-
/// let ptr = DeviceBox::into_device(x).as_raw_mut();
294-
/// let x = unsafe { DeviceBox::from_raw(ptr) };
293+
/// let ptr = DeviceBox::into_device(x).as_raw();
294+
/// let x: DeviceBox<i32> = unsafe { DeviceBox::from_raw(ptr) };
295295
/// ```
296296
pub unsafe fn from_raw(ptr: driver_sys::CUdeviceptr) -> Self {
297297
DeviceBox {

crates/cust/src/memory/device/device_slice.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ impl<T: DeviceCopy> DeviceSlice<T> {
8383
/// # let _context = cust::quick_init().unwrap();
8484
/// use cust::memory::*;
8585
/// let a = DeviceBuffer::from_slice(&[1, 2, 3]).unwrap();
86-
/// println!("{:p}", a.as_ptr());
86+
/// println!("{:p}", a.as_slice().as_device_ptr());
8787
/// ```
8888
pub fn as_device_ptr(&self) -> DevicePointer<T> {
8989
DevicePointer::from_raw(self as *const _ as *const () as usize as u64)

crates/cust/src/memory/pointer.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,8 @@ impl<T: DeviceCopy> DevicePointer<T> {
6969
/// ```
7070
/// # let _context = cust::quick_init().unwrap();
7171
/// use cust::memory::*;
72-
/// use std::ptr;
73-
/// unsafe {
74-
/// let null : *mut u64 = ptr::null_mut();
75-
/// assert!(DevicePointer::wrap(null).is_null());
76-
/// }
72+
/// let null_ptr = DevicePointer::<u64>::null();
73+
/// assert!(null_ptr.is_null());
7774
/// ```
7875
pub fn is_null(self) -> bool {
7976
self.ptr == 0
@@ -245,14 +242,15 @@ impl<T: DeviceCopy> DevicePointer<T> {
245242
///
246243
/// # Examples
247244
///
248-
/// ```
245+
/// ```no_run
249246
/// # let _context = cust::quick_init().unwrap();
250247
/// use cust::memory::*;
251248
/// unsafe {
252249
/// let mut dev_ptr = cuda_malloc::<u64>(5).unwrap();
253-
/// let offset = dev_ptr.add(4).sub(3); // Points to the 2nd u64 in the buffer
250+
/// let offset = dev_ptr.add(3).sub(2); // Points to the 2nd u64 in the buffer
254251
/// cuda_free(dev_ptr); // Must free the buffer using the original pointer
255252
/// }
253+
/// ```
256254
#[allow(clippy::should_implement_trait)]
257255
pub unsafe fn sub(self, count: usize) -> Self
258256
where
@@ -309,7 +307,7 @@ impl<T: DeviceCopy> DevicePointer<T> {
309307
///
310308
/// # Examples
311309
///
312-
/// ```
310+
/// ```no_run
313311
/// # let _context = cust::quick_init().unwrap();
314312
/// use cust::memory::*;
315313
/// unsafe {

crates/cust/src/module.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ impl Module {
227227
///
228228
/// # Example
229229
///
230-
/// ```
230+
/// ```no_run
231231
/// # use cust::*;
232232
/// # use std::error::Error;
233233
/// # fn main() -> Result<(), Box<dyn Error>> {
@@ -290,7 +290,7 @@ impl Module {
290290
/// # fn main() -> Result<(), Box<dyn Error>> {
291291
/// # let _ctx = quick_init()?;
292292
/// use cust::module::Module;
293-
/// let ptx = std::fs::read("./resources/add.ptx")?;
293+
/// let ptx = std::fs::read_to_string("./resources/add.ptx")?;
294294
/// let module = Module::from_ptx(&ptx, &[])?;
295295
/// # Ok(())
296296
/// # }

crates/cust_core/src/lib.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ pub mod _hidden {
2020
/// There are two ways to implement DeviceCopy on your type. The simplest is to use `derive`:
2121
///
2222
/// ```
23-
/// use cust::DeviceCopy;
23+
/// use cust_core::DeviceCopy;
2424
///
25-
/// #[derive(Clone, DeviceCopy)]
25+
/// #[derive(Clone, Copy, DeviceCopy)]
2626
/// struct MyStruct(u64);
2727
///
2828
/// # fn main () {}
@@ -33,7 +33,7 @@ pub mod _hidden {
3333
/// be copied to the device:
3434
///
3535
/// ```compile_fail
36-
/// use cust::DeviceCopy;
36+
/// use cust_core::DeviceCopy;
3737
///
3838
/// #[derive(Clone, DeviceCopy)]
3939
/// struct MyStruct(Vec<u64>);
@@ -43,9 +43,9 @@ pub mod _hidden {
4343
/// You can also implement `DeviceCopy` unsafely:
4444
///
4545
/// ```
46-
/// use cust::memory::DeviceCopy;
46+
/// use cust_core::DeviceCopy;
4747
///
48-
/// #[derive(Clone)]
48+
/// #[derive(Clone, Copy)]
4949
/// struct MyStruct(u64);
5050
///
5151
/// unsafe impl DeviceCopy for MyStruct { }

crates/cust_raw/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ build = "build/main.rs"
1313
bindgen = "0.71.1"
1414
bimap = "0.6.3"
1515
cc = "1.2.17"
16+
doxygen-bindgen = "0.1"
1617

1718
[package.metadata.docs.rs]
1819
features = [

crates/cust_raw/build/callbacks.rs

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,118 @@ use std::fs;
33
use std::path;
44
use std::sync;
55

6-
use bindgen::callbacks::{ItemInfo, ItemKind, MacroParsingBehavior, ParseCallbacks};
6+
use bindgen::callbacks::{DeriveInfo, ItemInfo, ItemKind, MacroParsingBehavior, ParseCallbacks};
7+
8+
/// Enum to handle different callback combinations
9+
#[derive(Debug)]
10+
pub(crate) enum BindgenCallbacks {
11+
/// For bindings that need function renaming (driver, runtime, cublas)
12+
WithFunctionRenames {
13+
function_renames: FunctionRenames,
14+
cargo_callbacks: bindgen::CargoCallbacks,
15+
},
16+
/// For bindings that only need comment processing (nvptx, nvvm)
17+
Simple {
18+
cargo_callbacks: bindgen::CargoCallbacks,
19+
},
20+
}
21+
22+
impl BindgenCallbacks {
23+
pub fn with_function_renames(function_renames: FunctionRenames) -> Self {
24+
Self::WithFunctionRenames {
25+
function_renames,
26+
cargo_callbacks: bindgen::CargoCallbacks::new(),
27+
}
28+
}
29+
30+
pub fn simple() -> Self {
31+
Self::Simple {
32+
cargo_callbacks: bindgen::CargoCallbacks::new(),
33+
}
34+
}
35+
}
36+
37+
impl ParseCallbacks for BindgenCallbacks {
38+
fn process_comment(&self, comment: &str) -> Option<String> {
39+
// First replace backslashes with @ to avoid doctest parsing issues
40+
let cleaned = comment.replace('\\', "@");
41+
// Then transform doxygen syntax to rustdoc
42+
match doxygen_bindgen::transform(&cleaned) {
43+
Ok(res) => Some(res),
44+
Err(err) => {
45+
println!(
46+
"cargo:warning=Problem processing doxygen comment: {}\n{}",
47+
comment, err
48+
);
49+
None
50+
}
51+
}
52+
}
53+
54+
fn will_parse_macro(&self, name: &str) -> MacroParsingBehavior {
55+
match self {
56+
Self::WithFunctionRenames {
57+
function_renames, ..
58+
} => function_renames.will_parse_macro(name),
59+
Self::Simple { .. } => MacroParsingBehavior::Default,
60+
}
61+
}
62+
63+
fn item_name(&self, original_item_name: &str) -> Option<String> {
64+
match self {
65+
Self::WithFunctionRenames {
66+
function_renames, ..
67+
} => function_renames.item_name(original_item_name),
68+
Self::Simple { .. } => None,
69+
}
70+
}
71+
72+
fn add_derives(&self, info: &DeriveInfo) -> Vec<String> {
73+
match self {
74+
Self::WithFunctionRenames {
75+
function_renames, ..
76+
} => ParseCallbacks::add_derives(function_renames, info),
77+
Self::Simple { .. } => vec![],
78+
}
79+
}
80+
81+
fn generated_name_override(&self, item_info: ItemInfo<'_>) -> Option<String> {
82+
match self {
83+
Self::WithFunctionRenames {
84+
function_renames, ..
85+
} => ParseCallbacks::generated_name_override(function_renames, item_info),
86+
Self::Simple { .. } => None,
87+
}
88+
}
89+
90+
fn generated_link_name_override(&self, item_info: ItemInfo<'_>) -> Option<String> {
91+
match self {
92+
Self::WithFunctionRenames {
93+
function_renames, ..
94+
} => ParseCallbacks::generated_link_name_override(function_renames, item_info),
95+
Self::Simple { .. } => None,
96+
}
97+
}
98+
99+
// Delegate cargo callbacks
100+
fn include_file(&self, filename: &str) {
101+
match self {
102+
Self::WithFunctionRenames {
103+
cargo_callbacks, ..
104+
}
105+
| Self::Simple { cargo_callbacks } => cargo_callbacks.include_file(filename),
106+
}
107+
}
108+
109+
fn read_env_var(&self, var: &str) {
110+
match self {
111+
Self::WithFunctionRenames {
112+
cargo_callbacks, ..
113+
}
114+
| Self::Simple { cargo_callbacks } => cargo_callbacks.read_env_var(var),
115+
}
116+
}
117+
}
7118

8119
/// Struct to handle renaming of functions through macro expansion.
9120
#[derive(Debug)]
@@ -123,4 +234,8 @@ impl ParseCallbacks for FunctionRenames {
123234
_ => None,
124235
}
125236
}
237+
238+
fn add_derives(&self, _info: &DeriveInfo) -> Vec<String> {
239+
vec![]
240+
}
126241
}

crates/cust_raw/build/main.rs

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,14 @@ fn create_cuda_driver_bindings(
118118
println!("cargo::rerun-if-changed={}", header.display());
119119
let bindings = bindgen::Builder::default()
120120
.header(header.to_str().expect("header should be valid UTF-8"))
121-
.parse_callbacks(Box::new(callbacks::FunctionRenames::new(
122-
"cu",
123-
outdir,
124-
header,
125-
sdk.cuda_include_paths().to_owned(),
126-
)))
127-
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
121+
.parse_callbacks(Box::new(
122+
callbacks::BindgenCallbacks::with_function_renames(callbacks::FunctionRenames::new(
123+
"cu",
124+
outdir,
125+
header,
126+
sdk.cuda_include_paths().to_owned(),
127+
)),
128+
))
128129
.clang_args(
129130
sdk.cuda_include_paths()
130131
.iter()
@@ -167,13 +168,14 @@ fn create_cuda_runtime_bindings(
167168
println!("cargo::rerun-if-changed={}", header.display());
168169
let bindings = bindgen::Builder::default()
169170
.header(header.to_str().expect("header should be valid UTF-8"))
170-
.parse_callbacks(Box::new(callbacks::FunctionRenames::new(
171-
"cuda",
172-
outdir,
173-
header,
174-
sdk.cuda_include_paths().to_owned(),
175-
)))
176-
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
171+
.parse_callbacks(Box::new(
172+
callbacks::BindgenCallbacks::with_function_renames(callbacks::FunctionRenames::new(
173+
"cuda",
174+
outdir,
175+
header,
176+
sdk.cuda_include_paths().to_owned(),
177+
)),
178+
))
177179
.clang_args(
178180
sdk.cuda_include_paths()
179181
.iter()
@@ -217,13 +219,16 @@ fn create_cublas_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path, manifest
217219
println!("cargo::rerun-if-changed={}", header.display());
218220
let bindings = bindgen::Builder::default()
219221
.header(header.to_str().expect("header should be valid UTF-8"))
220-
.parse_callbacks(Box::new(callbacks::FunctionRenames::new(
221-
pkg,
222-
outdir,
223-
header,
224-
sdk.cuda_include_paths().to_owned(),
225-
)))
226-
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
222+
.parse_callbacks(Box::new(
223+
callbacks::BindgenCallbacks::with_function_renames(
224+
callbacks::FunctionRenames::new(
225+
pkg,
226+
outdir,
227+
header.clone(),
228+
sdk.cuda_include_paths().to_owned(),
229+
),
230+
),
231+
))
227232
.clang_args(
228233
sdk.cuda_include_paths()
229234
.iter()
@@ -263,7 +268,7 @@ fn create_nptx_compiler_bindings(
263268
println!("cargo::rerun-if-changed={}", header.display());
264269
let bindings = bindgen::Builder::default()
265270
.header(header.to_str().expect("header should be valid UTF-8"))
266-
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
271+
.parse_callbacks(Box::new(callbacks::BindgenCallbacks::simple()))
267272
.clang_args(
268273
sdk.cuda_include_paths()
269274
.iter()
@@ -298,7 +303,7 @@ fn create_nvvm_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path, manifest_d
298303
println!("cargo::rerun-if-changed={}", header.display());
299304
let bindings = bindgen::Builder::default()
300305
.header(header.to_str().expect("header should be valid UTF-8"))
301-
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
306+
.parse_callbacks(Box::new(callbacks::BindgenCallbacks::simple()))
302307
.clang_args(
303308
sdk.nvvm_include_paths()
304309
.iter()

0 commit comments

Comments
 (0)