Skip to content

Use bit ops instead of integer modulo and divide in shaders #19994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ fn downsample_depth_first(
@builtin(workgroup_id) workgroup_id: vec3u,
@builtin(local_invocation_index) local_invocation_index: u32,
) {
let sub_xy = remap_for_wave_reduction(local_invocation_index % 64u);
let x = sub_xy.x + 8u * ((local_invocation_index >> 6u) % 2u);
let sub_xy = remap_for_wave_reduction(local_invocation_index & 63u);
let x = sub_xy.x + 8u * ((local_invocation_index >> 6u) & 1u);
let y = sub_xy.y + 8u * (local_invocation_index >> 7u);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 8u is << 3u. Or does naga deal with that properly?


downsample_mips_0_and_1(x, y, workgroup_id.xy, local_invocation_index);
Expand All @@ -54,8 +54,8 @@ fn downsample_depth_first(
@compute
@workgroup_size(256, 1, 1)
fn downsample_depth_second(@builtin(local_invocation_index) local_invocation_index: u32) {
let sub_xy = remap_for_wave_reduction(local_invocation_index % 64u);
let x = sub_xy.x + 8u * ((local_invocation_index >> 6u) % 2u);
let sub_xy = remap_for_wave_reduction(local_invocation_index & 63u);
let x = sub_xy.x + 8u * ((local_invocation_index >> 6u) & 1u);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 8u is << 3u

let y = sub_xy.y + 8u * (local_invocation_index >> 7u);

downsample_mips_6_and_7(x, y);
Expand Down Expand Up @@ -99,8 +99,8 @@ fn downsample_mips_0_and_1(x: u32, y: u32, workgroup_id: vec2u, local_invocation
intermediate_memory[x * 2u + 1u][y * 2u + 1u],
));
pix = (workgroup_id * 16u) + vec2(
x + (i % 2u) * 8u,
y + (i / 2u) * 8u,
x + (i & 1u) * 8u,
y + (i >> 1u) * 8u,
);
textureStore(mip_2, pix, vec4(v[i]));
}
Expand Down Expand Up @@ -142,7 +142,7 @@ fn downsample_mip_2(x: u32, y: u32, workgroup_id: vec2u, local_invocation_index:
intermediate_memory[x * 2u + 1u][y * 2u + 1u],
));
textureStore(mip_3, (workgroup_id * 8u) + vec2(x, y), vec4(v));
intermediate_memory[x * 2u + y % 2u][y * 2u] = v;
intermediate_memory[x * 2u + (y & 1u)][y * 2u] = v;
}
}

Expand Down Expand Up @@ -241,7 +241,7 @@ fn downsample_mip_8(x: u32, y: u32, local_invocation_index: u32) {
intermediate_memory[x * 2u + 1u][y * 2u + 1u],
));
textureStore(mip_9, vec2(x, y), vec4(v));
intermediate_memory[x * 2u + y % 2u][y * 2u] = v;
intermediate_memory[x * 2u + (y & 1u)][y * 2u] = v;
}
}

Expand Down
12 changes: 6 additions & 6 deletions crates/bevy_pbr/src/meshlet/meshlet_bindings.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ var<push_constant> constants: Constants;

// TODO: Load only twice, instead of 3x in cases where you load 3 indices per thread?
fn get_meshlet_vertex_id(index_id: u32) -> u32 {
let packed_index = meshlet_indices[index_id / 4u];
let bit_offset = (index_id % 4u) * 8u;
let packed_index = meshlet_indices[index_id >> 2u];
let bit_offset = (index_id & 3u) * 8u;
return extractBits(packed_index, bit_offset, 8u);
}

