@@ -32,31 +32,38 @@ use error::emit_ptx_build_error;
32
32
use ptx_compiler_sys:: NvptxError ;
33
33
34
34
pub fn check_kernel ( tokens : TokenStream ) -> TokenStream {
35
- proc_macro_error:: set_dummy ( quote ! {
36
- "ERROR in this PTX compilation"
37
- } ) ;
35
+ proc_macro_error:: set_dummy ( quote ! { :: core:: result:: Result :: Err ( ( ) ) } ) ;
38
36
39
37
let CheckKernelConfig {
38
+ kernel_hash,
40
39
args,
41
40
crate_name,
42
41
crate_path,
43
42
} = match syn:: parse_macro_input:: parse ( tokens) {
44
43
Ok ( config) => config,
45
44
Err ( err) => {
46
45
abort_call_site ! (
47
- "check_kernel!(ARGS NAME PATH) expects ARGS identifier, NAME and PATH string \
48
- literals: {:?}",
46
+ "check_kernel!(HASH ARGS NAME PATH) expects HASH and ARGS identifiers, annd NAME \
47
+ and PATH string literals: {:?}",
49
48
err
50
49
)
51
50
} ,
52
51
} ;
53
52
54
53
let kernel_ptx = compile_kernel ( & args, & crate_name, & crate_path, Specialisation :: Check ) ;
55
54
56
- match kernel_ptx {
57
- Some ( kernel_ptx) => quote ! ( #kernel_ptx) . into ( ) ,
58
- None => quote ! ( "ERROR in this PTX compilation" ) . into ( ) ,
59
- }
55
+ let Some ( kernel_ptx) = kernel_ptx else {
56
+ return quote ! ( :: core:: result:: Result :: Err ( ( ) ) ) . into ( )
57
+ } ;
58
+
59
+ check_kernel_ptx_and_report (
60
+ & kernel_ptx,
61
+ Specialisation :: Check ,
62
+ & kernel_hash,
63
+ & HashMap :: new ( ) ,
64
+ ) ;
65
+
66
+ quote ! ( :: core:: result:: Result :: Ok ( ( ) ) ) . into ( )
60
67
}
61
68
62
69
#[ allow( clippy:: module_name_repetitions, clippy:: too_many_lines) ]
@@ -77,9 +84,9 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
77
84
Ok ( config) => config,
78
85
Err ( err) => {
79
86
abort_call_site ! (
80
- "link_kernel!(KERNEL ARGS NAME PATH SPECIALISATION LINTS,*) expects KERNEL and \
81
- ARGS identifiers, NAME and PATH string literals, SPECIALISATION and LINTS \
82
- tokens: {:?}",
87
+ "link_kernel!(KERNEL HASH ARGS NAME PATH SPECIALISATION LINTS,*) expects KERNEL, \
88
+ HASH, and ARGS identifiers, NAME and PATH string literals, and SPECIALISATION \
89
+ and LINTS tokens: {:?}",
83
90
err
84
91
)
85
92
} ,
@@ -213,88 +220,162 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
213
220
kernel_ptx. replace_range ( type_layout_start..type_layout_end, "" ) ;
214
221
}
215
222
216
- let ( result, error_log, info_log, version, drop) =
217
- check_kernel_ptx ( & kernel_ptx, & specialisation, & kernel_hash, & ptx_lint_levels) ;
223
+ check_kernel_ptx_and_report (
224
+ & kernel_ptx,
225
+ Specialisation :: Link ( & specialisation) ,
226
+ & kernel_hash,
227
+ & ptx_lint_levels,
228
+ ) ;
229
+
230
+ ( quote ! { const PTX_STR : & ' static str = #kernel_ptx; #( #type_layouts) * } ) . into ( )
231
+ }
232
+
233
+ #[ allow( clippy:: too_many_lines) ]
234
+ fn check_kernel_ptx_and_report (
235
+ kernel_ptx : & str ,
236
+ specialisation : Specialisation ,
237
+ kernel_hash : & proc_macro2:: Ident ,
238
+ ptx_lint_levels : & HashMap < PtxLint , LintLevel > ,
239
+ ) {
240
+ let ( result, error_log, info_log, binary, version, drop) =
241
+ check_kernel_ptx ( kernel_ptx, specialisation, kernel_hash, ptx_lint_levels) ;
218
242
219
243
let ptx_compiler = match & version {
220
244
Ok ( ( major, minor) ) => format ! ( "PTX compiler v{major}.{minor}" ) ,
221
245
Err ( _) => String :: from ( "PTX compiler" ) ,
222
246
} ;
223
247
224
- // TODO: allow user to select
225
- // - warn on double
226
- // - warn on float
227
- // - warn on spills
228
- // - verbose warn
229
- // - warnings as errors
230
- // - show PTX source if warning or error
231
-
232
248
let mut errors = String :: new ( ) ;
249
+
233
250
if let Err ( err) = drop {
234
251
let _ = errors. write_fmt ( format_args ! ( "Error dropping the {ptx_compiler}: {err}\n " ) ) ;
235
252
}
253
+
236
254
if let Err ( err) = version {
237
255
let _ = errors. write_fmt ( format_args ! (
238
256
"Error fetching the version of the {ptx_compiler}: {err}\n "
239
257
) ) ;
240
258
}
241
- if let ( Ok ( Some ( _) ) , _) | ( _, Ok ( Some ( _) ) ) = ( & info_log, & error_log) {
259
+
260
+ let ptx_source_code = {
242
261
let mut max_lines = kernel_ptx. chars ( ) . filter ( |c| * c == '\n' ) . count ( ) + 1 ;
243
262
let mut indent = 0 ;
244
263
while max_lines > 0 {
245
264
max_lines /= 10 ;
246
265
indent += 1 ;
247
266
}
248
267
249
- emit_call_site_warning ! (
268
+ format ! (
250
269
"PTX source code:\n {}" ,
251
270
kernel_ptx
252
271
. lines( )
253
272
. enumerate( )
254
273
. map( |( i, l) | format!( "{:indent$}| {l}" , i + 1 ) )
255
274
. collect:: <Vec <_>>( )
256
275
. join( "\n " )
257
- ) ;
276
+ )
277
+ } ;
278
+
279
+ match binary {
280
+ Ok ( None ) => ( ) ,
281
+ Ok ( Some ( binary) ) => {
282
+ if ptx_lint_levels
283
+ . get ( & PtxLint :: DumpBinary )
284
+ . map_or ( false , |level| * level > LintLevel :: Allow )
285
+ {
286
+ const HEX : [ char ; 16 ] = [
287
+ '0' , '1' , '2' , '3' , '4' , '5' , '6' , '7' , '8' , '9' , 'a' , 'b' , 'c' , 'd' , 'e' , 'f' ,
288
+ ] ;
289
+
290
+ let mut binary_hex = String :: with_capacity ( binary. len ( ) * 2 ) ;
291
+ for byte in binary {
292
+ binary_hex. push ( HEX [ usize:: from ( byte >> 4 ) ] ) ;
293
+ binary_hex. push ( HEX [ usize:: from ( byte & 0x0F ) ] ) ;
294
+ }
295
+
296
+ if ptx_lint_levels
297
+ . get ( & PtxLint :: DumpBinary )
298
+ . map_or ( false , |level| * level > LintLevel :: Warn )
299
+ {
300
+ emit_call_site_error ! (
301
+ "{} compiled binary:\n {}\n \n {}" ,
302
+ ptx_compiler,
303
+ binary_hex,
304
+ ptx_source_code
305
+ ) ;
306
+ } else {
307
+ emit_call_site_warning ! (
308
+ "{} compiled binary:\n {}\n \n {}" ,
309
+ ptx_compiler,
310
+ binary_hex,
311
+ ptx_source_code
312
+ ) ;
313
+ }
314
+ }
315
+ } ,
316
+ Err ( err) => {
317
+ let _ = errors. write_fmt ( format_args ! (
318
+ "Error fetching the compiled binary from {ptx_compiler}: {err}\n "
319
+ ) ) ;
320
+ } ,
258
321
}
322
+
259
323
match info_log {
260
324
Ok ( None ) => ( ) ,
261
- Ok ( Some ( info_log) ) => emit_call_site_warning ! ( "{ptx_compiler} info log:\n {}" , info_log) ,
325
+ Ok ( Some ( info_log) ) => emit_call_site_warning ! (
326
+ "{} info log:\n {}\n {}" ,
327
+ ptx_compiler,
328
+ info_log,
329
+ ptx_source_code
330
+ ) ,
262
331
Err ( err) => {
263
332
let _ = errors. write_fmt ( format_args ! (
264
333
"Error fetching the info log of the {ptx_compiler}: {err}\n "
265
334
) ) ;
266
335
} ,
267
336
} ;
268
- match error_log {
269
- Ok ( None ) => ( ) ,
270
- Ok ( Some ( error_log) ) => emit_call_site_error ! ( "{ptx_compiler} error log:\n {}" , error_log) ,
337
+
338
+ let error_log = match error_log {
339
+ Ok ( None ) => String :: new ( ) ,
340
+ Ok ( Some ( error_log) ) => {
341
+ format ! ( "{ptx_compiler} error log:\n {error_log}\n {ptx_source_code}" )
342
+ } ,
271
343
Err ( err) => {
272
344
let _ = errors. write_fmt ( format_args ! (
273
345
"Error fetching the error log of the {ptx_compiler}: {err}\n "
274
346
) ) ;
347
+ String :: new ( )
275
348
} ,
276
349
} ;
350
+
277
351
if let Err ( err) = result {
278
352
let _ = errors. write_fmt ( format_args ! ( "Error compiling the PTX source code: {err}\n " ) ) ;
279
353
}
280
- if !errors. is_empty ( ) {
281
- abort_call_site ! ( "{}" , errors) ;
282
- }
283
354
284
- ( quote ! { const PTX_STR : & ' static str = #kernel_ptx; #( #type_layouts) * } ) . into ( )
355
+ if !error_log. is_empty ( ) || !errors. is_empty ( ) {
356
+ abort_call_site ! (
357
+ "{error_log}{}{errors}" ,
358
+ if !error_log. is_empty( ) && !errors. is_empty( ) {
359
+ "\n \n "
360
+ } else {
361
+ ""
362
+ }
363
+ ) ;
364
+ }
285
365
}
286
366
287
367
#[ allow( clippy:: type_complexity) ]
288
368
#[ allow( clippy:: too_many_lines) ]
289
369
fn check_kernel_ptx (
290
370
kernel_ptx : & str ,
291
- specialisation : & str ,
371
+ specialisation : Specialisation ,
292
372
kernel_hash : & proc_macro2:: Ident ,
293
373
ptx_lint_levels : & HashMap < PtxLint , LintLevel > ,
294
374
) -> (
295
375
Result < ( ) , NvptxError > ,
296
376
Result < Option < String > , NvptxError > ,
297
377
Result < Option < String > , NvptxError > ,
378
+ Result < Option < Vec < u8 > > , NvptxError > ,
298
379
Result < ( u32 , u32 ) , NvptxError > ,
299
380
Result < ( ) , NvptxError > ,
300
381
) {
@@ -313,14 +394,15 @@ fn check_kernel_ptx(
313
394
} ;
314
395
315
396
let result = ( || {
316
- let kernel_name = if specialisation. is_empty ( ) {
317
- format ! ( "{kernel_hash}_kernel" )
318
- } else {
319
- format ! (
397
+ let kernel_name = match specialisation {
398
+ Specialisation :: Check => format ! ( "{kernel_hash}_chECK" ) ,
399
+ Specialisation :: Link ( "" ) => format ! ( "{kernel_hash}_kernel" ) ,
400
+ Specialisation :: Link ( specialisation ) => format ! (
320
401
"{kernel_hash}_kernel_{:016x}" ,
321
402
seahash:: hash( specialisation. as_bytes( ) )
322
- )
403
+ ) ,
323
404
} ;
405
+
324
406
let mut options = vec ! [
325
407
CString :: new( "--entry" ) . unwrap( ) ,
326
408
CString :: new( kernel_name) . unwrap( ) ,
@@ -457,6 +539,39 @@ fn check_kernel_ptx(
457
539
Ok ( Some ( String :: from_utf8_lossy ( & info_log) . into_owned ( ) ) )
458
540
} ) ( ) ;
459
541
542
+ let binary = ( || {
543
+ if result. is_err ( ) {
544
+ return Ok ( None ) ;
545
+ }
546
+
547
+ let mut binary_size = 0 ;
548
+
549
+ NvptxError :: try_err_from ( unsafe {
550
+ ptx_compiler_sys:: nvPTXCompilerGetCompiledProgramSize (
551
+ compiler,
552
+ addr_of_mut ! ( binary_size) ,
553
+ )
554
+ } ) ?;
555
+
556
+ if binary_size == 0 {
557
+ return Ok ( None ) ;
558
+ }
559
+
560
+ #[ allow( clippy:: cast_possible_truncation) ]
561
+ let mut binary: Vec < u8 > = Vec :: with_capacity ( binary_size as usize ) ;
562
+
563
+ NvptxError :: try_err_from ( unsafe {
564
+ ptx_compiler_sys:: nvPTXCompilerGetCompiledProgram ( compiler, binary. as_mut_ptr ( ) . cast ( ) )
565
+ } ) ?;
566
+
567
+ #[ allow( clippy:: cast_possible_truncation) ]
568
+ unsafe {
569
+ binary. set_len ( binary_size as usize ) ;
570
+ }
571
+
572
+ Ok ( Some ( binary) )
573
+ } ) ( ) ;
574
+
460
575
let version = ( || {
461
576
let mut major = 0 ;
462
577
let mut minor = 0 ;
@@ -475,7 +590,7 @@ fn check_kernel_ptx(
475
590
} )
476
591
} ;
477
592
478
- ( result, error_log, info_log, version, drop)
593
+ ( result, error_log, info_log, binary , version, drop)
479
594
}
480
595
481
596
fn compile_kernel (
0 commit comments