Skip to content

Improve spec_v2 patterns #20114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
122 changes: 13 additions & 109 deletions crates/bevy_render/macros/src/specializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ const SPECIALIZE_ALL_IDENT: &str = "all";
const KEY_ATTR_IDENT: &str = "key";
const KEY_DEFAULT_IDENT: &str = "default";

const BASE_DESCRIPTOR_ATTR_IDENT: &str = "base_descriptor";

enum SpecializeImplTargets {
All,
Specific(Vec<Path>),
Expand Down Expand Up @@ -87,7 +85,6 @@ struct FieldInfo {
ty: Type,
member: Member,
key: Key,
use_base_descriptor: bool,
}

impl FieldInfo {
Expand Down Expand Up @@ -117,15 +114,6 @@ impl FieldInfo {
parse_quote!(#ty: #specialize_path::Specializer<#target_path>)
}
}

fn get_base_descriptor_predicate(
&self,
specialize_path: &Path,
target_path: &Path,
) -> WherePredicate {
let ty = &self.ty;
parse_quote!(#ty: #specialize_path::GetBaseDescriptor<#target_path>)
}
}

fn get_field_info(
Expand All @@ -151,12 +139,8 @@ fn get_field_info(

let mut use_key_field = true;
let mut key = Key::Index(key_index);
let mut use_base_descriptor = false;
for attr in &field.attrs {
match &attr.meta {
Meta::Path(path) if path.is_ident(&BASE_DESCRIPTOR_ATTR_IDENT) => {
use_base_descriptor = true;
}
Meta::List(MetaList { path, tokens, .. }) if path.is_ident(&KEY_ATTR_IDENT) => {
let owned_tokens = tokens.clone().into();
let Ok(parsed_key) = syn::parse::<Key>(owned_tokens) else {
Expand Down Expand Up @@ -190,7 +174,6 @@ fn get_field_info(
ty: field_ty,
member: field_member,
key,
use_base_descriptor,
});
}

Expand Down Expand Up @@ -261,41 +244,18 @@ pub fn impl_specializer(input: TokenStream) -> TokenStream {
})
.collect();

let base_descriptor_fields = field_info
.iter()
.filter(|field| field.use_base_descriptor)
.collect::<Vec<_>>();

if base_descriptor_fields.len() > 1 {
return syn::Error::new(
Span::call_site(),
"Too many #[base_descriptor] attributes found. It must be present on exactly one field",
)
.into_compile_error()
.into();
}

let base_descriptor_field = base_descriptor_fields.first().copied();

match targets {
SpecializeImplTargets::All => {
let specialize_impl = impl_specialize_all(
&specialize_path,
&ecs_path,
&ast,
&field_info,
&key_patterns,
&key_tuple_idents,
);
let get_base_descriptor_impl = base_descriptor_field
.map(|field_info| impl_get_base_descriptor_all(&specialize_path, &ast, field_info))
.unwrap_or_default();
[specialize_impl, get_base_descriptor_impl]
.into_iter()
.collect()
}
SpecializeImplTargets::Specific(targets) => {
let specialize_impls = targets.iter().map(|target| {
SpecializeImplTargets::All => impl_specialize_all(
&specialize_path,
&ecs_path,
&ast,
&field_info,
&key_patterns,
&key_tuple_idents,
),
SpecializeImplTargets::Specific(targets) => targets
.iter()
.map(|target| {
impl_specialize_specific(
&specialize_path,
&ecs_path,
Expand All @@ -305,14 +265,8 @@ pub fn impl_specializer(input: TokenStream) -> TokenStream {
&key_patterns,
&key_tuple_idents,
)
});
let get_base_descriptor_impls = targets.iter().filter_map(|target| {
base_descriptor_field.map(|field_info| {
impl_get_base_descriptor_specific(&specialize_path, &ast, field_info, target)
})
});
specialize_impls.chain(get_base_descriptor_impls).collect()
}
})
.collect(),
}
}

Expand Down Expand Up @@ -406,56 +360,6 @@ fn impl_specialize_specific(
})
}

fn impl_get_base_descriptor_specific(
specialize_path: &Path,
ast: &DeriveInput,
base_descriptor_field_info: &FieldInfo,
target_path: &Path,
) -> TokenStream {
let struct_name = &ast.ident;
let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl();
let field_ty = &base_descriptor_field_info.ty;
let field_member = &base_descriptor_field_info.member;
TokenStream::from(quote!(
impl #impl_generics #specialize_path::GetBaseDescriptor<#target_path> for #struct_name #type_generics #where_clause {
fn get_base_descriptor(&self) -> <#target_path as #specialize_path::Specializable>::Descriptor {
<#field_ty as #specialize_path::GetBaseDescriptor<#target_path>>::get_base_descriptor(&self.#field_member)
}
}
))
}

fn impl_get_base_descriptor_all(
specialize_path: &Path,
ast: &DeriveInput,
base_descriptor_field_info: &FieldInfo,
) -> TokenStream {
let target_path = Path::from(format_ident!("T"));
let struct_name = &ast.ident;
let mut generics = ast.generics.clone();
generics.params.insert(
0,
parse_quote!(#target_path: #specialize_path::Specializable),
);

let where_clause = generics.make_where_clause();
where_clause.predicates.push(
base_descriptor_field_info.get_base_descriptor_predicate(specialize_path, &target_path),
);

let (_, type_generics, _) = ast.generics.split_for_impl();
let (impl_generics, _, where_clause) = &generics.split_for_impl();
let field_ty = &base_descriptor_field_info.ty;
let field_member = &base_descriptor_field_info.member;
TokenStream::from(quote! {
impl #impl_generics #specialize_path::GetBaseDescriptor<#target_path> for #struct_name #type_generics #where_clause {
fn get_base_descriptor(&self) -> <#target_path as #specialize_path::Specializable>::Descriptor {
<#field_ty as #specialize_path::GetBaseDescriptor<#target_path>>::get_base_descriptor(&self.#field_member)
}
}
})
}

