Skip to content

Commit d8a732f

Browse files
committed
Improve kernel checking + added cubin dump lint
1 parent 1ab8b47 commit d8a732f

File tree

6 files changed

+278
-185
lines changed

6 files changed

+278
-185
lines changed

rust-cuda-derive/src/kernel/link/config.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,21 @@ impl syn::parse::Parse for LinkKernelConfig {
6767

6868
#[allow(clippy::module_name_repetitions)]
6969
pub(super) struct CheckKernelConfig {
70+
pub(super) kernel_hash: syn::Ident,
7071
pub(super) args: syn::Ident,
7172
pub(super) crate_name: String,
7273
pub(super) crate_path: PathBuf,
7374
}
7475

7576
impl syn::parse::Parse for CheckKernelConfig {
7677
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
78+
let kernel_hash: syn::Ident = input.parse()?;
7779
let args: syn::Ident = input.parse()?;
7880
let name: syn::LitStr = input.parse()?;
7981
let path: syn::LitStr = input.parse()?;
8082

8183
Ok(Self {
84+
kernel_hash,
8285
args,
8386
crate_name: name.value(),
8487
crate_path: PathBuf::from(path.value()),

rust-cuda-derive/src/kernel/link/mod.rs

Lines changed: 155 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -32,31 +32,38 @@ use error::emit_ptx_build_error;
3232
use ptx_compiler_sys::NvptxError;
3333

3434
pub fn check_kernel(tokens: TokenStream) -> TokenStream {
35-
proc_macro_error::set_dummy(quote! {
36-
"ERROR in this PTX compilation"
37-
});
35+
proc_macro_error::set_dummy(quote! {::core::result::Result::Err(())});
3836

3937
let CheckKernelConfig {
38+
kernel_hash,
4039
args,
4140
crate_name,
4241
crate_path,
4342
} = match syn::parse_macro_input::parse(tokens) {
4443
Ok(config) => config,
4544
Err(err) => {
4645
abort_call_site!(
47-
"check_kernel!(ARGS NAME PATH) expects ARGS identifier, NAME and PATH string \
48-
literals: {:?}",
46+
"check_kernel!(HASH ARGS NAME PATH) expects HASH and ARGS identifiers, annd NAME \
47+
and PATH string literals: {:?}",
4948
err
5049
)
5150
},
5251
};
5352

5453
let kernel_ptx = compile_kernel(&args, &crate_name, &crate_path, Specialisation::Check);
5554

56-
match kernel_ptx {
57-
Some(kernel_ptx) => quote!(#kernel_ptx).into(),
58-
None => quote!("ERROR in this PTX compilation").into(),
59-
}
55+
let Some(kernel_ptx) = kernel_ptx else {
56+
return quote!(::core::result::Result::Err(())).into()
57+
};
58+
59+
check_kernel_ptx_and_report(
60+
&kernel_ptx,
61+
Specialisation::Check,
62+
&kernel_hash,
63+
&HashMap::new(),
64+
);
65+
66+
quote!(::core::result::Result::Ok(())).into()
6067
}
6168

6269
#[allow(clippy::module_name_repetitions, clippy::too_many_lines)]
@@ -77,9 +84,9 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
7784
Ok(config) => config,
7885
Err(err) => {
7986
abort_call_site!(
80-
"link_kernel!(KERNEL ARGS NAME PATH SPECIALISATION LINTS,*) expects KERNEL and \
81-
ARGS identifiers, NAME and PATH string literals, SPECIALISATION and LINTS \
82-
tokens: {:?}",
87+
"link_kernel!(KERNEL HASH ARGS NAME PATH SPECIALISATION LINTS,*) expects KERNEL, \
88+
HASH, and ARGS identifiers, NAME and PATH string literals, and SPECIALISATION \
89+
and LINTS tokens: {:?}",
8390
err
8491
)
8592
},
@@ -213,88 +220,162 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
213220
kernel_ptx.replace_range(type_layout_start..type_layout_end, "");
214221
}
215222

216-
let (result, error_log, info_log, version, drop) =
217-
check_kernel_ptx(&kernel_ptx, &specialisation, &kernel_hash, &ptx_lint_levels);
223+
check_kernel_ptx_and_report(
224+
&kernel_ptx,
225+
Specialisation::Link(&specialisation),
226+
&kernel_hash,
227+
&ptx_lint_levels,
228+
);
229+
230+
(quote! { const PTX_STR: &'static str = #kernel_ptx; #(#type_layouts)* }).into()
231+
}
232+
233+
#[allow(clippy::too_many_lines)]
234+
fn check_kernel_ptx_and_report(
235+
kernel_ptx: &str,
236+
specialisation: Specialisation,
237+
kernel_hash: &proc_macro2::Ident,
238+
ptx_lint_levels: &HashMap<PtxLint, LintLevel>,
239+
) {
240+
let (result, error_log, info_log, binary, version, drop) =
241+
check_kernel_ptx(kernel_ptx, specialisation, kernel_hash, ptx_lint_levels);
218242

219243
let ptx_compiler = match &version {
220244
Ok((major, minor)) => format!("PTX compiler v{major}.{minor}"),
221245
Err(_) => String::from("PTX compiler"),
222246
};
223247

224-
// TODO: allow user to select
225-
// - warn on double
226-
// - warn on float
227-
// - warn on spills
228-
// - verbose warn
229-
// - warnings as errors
230-
// - show PTX source if warning or error
231-
232248
let mut errors = String::new();
249+
233250
if let Err(err) = drop {
234251
let _ = errors.write_fmt(format_args!("Error dropping the {ptx_compiler}: {err}\n"));
235252
}
253+
236254
if let Err(err) = version {
237255
let _ = errors.write_fmt(format_args!(
238256
"Error fetching the version of the {ptx_compiler}: {err}\n"
239257
));
240258
}
241-
if let (Ok(Some(_)), _) | (_, Ok(Some(_))) = (&info_log, &error_log) {
259+
260+
let ptx_source_code = {
242261
let mut max_lines = kernel_ptx.chars().filter(|c| *c == '\n').count() + 1;
243262
let mut indent = 0;
244263
while max_lines > 0 {
245264
max_lines /= 10;
246265
indent += 1;
247266
}
248267

249-
emit_call_site_warning!(
268+
format!(
250269
"PTX source code:\n{}",
251270
kernel_ptx
252271
.lines()
253272
.enumerate()
254273
.map(|(i, l)| format!("{:indent$}| {l}", i + 1))
255274
.collect::<Vec<_>>()
256275
.join("\n")
257-
);
276+
)
277+
};
278+
279+
match binary {
280+
Ok(None) => (),
281+
Ok(Some(binary)) => {
282+
if ptx_lint_levels
283+
.get(&PtxLint::DumpBinary)
284+
.map_or(false, |level| *level > LintLevel::Allow)
285+
{
286+
const HEX: [char; 16] = [
287+
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f',
288+
];
289+
290+
let mut binary_hex = String::with_capacity(binary.len() * 2);
291+
for byte in binary {
292+
binary_hex.push(HEX[usize::from(byte >> 4)]);
293+
binary_hex.push(HEX[usize::from(byte & 0x0F)]);
294+
}
295+
296+
if ptx_lint_levels
297+
.get(&PtxLint::DumpBinary)
298+
.map_or(false, |level| *level > LintLevel::Warn)
299+
{
300+
emit_call_site_error!(
301+
"{} compiled binary:\n{}\n\n{}",
302+
ptx_compiler,
303+
binary_hex,
304+
ptx_source_code
305+
);
306+
} else {
307+
emit_call_site_warning!(
308+
"{} compiled binary:\n{}\n\n{}",
309+
ptx_compiler,
310+
binary_hex,
311+
ptx_source_code
312+
);
313+
}
314+
}
315+
},
316+
Err(err) => {
317+
let _ = errors.write_fmt(format_args!(
318+
"Error fetching the compiled binary from {ptx_compiler}: {err}\n"
319+
));
320+
},
258321
}
322+
259323
match info_log {
260324
Ok(None) => (),
261-
Ok(Some(info_log)) => emit_call_site_warning!("{ptx_compiler} info log:\n{}", info_log),
325+
Ok(Some(info_log)) => emit_call_site_warning!(
326+
"{} info log:\n{}\n{}",
327+
ptx_compiler,
328+
info_log,
329+
ptx_source_code
330+
),
262331
Err(err) => {
263332
let _ = errors.write_fmt(format_args!(
264333
"Error fetching the info log of the {ptx_compiler}: {err}\n"
265334
));
266335
},
267336
};
268-
match error_log {
269-
Ok(None) => (),
270-
Ok(Some(error_log)) => emit_call_site_error!("{ptx_compiler} error log:\n{}", error_log),
337+
338+
let error_log = match error_log {
339+
Ok(None) => String::new(),
340+
Ok(Some(error_log)) => {
341+
format!("{ptx_compiler} error log:\n{error_log}\n{ptx_source_code}")
342+
},
271343
Err(err) => {
272344
let _ = errors.write_fmt(format_args!(
273345
"Error fetching the error log of the {ptx_compiler}: {err}\n"
274346
));
347+
String::new()
275348
},
276349
};
350+
277351
if let Err(err) = result {
278352
let _ = errors.write_fmt(format_args!("Error compiling the PTX source code: {err}\n"));
279353
}
280-
if !errors.is_empty() {
281-
abort_call_site!("{}", errors);
282-
}
283354

284-
(quote! { const PTX_STR: &'static str = #kernel_ptx; #(#type_layouts)* }).into()
355+
if !error_log.is_empty() || !errors.is_empty() {
356+
abort_call_site!(
357+
"{error_log}{}{errors}",
358+
if !error_log.is_empty() && !errors.is_empty() {
359+
"\n\n"
360+
} else {
361+
""
362+
}
363+
);
364+
}
285365
}
286366

287367
#[allow(clippy::type_complexity)]
288368
#[allow(clippy::too_many_lines)]
289369
fn check_kernel_ptx(
290370
kernel_ptx: &str,
291-
specialisation: &str,
371+
specialisation: Specialisation,
292372
kernel_hash: &proc_macro2::Ident,
293373
ptx_lint_levels: &HashMap<PtxLint, LintLevel>,
294374
) -> (
295375
Result<(), NvptxError>,
296376
Result<Option<String>, NvptxError>,
297377
Result<Option<String>, NvptxError>,
378+
Result<Option<Vec<u8>>, NvptxError>,
298379
Result<(u32, u32), NvptxError>,
299380
Result<(), NvptxError>,
300381
) {
@@ -313,14 +394,15 @@ fn check_kernel_ptx(
313394
};
314395

315396
let result = (|| {
316-
let kernel_name = if specialisation.is_empty() {
317-
format!("{kernel_hash}_kernel")
318-
} else {
319-
format!(
397+
let kernel_name = match specialisation {
398+
Specialisation::Check => format!("{kernel_hash}_chECK"),
399+
Specialisation::Link("") => format!("{kernel_hash}_kernel"),
400+
Specialisation::Link(specialisation) => format!(
320401
"{kernel_hash}_kernel_{:016x}",
321402
seahash::hash(specialisation.as_bytes())
322-
)
403+
),
323404
};
405+
324406
let mut options = vec![
325407
CString::new("--entry").unwrap(),
326408
CString::new(kernel_name).unwrap(),
@@ -457,6 +539,39 @@ fn check_kernel_ptx(
457539
Ok(Some(String::from_utf8_lossy(&info_log).into_owned()))
458540
})();
459541

542+
let binary = (|| {
543+
if result.is_err() {
544+
return Ok(None);
545+
}
546+
547+
let mut binary_size = 0;
548+
549+
NvptxError::try_err_from(unsafe {
550+
ptx_compiler_sys::nvPTXCompilerGetCompiledProgramSize(
551+
compiler,
552+
addr_of_mut!(binary_size),
553+
)
554+
})?;
555+
556+
if binary_size == 0 {
557+
return Ok(None);
558+
}
559+
560+
#[allow(clippy::cast_possible_truncation)]
561+
let mut binary: Vec<u8> = Vec::with_capacity(binary_size as usize);
562+
563+
NvptxError::try_err_from(unsafe {
564+
ptx_compiler_sys::nvPTXCompilerGetCompiledProgram(compiler, binary.as_mut_ptr().cast())
565+
})?;
566+
567+
#[allow(clippy::cast_possible_truncation)]
568+
unsafe {
569+
binary.set_len(binary_size as usize);
570+
}
571+
572+
Ok(Some(binary))
573+
})();
574+
460575
let version = (|| {
461576
let mut major = 0;
462577
let mut minor = 0;
@@ -475,7 +590,7 @@ fn check_kernel_ptx(
475590
})
476591
};
477592

478-
(result, error_log, info_log, version, drop)
593+
(result, error_log, info_log, binary, version, drop)
479594
}
480595

481596
fn compile_kernel(

0 commit comments

Comments
 (0)