From c02cff1f0a907a863a0c7037f24e97166d1ca8cc Mon Sep 17 00:00:00 2001 From: Scott McMurray Date: Thu, 28 Mar 2024 13:45:14 -0700 Subject: [PATCH] Rework MIR inlining costs A bunch of the current costs are surprising, probably accidentally from from not writing out the matches in full. For example, a runtime-length `memcpy` was treated as the same cost as an `Unreachable`. This reworks things around two main ideas: - Give everything a baseline cost, because even "free" things do take effort to MIR inline, and that's easy to calculate - Then just penalize those things that are materially more than the baseline, like how `[foo; 123]` is far more work than `BinOp::AddUnchecked` in an `Rvalue` By including costs for locals and vardebuginfo this makes some things overall more expensive, but because it also greatly reduces the cost for simple things like local variable addition, other things also become less expensive overall. --- .../rustc_mir_transform/src/cost_checker.rs | 147 +++++++++++++++--- compiler/rustc_mir_transform/src/inline.rs | 11 ++ 2 files changed, 139 insertions(+), 19 deletions(-) diff --git a/compiler/rustc_mir_transform/src/cost_checker.rs b/compiler/rustc_mir_transform/src/cost_checker.rs index 2c692c9500303..7e42cf86ca9d0 100644 --- a/compiler/rustc_mir_transform/src/cost_checker.rs +++ b/compiler/rustc_mir_transform/src/cost_checker.rs @@ -2,10 +2,30 @@ use rustc_middle::mir::visit::*; use rustc_middle::mir::*; use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt}; -const INSTR_COST: usize = 5; -const CALL_PENALTY: usize = 25; -const LANDINGPAD_PENALTY: usize = 50; -const RESUME_PENALTY: usize = 45; +// Even if they're zero-cost at runtime, everything has *some* cost to inline +// in terms of copying them into the MIR caller, processing them in codegen, etc. +// These baseline costs give a simple usually-too-low estimate of the cost, +// which will be updated afterwards to account for the "real" costs. +const STMT_BASELINE_COST: usize = 1; +const BLOCK_BASELINE_COST: usize = 3; +const DEBUG_BASELINE_COST: usize = 1; +const LOCAL_BASELINE_COST: usize = 1; + +// These penalties represent the cost above baseline for those things which +// have substantially more cost than is typical for their kind. +const CALL_PENALTY: usize = 22; +const LANDINGPAD_PENALTY: usize = 47; +const RESUME_PENALTY: usize = 42; +const DEREF_PENALTY: usize = 4; +const CHECKED_OP_PENALTY: usize = 2; +const THREAD_LOCAL_PENALTY: usize = 20; +const SMALL_SWITCH_PENALTY: usize = 3; +const LARGE_SWITCH_PENALTY: usize = 20; + +// Passing arguments isn't free, so give a bonus to functions with lots of them: +// if the body is small despite lots of arguments, some are probably unused. +const EXTRA_ARG_BONUS: usize = 4; +const MAX_ARG_BONUS: usize = CALL_PENALTY; /// Verify that the callee body is compatible with the caller. #[derive(Clone)] @@ -27,6 +47,20 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> { CostChecker { tcx, param_env, callee_body, instance, cost: 0 } } + // `Inline` doesn't call `visit_body`, so this is separate from the visitor. + pub fn before_body(&mut self, body: &Body<'tcx>) { + self.cost += BLOCK_BASELINE_COST * body.basic_blocks.len(); + self.cost += DEBUG_BASELINE_COST * body.var_debug_info.len(); + self.cost += LOCAL_BASELINE_COST * body.local_decls.len(); + + let total_statements = body.basic_blocks.iter().map(|x| x.statements.len()).sum::(); + self.cost += STMT_BASELINE_COST * total_statements; + + if let Some(extra_args) = body.arg_count.checked_sub(2) { + self.cost = self.cost.saturating_sub((EXTRA_ARG_BONUS * extra_args).min(MAX_ARG_BONUS)); + } + } + pub fn cost(&self) -> usize { self.cost } @@ -41,14 +75,70 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> { } impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { - fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) { - // Don't count StorageLive/StorageDead in the inlining cost. - match statement.kind { - StatementKind::StorageLive(_) + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + match &statement.kind { + StatementKind::Assign(place_and_rvalue) => { + if place_and_rvalue.0.is_indirect_first_projection() { + self.cost += DEREF_PENALTY; + } + self.visit_rvalue(&place_and_rvalue.1, location); + } + StatementKind::Intrinsic(intr) => match &**intr { + NonDivergingIntrinsic::Assume(..) => {} + NonDivergingIntrinsic::CopyNonOverlapping(_cno) => { + self.cost += CALL_PENALTY; + } + }, + StatementKind::FakeRead(..) + | StatementKind::SetDiscriminant { .. } + | StatementKind::StorageLive(_) | StatementKind::StorageDead(_) + | StatementKind::Retag(..) + | StatementKind::PlaceMention(..) + | StatementKind::AscribeUserType(..) + | StatementKind::Coverage(..) | StatementKind::Deinit(_) - | StatementKind::Nop => {} - _ => self.cost += INSTR_COST, + | StatementKind::ConstEvalCounter + | StatementKind::Nop => { + // No extra cost for these + } + } + } + + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, _location: Location) { + match rvalue { + Rvalue::Use(operand) => { + if let Some(place) = operand.place() + && place.is_indirect_first_projection() + { + self.cost += DEREF_PENALTY; + } + } + Rvalue::Repeat(_item, count) => { + let count = count.try_to_target_usize(self.tcx).unwrap_or(u64::MAX); + self.cost += (STMT_BASELINE_COST * count as usize).min(CALL_PENALTY); + } + Rvalue::Aggregate(_kind, fields) => { + self.cost += STMT_BASELINE_COST * fields.len(); + } + Rvalue::CheckedBinaryOp(..) => { + self.cost += CHECKED_OP_PENALTY; + } + Rvalue::ThreadLocalRef(..) => { + self.cost += THREAD_LOCAL_PENALTY; + } + Rvalue::Ref(..) + | Rvalue::AddressOf(..) + | Rvalue::Len(..) + | Rvalue::Cast(..) + | Rvalue::BinaryOp(..) + | Rvalue::NullaryOp(..) + | Rvalue::UnaryOp(..) + | Rvalue::Discriminant(..) + | Rvalue::ShallowInitBox(..) + | Rvalue::CopyForDeref(..) => { + // No extra cost for these + } } } @@ -63,24 +153,35 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { if let UnwindAction::Cleanup(_) = unwind { self.cost += LANDINGPAD_PENALTY; } - } else { - self.cost += INSTR_COST; } } - TerminatorKind::Call { func: Operand::Constant(ref f), unwind, .. } => { - let fn_ty = self.instantiate_ty(f.const_.ty()); - self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() + TerminatorKind::Call { ref func, unwind, .. } => { + if let Some(f) = func.constant() + && let fn_ty = self.instantiate_ty(f.ty()) + && let ty::FnDef(def_id, _) = *fn_ty.kind() && tcx.intrinsic(def_id).is_some() { // Don't give intrinsics the extra penalty for calls - INSTR_COST } else { - CALL_PENALTY + self.cost += CALL_PENALTY; }; if let UnwindAction::Cleanup(_) = unwind { self.cost += LANDINGPAD_PENALTY; } } + TerminatorKind::SwitchInt { ref discr, ref targets } => { + if let Operand::Constant(..) = discr { + // This'll be a goto once we're monomorphizing + } else { + // 0/1/unreachable is extremely common (bool, Option, Result, ...) + // but once there's more this can be a fair bit of work. + self.cost += if targets.all_targets().len() <= 3 { + SMALL_SWITCH_PENALTY + } else { + LARGE_SWITCH_PENALTY + }; + } + } TerminatorKind::Assert { unwind, .. } => { self.cost += CALL_PENALTY; if let UnwindAction::Cleanup(_) = unwind { @@ -89,12 +190,20 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { } TerminatorKind::UnwindResume => self.cost += RESUME_PENALTY, TerminatorKind::InlineAsm { unwind, .. } => { - self.cost += INSTR_COST; if let UnwindAction::Cleanup(_) = unwind { self.cost += LANDINGPAD_PENALTY; } } - _ => self.cost += INSTR_COST, + TerminatorKind::Goto { .. } + | TerminatorKind::UnwindTerminate(..) + | TerminatorKind::Return + | TerminatorKind::Yield { .. } + | TerminatorKind::CoroutineDrop + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::Unreachable => { + // No extra cost for these + } } } } diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 5f74841151cda..ea75a72d21729 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -506,6 +506,17 @@ impl<'tcx> Inliner<'tcx> { let mut checker = CostChecker::new(self.tcx, self.param_env, Some(callsite.callee), callee_body); + checker.before_body(callee_body); + + let baseline_cost = checker.cost(); + if baseline_cost > threshold { + debug!( + "NOT inlining {:?} [baseline_cost={} > threshold={}]", + callsite, baseline_cost, threshold + ); + return Err("baseline_cost above threshold"); + } + // Traverse the MIR manually so we can account for the effects of inlining on the CFG. let mut work_list = vec![START_BLOCK]; let mut visited = BitSet::new_empty(callee_body.basic_blocks.len());