Skip to content

Commit cb6dbb8

Browse files
authored
Vulkan Sampler Cache (#6847)
1 parent d291571 commit cb6dbb8

File tree

8 files changed

+225
-12
lines changed

8 files changed

+225
-12
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ By @wumpf in [#6849](https://github.com/gfx-rs/wgpu/pull/6849).
183183

184184
- Avoid using FXC as fallback when the DXC container was passed at instance creation. Paths to `dxcompiler.dll` & `dxil.dll` are also now required. By @teoxoy in [#6643](https://github.com/gfx-rs/wgpu/pull/6643).
185185

186+
##### Vulkan
187+
188+
- Add a cache for samplers, deduplicating any samplers, allowing more programs to stay within the global sampler limit. By @cwfitzgerald in [#6847](https://github.com/gfx-rs/wgpu/pull/6847)
189+
186190
##### HAL
187191

188192
- Replace `usage: Range<T>`, for `BufferUses`, `TextureUses`, and `AccelerationStructureBarrier` with a new `StateTransition<T>`. By @atlv24 in [#6703](https://github.com/gfx-rs/wgpu/pull/6703)

Cargo.lock

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ noise = { version = "0.8", git = "https://github.com/Razaekel/noise-rs.git", rev
109109
nv-flip = "0.1"
110110
obj = "0.10"
111111
once_cell = "1.20.2"
112+
# Firefox has 3.4.0 vendored, so we allow that version in our dependencies
113+
ordered-float = ">=3,<=4.6"
112114
parking_lot = "0.12.1"
113115
pico-args = { version = "0.5.0", features = [
114116
"eq-separator",

wgpu-hal/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ vulkan = [
5353
"dep:libloading",
5454
"dep:smallvec",
5555
"dep:android_system_properties",
56+
"dep:ordered-float",
5657
]
5758
gles = [
5859
"naga/glsl-out",
@@ -125,6 +126,7 @@ profiling = { workspace = true, default-features = false }
125126
raw-window-handle.workspace = true
126127
thiserror.workspace = true
127128
once_cell.workspace = true
129+
ordered-float = { workspace = true, optional = true }
128130

129131
# backends common
130132
arrayvec.workspace = true

wgpu-hal/src/vulkan/adapter.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,6 +1559,10 @@ impl super::Instance {
15591559
.is_some_and(|ext| ext.shader_zero_initialize_workgroup_memory == vk::TRUE),
15601560
image_format_list: phd_capabilities.device_api_version >= vk::API_VERSION_1_2
15611561
|| phd_capabilities.supports_extension(khr::image_format_list::NAME),
1562+
maximum_samplers: phd_capabilities
1563+
.properties
1564+
.limits
1565+
.max_sampler_allocation_count,
15621566
};
15631567
let capabilities = crate::Capabilities {
15641568
limits: phd_capabilities.to_wgpu_limits(),
@@ -1907,6 +1911,9 @@ impl super::Adapter {
19071911
workarounds: self.workarounds,
19081912
render_passes: Mutex::new(Default::default()),
19091913
framebuffers: Mutex::new(Default::default()),
1914+
sampler_cache: Mutex::new(super::sampler::SamplerCache::new(
1915+
self.private_caps.maximum_samplers,
1916+
)),
19101917
memory_allocations_counter: Default::default(),
19111918
});
19121919

wgpu-hal/src/vulkan/device.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,7 +1304,7 @@ impl crate::Device for super::Device {
13041304
&self,
13051305
desc: &crate::SamplerDescriptor,
13061306
) -> Result<super::Sampler, crate::DeviceError> {
1307-
let mut vk_info = vk::SamplerCreateInfo::default()
1307+
let mut create_info = vk::SamplerCreateInfo::default()
13081308
.flags(vk::SamplerCreateFlags::empty())
13091309
.mag_filter(conv::map_filter_mode(desc.mag_filter))
13101310
.min_filter(conv::map_filter_mode(desc.min_filter))
@@ -1316,40 +1316,46 @@ impl crate::Device for super::Device {
13161316
.max_lod(desc.lod_clamp.end);
13171317

13181318
if let Some(fun) = desc.compare {
1319-
vk_info = vk_info
1319+
create_info = create_info
13201320
.compare_enable(true)
13211321
.compare_op(conv::map_comparison(fun));
13221322
}
13231323

13241324
if desc.anisotropy_clamp != 1 {
13251325
// We only enable anisotropy if it is supported, and wgpu-hal interface guarantees
13261326
// the clamp is in the range [1, 16] which is always supported if anisotropy is.
1327-
vk_info = vk_info
1327+
create_info = create_info
13281328
.anisotropy_enable(true)
13291329
.max_anisotropy(desc.anisotropy_clamp as f32);
13301330
}
13311331

13321332
if let Some(color) = desc.border_color {
1333-
vk_info = vk_info.border_color(conv::map_border_color(color));
1333+
create_info = create_info.border_color(conv::map_border_color(color));
13341334
}
13351335

1336-
let raw = unsafe {
1337-
self.shared
1338-
.raw
1339-
.create_sampler(&vk_info, None)
1340-
.map_err(super::map_host_device_oom_and_ioca_err)?
1341-
};
1336+
let raw = self
1337+
.shared
1338+
.sampler_cache
1339+
.lock()
1340+
.create_sampler(&self.shared.raw, create_info)?;
13421341

1342+
// Note: Cached samplers will just continually overwrite the label
1343+
//
1344+
// https://github.com/gfx-rs/wgpu/issues/6867
13431345
if let Some(label) = desc.label {
13441346
unsafe { self.shared.set_object_name(raw, label) };
13451347
}
13461348

13471349
self.counters.samplers.add(1);
13481350

1349-
Ok(super::Sampler { raw })
1351+
Ok(super::Sampler { raw, create_info })
13501352
}
13511353
unsafe fn destroy_sampler(&self, sampler: super::Sampler) {
1352-
unsafe { self.shared.raw.destroy_sampler(sampler.raw, None) };
1354+
self.shared.sampler_cache.lock().destroy_sampler(
1355+
&self.shared.raw,
1356+
sampler.create_info,
1357+
sampler.raw,
1358+
);
13531359

13541360
self.counters.samplers.sub(1);
13551361
}

wgpu-hal/src/vulkan/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ mod command;
2929
mod conv;
3030
mod device;
3131
mod instance;
32+
mod sampler;
3233

3334
use std::{
3435
borrow::Borrow,
@@ -532,6 +533,7 @@ struct PrivateCapabilities {
532533
robust_image_access2: bool,
533534
zero_initialize_workgroup_memory: bool,
534535
image_format_list: bool,
536+
maximum_samplers: u32,
535537
}
536538

537539
bitflags::bitflags!(
@@ -641,6 +643,7 @@ struct DeviceShared {
641643
features: wgt::Features,
642644
render_passes: Mutex<rustc_hash::FxHashMap<RenderPassKey, vk::RenderPass>>,
643645
framebuffers: Mutex<rustc_hash::FxHashMap<FramebufferKey, vk::Framebuffer>>,
646+
sampler_cache: Mutex<sampler::SamplerCache>,
644647
memory_allocations_counter: InternalCounter,
645648
}
646649

@@ -828,6 +831,7 @@ impl TextureView {
828831
#[derive(Debug)]
829832
pub struct Sampler {
830833
raw: vk::Sampler,
834+
create_info: vk::SamplerCreateInfo<'static>,
831835
}
832836

833837
impl crate::DynSampler for Sampler {}

wgpu-hal/src/vulkan/sampler.rs

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
//! Sampler cache for Vulkan backend.
2+
//!
3+
//! Nearly identical to the DX12 sampler cache, without descriptor heap management.
4+
5+
use std::collections::{hash_map::Entry, HashMap};
6+
7+
use ash::vk;
8+
use ordered_float::OrderedFloat;
9+
10+
/// If the allowed sampler count is above this value, the sampler cache is disabled.
11+
const ENABLE_SAMPLER_CACHE_CUTOFF: u32 = 1 << 20;
12+
13+
/// [`vk::SamplerCreateInfo`] is not hashable, so we wrap it in a newtype that is.
14+
///
15+
/// We use [`OrderedFloat`] to allow for floating point values to be compared and
16+
/// hashed in a defined way.
17+
#[derive(Copy, Clone)]
18+
struct HashableSamplerCreateInfo(vk::SamplerCreateInfo<'static>);
19+
20+
impl PartialEq for HashableSamplerCreateInfo {
21+
fn eq(&self, other: &Self) -> bool {
22+
self.0.flags == other.0.flags
23+
&& self.0.mag_filter == other.0.mag_filter
24+
&& self.0.min_filter == other.0.min_filter
25+
&& self.0.mipmap_mode == other.0.mipmap_mode
26+
&& self.0.address_mode_u == other.0.address_mode_u
27+
&& self.0.address_mode_v == other.0.address_mode_v
28+
&& self.0.address_mode_w == other.0.address_mode_w
29+
&& OrderedFloat(self.0.mip_lod_bias) == OrderedFloat(other.0.mip_lod_bias)
30+
&& self.0.anisotropy_enable == other.0.anisotropy_enable
31+
&& OrderedFloat(self.0.max_anisotropy) == OrderedFloat(other.0.max_anisotropy)
32+
&& self.0.compare_enable == other.0.compare_enable
33+
&& self.0.compare_op == other.0.compare_op
34+
&& OrderedFloat(self.0.min_lod) == OrderedFloat(other.0.min_lod)
35+
&& OrderedFloat(self.0.max_lod) == OrderedFloat(other.0.max_lod)
36+
&& self.0.border_color == other.0.border_color
37+
&& self.0.unnormalized_coordinates == other.0.unnormalized_coordinates
38+
}
39+
}
40+
41+
impl Eq for HashableSamplerCreateInfo {}
42+
43+
impl std::hash::Hash for HashableSamplerCreateInfo {
44+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
45+
self.0.flags.hash(state);
46+
self.0.mag_filter.hash(state);
47+
self.0.min_filter.hash(state);
48+
self.0.mipmap_mode.hash(state);
49+
self.0.address_mode_u.hash(state);
50+
self.0.address_mode_v.hash(state);
51+
self.0.address_mode_w.hash(state);
52+
OrderedFloat(self.0.mip_lod_bias).hash(state);
53+
self.0.anisotropy_enable.hash(state);
54+
OrderedFloat(self.0.max_anisotropy).hash(state);
55+
self.0.compare_enable.hash(state);
56+
self.0.compare_op.hash(state);
57+
OrderedFloat(self.0.min_lod).hash(state);
58+
OrderedFloat(self.0.max_lod).hash(state);
59+
self.0.border_color.hash(state);
60+
self.0.unnormalized_coordinates.hash(state);
61+
}
62+
}
63+
64+
/// Entry in the sampler cache.
65+
struct CacheEntry {
66+
sampler: vk::Sampler,
67+
ref_count: u32,
68+
}
69+
70+
/// Global sampler cache.
71+
///
72+
/// As some devices have a low limit (4000) on the number of unique samplers that can be created,
73+
/// we need to cache samplers to avoid running out if people eagerly create duplicate samplers.
74+
pub(crate) struct SamplerCache {
75+
/// Mapping from the sampler description to sampler and reference count.
76+
samplers: HashMap<HashableSamplerCreateInfo, CacheEntry>,
77+
/// Maximum number of unique samplers that can be created.
78+
total_capacity: u32,
79+
/// If true, the sampler cache is disabled and all samplers are created on demand.
80+
passthrough: bool,
81+
}
82+
83+
impl SamplerCache {
84+
pub fn new(total_capacity: u32) -> Self {
85+
let passthrough = total_capacity >= ENABLE_SAMPLER_CACHE_CUTOFF;
86+
Self {
87+
samplers: HashMap::new(),
88+
total_capacity,
89+
passthrough,
90+
}
91+
}
92+
93+
/// Create a sampler, or return an existing one if it already exists.
94+
///
95+
/// If the sampler already exists, the reference count is incremented.
96+
///
97+
/// If the sampler does not exist, a new sampler is created and inserted into the cache.
98+
///
99+
/// If the cache is full, an error is returned.
100+
pub fn create_sampler(
101+
&mut self,
102+
device: &ash::Device,
103+
create_info: vk::SamplerCreateInfo<'static>,
104+
) -> Result<vk::Sampler, crate::DeviceError> {
105+
if self.passthrough {
106+
return unsafe { device.create_sampler(&create_info, None) }
107+
.map_err(super::map_host_device_oom_and_ioca_err);
108+
};
109+
110+
// Get the number of used samplers. Needs to be done before to appease the borrow checker.
111+
let used_samplers = self.samplers.len();
112+
113+
match self.samplers.entry(HashableSamplerCreateInfo(create_info)) {
114+
Entry::Occupied(occupied_entry) => {
115+
// We have found a match, so increment the refcount and return the index.
116+
let value = occupied_entry.into_mut();
117+
value.ref_count += 1;
118+
Ok(value.sampler)
119+
}
120+
Entry::Vacant(vacant_entry) => {
121+
// We need to create a new sampler.
122+
123+
// We need to check if we can create more samplers.
124+
if used_samplers >= self.total_capacity as usize {
125+
log::error!("There is no more room in the global sampler heap for more unique samplers. Your device supports a maximum of {} unique samplers.", self.samplers.len());
126+
return Err(crate::DeviceError::OutOfMemory);
127+
}
128+
129+
// Create the sampler.
130+
let sampler = unsafe { device.create_sampler(&create_info, None) }
131+
.map_err(super::map_host_device_oom_and_ioca_err)?;
132+
133+
// Insert the new sampler into the mapping.
134+
vacant_entry.insert(CacheEntry {
135+
sampler,
136+
ref_count: 1,
137+
});
138+
139+
Ok(sampler)
140+
}
141+
}
142+
}
143+
144+
/// Decrease the reference count of a sampler and destroy it if the reference count reaches 0.
145+
///
146+
/// The provided sampler is checked against the sampler in the cache to ensure there is no clerical error.
147+
pub fn destroy_sampler(
148+
&mut self,
149+
device: &ash::Device,
150+
create_info: vk::SamplerCreateInfo<'static>,
151+
provided_sampler: vk::Sampler,
152+
) {
153+
if self.passthrough {
154+
unsafe { device.destroy_sampler(provided_sampler, None) };
155+
return;
156+
};
157+
158+
let Entry::Occupied(mut hash_map_entry) =
159+
self.samplers.entry(HashableSamplerCreateInfo(create_info))
160+
else {
161+
log::error!("Trying to destroy a sampler that does not exist.");
162+
return;
163+
};
164+
let cache_entry = hash_map_entry.get_mut();
165+
166+
assert_eq!(
167+
cache_entry.sampler, provided_sampler,
168+
"Provided sampler does not match the sampler in the cache."
169+
);
170+
171+
cache_entry.ref_count -= 1;
172+
173+
if cache_entry.ref_count == 0 {
174+
unsafe { device.destroy_sampler(cache_entry.sampler, None) };
175+
hash_map_entry.remove();
176+
}
177+
}
178+
}

0 commit comments

Comments
 (0)