From c18293261893f675c7a0e1d68aefc19d117ea8eb Mon Sep 17 00:00:00 2001 From: Giacomo Stevanato Date: Sat, 12 Jul 2025 17:26:27 +0200 Subject: [PATCH 1/9] Move recursion_check_stack parameter into ComponentsRegistrator --- crates/bevy_ecs/macros/src/component.rs | 16 ++---- crates/bevy_ecs/src/component/mod.rs | 2 - crates/bevy_ecs/src/component/register.rs | 34 ++++++------- crates/bevy_ecs/src/component/required.rs | 61 ++++++++++------------- 4 files changed, 45 insertions(+), 68 deletions(-) diff --git a/crates/bevy_ecs/macros/src/component.rs b/crates/bevy_ecs/macros/src/component.rs index 1322022581f2f..0265f337fb170 100644 --- a/crates/bevy_ecs/macros/src/component.rs +++ b/crates/bevy_ecs/macros/src/component.rs @@ -246,8 +246,7 @@ pub fn derive_component(input: TokenStream) -> TokenStream { requiree, components, required_components, - inheritance_depth + 1, - recursion_check_stack + inheritance_depth + 1 ); }); match &require.func { @@ -256,8 +255,7 @@ pub fn derive_component(input: TokenStream) -> TokenStream { components.register_required_components_manual::( required_components, || { let x: #ident = (#func)().into(); x }, - inheritance_depth, - recursion_check_stack + inheritance_depth ); }); } @@ -266,8 +264,7 @@ pub fn derive_component(input: TokenStream) -> TokenStream { components.register_required_components_manual::( required_components, <#ident as Default>::default, - inheritance_depth, - recursion_check_stack + inheritance_depth ); }); } @@ -312,18 +309,13 @@ pub fn derive_component(input: TokenStream) -> TokenStream { const STORAGE_TYPE: #bevy_ecs_path::component::StorageType = #storage; type Mutability = #mutable_type; fn register_required_components( - requiree: #bevy_ecs_path::component::ComponentId, + _requiree: #bevy_ecs_path::component::ComponentId, components: &mut #bevy_ecs_path::component::ComponentsRegistrator, required_components: &mut #bevy_ecs_path::component::RequiredComponents, inheritance_depth: u16, - recursion_check_stack: &mut #bevy_ecs_path::__macro_exports::Vec<#bevy_ecs_path::component::ComponentId> ) { - #bevy_ecs_path::component::enforce_no_required_components_recursion(components, recursion_check_stack); - let self_id = components.register_component::(); - recursion_check_stack.push(self_id); #(#register_required)* #(#register_recursive_requires)* - recursion_check_stack.pop(); } #on_add diff --git a/crates/bevy_ecs/src/component/mod.rs b/crates/bevy_ecs/src/component/mod.rs index 8b57829243eed..5c9a01e58f30e 100644 --- a/crates/bevy_ecs/src/component/mod.rs +++ b/crates/bevy_ecs/src/component/mod.rs @@ -18,7 +18,6 @@ use crate::{ system::{Local, SystemParam}, world::{FromWorld, World}, }; -use alloc::vec::Vec; pub use bevy_ecs_macros::Component; use core::{fmt::Debug, marker::PhantomData, ops::Deref}; @@ -528,7 +527,6 @@ pub trait Component: Send + Sync + 'static { _components: &mut ComponentsRegistrator, _required_components: &mut RequiredComponents, _inheritance_depth: u16, - _recursion_check_stack: &mut Vec, ) { } diff --git a/crates/bevy_ecs/src/component/register.rs b/crates/bevy_ecs/src/component/register.rs index bf1720b0053a2..1a50bbbb05137 100644 --- a/crates/bevy_ecs/src/component/register.rs +++ b/crates/bevy_ecs/src/component/register.rs @@ -64,6 +64,7 @@ impl ComponentIds { pub struct ComponentsRegistrator<'w> { components: &'w mut Components, ids: &'w mut ComponentIds, + pub(super) recursion_check_stack: Vec, } impl Deref for ComponentsRegistrator<'_> { @@ -88,7 +89,11 @@ impl<'w> ComponentsRegistrator<'w> { /// The [`Components`] and [`ComponentIds`] must match. /// For example, they must be from the same world. pub unsafe fn new(components: &'w mut Components, ids: &'w mut ComponentIds) -> Self { - Self { components, ids } + Self { + components, + ids, + recursion_check_stack: Vec::new(), + } } /// Converts this [`ComponentsRegistrator`] into a [`ComponentsQueuedRegistrator`]. @@ -177,15 +182,12 @@ impl<'w> ComponentsRegistrator<'w> { /// * [`ComponentsRegistrator::register_component_with_descriptor()`] #[inline] pub fn register_component(&mut self) -> ComponentId { - self.register_component_checked::(&mut Vec::new()) + self.register_component_checked::() } /// Same as [`Self::register_component_unchecked`] but keeps a checks for safety. #[inline] - pub(super) fn register_component_checked( - &mut self, - recursion_check_stack: &mut Vec, - ) -> ComponentId { + pub(super) fn register_component_checked(&mut self) -> ComponentId { let type_id = TypeId::of::(); if let Some(id) = self.indices.get(&type_id) { return *id; @@ -207,7 +209,7 @@ impl<'w> ComponentsRegistrator<'w> { let id = self.ids.next_mut(); // SAFETY: The component is not currently registered, and the id is fresh. unsafe { - self.register_component_unchecked::(recursion_check_stack, id); + self.register_component_unchecked::(id); } id } @@ -216,11 +218,7 @@ impl<'w> ComponentsRegistrator<'w> { /// /// Neither this component, nor its id may be registered or queued. This must be a new registration. #[inline] - unsafe fn register_component_unchecked( - &mut self, - recursion_check_stack: &mut Vec, - id: ComponentId, - ) { + unsafe fn register_component_unchecked(&mut self, id: ComponentId) { // SAFETY: ensured by caller. unsafe { self.register_component_inner(id, ComponentDescriptor::new::()); @@ -229,14 +227,10 @@ impl<'w> ComponentsRegistrator<'w> { let prev = self.indices.insert(type_id, id); debug_assert!(prev.is_none()); + self.recursion_check_stack.push(id); let mut required_components = RequiredComponents::default(); - T::register_required_components( - id, - self, - &mut required_components, - 0, - recursion_check_stack, - ); + T::register_required_components(id, self, &mut required_components, 0); + self.recursion_check_stack.pop(); // SAFETY: we just inserted it in `register_component_inner` let info = unsafe { &mut self @@ -563,7 +557,7 @@ impl<'w> ComponentsQueuedRegistrator<'w> { // SAFETY: We just checked that this is not currently registered or queued, and if it was registered since, this would have been dropped from the queue. #[expect(unused_unsafe, reason = "More precise to specify.")] unsafe { - registrator.register_component_unchecked::(&mut Vec::new(), id); + registrator.register_component_unchecked::(id); } }, ) diff --git a/crates/bevy_ecs/src/component/required.rs b/crates/bevy_ecs/src/component/required.rs index d46b6b61ce131..6cac5025b0494 100644 --- a/crates/bevy_ecs/src/component/required.rs +++ b/crates/bevy_ecs/src/component/required.rs @@ -231,9 +231,6 @@ impl<'w> ComponentsRegistrator<'w> { /// A direct requirement has a depth of `0`, and each level of inheritance increases the depth by `1`. /// Lower depths are more specific requirements, and can override existing less specific registrations. /// - /// The `recursion_check_stack` allows checking whether this component tried to register itself as its - /// own (indirect) required component. - /// /// This method does *not* register any components as required by components that require `T`. /// /// Only use this method if you know what you are doing. In most cases, you should instead use [`World::register_required_components`], @@ -246,10 +243,11 @@ impl<'w> ComponentsRegistrator<'w> { required_components: &mut RequiredComponents, constructor: fn() -> R, inheritance_depth: u16, - recursion_check_stack: &mut Vec, ) { - let requiree = self.register_component_checked::(recursion_check_stack); - let required = self.register_component_checked::(recursion_check_stack); + let requiree = self.register_component_checked::(); + let required = self.register_component_checked::(); + + enforce_no_required_components_recursion(self, &self.recursion_check_stack, required); // SAFETY: We just created the components. unsafe { @@ -501,36 +499,31 @@ impl RequiredComponents { } } -// NOTE: This should maybe be private, but it is currently public so that `bevy_ecs_macros` can use it. -// This exists as a standalone function instead of being inlined into the component derive macro so as -// to reduce the amount of generated code. -#[doc(hidden)] -pub fn enforce_no_required_components_recursion( +fn enforce_no_required_components_recursion( components: &Components, recursion_check_stack: &[ComponentId], + required: ComponentId, ) { - if let Some((&requiree, check)) = recursion_check_stack.split_last() { - if let Some(direct_recursion) = check - .iter() - .position(|&id| id == requiree) - .map(|index| index == check.len() - 1) - { - panic!( - "Recursive required components detected: {}\nhelp: {}", - recursion_check_stack - .iter() - .map(|id| format!("{}", components.get_name(*id).unwrap().shortname())) - .collect::>() - .join(" → "), - if direct_recursion { - format!( - "Remove require({}).", - components.get_name(requiree).unwrap().shortname() - ) - } else { - "If this is intentional, consider merging the components.".into() - } - ); - } + if let Some(direct_recursion) = recursion_check_stack + .iter() + .position(|&id| id == required) + .map(|index| index == recursion_check_stack.len() - 1) + { + panic!( + "Recursive required components detected: {}\nhelp: {}", + recursion_check_stack + .iter() + .map(|id| format!("{}", components.get_name(*id).unwrap().shortname())) + .collect::>() + .join(" → "), + if direct_recursion { + format!( + "Remove require({}).", + components.get_name(required).unwrap().shortname() + ) + } else { + "If this is intentional, consider merging the components.".into() + } + ); } } From e335244112e8103cb73e9faafab33acecc950077 Mon Sep 17 00:00:00 2001 From: Giacomo Stevanato Date: Sun, 13 Jul 2025 15:28:37 +0200 Subject: [PATCH 2/9] Rewrite required components --- crates/bevy_ecs/macros/src/component.rs | 39 +- crates/bevy_ecs/src/bundle/info.rs | 79 ++- crates/bevy_ecs/src/bundle/insert.rs | 4 +- crates/bevy_ecs/src/bundle/spawner.rs | 2 +- crates/bevy_ecs/src/component/info.rs | 19 +- crates/bevy_ecs/src/component/mod.rs | 7 +- crates/bevy_ecs/src/component/register.rs | 15 +- crates/bevy_ecs/src/component/required.rs | 829 +++++++++++----------- 8 files changed, 503 insertions(+), 491 deletions(-) diff --git a/crates/bevy_ecs/macros/src/component.rs b/crates/bevy_ecs/macros/src/component.rs index 0265f337fb170..c6d530dc95374 100644 --- a/crates/bevy_ecs/macros/src/component.rs +++ b/crates/bevy_ecs/macros/src/component.rs @@ -237,38 +237,17 @@ pub fn derive_component(input: TokenStream) -> TokenStream { let requires = &attrs.requires; let mut register_required = Vec::with_capacity(attrs.requires.iter().len()); - let mut register_recursive_requires = Vec::with_capacity(attrs.requires.iter().len()); if let Some(requires) = requires { for require in requires { let ident = &require.path; - register_recursive_requires.push(quote! { - <#ident as #bevy_ecs_path::component::Component>::register_required_components( - requiree, - components, - required_components, - inheritance_depth + 1 - ); + let constructor = match &require.func { + Some(func) => quote! { || { let x: #ident = (#func)().into(); x } }, + None => quote! { <#ident as Default>::default }, + }; + register_required.push(quote! { + // SAFETY: we registered all components with the same instance of components. + unsafe { required_components.register::<#ident>(components, #constructor) }; }); - match &require.func { - Some(func) => { - register_required.push(quote! { - components.register_required_components_manual::( - required_components, - || { let x: #ident = (#func)().into(); x }, - inheritance_depth - ); - }); - } - None => { - register_required.push(quote! { - components.register_required_components_manual::( - required_components, - <#ident as Default>::default, - inheritance_depth - ); - }); - } - } } } let struct_name = &ast.ident; @@ -308,14 +287,12 @@ pub fn derive_component(input: TokenStream) -> TokenStream { impl #impl_generics #bevy_ecs_path::component::Component for #struct_name #type_generics #where_clause { const STORAGE_TYPE: #bevy_ecs_path::component::StorageType = #storage; type Mutability = #mutable_type; - fn register_required_components( + unsafe fn register_required_components( _requiree: #bevy_ecs_path::component::ComponentId, components: &mut #bevy_ecs_path::component::ComponentsRegistrator, required_components: &mut #bevy_ecs_path::component::RequiredComponents, - inheritance_depth: u16, ) { #(#register_required)* - #(#register_recursive_requires)* } #on_add diff --git a/crates/bevy_ecs/src/bundle/info.rs b/crates/bevy_ecs/src/bundle/info.rs index a4093f8a889a1..3ac341fd73b3c 100644 --- a/crates/bevy_ecs/src/bundle/info.rs +++ b/crates/bevy_ecs/src/bundle/info.rs @@ -3,14 +3,15 @@ use bevy_platform::collections::{HashMap, HashSet}; use bevy_ptr::OwningPtr; use bevy_utils::TypeIdMap; use core::{any::TypeId, ptr::NonNull}; +use indexmap::{IndexMap, IndexSet}; use crate::{ archetype::{Archetype, BundleComponentStatus, ComponentStatus}, bundle::{Bundle, DynamicBundle}, change_detection::MaybeLocation, component::{ - ComponentId, Components, ComponentsRegistrator, RequiredComponentConstructor, - RequiredComponents, StorageType, Tick, + ComponentId, Components, ComponentsRegistrator, RequiredComponentConstructor, StorageType, + Tick, }, entity::Entity, query::DebugCheckedUnwrap as _, @@ -59,6 +60,7 @@ pub enum InsertMode { /// [`World`]: crate::world::World pub struct BundleInfo { pub(super) id: BundleId, + /// The list of all components contributed by the bundle (including Required Components). This is in /// the order `[EXPLICIT_COMPONENTS][REQUIRED_COMPONENTS]` /// @@ -67,9 +69,10 @@ pub struct BundleInfo { /// must have its storage initialized (i.e. columns created in tables, sparse set created), /// and the range (0..`explicit_components_len`) must be in the same order as the source bundle /// type writes its components in. - pub(super) component_ids: Vec, - pub(super) required_components: Vec, - pub(super) explicit_components_len: usize, + pub(super) contributed_components: Vec, + + /// The list of constructors for all required components indirectly contributed by this bundle. + pub(super) required_component_constructors: Vec, } impl BundleInfo { @@ -86,11 +89,10 @@ impl BundleInfo { mut component_ids: Vec, id: BundleId, ) -> BundleInfo { + let explicit_component_ids = component_ids.iter().copied().collect::>(); + // check for duplicates - let mut deduped = component_ids.clone(); - deduped.sort_unstable(); - deduped.dedup(); - if deduped.len() != component_ids.len() { + if explicit_component_ids.len() != component_ids.len() { // TODO: Replace with `Vec::partition_dedup` once https://github.com/rust-lang/rust/issues/54279 is stabilized let mut seen = >::default(); let mut dups = Vec::new(); @@ -111,31 +113,30 @@ impl BundleInfo { panic!("Bundle {bundle_type_name} has duplicate components: {names:?}"); } - // handle explicit components - let explicit_components_len = component_ids.len(); - let mut required_components = RequiredComponents::default(); - for component_id in component_ids.iter().copied() { + let mut depth_first_components = IndexMap::new(); + for &component_id in &component_ids { // SAFETY: caller has verified that all ids are valid let info = unsafe { components.get_info_unchecked(component_id) }; - required_components.merge(info.required_components()); + + for (&required_id, required_component) in &info.required_components().all { + depth_first_components + .entry(required_id) + .or_insert_with(|| required_component.clone()); + } + storages.prepare_component(info); } - required_components.remove_explicit_components(&component_ids); - - // handle required components - let required_components = required_components - .0 - .into_iter() - .map(|(component_id, v)| { - // Safety: These ids came out of the passed `components`, so they must be valid. - let info = unsafe { components.get_info_unchecked(component_id) }; - storages.prepare_component(info); - // This adds required components to the component_ids list _after_ using that list to remove explicitly provided - // components. This ordering is important! - component_ids.push(component_id); - v.constructor + + let required_components = depth_first_components + .iter() + .filter(|&(required_id, _)| !explicit_component_ids.contains(required_id)) + .inspect(|&(&required_id, _)| { + // SAFETY: These ids came out of the passed `components`, so they must be valid. + storages.prepare_component(unsafe { components.get_info_unchecked(required_id) }); + component_ids.push(required_id); }) - .collect(); + .map(|(_, required_component)| required_component.constructor.clone()) + .collect::>(); // SAFETY: The caller ensures that component_ids: // - is valid for the associated world @@ -143,9 +144,8 @@ impl BundleInfo { // - is in the same order as the source bundle type BundleInfo { id, - component_ids, - required_components, - explicit_components_len, + contributed_components: component_ids, + required_component_constructors: required_components, } } @@ -155,19 +155,24 @@ impl BundleInfo { self.id } + /// Returns the length of the explicit components part of the [contributed_components](Self::contributed_components) list. + pub(super) fn explicit_components_len(&self) -> usize { + self.contributed_components.len() - self.required_component_constructors.len() + } + /// Returns the [ID](ComponentId) of each component explicitly defined in this bundle (ex: Required Components are excluded). /// /// For all components contributed by this bundle (including Required Components), see [`BundleInfo::contributed_components`] #[inline] pub fn explicit_components(&self) -> &[ComponentId] { - &self.component_ids[0..self.explicit_components_len] + &self.contributed_components[0..self.explicit_components_len()] } /// Returns the [ID](ComponentId) of each Required Component needed by this bundle. This _does not include_ Required Components that are /// explicitly provided by the bundle. #[inline] pub fn required_components(&self) -> &[ComponentId] { - &self.component_ids[self.explicit_components_len..] + &self.contributed_components[self.explicit_components_len()..] } /// Returns the [ID](ComponentId) of each component contributed by this bundle. This includes Required Components. @@ -175,7 +180,7 @@ impl BundleInfo { /// For only components explicitly defined in this bundle, see [`BundleInfo::explicit_components`] #[inline] pub fn contributed_components(&self) -> &[ComponentId] { - &self.component_ids + &self.contributed_components } /// Returns an iterator over the [ID](ComponentId) of each component explicitly defined in this bundle (ex: this excludes Required Components). @@ -190,7 +195,7 @@ impl BundleInfo { /// To iterate only components explicitly defined in this bundle, see [`BundleInfo::iter_explicit_components`] #[inline] pub fn iter_contributed_components(&self) -> impl Iterator + Clone + '_ { - self.component_ids.iter().copied() + self.contributed_components().iter().copied() } /// Returns an iterator over the [ID](ComponentId) of each Required Component needed by this bundle. This _does not include_ Required Components that are @@ -236,7 +241,7 @@ impl BundleInfo { // bundle_info.component_ids are also in "bundle order" let mut bundle_component = 0; let after_effect = bundle.get_components(&mut |storage_type, component_ptr| { - let component_id = *self.component_ids.get_unchecked(bundle_component); + let component_id = *self.contributed_components.get_unchecked(bundle_component); // SAFETY: bundle_component is a valid index for this bundle let status = unsafe { bundle_component_status.get_status(bundle_component) }; match storage_type { diff --git a/crates/bevy_ecs/src/bundle/insert.rs b/crates/bevy_ecs/src/bundle/insert.rs index bf8a99a7b1a98..0388b5e6fd87c 100644 --- a/crates/bevy_ecs/src/bundle/insert.rs +++ b/crates/bevy_ecs/src/bundle/insert.rs @@ -433,7 +433,7 @@ impl BundleInfo { } let mut new_table_components = Vec::new(); let mut new_sparse_set_components = Vec::new(); - let mut bundle_status = Vec::with_capacity(self.explicit_components_len); + let mut bundle_status = Vec::with_capacity(self.explicit_components_len()); let mut added_required_components = Vec::new(); let mut added = Vec::new(); let mut existing = Vec::new(); @@ -457,7 +457,7 @@ impl BundleInfo { for (index, component_id) in self.iter_required_components().enumerate() { if !current_archetype.contains(component_id) { - added_required_components.push(self.required_components[index].clone()); + added_required_components.push(self.required_component_constructors[index].clone()); added.push(component_id); // SAFETY: component_id exists let component_info = unsafe { components.get_info_unchecked(component_id) }; diff --git a/crates/bevy_ecs/src/bundle/spawner.rs b/crates/bevy_ecs/src/bundle/spawner.rs index 05e8cd956d5a8..407bfda8facc4 100644 --- a/crates/bevy_ecs/src/bundle/spawner.rs +++ b/crates/bevy_ecs/src/bundle/spawner.rs @@ -108,7 +108,7 @@ impl<'w> BundleSpawner<'w> { table, sparse_sets, &SpawnBundleStatus, - bundle_info.required_components.iter(), + bundle_info.required_component_constructors.iter(), entity, table_row, self.change_tick, diff --git a/crates/bevy_ecs/src/component/info.rs b/crates/bevy_ecs/src/component/info.rs index 5a1bf96e1685e..9deeac9b4d28b 100644 --- a/crates/bevy_ecs/src/component/info.rs +++ b/crates/bevy_ecs/src/component/info.rs @@ -1,5 +1,5 @@ use alloc::{borrow::Cow, vec::Vec}; -use bevy_platform::{collections::HashSet, sync::PoisonError}; +use bevy_platform::sync::PoisonError; use bevy_ptr::OwningPtr; #[cfg(feature = "bevy_reflect")] use bevy_reflect::Reflect; @@ -10,6 +10,7 @@ use core::{ fmt::Debug, mem::needs_drop, }; +use indexmap::IndexSet; use crate::{ archetype::ArchetypeFlags, @@ -30,7 +31,10 @@ pub struct ComponentInfo { pub(super) descriptor: ComponentDescriptor, pub(super) hooks: ComponentHooks, pub(super) required_components: RequiredComponents, - pub(super) required_by: HashSet, + /// The set of components that require this components. + /// Invariant: this is stored in a depth-first order, that is components are stored after the components + /// that they depend on. + pub(super) required_by: IndexSet, } impl ComponentInfo { @@ -505,6 +509,13 @@ impl Components { .and_then(|info| info.as_mut().map(|info| &mut info.hooks)) } + #[inline] + pub(crate) fn get_required_components(&self, id: ComponentId) -> Option<&RequiredComponents> { + self.components + .get(id.0) + .and_then(|info| info.as_ref().map(|info| &info.required_components)) + } + #[inline] pub(crate) fn get_required_components_mut( &mut self, @@ -516,7 +527,7 @@ impl Components { } #[inline] - pub(crate) fn get_required_by(&self, id: ComponentId) -> Option<&HashSet> { + pub(crate) fn get_required_by(&self, id: ComponentId) -> Option<&IndexSet> { self.components .get(id.0) .and_then(|info| info.as_ref().map(|info| &info.required_by)) @@ -526,7 +537,7 @@ impl Components { pub(crate) fn get_required_by_mut( &mut self, id: ComponentId, - ) -> Option<&mut HashSet> { + ) -> Option<&mut IndexSet> { self.components .get_mut(id.0) .and_then(|info| info.as_mut().map(|info| &mut info.required_by)) diff --git a/crates/bevy_ecs/src/component/mod.rs b/crates/bevy_ecs/src/component/mod.rs index 5c9a01e58f30e..eeb4b2fdf85d0 100644 --- a/crates/bevy_ecs/src/component/mod.rs +++ b/crates/bevy_ecs/src/component/mod.rs @@ -522,11 +522,14 @@ pub trait Component: Send + Sync + 'static { } /// Registers required components. - fn register_required_components( + /// + /// # Safety + /// + /// - `_required_components` must only contain components valid in `_components`. + unsafe fn register_required_components( _component_id: ComponentId, _components: &mut ComponentsRegistrator, _required_components: &mut RequiredComponents, - _inheritance_depth: u16, ) { } diff --git a/crates/bevy_ecs/src/component/register.rs b/crates/bevy_ecs/src/component/register.rs index 1a50bbbb05137..6c4efae0c2b93 100644 --- a/crates/bevy_ecs/src/component/register.rs +++ b/crates/bevy_ecs/src/component/register.rs @@ -5,6 +5,7 @@ use core::any::Any; use core::ops::DerefMut; use core::{any::TypeId, fmt::Debug, ops::Deref}; +use crate::component::enforce_no_required_components_recursion; use crate::query::DebugCheckedUnwrap as _; use crate::{ component::{ @@ -189,8 +190,9 @@ impl<'w> ComponentsRegistrator<'w> { #[inline] pub(super) fn register_component_checked(&mut self) -> ComponentId { let type_id = TypeId::of::(); - if let Some(id) = self.indices.get(&type_id) { - return *id; + if let Some(&id) = self.indices.get(&type_id) { + enforce_no_required_components_recursion(self, &self.recursion_check_stack, id); + return id; } if let Some(registrator) = self @@ -229,8 +231,15 @@ impl<'w> ComponentsRegistrator<'w> { self.recursion_check_stack.push(id); let mut required_components = RequiredComponents::default(); - T::register_required_components(id, self, &mut required_components, 0); + // SAFETY: `required_components` is empty + unsafe { T::register_required_components(id, self, &mut required_components) }; + // SAFETY: + // - `id` was just registered in `self` + // - `register_required_components` have been given `self` to register components in + // (TODO: this is not really true... but the alternative would be making `Component` `unsafe`...) + unsafe { self.register_required_by(id, &required_components) }; self.recursion_check_stack.pop(); + // SAFETY: we just inserted it in `register_component_inner` let info = unsafe { &mut self diff --git a/crates/bevy_ecs/src/component/required.rs b/crates/bevy_ecs/src/component/required.rs index 6cac5025b0494..18586ef0cff7e 100644 --- a/crates/bevy_ecs/src/component/required.rs +++ b/crates/bevy_ecs/src/component/required.rs @@ -1,8 +1,8 @@ use alloc::{format, vec::Vec}; -use bevy_platform::{collections::HashMap, sync::Arc}; +use bevy_platform::sync::Arc; use bevy_ptr::OwningPtr; use core::fmt::Debug; -use smallvec::SmallVec; +use indexmap::{IndexMap, IndexSet}; use thiserror::Error; use crate::{ @@ -14,273 +14,80 @@ use crate::{ storage::{SparseSets, Table, TableRow}, }; -impl Components { - /// Registers the given component `R` and [required components] inherited from it as required by `T`. - /// - /// When `T` is added to an entity, `R` will also be added if it was not already provided. - /// The given `constructor` will be used for the creation of `R`. - /// - /// [required components]: Component#required-components - /// - /// # Safety - /// - /// The given component IDs `required` and `requiree` must be valid. - /// - /// # Errors - /// - /// Returns a [`RequiredComponentsError`] if the `required` component is already a directly required component for the `requiree`. - /// - /// Indirect requirements through other components are allowed. In those cases, the more specific - /// registration will be used. - pub(crate) unsafe fn register_required_components( - &mut self, - requiree: ComponentId, - required: ComponentId, - constructor: fn() -> R, - ) -> Result<(), RequiredComponentsError> { - // SAFETY: The caller ensures that the `requiree` is valid. - let required_components = unsafe { - self.get_required_components_mut(requiree) - .debug_checked_unwrap() - }; - - // Cannot directly require the same component twice. - if required_components - .0 - .get(&required) - .is_some_and(|c| c.inheritance_depth == 0) - { - return Err(RequiredComponentsError::DuplicateRegistration( - requiree, required, - )); - } - - // Register the required component for the requiree. - // This is a direct requirement with a depth of `0`. - required_components.register_by_id(required, constructor, 0); - - // Add the requiree to the list of components that require the required component. - // SAFETY: The component is in the list of required components, so it must exist already. - let required_by = unsafe { self.get_required_by_mut(required).debug_checked_unwrap() }; - required_by.insert(requiree); - - let mut required_components_tmp = RequiredComponents::default(); - // SAFETY: The caller ensures that the `requiree` and `required` components are valid. - let inherited_requirements = unsafe { - self.register_inherited_required_components( - requiree, - required, - &mut required_components_tmp, - ) - }; - - // SAFETY: The caller ensures that the `requiree` is valid. - let required_components = unsafe { - self.get_required_components_mut(requiree) - .debug_checked_unwrap() - }; - required_components.0.extend(required_components_tmp.0); - - // Propagate the new required components up the chain to all components that require the requiree. - if let Some(required_by) = self - .get_required_by(requiree) - .map(|set| set.iter().copied().collect::>()) - { - // `required` is now required by anything that `requiree` was required by. - self.get_required_by_mut(required) - .unwrap() - .extend(required_by.iter().copied()); - for &required_by_id in required_by.iter() { - // SAFETY: The component is in the list of required components, so it must exist already. - let required_components = unsafe { - self.get_required_components_mut(required_by_id) - .debug_checked_unwrap() - }; - - // Register the original required component in the "parent" of the requiree. - // The inheritance depth is 1 deeper than the `requiree` wrt `required_by_id`. - let depth = required_components.0.get(&requiree).expect("requiree is required by required_by_id, so its required_components must include requiree").inheritance_depth; - required_components.register_by_id(required, constructor, depth + 1); - - for (component_id, component) in inherited_requirements.iter() { - // Register the required component. - // The inheritance depth of inherited components is whatever the requiree's - // depth is relative to `required_by_id`, plus the inheritance depth of the - // inherited component relative to the requiree, plus 1 to account for the - // requiree in between. - // SAFETY: Component ID and constructor match the ones on the original requiree. - // The original requiree is responsible for making sure the registration is safe. - unsafe { - required_components.register_dynamic_with( - *component_id, - component.inheritance_depth + depth + 1, - || component.constructor.clone(), - ); - }; - } - } - } - - Ok(()) - } - - /// Registers the components inherited from `required` for the given `requiree`, - /// returning the requirements in a list. - /// - /// # Safety - /// - /// The given component IDs `requiree` and `required` must be valid. - unsafe fn register_inherited_required_components( - &mut self, - requiree: ComponentId, - required: ComponentId, - required_components: &mut RequiredComponents, - ) -> Vec<(ComponentId, RequiredComponent)> { - // Get required components inherited from the `required` component. - // SAFETY: The caller ensures that the `required` component is valid. - let required_component_info = unsafe { self.get_info(required).debug_checked_unwrap() }; - let inherited_requirements: Vec<(ComponentId, RequiredComponent)> = required_component_info - .required_components() - .0 - .iter() - .map(|(component_id, required_component)| { - ( - *component_id, - RequiredComponent { - constructor: required_component.constructor.clone(), - // Add `1` to the inheritance depth since this will be registered - // for the component that requires `required`. - inheritance_depth: required_component.inheritance_depth + 1, - }, - ) - }) - .collect(); - - // Register the new required components. - for (component_id, component) in inherited_requirements.iter() { - // Register the required component for the requiree. - // SAFETY: Component ID and constructor match the ones on the original requiree. - unsafe { - required_components.register_dynamic_with( - *component_id, - component.inheritance_depth, - || component.constructor.clone(), - ); - }; - - // Add the requiree to the list of components that require the required component. - // SAFETY: The caller ensures that the required components are valid. - let required_by = unsafe { - self.get_required_by_mut(*component_id) - .debug_checked_unwrap() - }; - required_by.insert(requiree); - } +/// Metadata associated with a required component. See [`Component`] for details. +#[derive(Clone)] +pub struct RequiredComponent { + /// The constructor used for the required component. + pub constructor: RequiredComponentConstructor, +} - inherited_requirements - } +/// A Required Component constructor. See [`Component`] for details. +#[derive(Clone)] +pub struct RequiredComponentConstructor( + // Note: this function makes `unsafe` assumptions, so it cannot be public. + Arc, +); - /// Registers the given component `R` and [required components] inherited from it as required by `T`, - /// and adds `T` to their lists of requirees. - /// - /// The given `inheritance_depth` determines how many levels of inheritance deep the requirement is. - /// A direct requirement has a depth of `0`, and each level of inheritance increases the depth by `1`. - /// Lower depths are more specific requirements, and can override existing less specific registrations. - /// - /// This method does *not* register any components as required by components that require `T`. - /// - /// [required component]: Component#required-components +impl RequiredComponentConstructor { + /// Creates a new instance of `RequiredComponentConstructor` for the given type /// /// # Safety /// - /// The given component IDs `required` and `requiree` must be valid. - pub(crate) unsafe fn register_required_components_manual_unchecked( - &mut self, - requiree: ComponentId, - required: ComponentId, - required_components: &mut RequiredComponents, - constructor: fn() -> R, - inheritance_depth: u16, - ) { - // Components cannot require themselves. - if required == requiree { - return; - } - - // Register the required component `R` for the requiree. - required_components.register_by_id(required, constructor, inheritance_depth); - - // Add the requiree to the list of components that require `R`. - // SAFETY: The caller ensures that the component ID is valid. - // Assuming it is valid, the component is in the list of required components, so it must exist already. - let required_by = unsafe { self.get_required_by_mut(required).debug_checked_unwrap() }; - required_by.insert(requiree); - - self.register_inherited_required_components(requiree, required, required_components); - } -} - -impl<'w> ComponentsRegistrator<'w> { - // NOTE: This should maybe be private, but it is currently public so that `bevy_ecs_macros` can use it. - // We can't directly move this there either, because this uses `Components::get_required_by_mut`, - // which is private, and could be equally risky to expose to users. - /// Registers the given component `R` and [required components] inherited from it as required by `T`, - /// and adds `T` to their lists of requirees. - /// - /// The given `inheritance_depth` determines how many levels of inheritance deep the requirement is. - /// A direct requirement has a depth of `0`, and each level of inheritance increases the depth by `1`. - /// Lower depths are more specific requirements, and can override existing less specific registrations. - /// - /// This method does *not* register any components as required by components that require `T`. - /// - /// Only use this method if you know what you are doing. In most cases, you should instead use [`World::register_required_components`], - /// or the equivalent method in `bevy_app::App`. - /// - /// [required component]: Component#required-components - #[doc(hidden)] - pub fn register_required_components_manual( - &mut self, - required_components: &mut RequiredComponents, - constructor: fn() -> R, - inheritance_depth: u16, - ) { - let requiree = self.register_component_checked::(); - let required = self.register_component_checked::(); - - enforce_no_required_components_recursion(self, &self.recursion_check_stack, required); - - // SAFETY: We just created the components. - unsafe { - self.register_required_components_manual_unchecked::( - requiree, - required, - required_components, - constructor, - inheritance_depth, + /// - `component_id` must be a valid component for type `C`. + pub unsafe fn new(component_id: ComponentId, constructor: fn() -> C) -> Self { + RequiredComponentConstructor({ + // `portable-atomic-util` `Arc` is not able to coerce an unsized + // type like `std::sync::Arc` can. Creating a `Box` first does the + // coercion. + // + // This would be resolved by https://github.com/rust-lang/rust/issues/123430 + + #[cfg(not(target_has_atomic = "ptr"))] + use alloc::boxed::Box; + + type Constructor = dyn for<'a, 'b> Fn( + &'a mut Table, + &'b mut SparseSets, + Tick, + TableRow, + Entity, + MaybeLocation, ); - } - } -} -/// An error returned when the registration of a required component fails. -#[derive(Error, Debug)] -#[non_exhaustive] -pub enum RequiredComponentsError { - /// The component is already a directly required component for the requiree. - #[error("Component {0:?} already directly requires component {1:?}")] - DuplicateRegistration(ComponentId, ComponentId), - /// An archetype with the component that requires other components already exists - #[error("An archetype with the component {0:?} that requires other components already exists")] - ArchetypeExists(ComponentId), -} + #[cfg(not(target_has_atomic = "ptr"))] + type Intermediate = Box; + + #[cfg(target_has_atomic = "ptr")] + type Intermediate = Arc; + + let boxed: Intermediate = Intermediate::new( + move |table, sparse_sets, change_tick, table_row, entity, caller| { + OwningPtr::make(constructor(), |ptr| { + // SAFETY: This will only be called in the context of `BundleInfo::write_components`, which will + // pass in a valid table_row and entity requiring a C constructor + // C::STORAGE_TYPE is the storage type associated with `component_id` / `C` + // `ptr` points to valid `C` data, which matches the type associated with `component_id` + unsafe { + BundleInfo::initialize_required_component( + table, + sparse_sets, + change_tick, + table_row, + entity, + component_id, + C::STORAGE_TYPE, + ptr, + caller, + ); + } + }); + }, + ); -/// A Required Component constructor. See [`Component`] for details. -#[derive(Clone)] -pub struct RequiredComponentConstructor( - pub Arc, -); + Arc::from(boxed) + }) + } -impl RequiredComponentConstructor { /// # Safety /// This is intended to only be called in the context of [`BundleInfo::write_components`] to initialized required components. /// Calling it _anywhere else_ should be considered unsafe. @@ -303,203 +110,403 @@ impl RequiredComponentConstructor { } } -/// Metadata associated with a required component. See [`Component`] for details. -#[derive(Clone)] -pub struct RequiredComponent { - /// The constructor used for the required component. - pub constructor: RequiredComponentConstructor, - - /// The depth of the component requirement in the requirement hierarchy for this component. - /// This is used for determining which constructor is used in cases where there are duplicate requires. - /// - /// For example, consider the inheritance tree `X -> Y -> Z`, where `->` indicates a requirement. - /// `X -> Y` and `Y -> Z` are direct requirements with a depth of 0, while `Z` is only indirectly - /// required for `X` with a depth of `1`. - /// - /// In cases where there are multiple conflicting requirements with the same depth, a higher priority - /// will be given to components listed earlier in the `require` attribute, or to the latest added requirement - /// if registered at runtime. - pub inheritance_depth: u16, -} - /// The collection of metadata for components that are required for a given component. /// /// For more information, see the "Required Components" section of [`Component`]. #[derive(Default, Clone)] -pub struct RequiredComponents(pub(crate) HashMap); +pub struct RequiredComponents { + /// The components that are directly required (i.e. excluding inherited ones), in the order of their precedence. + /// + /// # Safety + /// The [`RequiredComponent`] instance associated to each ID must be valid for its component. + pub(crate) direct: IndexMap, + /// All the components that are required (i.e. including inherited ones), in depth-first order. Most importantly, + /// components in this list always appear after all the components that they require. + /// + /// Note that the direct components are not necessarily at the end of this list, for example if A and C are directly + /// requires, and A requires B requires C, then `all` will hold [C, B, A]. + /// + /// # Safety + /// The [`RequiredComponent`] instance associated to each ID must be valid for its component. + pub(crate) all: IndexMap, +} impl Debug for RequiredComponents { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_tuple("RequiredComponents") - .field(&self.0.keys()) + f.debug_struct("RequiredComponents") + .field("direct", &self.direct.keys()) + .field("all", &self.all.keys()) .finish() } } impl RequiredComponents { - /// Registers a required component. + /// Registers the [`Component`] `C` as an explicitly required component. /// - /// If the component is already registered, it will be overwritten if the given inheritance depth - /// is smaller than the depth of the existing registration. Otherwise, the new registration will be ignored. + /// If the component was not already registered as an explicit required component then it is added + /// as one, potentially overriding the constructor of a inherited required component, and `true` is returned. + /// Otherwise `false` is returned. /// /// # Safety /// - /// `component_id` must match the type initialized by `constructor`. - /// `constructor` _must_ initialize a component for `component_id` in such a way that - /// matches the storage type of the component. It must only use the given `table_row` or `Entity` to - /// initialize the storage for `component_id` corresponding to the given entity. - pub unsafe fn register_dynamic_with( + /// - all other components in this [`RequiredComponents`] instance must have been registrated in `components`. + pub unsafe fn register( &mut self, - component_id: ComponentId, - inheritance_depth: u16, - constructor: impl FnOnce() -> RequiredComponentConstructor, - ) { - let entry = self.0.entry(component_id); - match entry { - bevy_platform::collections::hash_map::Entry::Occupied(mut occupied) => { - let current = occupied.get_mut(); - if current.inheritance_depth > inheritance_depth { - *current = RequiredComponent { - constructor: constructor(), - inheritance_depth, - } - } - } - bevy_platform::collections::hash_map::Entry::Vacant(vacant) => { - vacant.insert(RequiredComponent { - constructor: constructor(), - inheritance_depth, - }); - } - } + components: &mut ComponentsRegistrator<'_>, + constructor: fn() -> C, + ) -> bool { + let id = components.register_component::(); + // SAFETY: + // - `id` was just registered in `components`; + // - the caller guarantees all other components were registered in `components`. + unsafe { self.register_by_id::(id, components, constructor) } } - /// Registers a required component. + /// Registers the [`Component`] with the given `component_id` ID as an explicitly required component. + /// + /// If the component was not already registered as an explicit required component then it is added + /// as one, potentially overriding the constructor of a inherited required component, and `true` is returned. + /// Otherwise `false` is returned. + /// + /// # Safety /// - /// If the component is already registered, it will be overwritten if the given inheritance depth - /// is smaller than the depth of the existing registration. Otherwise, the new registration will be ignored. - pub fn register( + /// - `component_id` must be a valid component in `components` for the type `C`; + /// - all other components in this [`RequiredComponents`] instance must have been registrated in `components`. + pub unsafe fn register_by_id( &mut self, - components: &mut ComponentsRegistrator, + component_id: ComponentId, + components: &Components, constructor: fn() -> C, - inheritance_depth: u16, - ) { - let component_id = components.register_component::(); - self.register_by_id(component_id, constructor, inheritance_depth); + ) -> bool { + // SAFETY: the caller guarantees that `component_id` is valid for the type `C`. + let constructor = + || unsafe { RequiredComponentConstructor::new(component_id, constructor) }; + + // SAFETY: + // - the caller guarantees that `component_id` is valid in `components` + // - the caller guarantees all other components were registered in `components`; + // - constructor is guaranteed to create a valid constructor for the component with id `component_id`. + unsafe { self.register_dynamic_with(component_id, components, constructor) } } - /// Registers the [`Component`] with the given ID as required if it exists. + /// Registers the [`Component`] with the given `component_id` ID as an explicitly required component. /// - /// If the component is already registered, it will be overwritten if the given inheritance depth - /// is smaller than the depth of the existing registration. Otherwise, the new registration will be ignored. - pub fn register_by_id( + /// If the component was not already registered as an explicit required component then it is added + /// as one, potentially overriding the constructor of a inherited required component, and `true` is returned. + /// Otherwise `false` is returned. + /// + /// # Safety + /// + /// - `component_id` must be a valid component in `components`; + /// - all other components in this [`RequiredComponents`] instance must have been registrated in `components`; + /// - `constructor` must return a [`RequiredComponentConstructor`] that constructs a valid instance for the + /// component with ID `component_id`. + pub unsafe fn register_dynamic_with( &mut self, component_id: ComponentId, - constructor: fn() -> C, - inheritance_depth: u16, + components: &Components, + constructor: impl FnOnce() -> RequiredComponentConstructor, + ) -> bool { + // If already registered as a direct required component then bail. + let entry = match self.direct.entry(component_id) { + indexmap::map::Entry::Vacant(entry) => entry, + indexmap::map::Entry::Occupied(_) => return false, + }; + + // Insert into `direct`. + let constructor = constructor(); + let required_component = RequiredComponent { constructor }; + entry.insert(required_component.clone()); + + // Register inherited required components. + unsafe { + Self::register_inherited_required_components_unchecked( + &mut self.all, + component_id, + required_component, + components, + ) + }; + + true + } + + /// Rebuild the `all` list + /// + /// # Safety + /// + /// - all components in this [`RequiredComponents`] instance must have been registrated in `components`. + unsafe fn rebuild_inherited_required_components(&mut self, components: &Components) { + // Clear `all`, we are re-initializing it. + self.all.clear(); + + // Register all inherited components as if we just registered all components in `direct` one-by-one. + for (&required_id, required_component) in &self.direct { + // SAFETY: + // - the caller guarantees that all components in this instance have been registered in `components`, + // meaning both `all` and `required_id` have been registered in `components`; + // - `required_component` was associated to `required_id`, so it must hold a constructor valid for it. + unsafe { + Self::register_inherited_required_components_unchecked( + &mut self.all, + required_id, + required_component.clone(), + components, + ) + } + } + } + + /// Registers all the inherited required components from `required_id`. + /// + /// # Safety + /// + /// - all components in `all` must have been registered in `components`; + /// - `required_id` must have been registered in `components`; + /// - `required_component` must hold a valid constructor for the component with id `required_id`. + unsafe fn register_inherited_required_components_unchecked( + all: &mut IndexMap, + required_id: ComponentId, + required_component: RequiredComponent, + components: &Components, ) { - let erased = || { - RequiredComponentConstructor({ - // `portable-atomic-util` `Arc` is not able to coerce an unsized - // type like `std::sync::Arc` can. Creating a `Box` first does the - // coercion. + // SAFETY: the caller guarantees that `required_id` is valid in `components`. + let info = unsafe { components.get_info(required_id).debug_checked_unwrap() }; + + // Now we need to "recursively" register the + // Small optimization: if the current required component was already required recursively + // by an earlier direct required component then all its inherited components have all already + // been inserted, so let's not try to reinsert them. + if !all.contains_key(&required_id) { + for (&inherited_id, inherited_required) in &info.required_components().all { + // This is an inherited required component: insert it only if not already present. + // By the invariants of `RequiredComponents`, `info.required_components().all` holds the required + // components in a depth-first order, and this makes us store teh components in `self.all` also + // in depth-first order, as long as we don't overwrite existing ones. // - // This would be resolved by https://github.com/rust-lang/rust/issues/123430 - - #[cfg(not(target_has_atomic = "ptr"))] - use alloc::boxed::Box; - - type Constructor = dyn for<'a, 'b> Fn( - &'a mut Table, - &'b mut SparseSets, - Tick, - TableRow, - Entity, - MaybeLocation, - ); - - #[cfg(not(target_has_atomic = "ptr"))] - type Intermediate = Box; - - #[cfg(target_has_atomic = "ptr")] - type Intermediate = Arc; - - let boxed: Intermediate = Intermediate::new( - move |table, sparse_sets, change_tick, table_row, entity, caller| { - OwningPtr::make(constructor(), |ptr| { - // SAFETY: This will only be called in the context of `BundleInfo::write_components`, which will - // pass in a valid table_row and entity requiring a C constructor - // C::STORAGE_TYPE is the storage type associated with `component_id` / `C` - // `ptr` points to valid `C` data, which matches the type associated with `component_id` - unsafe { - BundleInfo::initialize_required_component( - table, - sparse_sets, - change_tick, - table_row, - entity, - component_id, - C::STORAGE_TYPE, - ptr, - caller, - ); - } - }); - }, - ); - - Arc::from(boxed) - }) - }; + // SAFETY: + // `inherited_required` was associated to `inherited_id`, so it must have been valid for its component. + all.entry(inherited_id) + .or_insert_with(|| inherited_required.clone()); + } + } - // SAFETY: - // `component_id` matches the type initialized by the `erased` constructor above. - // `erased` initializes a component for `component_id` in such a way that - // matches the storage type of the component. It only uses the given `table_row` or `Entity` to - // initialize the storage corresponding to the given entity. - unsafe { self.register_dynamic_with(component_id, inheritance_depth, erased) }; + // For direct required components: + // - insert them after inherited components to follow the depth-first order; + // - insert them unconditionally in order to make their constructor the one that's used. + // Note that `insert` does not change the order of components, meaning `component_id` will still appear + // before any other component that requires it. + // + // SAFETY: the caller guaranees that `required_component` is valid for the component with ID `required_id`. + all.insert(required_id, required_component); } /// Iterates the ids of all required components. This includes recursive required components. pub fn iter_ids(&self) -> impl Iterator + '_ { - self.0.keys().copied() + self.all.keys().copied() } +} - /// Removes components that are explicitly provided in a given [`Bundle`]. These components should - /// be logically treated as normal components, not "required components". +impl Components { + /// Registers the components in `required_components` as required by `requiree`. + /// + /// # Safety /// - /// [`Bundle`]: crate::bundle::Bundle - pub(crate) fn remove_explicit_components(&mut self, components: &[ComponentId]) { - for component in components { - self.0.remove(component); + /// - `requiree` must have been registered in `self` + /// - all components in `required_components` must have been registered in `self`. + pub(crate) unsafe fn register_required_by( + &mut self, + requiree: ComponentId, + required_components: &RequiredComponents, + ) { + for &required in required_components.all.keys() { + // SAFETY: the caller guarantees that all components in `required_components` have been registered in `self`. + let required_by = unsafe { self.get_required_by_mut(required).debug_checked_unwrap() }; + required_by.insert(requiree); } } +} - /// Merges `required_components` into this collection. This only inserts a required component - /// if it _did not already exist_ *or* if the required component is more specific than the existing one - /// (in other words, if the inheritance depth is smaller). +impl Components { + /// Registers the given component `R` and [required components] inherited from it as required by `T`. /// - /// See [`register_dynamic_with`](Self::register_dynamic_with) for details. - pub(crate) fn merge(&mut self, required_components: &RequiredComponents) { - for ( - component_id, - RequiredComponent { - constructor, - inheritance_depth, - }, - ) in required_components.0.iter() - { - // SAFETY: This exact registration must have been done on `required_components`, so safety is ensured by that caller. - unsafe { - self.register_dynamic_with(*component_id, *inheritance_depth, || { - constructor.clone() - }); + /// When `T` is added to an entity, `R` will also be added if it was not already provided. + /// The given `constructor` will be used for the creation of `R`. + /// + /// [required components]: Component#required-components + /// + /// # Safety + /// + /// - the given component IDs `required` and `requiree` must be valid in `self`; + /// - the given component ID `required` must be valid for the component type `R`. + /// + /// + /// # Errors + /// + /// Returns a [`RequiredComponentsError`] if either of these are true: + /// - the `required` component is already a *directly* required component for the `requiree`; indirect + /// requirements through other components are allowed. In those cases, the more specific + /// registration will be used. + /// - the `requiree` component is already a (possibly indirect) required component for the `required` component. + pub(crate) unsafe fn register_required_components( + &mut self, + requiree: ComponentId, + required: ComponentId, + constructor: fn() -> R, + ) -> Result<(), RequiredComponentsError> { + // First step: validate inputs and return errors. + + // SAFETY: The caller ensures that the `required` is valid. + let required_required_components = unsafe { + self.get_required_components(required) + .debug_checked_unwrap() + }; + + // Cannot create cyclic requirements. + if required_required_components.all.contains_key(&requiree) { + return Err(RequiredComponentsError::CyclicRequirement( + requiree, required, + )); + } + + // SAFETY: The caller ensures that the `requiree` is valid. + let required_components = unsafe { + self.get_required_components_mut(requiree) + .debug_checked_unwrap() + }; + + // Cannot directly require the same component twice. + if required_components.direct.contains_key(&required) { + return Err(RequiredComponentsError::DuplicateRegistration( + requiree, required, + )); + } + + // Second step: register the single requirement requiree->required + + // Store the old count of (all) required components. This will help determine which ones are new. + let old_required_count = required_components.all.len(); + + // SAFETY: the caller guarantees that `requiree` and `required` are valid in `self`, with `required` valid for R. + unsafe { self.register_required_component_single(requiree, required, constructor) }; + + // Third step: update the required components and required_by of all the indirect requirements/requirees. + + // Borrow again otherwise it conflicts with the `self.register_required_component_single` call. + // SAFETY: The caller ensures that the `requiree` is valid. + let required_components = unsafe { + self.get_required_components_mut(requiree) + .debug_checked_unwrap() + }; + + // Optimization: get all the new required components, i.e. those that were appended. + // Other components that might be inherited when requiring `required` can be safely ignored because + // any component requiring `requiree` will already transitively require them. + // Note: the only small exception is for `required` itself, for which we cannot ignore the value of the + // constructor. But for simplicity we will rebuild any `RequiredComponents` + let new_required_components = required_components.all[old_required_count..] + .keys() + .copied() + .collect::>(); + + // Get all the new requiree components, i.e. `requiree` and all the components that `requiree` is required by. + // SAFETY: The caller ensures that the `requiree` is valid. + let new_requiree_components = + unsafe { self.get_required_by(requiree).debug_checked_unwrap() }.clone(); + + // We now need to update the required and required_by components of all the components + // directly or indirectly involved. + // Important: we need to be careful about the order we do these operations in. + // Since computing the required components of some component depends on the required components of + // other components, and while we do this operations not all required components are up-to-date, we need + // to ensure we update components in such a way that we update a component after the components it depends on. + // Luckily this is exactly the depth-first order, which is guaranteed to be the order of `new_requiree_components`. + + // Update the inherited required components of all requiree components (directly or indirectly). + for &indirect_requiree in &new_requiree_components { + // Extract the required components to avoid conflicting borrows. Remember to put this back before continuing! + // SAFETY: `indirect_requiree` comes from `self`, so it must be valid. + let mut required_components = std::mem::take(unsafe { + self.get_required_components_mut(indirect_requiree) + .debug_checked_unwrap() + }); + + // Rebuild the inherited required components. + // SAFETY: `required_components` comes from `self`, so all its components must have be valid in `self`. + unsafe { required_components.rebuild_inherited_required_components(self) }; + + // Let's not forget to put back `required_components`! + // SAFETY: `indirect_requiree` comes from `self`, so it must be valid. + *unsafe { + self.get_required_components_mut(indirect_requiree) + .debug_checked_unwrap() + } = required_components; + } + + // Update the `required_by` of all the components that were newly required (directly or indirectly). + for &indirect_required in &new_required_components { + // SAFETY: `indirect_required` comes from `self`, so it must be valid. + let required_by = unsafe { + self.get_required_by_mut(indirect_required) + .debug_checked_unwrap() + }; + + for &requiree in [&requiree].into_iter().chain(&new_requiree_components) { + required_by.insert_before(required_by.len(), requiree); } } + + Ok(()) } + + /// Register the `required` as a required component in the [`RequiredComponents`] for `requiree`. + /// This function does not update any other metadata, such as required components of components requiring `requiree`. + /// + /// # Safety + /// + /// - `requiree` and `required` must be defined in `self`. + /// - `required` must be a valid component ID for the type `R`. + unsafe fn register_required_component_single( + &mut self, + requiree: ComponentId, + required: ComponentId, + constructor: fn() -> R, + ) { + // Extract the required components to avoid conflicting borrows. Remember to put this back before returning! + // SAFETY: The caller ensures that the `requiree` is valid. + let mut required_components = std::mem::take(unsafe { + self.get_required_components_mut(requiree) + .debug_checked_unwrap() + }); + + // Register the required component for the requiree. + required_components.register_by_id(required, self, constructor); + + // Let's not forget to put back `required_components`! + // SAFETY: The caller ensures that the `requiree` is valid. + *unsafe { + self.get_required_components_mut(requiree) + .debug_checked_unwrap() + } = required_components; + } +} + +/// An error returned when the registration of a required component fails. +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum RequiredComponentsError { + /// The component is already a directly required component for the requiree. + #[error("Component {0:?} already directly requires component {1:?}")] + DuplicateRegistration(ComponentId, ComponentId), + /// Adding the given requirement would create a cycle. + #[error("Cyclic requirement found: the requiree component {0:?} is required by the required component {1:?}")] + CyclicRequirement(ComponentId, ComponentId), + /// An archetype with the component that requires other components already exists + #[error("An archetype with the component {0:?} that requires other components already exists")] + ArchetypeExists(ComponentId), } -fn enforce_no_required_components_recursion( +pub(super) fn enforce_no_required_components_recursion( components: &Components, recursion_check_stack: &[ComponentId], required: ComponentId, From dfb290a8987758ce84055da79ba5de46652c6860 Mon Sep 17 00:00:00 2001 From: Giacomo Stevanato Date: Sun, 13 Jul 2025 15:34:45 +0200 Subject: [PATCH 3/9] Move required components tests --- crates/bevy_ecs/src/component/required.rs | 801 ++++++++++++++++++++++ crates/bevy_ecs/src/lib.rs | 794 +-------------------- 2 files changed, 803 insertions(+), 792 deletions(-) diff --git a/crates/bevy_ecs/src/component/required.rs b/crates/bevy_ecs/src/component/required.rs index 18586ef0cff7e..e7e986d9090df 100644 --- a/crates/bevy_ecs/src/component/required.rs +++ b/crates/bevy_ecs/src/component/required.rs @@ -534,3 +534,804 @@ pub(super) fn enforce_no_required_components_recursion( ); } } + +#[cfg(test)] +mod tests { + use std::{ + string::{String, ToString}, + vec, + vec::Vec, + }; + + use crate::{ + bundle::Bundle, + component::{Component, ComponentId, RequiredComponents, RequiredComponentsError}, + prelude::Resource, + world::World, + }; + + #[test] + fn required_components() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component)] + #[require(Z = new_z())] + struct Y { + value: String, + } + + #[derive(Component)] + struct Z(u32); + + impl Default for Y { + fn default() -> Self { + Self { + value: "hello".to_string(), + } + } + } + + fn new_z() -> Z { + Z(7) + } + + let mut world = World::new(); + let id = world.spawn(X).id(); + assert_eq!( + "hello", + world.entity(id).get::().unwrap().value, + "Y should have the default value" + ); + assert_eq!( + 7, + world.entity(id).get::().unwrap().0, + "Z should have the value provided by the constructor defined in Y" + ); + + let id = world + .spawn(( + X, + Y { + value: "foo".to_string(), + }, + )) + .id(); + assert_eq!( + "foo", + world.entity(id).get::().unwrap().value, + "Y should have the manually provided value" + ); + assert_eq!( + 7, + world.entity(id).get::().unwrap().0, + "Z should have the value provided by the constructor defined in Y" + ); + + let id = world.spawn((X, Z(8))).id(); + assert_eq!( + "hello", + world.entity(id).get::().unwrap().value, + "Y should have the default value" + ); + assert_eq!( + 8, + world.entity(id).get::().unwrap().0, + "Z should have the manually provided value" + ); + } + + #[test] + fn generic_required_components() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component, Default)] + struct Y { + value: T, + } + + let mut world = World::new(); + let id = world.spawn(X).id(); + assert_eq!( + 0, + world.entity(id).get::>().unwrap().value, + "Y should have the default value" + ); + } + + #[test] + fn required_components_spawn_nonexistent_hooks() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component, Default)] + struct Y; + + #[derive(Resource)] + struct A(usize); + + #[derive(Resource)] + struct I(usize); + + let mut world = World::new(); + world.insert_resource(A(0)); + world.insert_resource(I(0)); + world + .register_component_hooks::() + .on_add(|mut world, _| world.resource_mut::().0 += 1) + .on_insert(|mut world, _| world.resource_mut::().0 += 1); + + // Spawn entity and ensure Y was added + assert!(world.spawn(X).contains::()); + + assert_eq!(world.resource::().0, 1); + assert_eq!(world.resource::().0, 1); + } + + #[test] + fn required_components_insert_existing_hooks() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component, Default)] + struct Y; + + #[derive(Resource)] + struct A(usize); + + #[derive(Resource)] + struct I(usize); + + let mut world = World::new(); + world.insert_resource(A(0)); + world.insert_resource(I(0)); + world + .register_component_hooks::() + .on_add(|mut world, _| world.resource_mut::().0 += 1) + .on_insert(|mut world, _| world.resource_mut::().0 += 1); + + // Spawn entity and ensure Y was added + assert!(world.spawn_empty().insert(X).contains::()); + + assert_eq!(world.resource::().0, 1); + assert_eq!(world.resource::().0, 1); + } + + #[test] + fn required_components_take_leaves_required() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component, Default)] + struct Y; + + let mut world = World::new(); + let e = world.spawn(X).id(); + let _ = world.entity_mut(e).take::().unwrap(); + assert!(world.entity_mut(e).contains::()); + } + + #[test] + fn required_components_retain_keeps_required() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component, Default)] + struct Y; + + #[derive(Component, Default)] + struct Z; + + let mut world = World::new(); + let e = world.spawn((X, Z)).id(); + world.entity_mut(e).retain::(); + assert!(world.entity_mut(e).contains::()); + assert!(world.entity_mut(e).contains::()); + assert!(!world.entity_mut(e).contains::()); + } + + #[test] + fn required_components_spawn_then_insert_no_overwrite() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component, Default)] + struct Y(usize); + + let mut world = World::new(); + let id = world.spawn((X, Y(10))).id(); + world.entity_mut(id).insert(X); + + assert_eq!( + 10, + world.entity(id).get::().unwrap().0, + "Y should still have the manually provided value" + ); + } + + #[test] + fn dynamic_required_components() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component, Default)] + struct Y; + + let mut world = World::new(); + let x_id = world.register_component::(); + + let mut e = world.spawn_empty(); + + // SAFETY: x_id is a valid component id + bevy_ptr::OwningPtr::make(X, |ptr| unsafe { + e.insert_by_id(x_id, ptr); + }); + + assert!(e.contains::()); + } + + #[test] + fn remove_component_and_its_runtime_required_components() { + #[derive(Component)] + struct X; + + #[derive(Component, Default)] + struct Y; + + #[derive(Component, Default)] + struct Z; + + #[derive(Component)] + struct V; + + let mut world = World::new(); + world.register_required_components::(); + world.register_required_components::(); + + let e = world.spawn((X, V)).id(); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + + //check that `remove` works as expected + world.entity_mut(e).remove::(); + assert!(!world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + + world.entity_mut(e).insert(X); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + + //remove `X` again and ensure that `Y` and `Z` was removed too + world.entity_mut(e).remove_with_requires::(); + assert!(!world.entity(e).contains::()); + assert!(!world.entity(e).contains::()); + assert!(!world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + } + + #[test] + fn remove_component_and_its_required_components() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component, Default)] + #[require(Z)] + struct Y; + + #[derive(Component, Default)] + struct Z; + + #[derive(Component)] + struct V; + + let mut world = World::new(); + + let e = world.spawn((X, V)).id(); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + + //check that `remove` works as expected + world.entity_mut(e).remove::(); + assert!(!world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + + world.entity_mut(e).insert(X); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + + //remove `X` again and ensure that `Y` and `Z` was removed too + world.entity_mut(e).remove_with_requires::(); + assert!(!world.entity(e).contains::()); + assert!(!world.entity(e).contains::()); + assert!(!world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + } + + #[test] + fn remove_bundle_and_his_required_components() { + #[derive(Component, Default)] + #[require(Y)] + struct X; + + #[derive(Component, Default)] + struct Y; + + #[derive(Component, Default)] + #[require(W)] + struct Z; + + #[derive(Component, Default)] + struct W; + + #[derive(Component)] + struct V; + + #[derive(Bundle, Default)] + struct TestBundle { + x: X, + z: Z, + } + + let mut world = World::new(); + let e = world.spawn((TestBundle::default(), V)).id(); + + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + + world.entity_mut(e).remove_with_requires::(); + assert!(!world.entity(e).contains::()); + assert!(!world.entity(e).contains::()); + assert!(!world.entity(e).contains::()); + assert!(!world.entity(e).contains::()); + assert!(world.entity(e).contains::()); + } + + #[test] + fn runtime_required_components() { + // Same as `required_components` test but with runtime registration + + #[derive(Component)] + struct X; + + #[derive(Component)] + struct Y { + value: String, + } + + #[derive(Component)] + struct Z(u32); + + impl Default for Y { + fn default() -> Self { + Self { + value: "hello".to_string(), + } + } + } + + let mut world = World::new(); + + world.register_required_components::(); + world.register_required_components_with::(|| Z(7)); + + let id = world.spawn(X).id(); + + assert_eq!( + "hello", + world.entity(id).get::().unwrap().value, + "Y should have the default value" + ); + assert_eq!( + 7, + world.entity(id).get::().unwrap().0, + "Z should have the value provided by the constructor defined in Y" + ); + + let id = world + .spawn(( + X, + Y { + value: "foo".to_string(), + }, + )) + .id(); + assert_eq!( + "foo", + world.entity(id).get::().unwrap().value, + "Y should have the manually provided value" + ); + assert_eq!( + 7, + world.entity(id).get::().unwrap().0, + "Z should have the value provided by the constructor defined in Y" + ); + + let id = world.spawn((X, Z(8))).id(); + assert_eq!( + "hello", + world.entity(id).get::().unwrap().value, + "Y should have the default value" + ); + assert_eq!( + 8, + world.entity(id).get::().unwrap().0, + "Z should have the manually provided value" + ); + } + + #[test] + fn runtime_required_components_override_1() { + #[derive(Component)] + struct X; + + #[derive(Component, Default)] + struct Y; + + #[derive(Component)] + struct Z(u32); + + let mut world = World::new(); + + // - X requires Y with default constructor + // - Y requires Z with custom constructor + // - X requires Z with custom constructor (more specific than X -> Y -> Z) + world.register_required_components::(); + world.register_required_components_with::(|| Z(5)); + world.register_required_components_with::(|| Z(7)); + + let id = world.spawn(X).id(); + + assert_eq!( + 7, + world.entity(id).get::().unwrap().0, + "Z should have the value provided by the constructor defined in X" + ); + } + + #[test] + fn runtime_required_components_override_2() { + // Same as `runtime_required_components_override_1` test but with different registration order + + #[derive(Component)] + struct X; + + #[derive(Component, Default)] + struct Y; + + #[derive(Component)] + struct Z(u32); + + let mut world = World::new(); + + // - X requires Y with default constructor + // - X requires Z with custom constructor (more specific than X -> Y -> Z) + // - Y requires Z with custom constructor + world.register_required_components::(); + world.register_required_components_with::(|| Z(7)); + world.register_required_components_with::(|| Z(5)); + + let id = world.spawn(X).id(); + + assert_eq!( + 7, + world.entity(id).get::().unwrap().0, + "Z should have the value provided by the constructor defined in X" + ); + } + + #[test] + fn runtime_required_components_propagate_up() { + // `A` requires `B` directly. + #[derive(Component)] + #[require(B)] + struct A; + + #[derive(Component, Default)] + struct B; + + #[derive(Component, Default)] + struct C; + + let mut world = World::new(); + + // `B` requires `C` with a runtime registration. + // `A` should also require `C` because it requires `B`. + world.register_required_components::(); + + let id = world.spawn(A).id(); + + assert!(world.entity(id).get::().is_some()); + } + + #[test] + fn runtime_required_components_propagate_up_even_more() { + #[derive(Component)] + struct A; + + #[derive(Component, Default)] + struct B; + + #[derive(Component, Default)] + struct C; + + #[derive(Component, Default)] + struct D; + + let mut world = World::new(); + + world.register_required_components::(); + world.register_required_components::(); + world.register_required_components::(); + + let id = world.spawn(A).id(); + + assert!(world.entity(id).get::().is_some()); + } + + #[test] + fn runtime_required_components_deep_require_does_not_override_shallow_require() { + #[derive(Component)] + struct A; + #[derive(Component, Default)] + struct B; + #[derive(Component, Default)] + struct C; + #[derive(Component)] + struct Counter(i32); + #[derive(Component, Default)] + struct D; + + let mut world = World::new(); + + world.register_required_components::(); + world.register_required_components::(); + world.register_required_components::(); + world.register_required_components_with::(|| Counter(2)); + // This should replace the require constructor in A since it is + // shallower. + world.register_required_components_with::(|| Counter(1)); + + let id = world.spawn(A).id(); + + // The "shallower" of the two components is used. + assert_eq!(world.entity(id).get::().unwrap().0, 1); + } + + #[test] + fn runtime_required_components_deep_require_does_not_override_shallow_require_deep_subtree_after_shallow( + ) { + #[derive(Component)] + struct A; + #[derive(Component, Default)] + struct B; + #[derive(Component, Default)] + struct C; + #[derive(Component, Default)] + struct D; + #[derive(Component, Default)] + struct E; + #[derive(Component)] + struct Counter(i32); + #[derive(Component, Default)] + struct F; + + let mut world = World::new(); + + world.register_required_components::(); + world.register_required_components::(); + world.register_required_components::(); + world.register_required_components::(); + world.register_required_components_with::(|| Counter(1)); + world.register_required_components_with::(|| Counter(2)); + world.register_required_components::(); + + let id = world.spawn(A).id(); + + // The "shallower" of the two components is used. + assert_eq!(world.entity(id).get::().unwrap().0, 1); + } + + #[test] + fn runtime_required_components_existing_archetype() { + #[derive(Component)] + struct X; + + #[derive(Component, Default)] + struct Y; + + let mut world = World::new(); + + // Registering required components after the archetype has already been created should panic. + // This may change in the future. + world.spawn(X); + assert!(matches!( + world.try_register_required_components::(), + Err(RequiredComponentsError::ArchetypeExists(_)) + )); + } + + #[test] + fn runtime_required_components_fail_with_duplicate() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component, Default)] + struct Y; + + let mut world = World::new(); + + // This should fail: Tried to register Y as a requirement for X, but the requirement already exists. + assert!(matches!( + world.try_register_required_components::(), + Err(RequiredComponentsError::DuplicateRegistration(_, _)) + )); + } + + #[test] + fn required_components_inheritance_depth() { + // Test that inheritance depths are computed correctly for requirements. + // + // Requirements with `require` attribute: + // + // A -> B -> C + // 0 1 + // + // Runtime requirements: + // + // X -> A -> B -> C + // 0 1 2 + // + // X -> Y -> Z -> B -> C + // 0 1 2 3 + + #[derive(Component, Default)] + #[require(B)] + struct A; + + #[derive(Component, Default)] + #[require(C)] + struct B; + + #[derive(Component, Default)] + struct C; + + #[derive(Component, Default)] + struct X; + + #[derive(Component, Default)] + struct Y; + + #[derive(Component, Default)] + struct Z; + + let mut world = World::new(); + + let a = world.register_component::(); + let b = world.register_component::(); + let c = world.register_component::(); + let y = world.register_component::(); + let z = world.register_component::(); + + world.register_required_components::(); + world.register_required_components::(); + world.register_required_components::(); + world.register_required_components::(); + + world.spawn(X); + + let required_a = world.get_required_components::().unwrap(); + let required_b = world.get_required_components::().unwrap(); + let required_c = world.get_required_components::().unwrap(); + let required_x = world.get_required_components::().unwrap(); + let required_y = world.get_required_components::().unwrap(); + let required_z = world.get_required_components::().unwrap(); + + /// Returns the component IDs and inheritance depths of the required components + /// in ascending order based on the component ID. + fn to_vec(required: &RequiredComponents) -> Vec<(ComponentId, u16)> { + let mut vec = required + .0 + .iter() + .map(|(id, component)| (*id, component.inheritance_depth)) + .collect::>(); + vec.sort_by_key(|(id, _)| *id); + vec + } + + // Check that the inheritance depths are correct for each component. + assert_eq!(to_vec(required_a), vec![(b, 0), (c, 1)]); + assert_eq!(to_vec(required_b), vec![(c, 0)]); + assert_eq!(to_vec(required_c), vec![]); + assert_eq!( + to_vec(required_x), + vec![(a, 0), (b, 1), (c, 2), (y, 0), (z, 1)] + ); + assert_eq!(to_vec(required_y), vec![(b, 1), (c, 2), (z, 0)]); + assert_eq!(to_vec(required_z), vec![(b, 0), (c, 1)]); + } + + #[test] + fn required_components_inheritance_depth_bias() { + #[derive(Component, PartialEq, Eq, Clone, Copy, Debug)] + struct MyRequired(bool); + + #[derive(Component, Default)] + #[require(MyRequired(false))] + struct MiddleMan; + + #[derive(Component, Default)] + #[require(MiddleMan)] + struct ConflictingRequire; + + #[derive(Component, Default)] + #[require(MyRequired(true))] + struct MyComponent; + + let mut world = World::new(); + let order_a = world + .spawn((ConflictingRequire, MyComponent)) + .get::() + .cloned(); + let order_b = world + .spawn((MyComponent, ConflictingRequire)) + .get::() + .cloned(); + + assert_eq!(order_a, Some(MyRequired(true))); + assert_eq!(order_b, Some(MyRequired(true))); + } + + #[test] + #[should_panic] + fn required_components_recursion_errors() { + #[derive(Component, Default)] + #[require(B)] + struct A; + + #[derive(Component, Default)] + #[require(C)] + struct B; + + #[derive(Component, Default)] + #[require(B)] + struct C; + + World::new().register_component::(); + } + + #[test] + #[should_panic] + fn required_components_self_errors() { + #[derive(Component, Default)] + #[require(A)] + struct A; + + World::new().register_component::(); + } +} diff --git a/crates/bevy_ecs/src/lib.rs b/crates/bevy_ecs/src/lib.rs index 8a07cdc8e1b92..abb30e973dac5 100644 --- a/crates/bevy_ecs/src/lib.rs +++ b/crates/bevy_ecs/src/lib.rs @@ -148,7 +148,7 @@ mod tests { use crate::{ bundle::Bundle, change_detection::Ref, - component::{Component, ComponentId, RequiredComponents, RequiredComponentsError}, + component::{Component, ComponentId}, entity::{Entity, EntityMapper}, entity_disabling::DefaultQueryFilters, prelude::Or, @@ -156,12 +156,7 @@ mod tests { resource::Resource, world::{EntityMut, EntityRef, Mut, World}, }; - use alloc::{ - string::{String, ToString}, - sync::Arc, - vec, - vec::Vec, - }; + use alloc::{string::String, sync::Arc, vec, vec::Vec}; use bevy_platform::collections::HashSet; use bevy_tasks::{ComputeTaskPool, TaskPool}; use core::{ @@ -1830,791 +1825,6 @@ mod tests { ); } - #[test] - fn required_components() { - #[derive(Component)] - #[require(Y)] - struct X; - - #[derive(Component)] - #[require(Z = new_z())] - struct Y { - value: String, - } - - #[derive(Component)] - struct Z(u32); - - impl Default for Y { - fn default() -> Self { - Self { - value: "hello".to_string(), - } - } - } - - fn new_z() -> Z { - Z(7) - } - - let mut world = World::new(); - let id = world.spawn(X).id(); - assert_eq!( - "hello", - world.entity(id).get::().unwrap().value, - "Y should have the default value" - ); - assert_eq!( - 7, - world.entity(id).get::().unwrap().0, - "Z should have the value provided by the constructor defined in Y" - ); - - let id = world - .spawn(( - X, - Y { - value: "foo".to_string(), - }, - )) - .id(); - assert_eq!( - "foo", - world.entity(id).get::().unwrap().value, - "Y should have the manually provided value" - ); - assert_eq!( - 7, - world.entity(id).get::().unwrap().0, - "Z should have the value provided by the constructor defined in Y" - ); - - let id = world.spawn((X, Z(8))).id(); - assert_eq!( - "hello", - world.entity(id).get::().unwrap().value, - "Y should have the default value" - ); - assert_eq!( - 8, - world.entity(id).get::().unwrap().0, - "Z should have the manually provided value" - ); - } - - #[test] - fn generic_required_components() { - #[derive(Component)] - #[require(Y)] - struct X; - - #[derive(Component, Default)] - struct Y { - value: T, - } - - let mut world = World::new(); - let id = world.spawn(X).id(); - assert_eq!( - 0, - world.entity(id).get::>().unwrap().value, - "Y should have the default value" - ); - } - - #[test] - fn required_components_spawn_nonexistent_hooks() { - #[derive(Component)] - #[require(Y)] - struct X; - - #[derive(Component, Default)] - struct Y; - - #[derive(Resource)] - struct A(usize); - - #[derive(Resource)] - struct I(usize); - - let mut world = World::new(); - world.insert_resource(A(0)); - world.insert_resource(I(0)); - world - .register_component_hooks::() - .on_add(|mut world, _| world.resource_mut::().0 += 1) - .on_insert(|mut world, _| world.resource_mut::().0 += 1); - - // Spawn entity and ensure Y was added - assert!(world.spawn(X).contains::()); - - assert_eq!(world.resource::().0, 1); - assert_eq!(world.resource::().0, 1); - } - - #[test] - fn required_components_insert_existing_hooks() { - #[derive(Component)] - #[require(Y)] - struct X; - - #[derive(Component, Default)] - struct Y; - - #[derive(Resource)] - struct A(usize); - - #[derive(Resource)] - struct I(usize); - - let mut world = World::new(); - world.insert_resource(A(0)); - world.insert_resource(I(0)); - world - .register_component_hooks::() - .on_add(|mut world, _| world.resource_mut::().0 += 1) - .on_insert(|mut world, _| world.resource_mut::().0 += 1); - - // Spawn entity and ensure Y was added - assert!(world.spawn_empty().insert(X).contains::()); - - assert_eq!(world.resource::().0, 1); - assert_eq!(world.resource::().0, 1); - } - - #[test] - fn required_components_take_leaves_required() { - #[derive(Component)] - #[require(Y)] - struct X; - - #[derive(Component, Default)] - struct Y; - - let mut world = World::new(); - let e = world.spawn(X).id(); - let _ = world.entity_mut(e).take::().unwrap(); - assert!(world.entity_mut(e).contains::()); - } - - #[test] - fn required_components_retain_keeps_required() { - #[derive(Component)] - #[require(Y)] - struct X; - - #[derive(Component, Default)] - struct Y; - - #[derive(Component, Default)] - struct Z; - - let mut world = World::new(); - let e = world.spawn((X, Z)).id(); - world.entity_mut(e).retain::(); - assert!(world.entity_mut(e).contains::()); - assert!(world.entity_mut(e).contains::()); - assert!(!world.entity_mut(e).contains::()); - } - - #[test] - fn required_components_spawn_then_insert_no_overwrite() { - #[derive(Component)] - #[require(Y)] - struct X; - - #[derive(Component, Default)] - struct Y(usize); - - let mut world = World::new(); - let id = world.spawn((X, Y(10))).id(); - world.entity_mut(id).insert(X); - - assert_eq!( - 10, - world.entity(id).get::().unwrap().0, - "Y should still have the manually provided value" - ); - } - - #[test] - fn dynamic_required_components() { - #[derive(Component)] - #[require(Y)] - struct X; - - #[derive(Component, Default)] - struct Y; - - let mut world = World::new(); - let x_id = world.register_component::(); - - let mut e = world.spawn_empty(); - - // SAFETY: x_id is a valid component id - bevy_ptr::OwningPtr::make(X, |ptr| unsafe { - e.insert_by_id(x_id, ptr); - }); - - assert!(e.contains::()); - } - - #[test] - fn remove_component_and_its_runtime_required_components() { - #[derive(Component)] - struct X; - - #[derive(Component, Default)] - struct Y; - - #[derive(Component, Default)] - struct Z; - - #[derive(Component)] - struct V; - - let mut world = World::new(); - world.register_required_components::(); - world.register_required_components::(); - - let e = world.spawn((X, V)).id(); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - - //check that `remove` works as expected - world.entity_mut(e).remove::(); - assert!(!world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - - world.entity_mut(e).insert(X); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - - //remove `X` again and ensure that `Y` and `Z` was removed too - world.entity_mut(e).remove_with_requires::(); - assert!(!world.entity(e).contains::()); - assert!(!world.entity(e).contains::()); - assert!(!world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - } - - #[test] - fn remove_component_and_its_required_components() { - #[derive(Component)] - #[require(Y)] - struct X; - - #[derive(Component, Default)] - #[require(Z)] - struct Y; - - #[derive(Component, Default)] - struct Z; - - #[derive(Component)] - struct V; - - let mut world = World::new(); - - let e = world.spawn((X, V)).id(); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - - //check that `remove` works as expected - world.entity_mut(e).remove::(); - assert!(!world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - - world.entity_mut(e).insert(X); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - - //remove `X` again and ensure that `Y` and `Z` was removed too - world.entity_mut(e).remove_with_requires::(); - assert!(!world.entity(e).contains::()); - assert!(!world.entity(e).contains::()); - assert!(!world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - } - - #[test] - fn remove_bundle_and_his_required_components() { - #[derive(Component, Default)] - #[require(Y)] - struct X; - - #[derive(Component, Default)] - struct Y; - - #[derive(Component, Default)] - #[require(W)] - struct Z; - - #[derive(Component, Default)] - struct W; - - #[derive(Component)] - struct V; - - #[derive(Bundle, Default)] - struct TestBundle { - x: X, - z: Z, - } - - let mut world = World::new(); - let e = world.spawn((TestBundle::default(), V)).id(); - - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - - world.entity_mut(e).remove_with_requires::(); - assert!(!world.entity(e).contains::()); - assert!(!world.entity(e).contains::()); - assert!(!world.entity(e).contains::()); - assert!(!world.entity(e).contains::()); - assert!(world.entity(e).contains::()); - } - - #[test] - fn runtime_required_components() { - // Same as `required_components` test but with runtime registration - - #[derive(Component)] - struct X; - - #[derive(Component)] - struct Y { - value: String, - } - - #[derive(Component)] - struct Z(u32); - - impl Default for Y { - fn default() -> Self { - Self { - value: "hello".to_string(), - } - } - } - - let mut world = World::new(); - - world.register_required_components::(); - world.register_required_components_with::(|| Z(7)); - - let id = world.spawn(X).id(); - - assert_eq!( - "hello", - world.entity(id).get::().unwrap().value, - "Y should have the default value" - ); - assert_eq!( - 7, - world.entity(id).get::().unwrap().0, - "Z should have the value provided by the constructor defined in Y" - ); - - let id = world - .spawn(( - X, - Y { - value: "foo".to_string(), - }, - )) - .id(); - assert_eq!( - "foo", - world.entity(id).get::().unwrap().value, - "Y should have the manually provided value" - ); - assert_eq!( - 7, - world.entity(id).get::().unwrap().0, - "Z should have the value provided by the constructor defined in Y" - ); - - let id = world.spawn((X, Z(8))).id(); - assert_eq!( - "hello", - world.entity(id).get::().unwrap().value, - "Y should have the default value" - ); - assert_eq!( - 8, - world.entity(id).get::().unwrap().0, - "Z should have the manually provided value" - ); - } - - #[test] - fn runtime_required_components_override_1() { - #[derive(Component)] - struct X; - - #[derive(Component, Default)] - struct Y; - - #[derive(Component)] - struct Z(u32); - - let mut world = World::new(); - - // - X requires Y with default constructor - // - Y requires Z with custom constructor - // - X requires Z with custom constructor (more specific than X -> Y -> Z) - world.register_required_components::(); - world.register_required_components_with::(|| Z(5)); - world.register_required_components_with::(|| Z(7)); - - let id = world.spawn(X).id(); - - assert_eq!( - 7, - world.entity(id).get::().unwrap().0, - "Z should have the value provided by the constructor defined in X" - ); - } - - #[test] - fn runtime_required_components_override_2() { - // Same as `runtime_required_components_override_1` test but with different registration order - - #[derive(Component)] - struct X; - - #[derive(Component, Default)] - struct Y; - - #[derive(Component)] - struct Z(u32); - - let mut world = World::new(); - - // - X requires Y with default constructor - // - X requires Z with custom constructor (more specific than X -> Y -> Z) - // - Y requires Z with custom constructor - world.register_required_components::(); - world.register_required_components_with::(|| Z(7)); - world.register_required_components_with::(|| Z(5)); - - let id = world.spawn(X).id(); - - assert_eq!( - 7, - world.entity(id).get::().unwrap().0, - "Z should have the value provided by the constructor defined in X" - ); - } - - #[test] - fn runtime_required_components_propagate_up() { - // `A` requires `B` directly. - #[derive(Component)] - #[require(B)] - struct A; - - #[derive(Component, Default)] - struct B; - - #[derive(Component, Default)] - struct C; - - let mut world = World::new(); - - // `B` requires `C` with a runtime registration. - // `A` should also require `C` because it requires `B`. - world.register_required_components::(); - - let id = world.spawn(A).id(); - - assert!(world.entity(id).get::().is_some()); - } - - #[test] - fn runtime_required_components_propagate_up_even_more() { - #[derive(Component)] - struct A; - - #[derive(Component, Default)] - struct B; - - #[derive(Component, Default)] - struct C; - - #[derive(Component, Default)] - struct D; - - let mut world = World::new(); - - world.register_required_components::(); - world.register_required_components::(); - world.register_required_components::(); - - let id = world.spawn(A).id(); - - assert!(world.entity(id).get::().is_some()); - } - - #[test] - fn runtime_required_components_deep_require_does_not_override_shallow_require() { - #[derive(Component)] - struct A; - #[derive(Component, Default)] - struct B; - #[derive(Component, Default)] - struct C; - #[derive(Component)] - struct Counter(i32); - #[derive(Component, Default)] - struct D; - - let mut world = World::new(); - - world.register_required_components::(); - world.register_required_components::(); - world.register_required_components::(); - world.register_required_components_with::(|| Counter(2)); - // This should replace the require constructor in A since it is - // shallower. - world.register_required_components_with::(|| Counter(1)); - - let id = world.spawn(A).id(); - - // The "shallower" of the two components is used. - assert_eq!(world.entity(id).get::().unwrap().0, 1); - } - - #[test] - fn runtime_required_components_deep_require_does_not_override_shallow_require_deep_subtree_after_shallow( - ) { - #[derive(Component)] - struct A; - #[derive(Component, Default)] - struct B; - #[derive(Component, Default)] - struct C; - #[derive(Component, Default)] - struct D; - #[derive(Component, Default)] - struct E; - #[derive(Component)] - struct Counter(i32); - #[derive(Component, Default)] - struct F; - - let mut world = World::new(); - - world.register_required_components::(); - world.register_required_components::(); - world.register_required_components::(); - world.register_required_components::(); - world.register_required_components_with::(|| Counter(1)); - world.register_required_components_with::(|| Counter(2)); - world.register_required_components::(); - - let id = world.spawn(A).id(); - - // The "shallower" of the two components is used. - assert_eq!(world.entity(id).get::().unwrap().0, 1); - } - - #[test] - fn runtime_required_components_existing_archetype() { - #[derive(Component)] - struct X; - - #[derive(Component, Default)] - struct Y; - - let mut world = World::new(); - - // Registering required components after the archetype has already been created should panic. - // This may change in the future. - world.spawn(X); - assert!(matches!( - world.try_register_required_components::(), - Err(RequiredComponentsError::ArchetypeExists(_)) - )); - } - - #[test] - fn runtime_required_components_fail_with_duplicate() { - #[derive(Component)] - #[require(Y)] - struct X; - - #[derive(Component, Default)] - struct Y; - - let mut world = World::new(); - - // This should fail: Tried to register Y as a requirement for X, but the requirement already exists. - assert!(matches!( - world.try_register_required_components::(), - Err(RequiredComponentsError::DuplicateRegistration(_, _)) - )); - } - - #[test] - fn required_components_inheritance_depth() { - // Test that inheritance depths are computed correctly for requirements. - // - // Requirements with `require` attribute: - // - // A -> B -> C - // 0 1 - // - // Runtime requirements: - // - // X -> A -> B -> C - // 0 1 2 - // - // X -> Y -> Z -> B -> C - // 0 1 2 3 - - #[derive(Component, Default)] - #[require(B)] - struct A; - - #[derive(Component, Default)] - #[require(C)] - struct B; - - #[derive(Component, Default)] - struct C; - - #[derive(Component, Default)] - struct X; - - #[derive(Component, Default)] - struct Y; - - #[derive(Component, Default)] - struct Z; - - let mut world = World::new(); - - let a = world.register_component::(); - let b = world.register_component::(); - let c = world.register_component::(); - let y = world.register_component::(); - let z = world.register_component::(); - - world.register_required_components::(); - world.register_required_components::(); - world.register_required_components::(); - world.register_required_components::(); - - world.spawn(X); - - let required_a = world.get_required_components::().unwrap(); - let required_b = world.get_required_components::().unwrap(); - let required_c = world.get_required_components::().unwrap(); - let required_x = world.get_required_components::().unwrap(); - let required_y = world.get_required_components::().unwrap(); - let required_z = world.get_required_components::().unwrap(); - - /// Returns the component IDs and inheritance depths of the required components - /// in ascending order based on the component ID. - fn to_vec(required: &RequiredComponents) -> Vec<(ComponentId, u16)> { - let mut vec = required - .0 - .iter() - .map(|(id, component)| (*id, component.inheritance_depth)) - .collect::>(); - vec.sort_by_key(|(id, _)| *id); - vec - } - - // Check that the inheritance depths are correct for each component. - assert_eq!(to_vec(required_a), vec![(b, 0), (c, 1)]); - assert_eq!(to_vec(required_b), vec![(c, 0)]); - assert_eq!(to_vec(required_c), vec![]); - assert_eq!( - to_vec(required_x), - vec![(a, 0), (b, 1), (c, 2), (y, 0), (z, 1)] - ); - assert_eq!(to_vec(required_y), vec![(b, 1), (c, 2), (z, 0)]); - assert_eq!(to_vec(required_z), vec![(b, 0), (c, 1)]); - } - - #[test] - fn required_components_inheritance_depth_bias() { - #[derive(Component, PartialEq, Eq, Clone, Copy, Debug)] - struct MyRequired(bool); - - #[derive(Component, Default)] - #[require(MyRequired(false))] - struct MiddleMan; - - #[derive(Component, Default)] - #[require(MiddleMan)] - struct ConflictingRequire; - - #[derive(Component, Default)] - #[require(MyRequired(true))] - struct MyComponent; - - let mut world = World::new(); - let order_a = world - .spawn((ConflictingRequire, MyComponent)) - .get::() - .cloned(); - let order_b = world - .spawn((MyComponent, ConflictingRequire)) - .get::() - .cloned(); - - assert_eq!(order_a, Some(MyRequired(true))); - assert_eq!(order_b, Some(MyRequired(true))); - } - - #[test] - #[should_panic] - fn required_components_recursion_errors() { - #[derive(Component, Default)] - #[require(B)] - struct A; - - #[derive(Component, Default)] - #[require(C)] - struct B; - - #[derive(Component, Default)] - #[require(B)] - struct C; - - World::new().register_component::(); - } - - #[test] - #[should_panic] - fn required_components_self_errors() { - #[derive(Component, Default)] - #[require(A)] - struct A; - - World::new().register_component::(); - } - #[derive(Default)] struct CaptureMapper(Vec); impl EntityMapper for CaptureMapper { From 2d404435c910929abe102f74081812a50ff4a044 Mon Sep 17 00:00:00 2001 From: Giacomo Stevanato Date: Sun, 13 Jul 2025 15:39:19 +0200 Subject: [PATCH 4/9] Remove and/or adapt invalid tests --- crates/bevy_ecs/src/component/required.rs | 95 +---------------------- 1 file changed, 4 insertions(+), 91 deletions(-) diff --git a/crates/bevy_ecs/src/component/required.rs b/crates/bevy_ecs/src/component/required.rs index e7e986d9090df..4193f185eba97 100644 --- a/crates/bevy_ecs/src/component/required.rs +++ b/crates/bevy_ecs/src/component/required.rs @@ -537,15 +537,11 @@ pub(super) fn enforce_no_required_components_recursion( #[cfg(test)] mod tests { - use std::{ - string::{String, ToString}, - vec, - vec::Vec, - }; + use std::string::{String, ToString}; use crate::{ bundle::Bundle, - component::{Component, ComponentId, RequiredComponents, RequiredComponentsError}, + component::{Component, RequiredComponentsError}, prelude::Resource, world::World, }; @@ -1194,90 +1190,7 @@ mod tests { } #[test] - fn required_components_inheritance_depth() { - // Test that inheritance depths are computed correctly for requirements. - // - // Requirements with `require` attribute: - // - // A -> B -> C - // 0 1 - // - // Runtime requirements: - // - // X -> A -> B -> C - // 0 1 2 - // - // X -> Y -> Z -> B -> C - // 0 1 2 3 - - #[derive(Component, Default)] - #[require(B)] - struct A; - - #[derive(Component, Default)] - #[require(C)] - struct B; - - #[derive(Component, Default)] - struct C; - - #[derive(Component, Default)] - struct X; - - #[derive(Component, Default)] - struct Y; - - #[derive(Component, Default)] - struct Z; - - let mut world = World::new(); - - let a = world.register_component::(); - let b = world.register_component::(); - let c = world.register_component::(); - let y = world.register_component::(); - let z = world.register_component::(); - - world.register_required_components::(); - world.register_required_components::(); - world.register_required_components::(); - world.register_required_components::(); - - world.spawn(X); - - let required_a = world.get_required_components::().unwrap(); - let required_b = world.get_required_components::().unwrap(); - let required_c = world.get_required_components::().unwrap(); - let required_x = world.get_required_components::().unwrap(); - let required_y = world.get_required_components::().unwrap(); - let required_z = world.get_required_components::().unwrap(); - - /// Returns the component IDs and inheritance depths of the required components - /// in ascending order based on the component ID. - fn to_vec(required: &RequiredComponents) -> Vec<(ComponentId, u16)> { - let mut vec = required - .0 - .iter() - .map(|(id, component)| (*id, component.inheritance_depth)) - .collect::>(); - vec.sort_by_key(|(id, _)| *id); - vec - } - - // Check that the inheritance depths are correct for each component. - assert_eq!(to_vec(required_a), vec![(b, 0), (c, 1)]); - assert_eq!(to_vec(required_b), vec![(c, 0)]); - assert_eq!(to_vec(required_c), vec![]); - assert_eq!( - to_vec(required_x), - vec![(a, 0), (b, 1), (c, 2), (y, 0), (z, 1)] - ); - assert_eq!(to_vec(required_y), vec![(b, 1), (c, 2), (z, 0)]); - assert_eq!(to_vec(required_z), vec![(b, 0), (c, 1)]); - } - - #[test] - fn required_components_inheritance_depth_bias() { + fn required_components_bundle_priority() { #[derive(Component, PartialEq, Eq, Clone, Copy, Debug)] struct MyRequired(bool); @@ -1303,7 +1216,7 @@ mod tests { .get::() .cloned(); - assert_eq!(order_a, Some(MyRequired(true))); + assert_eq!(order_a, Some(MyRequired(false))); assert_eq!(order_b, Some(MyRequired(true))); } From 229a89714b779a4dc1f349f819aebeed06150f60 Mon Sep 17 00:00:00 2001 From: Giacomo Stevanato Date: Sun, 13 Jul 2025 15:54:49 +0200 Subject: [PATCH 5/9] Add migration guide --- .../migration-guides/required_components_rework.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 release-content/migration-guides/required_components_rework.md diff --git a/release-content/migration-guides/required_components_rework.md b/release-content/migration-guides/required_components_rework.md new file mode 100644 index 0000000000000..db5c55abb4eef --- /dev/null +++ b/release-content/migration-guides/required_components_rework.md @@ -0,0 +1,13 @@ +--- +title: Required components refactor +pull_requests: [20110] +--- + +The required components feature has been reworked to be more consistent around the priority of the required components and fix some soundness issues. In particular: + +- the priority of required components will now always follow a priority given by the depth-first/preorder traversal of the dependency tree. This was mostly the case before with a couple of exceptions that we are now fixing: + - when deriving the `Component` trait, sometimes required components at depth 1 had priority over components at depth 2 even if they came after in the depth-first ordering; + - registering runtime required components followed a breadth-first ordering and used the wrong inheritance depth for derived required components. +- uses of the inheritance depth were removed from the `RequiredComponent` struct and from the methods for registering runtime required components, as it's not unused for the depth-first ordering; +- `Component::register_required_components`, `RequiredComponents::register` and `RequiredComponents::register_by_id` are now `unsafe`; +- `RequiredComponentConstructor`'s only field is now private for safety reasons. From 87bd81b8abc320b6bd826df3ba336a4b67191fec Mon Sep 17 00:00:00 2001 From: Giacomo Stevanato Date: Sun, 13 Jul 2025 16:27:07 +0200 Subject: [PATCH 6/9] Add regression test --- crates/bevy_ecs/src/component/required.rs | 27 +++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/crates/bevy_ecs/src/component/required.rs b/crates/bevy_ecs/src/component/required.rs index 4193f185eba97..37ec81021f9d0 100644 --- a/crates/bevy_ecs/src/component/required.rs +++ b/crates/bevy_ecs/src/component/required.rs @@ -1247,4 +1247,31 @@ mod tests { World::new().register_component::(); } + + #[test] + fn regression_19333() { + #[derive(Component)] + struct X(bool); + + #[derive(Default, Component)] + #[require(X(false))] + struct Base; + + #[derive(Default, Component)] + #[require(X(true), Base)] + struct A; + + #[derive(Default, Component)] + #[require(A, Base)] + struct B; + + #[derive(Default, Component)] + #[require(B, Base)] + struct C; + + let mut w = World::new(); + + assert_eq!(w.spawn(B).get::().unwrap().0, true); + assert_eq!(w.spawn(C).get::().unwrap().0, true); + } } From 18c4ca405e9d22b2663827e7a611828f2a25ae10 Mon Sep 17 00:00:00 2001 From: Giacomo Stevanato Date: Sun, 13 Jul 2025 16:22:30 +0200 Subject: [PATCH 7/9] Fix CI issues --- crates/bevy_ecs/src/bundle/info.rs | 14 +++++-- crates/bevy_ecs/src/component/info.rs | 11 ++++-- crates/bevy_ecs/src/component/required.rs | 45 ++++++++++++----------- 3 files changed, 41 insertions(+), 29 deletions(-) diff --git a/crates/bevy_ecs/src/bundle/info.rs b/crates/bevy_ecs/src/bundle/info.rs index 3ac341fd73b3c..b92c2bb1a96f9 100644 --- a/crates/bevy_ecs/src/bundle/info.rs +++ b/crates/bevy_ecs/src/bundle/info.rs @@ -1,5 +1,8 @@ use alloc::{boxed::Box, vec, vec::Vec}; -use bevy_platform::collections::{HashMap, HashSet}; +use bevy_platform::{ + collections::{HashMap, HashSet}, + hash::FixedHasher, +}; use bevy_ptr::OwningPtr; use bevy_utils::TypeIdMap; use core::{any::TypeId, ptr::NonNull}; @@ -89,7 +92,10 @@ impl BundleInfo { mut component_ids: Vec, id: BundleId, ) -> BundleInfo { - let explicit_component_ids = component_ids.iter().copied().collect::>(); + let explicit_component_ids = component_ids + .iter() + .copied() + .collect::>(); // check for duplicates if explicit_component_ids.len() != component_ids.len() { @@ -113,7 +119,7 @@ impl BundleInfo { panic!("Bundle {bundle_type_name} has duplicate components: {names:?}"); } - let mut depth_first_components = IndexMap::new(); + let mut depth_first_components = IndexMap::<_, _, FixedHasher>::default(); for &component_id in &component_ids { // SAFETY: caller has verified that all ids are valid let info = unsafe { components.get_info_unchecked(component_id) }; @@ -155,7 +161,7 @@ impl BundleInfo { self.id } - /// Returns the length of the explicit components part of the [contributed_components](Self::contributed_components) list. + /// Returns the length of the explicit components part of the [`contributed_components`](Self::contributed_components) list. pub(super) fn explicit_components_len(&self) -> usize { self.contributed_components.len() - self.required_component_constructors.len() } diff --git a/crates/bevy_ecs/src/component/info.rs b/crates/bevy_ecs/src/component/info.rs index 9deeac9b4d28b..09e5ed948ac12 100644 --- a/crates/bevy_ecs/src/component/info.rs +++ b/crates/bevy_ecs/src/component/info.rs @@ -1,5 +1,5 @@ use alloc::{borrow::Cow, vec::Vec}; -use bevy_platform::sync::PoisonError; +use bevy_platform::{hash::FixedHasher, sync::PoisonError}; use bevy_ptr::OwningPtr; #[cfg(feature = "bevy_reflect")] use bevy_reflect::Reflect; @@ -34,7 +34,7 @@ pub struct ComponentInfo { /// The set of components that require this components. /// Invariant: this is stored in a depth-first order, that is components are stored after the components /// that they depend on. - pub(super) required_by: IndexSet, + pub(super) required_by: IndexSet, } impl ComponentInfo { @@ -527,7 +527,10 @@ impl Components { } #[inline] - pub(crate) fn get_required_by(&self, id: ComponentId) -> Option<&IndexSet> { + pub(crate) fn get_required_by( + &self, + id: ComponentId, + ) -> Option<&IndexSet> { self.components .get(id.0) .and_then(|info| info.as_ref().map(|info| &info.required_by)) @@ -537,7 +540,7 @@ impl Components { pub(crate) fn get_required_by_mut( &mut self, id: ComponentId, - ) -> Option<&mut IndexSet> { + ) -> Option<&mut IndexSet> { self.components .get_mut(id.0) .and_then(|info| info.as_mut().map(|info| &mut info.required_by)) diff --git a/crates/bevy_ecs/src/component/required.rs b/crates/bevy_ecs/src/component/required.rs index 37ec81021f9d0..0a0eef7d687be 100644 --- a/crates/bevy_ecs/src/component/required.rs +++ b/crates/bevy_ecs/src/component/required.rs @@ -1,5 +1,5 @@ use alloc::{format, vec::Vec}; -use bevy_platform::sync::Arc; +use bevy_platform::{hash::FixedHasher, sync::Arc}; use bevy_ptr::OwningPtr; use core::fmt::Debug; use indexmap::{IndexMap, IndexSet}; @@ -119,7 +119,7 @@ pub struct RequiredComponents { /// /// # Safety /// The [`RequiredComponent`] instance associated to each ID must be valid for its component. - pub(crate) direct: IndexMap, + pub(crate) direct: IndexMap, /// All the components that are required (i.e. including inherited ones), in depth-first order. Most importantly, /// components in this list always appear after all the components that they require. /// @@ -128,7 +128,7 @@ pub struct RequiredComponents { /// /// # Safety /// The [`RequiredComponent`] instance associated to each ID must be valid for its component. - pub(crate) all: IndexMap, + pub(crate) all: IndexMap, } impl Debug for RequiredComponents { @@ -149,7 +149,7 @@ impl RequiredComponents { /// /// # Safety /// - /// - all other components in this [`RequiredComponents`] instance must have been registrated in `components`. + /// - all other components in this [`RequiredComponents`] instance must have been registered in `components`. pub unsafe fn register( &mut self, components: &mut ComponentsRegistrator<'_>, @@ -171,7 +171,7 @@ impl RequiredComponents { /// # Safety /// /// - `component_id` must be a valid component in `components` for the type `C`; - /// - all other components in this [`RequiredComponents`] instance must have been registrated in `components`. + /// - all other components in this [`RequiredComponents`] instance must have been registered in `components`. pub unsafe fn register_by_id( &mut self, component_id: ComponentId, @@ -198,7 +198,7 @@ impl RequiredComponents { /// # Safety /// /// - `component_id` must be a valid component in `components`; - /// - all other components in this [`RequiredComponents`] instance must have been registrated in `components`; + /// - all other components in `self` must have been registered in `components`; /// - `constructor` must return a [`RequiredComponentConstructor`] that constructs a valid instance for the /// component with ID `component_id`. pub unsafe fn register_dynamic_with( @@ -219,14 +219,17 @@ impl RequiredComponents { entry.insert(required_component.clone()); // Register inherited required components. + // SAFETY: + // - the caller guarantees all components that were in `self` have been registered in `components`; + // - `component_id` has just been added, but is also guaranteed by the called to be valid in `components`. unsafe { Self::register_inherited_required_components_unchecked( &mut self.all, component_id, required_component, components, - ) - }; + ); + } true } @@ -235,7 +238,7 @@ impl RequiredComponents { /// /// # Safety /// - /// - all components in this [`RequiredComponents`] instance must have been registrated in `components`. + /// - all components in `self` must have been registered in `components`. unsafe fn rebuild_inherited_required_components(&mut self, components: &Components) { // Clear `all`, we are re-initializing it. self.all.clear(); @@ -252,7 +255,7 @@ impl RequiredComponents { required_id, required_component.clone(), components, - ) + ); } } } @@ -265,7 +268,7 @@ impl RequiredComponents { /// - `required_id` must have been registered in `components`; /// - `required_component` must hold a valid constructor for the component with id `required_id`. unsafe fn register_inherited_required_components_unchecked( - all: &mut IndexMap, + all: &mut IndexMap, required_id: ComponentId, required_component: RequiredComponent, components: &Components, @@ -281,7 +284,7 @@ impl RequiredComponents { for (&inherited_id, inherited_required) in &info.required_components().all { // This is an inherited required component: insert it only if not already present. // By the invariants of `RequiredComponents`, `info.required_components().all` holds the required - // components in a depth-first order, and this makes us store teh components in `self.all` also + // components in a depth-first order, and this makes us store the components in `self.all` also // in depth-first order, as long as we don't overwrite existing ones. // // SAFETY: @@ -407,7 +410,7 @@ impl Components { let new_required_components = required_components.all[old_required_count..] .keys() .copied() - .collect::>(); + .collect::>(); // Get all the new requiree components, i.e. `requiree` and all the components that `requiree` is required by. // SAFETY: The caller ensures that the `requiree` is valid. @@ -426,7 +429,7 @@ impl Components { for &indirect_requiree in &new_requiree_components { // Extract the required components to avoid conflicting borrows. Remember to put this back before continuing! // SAFETY: `indirect_requiree` comes from `self`, so it must be valid. - let mut required_components = std::mem::take(unsafe { + let mut required_components = core::mem::take(unsafe { self.get_required_components_mut(indirect_requiree) .debug_checked_unwrap() }); @@ -474,7 +477,7 @@ impl Components { ) { // Extract the required components to avoid conflicting borrows. Remember to put this back before returning! // SAFETY: The caller ensures that the `requiree` is valid. - let mut required_components = std::mem::take(unsafe { + let mut required_components = core::mem::take(unsafe { self.get_required_components_mut(requiree) .debug_checked_unwrap() }); @@ -537,7 +540,7 @@ pub(super) fn enforce_no_required_components_recursion( #[cfg(test)] mod tests { - use std::string::{String, ToString}; + use alloc::string::{String, ToString}; use crate::{ bundle::Bundle, @@ -1251,14 +1254,14 @@ mod tests { #[test] fn regression_19333() { #[derive(Component)] - struct X(bool); + struct X(usize); #[derive(Default, Component)] - #[require(X(false))] + #[require(X(0))] struct Base; #[derive(Default, Component)] - #[require(X(true), Base)] + #[require(X(1), Base)] struct A; #[derive(Default, Component)] @@ -1271,7 +1274,7 @@ mod tests { let mut w = World::new(); - assert_eq!(w.spawn(B).get::().unwrap().0, true); - assert_eq!(w.spawn(C).get::().unwrap().0, true); + assert_eq!(w.spawn(B).get::().unwrap().0, 1); + assert_eq!(w.spawn(C).get::().unwrap().0, 1); } } From 156f46e123a9ebdd596172dac5a9ddae98c449c9 Mon Sep 17 00:00:00 2001 From: Giacomo Stevanato Date: Tue, 15 Jul 2025 23:01:29 +0200 Subject: [PATCH 8/9] Nit: merge two impl blocks --- crates/bevy_ecs/src/component/required.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/crates/bevy_ecs/src/component/required.rs b/crates/bevy_ecs/src/component/required.rs index 0a0eef7d687be..3e0ba5b7d1d84 100644 --- a/crates/bevy_ecs/src/component/required.rs +++ b/crates/bevy_ecs/src/component/required.rs @@ -328,9 +328,7 @@ impl Components { required_by.insert(requiree); } } -} -impl Components { /// Registers the given component `R` and [required components] inherited from it as required by `T`. /// /// When `T` is added to an entity, `R` will also be added if it was not already provided. From a5c392180b0c1df42865221b0338bd34644c5e54 Mon Sep 17 00:00:00 2001 From: Giacomo Stevanato Date: Thu, 17 Jul 2025 18:56:12 +0200 Subject: [PATCH 9/9] Fix soundness issue and add more migration docs --- crates/bevy_ecs/macros/src/component.rs | 8 ++--- crates/bevy_ecs/src/component/mod.rs | 5 ++-- crates/bevy_ecs/src/component/register.rs | 9 +++--- crates/bevy_ecs/src/component/required.rs | 30 +++++++++++++++++++ .../required_components_rework.md | 2 ++ 5 files changed, 42 insertions(+), 12 deletions(-) diff --git a/crates/bevy_ecs/macros/src/component.rs b/crates/bevy_ecs/macros/src/component.rs index c6d530dc95374..743c16ca3e75e 100644 --- a/crates/bevy_ecs/macros/src/component.rs +++ b/crates/bevy_ecs/macros/src/component.rs @@ -245,8 +245,7 @@ pub fn derive_component(input: TokenStream) -> TokenStream { None => quote! { <#ident as Default>::default }, }; register_required.push(quote! { - // SAFETY: we registered all components with the same instance of components. - unsafe { required_components.register::<#ident>(components, #constructor) }; + required_components.register_required::<#ident>(#constructor); }); } } @@ -287,10 +286,9 @@ pub fn derive_component(input: TokenStream) -> TokenStream { impl #impl_generics #bevy_ecs_path::component::Component for #struct_name #type_generics #where_clause { const STORAGE_TYPE: #bevy_ecs_path::component::StorageType = #storage; type Mutability = #mutable_type; - unsafe fn register_required_components( + fn register_required_components( _requiree: #bevy_ecs_path::component::ComponentId, - components: &mut #bevy_ecs_path::component::ComponentsRegistrator, - required_components: &mut #bevy_ecs_path::component::RequiredComponents, + required_components: &mut #bevy_ecs_path::component::RequiredComponentsRegistrator, ) { #(#register_required)* } diff --git a/crates/bevy_ecs/src/component/mod.rs b/crates/bevy_ecs/src/component/mod.rs index eeb4b2fdf85d0..1d808f0e1d1ee 100644 --- a/crates/bevy_ecs/src/component/mod.rs +++ b/crates/bevy_ecs/src/component/mod.rs @@ -526,10 +526,9 @@ pub trait Component: Send + Sync + 'static { /// # Safety /// /// - `_required_components` must only contain components valid in `_components`. - unsafe fn register_required_components( + fn register_required_components( _component_id: ComponentId, - _components: &mut ComponentsRegistrator, - _required_components: &mut RequiredComponents, + _required_components: &mut RequiredComponentsRegistrator, ) { } diff --git a/crates/bevy_ecs/src/component/register.rs b/crates/bevy_ecs/src/component/register.rs index 6c4efae0c2b93..4fde7a639082a 100644 --- a/crates/bevy_ecs/src/component/register.rs +++ b/crates/bevy_ecs/src/component/register.rs @@ -5,7 +5,7 @@ use core::any::Any; use core::ops::DerefMut; use core::{any::TypeId, fmt::Debug, ops::Deref}; -use crate::component::enforce_no_required_components_recursion; +use crate::component::{enforce_no_required_components_recursion, RequiredComponentsRegistrator}; use crate::query::DebugCheckedUnwrap as _; use crate::{ component::{ @@ -232,11 +232,12 @@ impl<'w> ComponentsRegistrator<'w> { self.recursion_check_stack.push(id); let mut required_components = RequiredComponents::default(); // SAFETY: `required_components` is empty - unsafe { T::register_required_components(id, self, &mut required_components) }; + let mut required_components_registrator = + unsafe { RequiredComponentsRegistrator::new(self, &mut required_components) }; + T::register_required_components(id, &mut required_components_registrator); // SAFETY: // - `id` was just registered in `self` - // - `register_required_components` have been given `self` to register components in - // (TODO: this is not really true... but the alternative would be making `Component` `unsafe`...) + // - RequiredComponentsRegistrator guarantees that only components from `self` are included in `required_components`. unsafe { self.register_required_by(id, &required_components) }; self.recursion_check_stack.pop(); diff --git a/crates/bevy_ecs/src/component/required.rs b/crates/bevy_ecs/src/component/required.rs index 3e0ba5b7d1d84..14d295cb26eba 100644 --- a/crates/bevy_ecs/src/component/required.rs +++ b/crates/bevy_ecs/src/component/required.rs @@ -536,6 +536,36 @@ pub(super) fn enforce_no_required_components_recursion( } } +/// This is a safe handle around `ComponentsRegistrator` and `RequiredComponents` to register required components. +pub struct RequiredComponentsRegistrator<'a, 'w> { + components: &'a mut ComponentsRegistrator<'w>, + required_components: &'a mut RequiredComponents, +} + +impl<'a, 'w> RequiredComponentsRegistrator<'a, 'w> { + /// # Safety + /// + /// All components in `required_components` must have been registered in `components` + pub(super) unsafe fn new( + components: &'a mut ComponentsRegistrator<'w>, + required_components: &'a mut RequiredComponents, + ) -> Self { + Self { + components, + required_components, + } + } + + /// Register `C` as a required component. + pub fn register_required(&mut self, constructor: fn() -> C) { + // SAFETY: + unsafe { + self.required_components + .register(self.components, constructor); + } + } +} + #[cfg(test)] mod tests { use alloc::string::{String, ToString}; diff --git a/release-content/migration-guides/required_components_rework.md b/release-content/migration-guides/required_components_rework.md index db5c55abb4eef..c023143e4dd28 100644 --- a/release-content/migration-guides/required_components_rework.md +++ b/release-content/migration-guides/required_components_rework.md @@ -11,3 +11,5 @@ The required components feature has been reworked to be more consistent around t - uses of the inheritance depth were removed from the `RequiredComponent` struct and from the methods for registering runtime required components, as it's not unused for the depth-first ordering; - `Component::register_required_components`, `RequiredComponents::register` and `RequiredComponents::register_by_id` are now `unsafe`; - `RequiredComponentConstructor`'s only field is now private for safety reasons. + +The `Component::register_required_components` method has also changed signature. It now takes the `ComponentId` of the component currently being registered and a single other parameter `RequiredComponentsRegistrator` which combines the old `components` and `required_components` parameters, since exposing both of them was unsound. As previously discussed the `inheritance_depth` is now useless and has been removed, while the `recursion_check_stack` has been moved into `ComponentsRegistrator` and will be handled automatically.