@@ -26,7 +26,7 @@ use crate::{
26
26
} ,
27
27
dx12:: {
28
28
borrow_optional_interface_temporarily, shader_compilation, suballocation,
29
- DynamicStorageBufferOffsets , Event ,
29
+ DynamicStorageBufferOffsets , Event , ShaderCacheKey , ShaderCacheValue ,
30
30
} ,
31
31
AccelerationStructureEntries , TlasInstance ,
32
32
} ;
@@ -203,6 +203,7 @@ impl super::Device {
203
203
null_rtv_handle,
204
204
mem_allocator,
205
205
compiler_container,
206
+ shader_cache : Default :: default ( ) ,
206
207
counters : Default :: default ( ) ,
207
208
} )
208
209
}
@@ -304,50 +305,85 @@ impl super::Device {
304
305
} ;
305
306
306
307
//TODO: reuse the writer
307
- let mut source = String :: new ( ) ;
308
- let mut writer = hlsl:: Writer :: new ( & mut source, naga_options, & pipeline_options) ;
309
- let reflection_info = {
308
+ let ( source, entry_point) = {
309
+ let mut source = String :: new ( ) ;
310
+ let mut writer = hlsl:: Writer :: new ( & mut source, naga_options, & pipeline_options) ;
311
+
310
312
profiling:: scope!( "naga::back::hlsl::write" ) ;
311
- writer
313
+ let mut reflection_info = writer
312
314
. write ( & module, & info, frag_ep. as_ref ( ) )
313
- . map_err ( |e| crate :: PipelineError :: Linkage ( stage_bit, format ! ( "HLSL: {e:?}" ) ) ) ?
315
+ . map_err ( |e| crate :: PipelineError :: Linkage ( stage_bit, format ! ( "HLSL: {e:?}" ) ) ) ?;
316
+
317
+ assert_eq ! ( reflection_info. entry_point_names. len( ) , 1 ) ;
318
+
319
+ let entry_point = reflection_info
320
+ . entry_point_names
321
+ . pop ( )
322
+ . unwrap ( )
323
+ . map_err ( |e| crate :: PipelineError :: Linkage ( stage_bit, format ! ( "{e}" ) ) ) ?;
324
+
325
+ ( source, entry_point)
314
326
} ;
315
327
328
+ log:: info!(
329
+ "Naga generated shader for {:?} at {:?}:\n {}" ,
330
+ entry_point,
331
+ naga_stage,
332
+ source
333
+ ) ;
334
+
335
+ let key = ShaderCacheKey {
336
+ source,
337
+ entry_point,
338
+ stage : naga_stage,
339
+ shader_model : naga_options. shader_model ,
340
+ } ;
341
+
342
+ {
343
+ let mut shader_cache = self . shader_cache . lock ( ) ;
344
+ let nr_of_shaders_compiled = shader_cache. nr_of_shaders_compiled ;
345
+ if let Some ( value) = shader_cache. entries . get_mut ( & key) {
346
+ value. last_used = nr_of_shaders_compiled;
347
+ return Ok ( value. shader . clone ( ) ) ;
348
+ }
349
+ }
350
+
351
+ let source_name = stage. module . raw_name . as_deref ( ) ;
352
+
316
353
let full_stage = format ! (
317
354
"{}_{}" ,
318
355
naga_stage. to_hlsl_str( ) ,
319
356
naga_options. shader_model. to_str( )
320
357
) ;
321
358
322
- let raw_ep = reflection_info. entry_point_names [ 0 ]
323
- . as_ref ( )
324
- . map_err ( |e| crate :: PipelineError :: Linkage ( stage_bit, format ! ( "{e}" ) ) ) ?;
325
-
326
- let source_name = stage. module . raw_name . as_deref ( ) ;
327
-
328
- let result = self . compiler_container . compile (
359
+ let compiled_shader = self . compiler_container . compile (
329
360
self ,
330
- & source,
361
+ & key . source ,
331
362
source_name,
332
- raw_ep ,
363
+ & key . entry_point ,
333
364
stage_bit,
334
365
& full_stage,
335
- ) ;
366
+ ) ? ;
336
367
337
- let log_level = if result. is_ok ( ) {
338
- log:: Level :: Info
339
- } else {
340
- log:: Level :: Error
341
- } ;
368
+ {
369
+ let mut shader_cache = self . shader_cache . lock ( ) ;
370
+ shader_cache. nr_of_shaders_compiled += 1 ;
371
+ let nr_of_shaders_compiled = shader_cache. nr_of_shaders_compiled ;
372
+ let value = ShaderCacheValue {
373
+ last_used : nr_of_shaders_compiled,
374
+ shader : compiled_shader. clone ( ) ,
375
+ } ;
376
+ shader_cache. entries . insert ( key, value) ;
342
377
343
- log:: log!(
344
- log_level,
345
- "Naga generated shader for {:?} at {:?}:\n {}" ,
346
- raw_ep,
347
- naga_stage,
348
- source
349
- ) ;
350
- result
378
+ // Retain all entries that have been used since we compiled the last 100 shaders.
379
+ if shader_cache. entries . len ( ) > 200 {
380
+ shader_cache
381
+ . entries
382
+ . retain ( |_, v| v. last_used >= nr_of_shaders_compiled - 100 ) ;
383
+ }
384
+ }
385
+
386
+ Ok ( compiled_shader)
351
387
}
352
388
353
389
pub fn raw_device ( & self ) -> & Direct3D12 :: ID3D12Device {
@@ -1818,11 +1854,6 @@ impl crate::Device for super::Device {
1818
1854
}
1819
1855
. map_err ( |err| crate :: PipelineError :: Linkage ( shader_stages, err. to_string ( ) ) ) ?;
1820
1856
1821
- unsafe { blob_vs. destroy ( ) } ;
1822
- if let Some ( blob_fs) = blob_fs {
1823
- unsafe { blob_fs. destroy ( ) } ;
1824
- } ;
1825
-
1826
1857
if let Some ( label) = desc. label {
1827
1858
raw. set_name ( label) ?;
1828
1859
}
@@ -1880,8 +1911,6 @@ impl crate::Device for super::Device {
1880
1911
}
1881
1912
} ;
1882
1913
1883
- unsafe { blob_cs. destroy ( ) } ;
1884
-
1885
1914
let raw: Direct3D12 :: ID3D12PipelineState = pair. map_err ( |err| {
1886
1915
crate :: PipelineError :: Linkage ( wgt:: ShaderStages :: COMPUTE , err. to_string ( ) )
1887
1916
} ) ?;
0 commit comments