Skip to content

Fix require components depth #19976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 8 additions & 40 deletions crates/bevy_ecs/macros/src/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,42 +236,20 @@ pub fn derive_component(input: TokenStream) -> TokenStream {
.push(parse_quote! { Self: Send + Sync + 'static });

let requires = &attrs.requires;
let mut register_required = Vec::with_capacity(attrs.requires.iter().len());
let mut register_recursive_requires = Vec::with_capacity(attrs.requires.iter().len());
let mut register_required = Vec::with_capacity(requires.as_ref().map_or(0, Punctuated::len));
if let Some(requires) = requires {
for require in requires {
let ident = &require.path;
register_recursive_requires.push(quote! {
<#ident as #bevy_ecs_path::component::Component>::register_required_components(
requiree,
components,
let constructor = match &require.func {
Some(func) => quote!(|| { let x: #ident = (#func)().into(); x }),
None => quote!(<#ident as Default>::default),
};
register_required.push(quote! {
components.register_required_components_manual::<Self, #ident>(
required_components,
inheritance_depth + 1,
recursion_check_stack
#constructor,
);
});
match &require.func {
Some(func) => {
register_required.push(quote! {
components.register_required_components_manual::<Self, #ident>(
required_components,
|| { let x: #ident = (#func)().into(); x },
inheritance_depth,
recursion_check_stack
);
});
}
None => {
register_required.push(quote! {
components.register_required_components_manual::<Self, #ident>(
required_components,
<#ident as Default>::default,
inheritance_depth,
recursion_check_stack
);
});
}
}
}
}
let struct_name = &ast.ident;
Expand Down Expand Up @@ -304,26 +282,16 @@ pub fn derive_component(input: TokenStream) -> TokenStream {
)
};

// This puts `register_required` before `register_recursive_requires` to ensure that the constructors of _all_ top
// level components are initialized first, giving them precedence over recursively defined constructors for the same component type
TokenStream::from(quote! {
#required_component_docs
impl #impl_generics #bevy_ecs_path::component::Component for #struct_name #type_generics #where_clause {
const STORAGE_TYPE: #bevy_ecs_path::component::StorageType = #storage;
type Mutability = #mutable_type;
fn register_required_components(
requiree: #bevy_ecs_path::component::ComponentId,
components: &mut #bevy_ecs_path::component::ComponentsRegistrator,
required_components: &mut #bevy_ecs_path::component::RequiredComponents,
inheritance_depth: u16,
recursion_check_stack: &mut #bevy_ecs_path::__macro_exports::Vec<#bevy_ecs_path::component::ComponentId>
) {
#bevy_ecs_path::component::enforce_no_required_components_recursion(components, recursion_check_stack);
let self_id = components.register_component::<Self>();
recursion_check_stack.push(self_id);
#(#register_required)*
#(#register_recursive_requires)*
recursion_check_stack.pop();
}

#on_add
Expand Down
9 changes: 1 addition & 8 deletions crates/bevy_ecs/src/bundle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,14 +283,7 @@ unsafe impl<C: Component> Bundle for C {
components: &mut ComponentsRegistrator,
required_components: &mut RequiredComponents,
) {
let component_id = components.register_component::<C>();
<C as Component>::register_required_components(
component_id,
components,
required_components,
0,
&mut Vec::new(),
);
<C as Component>::register_required_components(components, required_components);
}

fn get_component_ids(components: &Components, ids: &mut impl FnMut(Option<ComponentId>)) {
Expand Down
118 changes: 45 additions & 73 deletions crates/bevy_ecs/src/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,11 +538,8 @@ pub trait Component: Send + Sync + 'static {

/// Registers required components.
fn register_required_components(
_component_id: ComponentId,
_components: &mut ComponentsRegistrator,
_required_components: &mut RequiredComponents,
_inheritance_depth: u16,
_recursion_check_stack: &mut Vec<ComponentId>,
) {
}

Expand Down Expand Up @@ -1323,7 +1320,7 @@ impl<'w> ComponentsQueuedRegistrator<'w> {
// SAFETY: We just checked that this is not currently registered or queued, and if it was registered since, this would have been dropped from the queue.
#[expect(unused_unsafe, reason = "More precise to specify.")]
unsafe {
registrator.register_component_unchecked::<T>(&mut Vec::new(), id);
registrator.register_component_unchecked::<T>(id);
}
},
)
Expand Down Expand Up @@ -1442,6 +1439,7 @@ impl<'w> ComponentsQueuedRegistrator<'w> {
pub struct ComponentsRegistrator<'w> {
components: &'w mut Components,
ids: &'w mut ComponentIds,
recursion_check_stack: Vec<ComponentId>,
}

impl Deref for ComponentsRegistrator<'_> {
Expand All @@ -1466,7 +1464,11 @@ impl<'w> ComponentsRegistrator<'w> {
/// The [`Components`] and [`ComponentIds`] must match.
/// For example, they must be from the same world.
pub unsafe fn new(components: &'w mut Components, ids: &'w mut ComponentIds) -> Self {
Self { components, ids }
Self {
components,
ids,
recursion_check_stack: Vec::new(),
}
}

/// Converts this [`ComponentsRegistrator`] into a [`ComponentsQueuedRegistrator`].
Expand Down Expand Up @@ -1555,15 +1557,6 @@ impl<'w> ComponentsRegistrator<'w> {
/// * [`ComponentsRegistrator::register_component_with_descriptor()`]
#[inline]
pub fn register_component<T: Component>(&mut self) -> ComponentId {
self.register_component_checked::<T>(&mut Vec::new())
}

/// Same as [`Self::register_component_unchecked`] but keeps a checks for safety.
#[inline]
fn register_component_checked<T: Component>(
&mut self,
recursion_check_stack: &mut Vec<ComponentId>,
) -> ComponentId {
let type_id = TypeId::of::<T>();
if let Some(id) = self.indices.get(&type_id) {
return *id;
Expand All @@ -1585,7 +1578,7 @@ impl<'w> ComponentsRegistrator<'w> {
let id = self.ids.next_mut();
// SAFETY: The component is not currently registered, and the id is fresh.
unsafe {
self.register_component_unchecked::<T>(recursion_check_stack, id);
self.register_component_unchecked::<T>(id);
}
id
}
Expand All @@ -1594,11 +1587,7 @@ impl<'w> ComponentsRegistrator<'w> {
///
/// Neither this component, nor its id may be registered or queued. This must be a new registration.
#[inline]
unsafe fn register_component_unchecked<T: Component>(
&mut self,
recursion_check_stack: &mut Vec<ComponentId>,
id: ComponentId,
) {
unsafe fn register_component_unchecked<T: Component>(&mut self, id: ComponentId) {
// SAFETY: ensured by caller.
unsafe {
self.register_component_inner(id, ComponentDescriptor::new::<T>());
Expand All @@ -1607,14 +1596,11 @@ impl<'w> ComponentsRegistrator<'w> {
let prev = self.indices.insert(type_id, id);
debug_assert!(prev.is_none());

self.recursion_check_stack.push(id);
let mut required_components = RequiredComponents::default();
T::register_required_components(
id,
self,
&mut required_components,
0,
recursion_check_stack,
);
T::register_required_components(self, &mut required_components);
self.recursion_check_stack.pop();

// SAFETY: we just inserted it in `register_component_inner`
let info = unsafe {
&mut self
Expand Down Expand Up @@ -1661,13 +1647,6 @@ impl<'w> ComponentsRegistrator<'w> {
/// Registers the given component `R` and [required components] inherited from it as required by `T`,
/// and adds `T` to their lists of requirees.
///
/// The given `inheritance_depth` determines how many levels of inheritance deep the requirement is.
/// A direct requirement has a depth of `0`, and each level of inheritance increases the depth by `1`.
/// Lower depths are more specific requirements, and can override existing less specific registrations.
///
/// The `recursion_check_stack` allows checking whether this component tried to register itself as its
/// own (indirect) required component.
///
/// This method does *not* register any components as required by components that require `T`.
///
/// Only use this method if you know what you are doing. In most cases, you should instead use [`World::register_required_components`],
Expand All @@ -1679,11 +1658,15 @@ impl<'w> ComponentsRegistrator<'w> {
&mut self,
required_components: &mut RequiredComponents,
constructor: fn() -> R,
inheritance_depth: u16,
recursion_check_stack: &mut Vec<ComponentId>,
) {
let requiree = self.register_component_checked::<T>(recursion_check_stack);
let required = self.register_component_checked::<R>(recursion_check_stack);
let requiree = self.register_component::<T>();
let required = self.register_component::<R>();

enforce_no_required_components_recursion(
required,
self.components,
&self.recursion_check_stack,
);

// SAFETY: We just created the components.
unsafe {
Expand All @@ -1692,7 +1675,6 @@ impl<'w> ComponentsRegistrator<'w> {
required,
required_components,
constructor,
inheritance_depth,
);
}
}
Expand Down Expand Up @@ -2132,10 +2114,6 @@ impl Components {
/// Registers the given component `R` and [required components] inherited from it as required by `T`,
/// and adds `T` to their lists of requirees.
///
/// The given `inheritance_depth` determines how many levels of inheritance deep the requirement is.
/// A direct requirement has a depth of `0`, and each level of inheritance increases the depth by `1`.
/// Lower depths are more specific requirements, and can override existing less specific registrations.
///
/// This method does *not* register any components as required by components that require `T`.
///
/// [required component]: Component#required-components
Expand All @@ -2149,15 +2127,14 @@ impl Components {
required: ComponentId,
required_components: &mut RequiredComponents,
constructor: fn() -> R,
inheritance_depth: u16,
) {
// Components cannot require themselves.
if required == requiree {
return;
}

// Register the required component `R` for the requiree.
required_components.register_by_id(required, constructor, inheritance_depth);
required_components.register_by_id(required, constructor, 0);

// Add the requiree to the list of components that require `R`.
// SAFETY: The caller ensures that the component ID is valid.
Expand Down Expand Up @@ -2852,37 +2829,32 @@ impl RequiredComponents {
}
}

// NOTE: This should maybe be private, but it is currently public so that `bevy_ecs_macros` can use it.
// This exists as a standalone function instead of being inlined into the component derive macro so as
// to reduce the amount of generated code.
#[doc(hidden)]
pub fn enforce_no_required_components_recursion(
fn enforce_no_required_components_recursion(
requiree: ComponentId,
components: &Components,
recursion_check_stack: &[ComponentId],
) {
if let Some((&requiree, check)) = recursion_check_stack.split_last() {
if let Some(direct_recursion) = check
.iter()
.position(|&id| id == requiree)
.map(|index| index == check.len() - 1)
{
panic!(
"Recursive required components detected: {}\nhelp: {}",
recursion_check_stack
.iter()
.map(|id| format!("{}", components.get_name(*id).unwrap().shortname()))
.collect::<Vec<_>>()
.join(" → "),
if direct_recursion {
format!(
"Remove require({}).",
components.get_name(requiree).unwrap().shortname()
)
} else {
"If this is intentional, consider merging the components.".into()
}
);
}
if let Some(direct_recursion) = recursion_check_stack
.iter()
.position(|&id| id == requiree)
.map(|index| index == recursion_check_stack.len() - 1)
{
panic!(
"Recursive required components detected: {}\nhelp: {}",
recursion_check_stack
.iter()
.map(|id| format!("{}", components.get_name(*id).unwrap().shortname()))
.collect::<Vec<_>>()
.join(" → "),
if direct_recursion {
format!(
"Remove require({}).",
components.get_name(requiree).unwrap().shortname()
)
} else {
"If this is intentional, consider merging the components.".into()
}
);
}
}

Expand Down
31 changes: 19 additions & 12 deletions crates/bevy_ecs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2479,16 +2479,16 @@ mod tests {
//
// Requirements with `require` attribute:
//
// A -> B -> C
// 0 1
// A -> B -> C -> D
// 0 1 2
//
// Runtime requirements:
//
// X -> A -> B -> C
// 0 1 2
//
// X -> Y -> Z -> B -> C
// X -> A -> B -> C -> D
// 0 1 2 3
//
// X -> Y -> Z -> B -> C -> D
// 0 1 2 3 4

#[derive(Component, Default)]
#[require(B)]
Expand All @@ -2499,8 +2499,12 @@ mod tests {
struct B;

#[derive(Component, Default)]
#[require(D)]
struct C;

#[derive(Component, Default)]
struct D;

#[derive(Component, Default)]
struct X;

Expand All @@ -2515,6 +2519,7 @@ mod tests {
let a = world.register_component::<A>();
let b = world.register_component::<B>();
let c = world.register_component::<C>();
let d = world.register_component::<D>();
let y = world.register_component::<Y>();
let z = world.register_component::<Z>();

Expand All @@ -2528,6 +2533,7 @@ mod tests {
let required_a = world.get_required_components::<A>().unwrap();
let required_b = world.get_required_components::<B>().unwrap();
let required_c = world.get_required_components::<C>().unwrap();
let required_d = world.get_required_components::<D>().unwrap();
let required_x = world.get_required_components::<X>().unwrap();
let required_y = world.get_required_components::<Y>().unwrap();
let required_z = world.get_required_components::<Z>().unwrap();
Expand All @@ -2545,15 +2551,16 @@ mod tests {
}

// Check that the inheritance depths are correct for each component.
assert_eq!(to_vec(required_a), vec![(b, 0), (c, 1)]);
assert_eq!(to_vec(required_b), vec![(c, 0)]);
assert_eq!(to_vec(required_c), vec![]);
assert_eq!(to_vec(required_a), vec![(b, 0), (c, 1), (d, 2)]);
assert_eq!(to_vec(required_b), vec![(c, 0), (d, 1)]);
assert_eq!(to_vec(required_c), vec![(d, 0)]);
assert_eq!(to_vec(required_d), vec![]);
assert_eq!(
to_vec(required_x),
vec![(a, 0), (b, 1), (c, 2), (y, 0), (z, 1)]
vec![(a, 0), (b, 1), (c, 2), (d, 3), (y, 0), (z, 1)]
);
assert_eq!(to_vec(required_y), vec![(b, 1), (c, 2), (z, 0)]);
assert_eq!(to_vec(required_z), vec![(b, 0), (c, 1)]);
assert_eq!(to_vec(required_y), vec![(b, 1), (c, 2), (d, 3), (z, 0)]);
assert_eq!(to_vec(required_z), vec![(b, 0), (c, 1), (d, 2)]);
}

#[test]
Expand Down
Loading