pub fn impl_specializer_key(input: TokenStream) -> TokenStream {
let bevy_render_path: Path = crate::bevy_render_path();
let specialize_path = {
Expand Down
3 changes: 3 additions & 0 deletions crates/bevy_render/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ use render_asset::{
extract_render_asset_bytes_per_frame, reset_render_asset_bytes_per_frame,
RenderAssetBytesPerFrame, RenderAssetBytesPerFrameLimiter,
};
use render_resource::init_empty_bind_group_layout;
use renderer::{RenderAdapter, RenderDevice, RenderQueue};
use settings::RenderResources;
use sync_world::{
Expand Down Expand Up @@ -465,6 +466,8 @@ impl Plugin for RenderPlugin {
Render,
reset_render_asset_bytes_per_frame.in_set(RenderSystems::Cleanup),
);

render_app.add_systems(RenderStartup, init_empty_bind_group_layout);
}

app.register_type::<alpha::AlphaMode>()
Expand Down
20 changes: 19 additions & 1 deletion crates/bevy_render/src/render_resource/bind_group_layout.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::define_atomic_id;
use crate::{define_atomic_id, renderer::RenderDevice};
use bevy_ecs::system::Res;
use bevy_platform::sync::OnceLock;
use bevy_utils::WgpuWrapper;
use core::ops::Deref;

Expand Down Expand Up @@ -62,3 +64,19 @@ impl Deref for BindGroupLayout {
&self.value
}
}

static EMPTY_BIND_GROUP_LAYOUT: OnceLock<BindGroupLayout> = OnceLock::new();

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, so we now use an empty layout in a few places post wgpu 25, but I'm a little wary about this. I guess the set_at api is nicer than just doing a push like we do in some places, but I'm not sure I understand why it should ever be necessary to implicitly fill in the holes in a layout?

Copy link
Contributor Author

@ecoskey ecoskey Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why it should ever be necessary to implicitly fill in the holes in a layout?

Only with the assumption that it'll be properly filled in by other logic, like by composing specializers.

There's probably a better solution than filling_set_at, but since specializers can't really rely on their input having a certain shape (like what the current length of the vec is) both push and direct indexing cause problems without bounds checks or padding.

pub(crate) fn init_empty_bind_group_layout(render_device: Res<RenderDevice>) {
let layout = render_device.create_bind_group_layout(Some("empty_bind_group_layout"), &[]);
EMPTY_BIND_GROUP_LAYOUT
.set(layout)
.expect("init_empty_bind_group_layout was called more than once");
}

pub fn empty_bind_group_layout() -> BindGroupLayout {
EMPTY_BIND_GROUP_LAYOUT
.get()
.expect("init_empty_bind_group_layout was not called")
.clone()
}
32 changes: 31 additions & 1 deletion crates/bevy_render/src/render_resource/pipeline.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::ShaderDefVal;
use super::{empty_bind_group_layout, ShaderDefVal};
use crate::mesh::VertexBufferLayout;
use crate::{
define_atomic_id,
Expand All @@ -7,7 +7,9 @@ use crate::{
use alloc::borrow::Cow;
use bevy_asset::Handle;
use bevy_utils::WgpuWrapper;
use core::iter;
use core::ops::Deref;
use thiserror::Error;
use wgpu::{
ColorTargetState, DepthStencilState, MultisampleState, PrimitiveState, PushConstantRange,
};
Expand Down Expand Up @@ -112,6 +114,20 @@ pub struct RenderPipelineDescriptor {
pub zero_initialize_workgroup_memory: bool,
}

#[derive(Copy, Clone, Debug, Error)]
#[error("RenderPipelineDescriptor has no FragmentState configured")]
pub struct NoFragmentStateError;

impl RenderPipelineDescriptor {
pub fn fragment_mut(&mut self) -> Result<&mut FragmentState, NoFragmentStateError> {
self.fragment.as_mut().ok_or(NoFragmentStateError)
}

pub fn set_layout(&mut self, index: usize, layout: BindGroupLayout) {
filling_set_at(&mut self.layout, index, empty_bind_group_layout(), layout);
}
}

#[derive(Clone, Debug, Eq, PartialEq, Default)]
pub struct VertexState {
/// The compiled shader module for this stage.
Expand All @@ -137,6 +153,12 @@ pub struct FragmentState {
pub targets: Vec<Option<ColorTargetState>>,
}

impl FragmentState {
pub fn set_target(&mut self, index: usize, target: ColorTargetState) {
filling_set_at(&mut self.targets, index, None, Some(target));
}
}

/// Describes a compute pipeline.
#[derive(Clone, Debug, PartialEq, Eq, Default)]
pub struct ComputePipelineDescriptor {
Expand All @@ -153,3 +175,11 @@ pub struct ComputePipelineDescriptor {
/// If this is false, reading from workgroup variables before writing to them will result in garbage values.
pub zero_initialize_workgroup_memory: bool,
}

// utility function to set a value at the specified index, extending with
// a filler value if the index is out of bounds.
fn filling_set_at<T: Clone>(vec: &mut Vec<T>, index: usize, filler: T, value: T) {
let num_to_fill = (index + 1).saturating_sub(vec.len());
vec.extend(iter::repeat_n(filler, num_to_fill));
vec[index] = value;
}
Loading
Loading