Expand All @@ -219,7 +219,7 @@ fn get_meshlet_vertex_position(meshlet: ptr<function, Meshlet>, vertex_id: u32)
// Read each vertex channel from the bitstream
var vertex_position_packed = vec3(0u);
for (var i = 0u; i < 3u; i++) {
let lower_word_index = start_bit / 32u;
let lower_word_index = start_bit >> 5u;
let lower_word_bit_offset = start_bit & 31u;
var next_32_bits = meshlet_vertex_positions[lower_word_index] >> lower_word_bit_offset;
if lower_word_bit_offset + bits_per_channel[i] > 32u {
Expand Down Expand Up @@ -256,8 +256,8 @@ fn get_meshlet_vertex_position(meshlet: ptr<function, Meshlet>, vertex_id: u32)

// TODO: Load only twice, instead of 3x in cases where you load 3 indices per thread?
fn get_meshlet_vertex_id(index_id: u32) -> u32 {
let packed_index = meshlet_indices[index_id / 4u];
let bit_offset = (index_id % 4u) * 8u;
let packed_index = meshlet_indices[index_id >> 2u];
let bit_offset = (index_id & 3u) * 8u;
return extractBits(packed_index, bit_offset, 8u);
}

Expand All @@ -271,7 +271,7 @@ fn get_meshlet_vertex_position(meshlet: ptr<function, Meshlet>, vertex_id: u32)
// Read each vertex channel from the bitstream
var vertex_position_packed = vec3(0u);
for (var i = 0u; i < 3u; i++) {
let lower_word_index = start_bit / 32u;
let lower_word_index = start_bit >> 5u;
let lower_word_bit_offset = start_bit & 31u;
var next_32_bits = meshlet_vertex_positions[lower_word_index] >> lower_word_bit_offset;
if lower_word_bit_offset + bits_per_channel[i] > 32u {
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_pbr/src/meshlet/meshlet_mesh_material.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

@vertex
fn vertex(@builtin(vertex_index) vertex_input: u32) -> @builtin(position) vec4<f32> {
let vertex_index = vertex_input % 3u;
let material_id = vertex_input / 3u;
let vertex_index = vertex_input - material_id * 3u;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be worse. I thought compilers can see consecutive % and / and combine the instructions into one thing(?)

Copy link
Contributor Author

@atlv24 atlv24 Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not when its split into a function call that looks like

int naga_mod(int lhs, int rhs) {
    int divisor = ((lhs == int(-2147483647 - 1) & rhs == -1) | (rhs == 0)) ? 1 : rhs;
    return lhs - (lhs / divisor) * divisor;
}
// ...
let vertex_index = naga_mod(vertex_input, 3u);
let material_id = vertex_input / 3u;

let material_depth = f32(material_id) / 65535.0;
let uv = vec2<f32>(vec2(vertex_index >> 1u, vertex_index & 1u)) * 2.0;
return vec4(uv_to_ndc(uv), material_depth, 1.0);
Expand Down
6 changes: 3 additions & 3 deletions crates/bevy_pbr/src/render/clustered_forward.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ fn get_clusterable_object_id(index: u32) -> u32 {
// The index is correct but in clusterable_object_index_lists we pack 4 u8s into a u32
// This means the index into clusterable_object_index_lists is index / 4
let indices = bindings::clusterable_object_index_lists.data[index >> 4u][(index >> 2u) &
((1u << 2u) - 1u)];
3u];
// And index % 4 gives the sub-index of the u8 within the u32 so we shift by 8 * sub-index
return (indices >> (8u * (index & ((1u << 2u) - 1u)))) & ((1u << 8u) - 1u);
return (indices >> (8u * (index & 3u))) & 255u;
#endif
}

Expand All @@ -151,7 +151,7 @@ fn cluster_debug_visualization(
var z_slice: u32 = view_z_to_z_slice(view_z, is_orthographic);
// A hack to make the colors alternate a bit more
if (z_slice & 1u) == 1u {
z_slice = z_slice + bindings::lights.cluster_dimensions.z / 2u;
z_slice = z_slice + bindings::lights.cluster_dimensions.z >> 1u;
}
let slice_color_hsv = vec3(
f32(z_slice) / f32(bindings::lights.cluster_dimensions.z + 1u) * PI_2,
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_pbr/src/render/mesh_preprocess.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ fn main(@builtin(global_invocation_id) global_invocation_id: vec3<u32>) {
// this frame.
let output_work_item_index = atomicAdd(&late_preprocess_work_item_indirect_parameters[
push_constants.late_preprocess_work_item_indirect_offset].work_item_count, 1u);
if (output_work_item_index % 64u == 0u) {
if ((output_work_item_index & 63u) == 0u) {
// Our workgroup size is 64, and the indirect parameters for the
// late mesh preprocessing phase are counted in workgroups, so if
// we're the first thread in this workgroup, bump the workgroup
Expand Down
4 changes: 2 additions & 2 deletions crates/bevy_pbr/src/render/morph.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ fn component_texture_coord(vertex_index: u32, component_offset: u32) -> vec2<u32
}
fn weight_at(weight_index: u32) -> f32 {
let i = weight_index;
return morph_weights.weights[i / 4u][i % 4u];
return morph_weights.weights[i >> 2u][i & 3u];
}
fn prev_weight_at(weight_index: u32) -> f32 {
let i = weight_index;
return prev_morph_weights.weights[i / 4u][i % 4u];
return prev_morph_weights.weights[i >> 2u][i & 3u];
}
fn morph_pixel(vertex: u32, component: u32, weight: u32) -> f32 {
let coord = component_texture_coord(vertex, component);
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_pbr/src/render/pbr_functions.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ fn visibility_range_dither(frag_coord: vec4<f32>, dither: i32) {
}

// Otherwise, check the dither pattern.
let coords = vec2<u32>(floor(frag_coord.xy)) % 4u;
let coords = vec2<u32>(floor(frag_coord.xy)) & 3u;
let threshold = i32((DITHER_THRESHOLD_MAP[coords.y] >> (coords.x * 8)) & 0xff);
if ((dither >= 0 && dither + threshold >= 16) || (dither < 0 && 1 + dither + threshold <= 0)) {
discard;
Expand Down
4 changes: 2 additions & 2 deletions crates/bevy_pbr/src/render/pbr_transmission.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ fn fetch_transmissive_background(offset_position: vec2<f32>, frag_coord: vec3<f3
let pixel_checkboard = (
#ifdef TEMPORAL_JITTER
// 0 or 1 on even/odd pixels, alternates every frame
(i32(frag_coord.x) + i32(frag_coord.y) + i32(view_bindings::globals.frame_count)) % 2
(i32(frag_coord.x) + i32(frag_coord.y) + i32(view_bindings::globals.frame_count)) & 1
#else
// 0 or 1 on even/odd pixels
(i32(frag_coord.x) + i32(frag_coord.y)) % 2
(i32(frag_coord.x) + i32(frag_coord.y)) & 1
#endif
);

Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_pbr/src/render/utils.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ fn octahedral_decode_signed(v: vec2<f32>) -> vec3<f32> {

// https://blog.demofox.org/2022/01/01/interleaved-gradient-noise-a-different-kind-of-low-discrepancy-sequence
fn interleaved_gradient_noise(pixel_coordinates: vec2<f32>, frame: u32) -> f32 {
let xy = pixel_coordinates + 5.588238 * f32(frame % 64u);
let xy = pixel_coordinates + 5.588238 * f32(frame & 63u);
return fract(52.9829189 * fract(0.06711056 * xy.x + 0.00583715 * xy.y));
}

Expand Down
12 changes: 6 additions & 6 deletions crates/bevy_pbr/src/ssao/preprocess_depth.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -65,38 +65,38 @@ fn preprocess_depth(@builtin(global_invocation_id) global_id: vec3<u32>, @builti
workgroupBarrier();

// MIP 2 - Weighted average of MIP 1's depth values (per invocation, 4x4 invocations per workgroup)
if all(local_id.xy % vec2<u32>(2u) == vec2<u32>(0u)) {
if all((local_id.xy & vec2<u32>(1u)) == vec2<u32>(0u)) {
let depth0 = previous_mip_depth[local_id.x + 0u][local_id.y + 0u];
let depth1 = previous_mip_depth[local_id.x + 1u][local_id.y + 0u];
let depth2 = previous_mip_depth[local_id.x + 0u][local_id.y + 1u];
let depth3 = previous_mip_depth[local_id.x + 1u][local_id.y + 1u];
let depth_mip2 = weighted_average(depth0, depth1, depth2, depth3);
textureStore(preprocessed_depth_mip2, base_coordinates / 2i, vec4<f32>(depth_mip2, 0.0, 0.0, 0.0));
textureStore(preprocessed_depth_mip2, base_coordinates >> vec2<u32>(1u), vec4<f32>(depth_mip2, 0.0, 0.0, 0.0));
previous_mip_depth[local_id.x][local_id.y] = depth_mip2;
}

workgroupBarrier();

// MIP 3 - Weighted average of MIP 2's depth values (per invocation, 2x2 invocations per workgroup)
if all(local_id.xy % vec2<u32>(4u) == vec2<u32>(0u)) {
if all((local_id.xy & vec2<u32>(3u)) == vec2<u32>(0u)) {
let depth0 = previous_mip_depth[local_id.x + 0u][local_id.y + 0u];
let depth1 = previous_mip_depth[local_id.x + 2u][local_id.y + 0u];
let depth2 = previous_mip_depth[local_id.x + 0u][local_id.y + 2u];
let depth3 = previous_mip_depth[local_id.x + 2u][local_id.y + 2u];
let depth_mip3 = weighted_average(depth0, depth1, depth2, depth3);
textureStore(preprocessed_depth_mip3, base_coordinates / 4i, vec4<f32>(depth_mip3, 0.0, 0.0, 0.0));
textureStore(preprocessed_depth_mip3, base_coordinates >> vec2<u32>(2u), vec4<f32>(depth_mip3, 0.0, 0.0, 0.0));
previous_mip_depth[local_id.x][local_id.y] = depth_mip3;
}

workgroupBarrier();

// MIP 4 - Weighted average of MIP 3's depth values (per invocation, 1 invocation per workgroup)
if all(local_id.xy % vec2<u32>(8u) == vec2<u32>(0u)) {
if all((local_id.xy & vec2<u32>(7u)) == vec2<u32>(0u)) {
let depth0 = previous_mip_depth[local_id.x + 0u][local_id.y + 0u];
let depth1 = previous_mip_depth[local_id.x + 4u][local_id.y + 0u];
let depth2 = previous_mip_depth[local_id.x + 0u][local_id.y + 4u];
let depth3 = previous_mip_depth[local_id.x + 4u][local_id.y + 4u];
let depth_mip4 = weighted_average(depth0, depth1, depth2, depth3);
textureStore(preprocessed_depth_mip4, base_coordinates / 8i, vec4<f32>(depth_mip4, 0.0, 0.0, 0.0));
textureStore(preprocessed_depth_mip4, base_coordinates >> vec2<u32>(3u), vec4<f32>(depth_mip4, 0.0, 0.0, 0.0));
}
}
4 changes: 2 additions & 2 deletions crates/bevy_pbr/src/ssao/ssao.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
@group(1) @binding(2) var<uniform> view: View;

fn load_noise(pixel_coordinates: vec2<i32>) -> vec2<f32> {
var index = textureLoad(hilbert_index_lut, pixel_coordinates % 64, 0).r;
var index = textureLoad(hilbert_index_lut, pixel_coordinates & vec2<i32>(63), 0).r;

#ifdef TEMPORAL_JITTER
index += 288u * (globals.frame_count % 64u);
index += 288u * (globals.frame_count & 63u);
#endif

// R2 sequence - http://extremelearning.com.au/unreasonable-effectiveness-of-quasirandom-sequences
Expand Down
11 changes: 9 additions & 2 deletions crates/bevy_pbr/src/volumetric_fog/volumetric_fog.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ fn henyey_greenstein(neg_LdotV: f32) -> f32 {
return FRAC_4_PI * (1.0 - g * g) / (denom * sqrt(denom));
}

fn simple_wrap_3(index: i32) -> i32 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can index be >=6? If so then this is wrong, if not then I think this function should be named/commented to indicate it is special purpose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its only used in this file, and its only called with numbers in range 0-5. I called it simple_wrap_3, not implying its modulo, because it is not an implementation of modulo, just something that works for this specific case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment saying that it only works for values 0 to 5, which is fine for its use in this file?

if (index >= 3) {
return index - 3;
}
return index;
}

@fragment
fn fragment(@builtin(position) position: vec4<f32>) -> @location(0) vec4<f32> {
// Unpack the `volumetric_fog` settings.
Expand Down Expand Up @@ -140,8 +147,8 @@ fn fragment(@builtin(position) position: vec4<f32>) -> @location(0) vec4<f32> {
var end_depth_view = 0.0;
for (var plane_index = 0; plane_index < 3; plane_index += 1) {
let plane = volumetric_fog.far_planes[plane_index];
let other_plane_a = volumetric_fog.far_planes[(plane_index + 1) % 3];
let other_plane_b = volumetric_fog.far_planes[(plane_index + 2) % 3];
let other_plane_a = volumetric_fog.far_planes[simple_wrap_3(plane_index + 1)];
let other_plane_b = volumetric_fog.far_planes[simple_wrap_3(plane_index + 2)];

// Calculate the intersection of the ray and the plane. The ray must
// intersect in front of us (t > 0).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ fn vertex(vertex: Vertex) -> VertexOutput {

out.position = mesh_functions::mesh2d_position_world_to_clip(world_position);
out.uv = vertex.uv;
out.tile_index = vertex.vertex_index / 4u;
out.tile_index = vertex.vertex_index >> 2u;

return out;
}
Expand Down