From c56efaedfa0a28d842365c6c98a11af591eab1a5 Mon Sep 17 00:00:00 2001 From: lcnr Date: Fri, 23 May 2025 14:46:38 +0000 Subject: [PATCH 1/2] add additional `TypeFlags` fast paths --- .../src/infer/canonical/canonicalizer.rs | 4 +++ compiler/rustc_infer/src/infer/resolve.rs | 8 ++++++ compiler/rustc_middle/src/ty/erase_regions.rs | 8 ++++++ compiler/rustc_middle/src/ty/fold.rs | 4 +++ compiler/rustc_middle/src/ty/predicate.rs | 2 ++ .../rustc_middle/src/ty/structural_impls.rs | 27 ++++++++++++++++++- .../src/canonicalizer.rs | 4 +++ .../rustc_next_trait_solver/src/resolve.rs | 20 ++++++++++++-- .../src/solve/eval_ctxt/canonical.rs | 11 ++++---- .../src/solve/eval_ctxt/mod.rs | 16 +++++++++++ .../src/solve/inspect/analyse.rs | 9 +++---- compiler/rustc_type_ir/src/binder.rs | 8 ++++++ compiler/rustc_type_ir/src/fold.rs | 8 ++++++ compiler/rustc_type_ir/src/inherent.rs | 12 +++++++++ compiler/rustc_type_ir/src/interner.rs | 4 +-- compiler/rustc_type_ir/src/visit.rs | 4 +-- 16 files changed, 131 insertions(+), 18 deletions(-) diff --git a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs index 0b543f091f730..060447ba72068 100644 --- a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs +++ b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs @@ -497,6 +497,10 @@ impl<'cx, 'tcx> TypeFolder> for Canonicalizer<'cx, 'tcx> { fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> { if p.flags().intersects(self.needs_canonical_flags) { p.super_fold_with(self) } else { p } } + + fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> { + if c.flags().intersects(self.needs_canonical_flags) { c.super_fold_with(self) } else { c } + } } impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> { diff --git a/compiler/rustc_infer/src/infer/resolve.rs b/compiler/rustc_infer/src/infer/resolve.rs index 4b0ace8c554d6..a95f24b5b95d0 100644 --- a/compiler/rustc_infer/src/infer/resolve.rs +++ b/compiler/rustc_infer/src/infer/resolve.rs @@ -55,6 +55,14 @@ impl<'a, 'tcx> TypeFolder> for OpportunisticVarResolver<'a, 'tcx> { ct.super_fold_with(self) } } + + fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> { + if !p.has_non_region_infer() { p } else { p.super_fold_with(self) } + } + + fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> { + if !c.has_non_region_infer() { c } else { c.super_fold_with(self) } + } } /// The opportunistic region resolver opportunistically resolves regions diff --git a/compiler/rustc_middle/src/ty/erase_regions.rs b/compiler/rustc_middle/src/ty/erase_regions.rs index 45a0b1288db87..f4fead7e9526d 100644 --- a/compiler/rustc_middle/src/ty/erase_regions.rs +++ b/compiler/rustc_middle/src/ty/erase_regions.rs @@ -86,4 +86,12 @@ impl<'tcx> TypeFolder> for RegionEraserVisitor<'tcx> { p } } + + fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> { + if c.has_type_flags(TypeFlags::HAS_BINDER_VARS | TypeFlags::HAS_FREE_REGIONS) { + c.super_fold_with(self) + } else { + c + } + } } diff --git a/compiler/rustc_middle/src/ty/fold.rs b/compiler/rustc_middle/src/ty/fold.rs index 8d6871d2f1fee..b2057fa36d7fc 100644 --- a/compiler/rustc_middle/src/ty/fold.rs +++ b/compiler/rustc_middle/src/ty/fold.rs @@ -177,6 +177,10 @@ where fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> { if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p } } + + fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> { + if c.has_vars_bound_at_or_above(self.current_index) { c.super_fold_with(self) } else { c } + } } impl<'tcx> TyCtxt<'tcx> { diff --git a/compiler/rustc_middle/src/ty/predicate.rs b/compiler/rustc_middle/src/ty/predicate.rs index 551d816941b6e..bc2ac42b6b1f8 100644 --- a/compiler/rustc_middle/src/ty/predicate.rs +++ b/compiler/rustc_middle/src/ty/predicate.rs @@ -238,6 +238,8 @@ impl<'tcx> Clause<'tcx> { } } +impl<'tcx> rustc_type_ir::inherent::Clauses> for ty::Clauses<'tcx> {} + #[extension(pub trait ExistentialPredicateStableCmpExt<'tcx>)] impl<'tcx> ExistentialPredicate<'tcx> { /// Compares via an ordering that will not change if modules are reordered or other changes are diff --git a/compiler/rustc_middle/src/ty/structural_impls.rs b/compiler/rustc_middle/src/ty/structural_impls.rs index 58f7bc75054bb..def7ad6cb3a11 100644 --- a/compiler/rustc_middle/src/ty/structural_impls.rs +++ b/compiler/rustc_middle/src/ty/structural_impls.rs @@ -570,6 +570,19 @@ impl<'tcx> TypeFoldable> for ty::Clause<'tcx> { } } +impl<'tcx> TypeFoldable> for ty::Clauses<'tcx> { + fn try_fold_with>>( + self, + folder: &mut F, + ) -> Result { + folder.try_fold_clauses(self) + } + + fn fold_with>>(self, folder: &mut F) -> Self { + folder.fold_clauses(self) + } +} + impl<'tcx> TypeVisitable> for ty::Predicate<'tcx> { fn visit_with>>(&self, visitor: &mut V) -> V::Result { visitor.visit_predicate(*self) @@ -615,6 +628,19 @@ impl<'tcx> TypeSuperVisitable> for ty::Clauses<'tcx> { } } +impl<'tcx> TypeSuperFoldable> for ty::Clauses<'tcx> { + fn try_super_fold_with>>( + self, + folder: &mut F, + ) -> Result { + ty::util::try_fold_list(self, folder, |tcx, v| tcx.mk_clauses(v)) + } + + fn super_fold_with>>(self, folder: &mut F) -> Self { + ty::util::fold_list(self, folder, |tcx, v| tcx.mk_clauses(v)) + } +} + impl<'tcx> TypeFoldable> for ty::Const<'tcx> { fn try_fold_with>>( self, @@ -775,7 +801,6 @@ macro_rules! list_fold { } list_fold! { - ty::Clauses<'tcx> : mk_clauses, &'tcx ty::List> : mk_poly_existential_predicates, &'tcx ty::List> : mk_place_elems, &'tcx ty::List> : mk_patterns, diff --git a/compiler/rustc_next_trait_solver/src/canonicalizer.rs b/compiler/rustc_next_trait_solver/src/canonicalizer.rs index addeb3e2b78e0..1aced3f261b37 100644 --- a/compiler/rustc_next_trait_solver/src/canonicalizer.rs +++ b/compiler/rustc_next_trait_solver/src/canonicalizer.rs @@ -572,4 +572,8 @@ impl, I: Interner> TypeFolder for Canonicaliz fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate { if p.flags().intersects(NEEDS_CANONICAL) { p.super_fold_with(self) } else { p } } + + fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses { + if c.flags().intersects(NEEDS_CANONICAL) { c.super_fold_with(self) } else { c } + } } diff --git a/compiler/rustc_next_trait_solver/src/resolve.rs b/compiler/rustc_next_trait_solver/src/resolve.rs index 39abec2d7d8db..c3c57eccd6eff 100644 --- a/compiler/rustc_next_trait_solver/src/resolve.rs +++ b/compiler/rustc_next_trait_solver/src/resolve.rs @@ -11,7 +11,7 @@ use crate::delegate::SolverDelegate; // EAGER RESOLUTION /// Resolves ty, region, and const vars to their inferred values or their root vars. -pub struct EagerResolver<'a, D, I = ::Interner> +struct EagerResolver<'a, D, I = ::Interner> where D: SolverDelegate, I: Interner, @@ -22,8 +22,20 @@ where cache: DelayedMap, } +pub fn eager_resolve_vars>( + delegate: &D, + value: T, +) -> T { + if value.has_infer() { + let mut folder = EagerResolver::new(delegate); + value.fold_with(&mut folder) + } else { + value + } +} + impl<'a, D: SolverDelegate> EagerResolver<'a, D> { - pub fn new(delegate: &'a D) -> Self { + fn new(delegate: &'a D) -> Self { EagerResolver { delegate, cache: Default::default() } } } @@ -90,4 +102,8 @@ impl, I: Interner> TypeFolder for EagerResolv fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate { if p.has_infer() { p.super_fold_with(self) } else { p } } + + fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses { + if c.has_infer() { c.super_fold_with(self) } else { c } + } } diff --git a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs index 455a178595b29..2828b13f03623 100644 --- a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs +++ b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs @@ -22,7 +22,7 @@ use tracing::{debug, instrument, trace}; use crate::canonicalizer::Canonicalizer; use crate::delegate::SolverDelegate; -use crate::resolve::EagerResolver; +use crate::resolve::eager_resolve_vars; use crate::solve::eval_ctxt::CurrentGoalKind; use crate::solve::{ CanonicalInput, CanonicalResponse, Certainty, EvalCtxt, ExternalConstraintsData, Goal, @@ -61,8 +61,7 @@ where // so we only canonicalize the lookup table and ignore // duplicate entries. let opaque_types = self.delegate.clone_opaque_types_lookup_table(); - let (goal, opaque_types) = - (goal, opaque_types).fold_with(&mut EagerResolver::new(self.delegate)); + let (goal, opaque_types) = eager_resolve_vars(self.delegate, (goal, opaque_types)); let mut orig_values = Default::default(); let canonical = Canonicalizer::canonicalize_input( @@ -157,8 +156,8 @@ where let external_constraints = self.compute_external_query_constraints(certainty, normalization_nested_goals); - let (var_values, mut external_constraints) = (self.var_values, external_constraints) - .fold_with(&mut EagerResolver::new(self.delegate)); + let (var_values, mut external_constraints) = + eager_resolve_vars(self.delegate, (self.var_values, external_constraints)); // Remove any trivial or duplicated region constraints once we've resolved regions let mut unique = HashSet::default(); @@ -469,7 +468,7 @@ where { let var_values = CanonicalVarValues { var_values: delegate.cx().mk_args(var_values) }; let state = inspect::State { var_values, data }; - let state = state.fold_with(&mut EagerResolver::new(delegate)); + let state = eager_resolve_vars(delegate, state); Canonicalizer::canonicalize_response(delegate, max_input_universe, &mut vec![], state) } diff --git a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs index dfabb94ebfc60..926b5c8123a06 100644 --- a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs +++ b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs @@ -848,6 +848,22 @@ where } } } + + fn visit_predicate(&mut self, p: I::Predicate) -> Self::Result { + if p.has_non_region_infer() || p.has_placeholders() { + p.super_visit_with(self) + } else { + ControlFlow::Continue(()) + } + } + + fn visit_clauses(&mut self, c: I::Clauses) -> Self::Result { + if c.has_non_region_infer() || c.has_placeholders() { + c.super_visit_with(self) + } else { + ControlFlow::Continue(()) + } + } } let mut visitor = ContainsTermOrNotNameable { diff --git a/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs b/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs index 9795655e84222..84808dc5b7e9e 100644 --- a/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs +++ b/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs @@ -15,9 +15,9 @@ use rustc_infer::infer::{DefineOpaqueTypes, InferCtxt, InferOk}; use rustc_macros::extension; use rustc_middle::traits::ObligationCause; use rustc_middle::traits::solve::{Certainty, Goal, GoalSource, NoSolution, QueryResult}; -use rustc_middle::ty::{TyCtxt, TypeFoldable, VisitorResult, try_visit}; +use rustc_middle::ty::{TyCtxt, VisitorResult, try_visit}; use rustc_middle::{bug, ty}; -use rustc_next_trait_solver::resolve::EagerResolver; +use rustc_next_trait_solver::resolve::eager_resolve_vars; use rustc_next_trait_solver::solve::inspect::{self, instantiate_canonical_state}; use rustc_next_trait_solver::solve::{GenerateProofTree, MaybeCause, SolverDelegateEvalExt as _}; use rustc_span::{DUMMY_SP, Span}; @@ -187,8 +187,7 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> { let _ = term_hack.constrain(infcx, span, param_env); } - let opt_impl_args = - opt_impl_args.map(|impl_args| impl_args.fold_with(&mut EagerResolver::new(infcx))); + let opt_impl_args = opt_impl_args.map(|impl_args| eager_resolve_vars(infcx, impl_args)); let goals = instantiated_goals .into_iter() @@ -392,7 +391,7 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> { infcx, depth, orig_values, - goal: uncanonicalized_goal.fold_with(&mut EagerResolver::new(infcx)), + goal: eager_resolve_vars(infcx, uncanonicalized_goal), result, evaluation_kind: evaluation.kind, normalizes_to_term_hack, diff --git a/compiler/rustc_type_ir/src/binder.rs b/compiler/rustc_type_ir/src/binder.rs index 000cf1e1fd8b1..1b056b887dba6 100644 --- a/compiler/rustc_type_ir/src/binder.rs +++ b/compiler/rustc_type_ir/src/binder.rs @@ -711,6 +711,14 @@ impl<'a, I: Interner> TypeFolder for ArgFolder<'a, I> { c.super_fold_with(self) } } + + fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate { + if p.has_param() { p.super_fold_with(self) } else { p } + } + + fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses { + if c.has_param() { c.super_fold_with(self) } else { c } + } } impl<'a, I: Interner> ArgFolder<'a, I> { diff --git a/compiler/rustc_type_ir/src/fold.rs b/compiler/rustc_type_ir/src/fold.rs index ce1188070ca7d..a5eb8699e5fc6 100644 --- a/compiler/rustc_type_ir/src/fold.rs +++ b/compiler/rustc_type_ir/src/fold.rs @@ -152,6 +152,10 @@ pub trait TypeFolder: Sized { fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate { p.super_fold_with(self) } + + fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses { + c.super_fold_with(self) + } } /// This trait is implemented for every folding traversal. There is a fold @@ -190,6 +194,10 @@ pub trait FallibleTypeFolder: Sized { fn try_fold_predicate(&mut self, p: I::Predicate) -> Result { p.try_super_fold_with(self) } + + fn try_fold_clauses(&mut self, c: I::Clauses) -> Result { + c.try_super_fold_with(self) + } } /////////////////////////////////////////////////////////////////////////// diff --git a/compiler/rustc_type_ir/src/inherent.rs b/compiler/rustc_type_ir/src/inherent.rs index ee4a8096462a0..b3b49b2c6ecac 100644 --- a/compiler/rustc_type_ir/src/inherent.rs +++ b/compiler/rustc_type_ir/src/inherent.rs @@ -510,6 +510,18 @@ pub trait Clause>: fn instantiate_supertrait(self, cx: I, trait_ref: ty::Binder>) -> Self; } +pub trait Clauses>: + Copy + + Debug + + Hash + + Eq + + TypeSuperVisitable + + TypeSuperFoldable + + Flags + + SliceLike +{ +} + /// Common capabilities of placeholder kinds pub trait PlaceholderLike: Copy + Debug + Hash + Eq { fn universe(self) -> ty::UniverseIndex; diff --git a/compiler/rustc_type_ir/src/interner.rs b/compiler/rustc_type_ir/src/interner.rs index 7e88114df460f..a9917192144ff 100644 --- a/compiler/rustc_type_ir/src/interner.rs +++ b/compiler/rustc_type_ir/src/interner.rs @@ -12,7 +12,7 @@ use crate::ir_print::IrPrint; use crate::lang_items::TraitSolverLangItem; use crate::relate::Relate; use crate::solve::{CanonicalInput, ExternalConstraintsData, PredefinedOpaquesData, QueryResult}; -use crate::visit::{Flags, TypeSuperVisitable, TypeVisitable}; +use crate::visit::{Flags, TypeVisitable}; use crate::{self as ty, search_graph}; #[cfg_attr(feature = "nightly", rustc_diagnostic_item = "type_ir_interner")] @@ -146,7 +146,7 @@ pub trait Interner: type ParamEnv: ParamEnv; type Predicate: Predicate; type Clause: Clause; - type Clauses: Copy + Debug + Hash + Eq + TypeSuperVisitable + Flags; + type Clauses: Clauses; fn with_global_cache(self, f: impl FnOnce(&mut search_graph::GlobalCache) -> R) -> R; diff --git a/compiler/rustc_type_ir/src/visit.rs b/compiler/rustc_type_ir/src/visit.rs index ccb84e2591122..fc3864dd5ae6e 100644 --- a/compiler/rustc_type_ir/src/visit.rs +++ b/compiler/rustc_type_ir/src/visit.rs @@ -120,8 +120,8 @@ pub trait TypeVisitor: Sized { p.super_visit_with(self) } - fn visit_clauses(&mut self, p: I::Clauses) -> Self::Result { - p.super_visit_with(self) + fn visit_clauses(&mut self, c: I::Clauses) -> Self::Result { + c.super_visit_with(self) } fn visit_error(&mut self, _guar: I::ErrorGuaranteed) -> Self::Result { From 0830ce036f92673fa54a06cc4eacb47426850d33 Mon Sep 17 00:00:00 2001 From: lcnr Date: Mon, 26 May 2025 11:00:29 +0000 Subject: [PATCH 2/2] assert we never incorrectly canonicalize envs --- compiler/rustc_next_trait_solver/src/canonicalizer.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/compiler/rustc_next_trait_solver/src/canonicalizer.rs b/compiler/rustc_next_trait_solver/src/canonicalizer.rs index 1aced3f261b37..e5ca2bda45923 100644 --- a/compiler/rustc_next_trait_solver/src/canonicalizer.rs +++ b/compiler/rustc_next_trait_solver/src/canonicalizer.rs @@ -574,6 +574,13 @@ impl, I: Interner> TypeFolder for Canonicaliz } fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses { + match self.canonicalize_mode { + CanonicalizeMode::Input { keep_static: true } + | CanonicalizeMode::Response { max_input_universe: _ } => {} + CanonicalizeMode::Input { keep_static: false } => { + panic!("erasing 'static in env") + } + } if c.flags().intersects(NEEDS_CANONICAL) { c.super_fold_with(self) } else { c } } }