diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index c8ec854504..163ed73e42 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -14,36 +14,66 @@ use rustc_errors::ErrorGuaranteed; use rustc_session::Session; use std::mem::take; -type FunctionMap = FxHashMap; +type FunctionMap = FxHashMap; pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { - // This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion - deny_recursion_in_module(sess, module)?; - - let functions = module + let (disallowed_argument_types, disallowed_return_types) = + compute_disallowed_argument_and_return_types(module); + let to_delete: Vec<_> = module .functions .iter() - .map(|f| (f.def_id().unwrap(), f.clone())) + .map(|f| should_inline(&disallowed_argument_types, &disallowed_return_types, f)) .collect(); - let (disallowed_argument_types, disallowed_return_types) = - compute_disallowed_argument_and_return_types(module); + // This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion, + // so we exit with an error if [`compute_function_postorder`] finds it. + let function_to_index = module + .functions + .iter() + .enumerate() + .map(|(idx, f)| (f.def_id().unwrap(), idx)) + .collect(); + let postorder = compute_function_postorder(sess, module, &function_to_index, &to_delete)?; let void = module .types_global_values .iter() .find(|inst| inst.class.opcode == Op::TypeVoid) .map_or(0, |inst| inst.result_id.unwrap()); - // Drop all the functions we'll be inlining. (This also means we won't waste time processing - // inlines in functions that will get inlined) + let ptr_map: FxHashMap<_, _> = module + .types_global_values + .iter() + .filter_map(|inst| { + if inst.class.opcode == Op::TypePointer + && inst.operands[0].unwrap_storage_class() == StorageClass::Function + { + Some((inst.operands[1].unwrap_id_ref(), inst.result_id.unwrap())) + } else { + None + } + }) + .collect(); + let invalid_args = module.functions.iter().flat_map(get_invalid_args).collect(); + let mut inliner = Inliner { + header: module.header.as_mut().unwrap(), + types_global_values: &mut module.types_global_values, + void, + ptr_map, + function_to_index: &function_to_index, + needs_inline: &to_delete, + invalid_args, + }; + // Processing functions in post-order of call tree we ensure that + // inlined functions already have all of the inner functions inlined, so we don't do + // the same work multiple times. + for index in postorder { + inliner.inline_fn(&mut module.functions, index); + fuse_trivial_branches(&mut module.functions[index]); + } let mut dropped_ids = FxHashSet::default(); - module.functions.retain(|f| { - if should_inline(&disallowed_argument_types, &disallowed_return_types, f) { - // TODO: We should insert all defined IDs in this function. - dropped_ids.insert(f.def_id().unwrap()); - false - } else { - true + for i in (0..module.functions.len()).rev() { + if to_delete[i] { + dropped_ids.insert(module.functions.remove(i).def_id().unwrap()); } - }); + } // Drop OpName etc. for inlined functions module.debug_names.retain(|inst| { !inst.operands.iter().any(|op| { @@ -51,42 +81,42 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { .map_or(false, |id| dropped_ids.contains(&id)) }) }); - let mut inliner = Inliner { - header: module.header.as_mut().unwrap(), - types_global_values: &mut module.types_global_values, - void, - functions: &functions, - disallowed_argument_types: &disallowed_argument_types, - disallowed_return_types: &disallowed_return_types, - }; - for function in &mut module.functions { - inliner.inline_fn(function); - fuse_trivial_branches(function); - } Ok(()) } -// https://stackoverflow.com/a/53995651 -fn deny_recursion_in_module(sess: &Session, module: &Module) -> super::Result<()> { - let func_to_index: FxHashMap = module - .functions - .iter() - .enumerate() - .map(|(index, func)| (func.def_id().unwrap(), index)) - .collect(); - let mut discovered = vec![false; module.functions.len()]; - let mut finished = vec![false; module.functions.len()]; +/// Topological sorting algorithm due to T. Cormen +/// Starts from module's entry points, so only reachable functions will be returned +/// in post-traversal order of DFS. For all unvisited functions `module.functions[i]`, +/// `to_delete[i]` is set to true. +fn compute_function_postorder( + sess: &Session, + module: &Module, + func_to_index: &FxHashMap, + to_delete: &[bool], +) -> super::Result> { + /// Possible node states for cycle-discovering DFS. + #[derive(Clone, PartialEq)] + enum NodeState { + /// Normal, not visited. + NotVisited, + /// Currently being visited. + Discovered, + /// DFS returned. + Finished, + } + let mut states = vec![NodeState::NotVisited; module.functions.len()]; let mut has_recursion = None; + let mut postorder = vec![]; for index in 0..module.functions.len() { - if !discovered[index] && !finished[index] { + if NodeState::NotVisited == states[index] && !to_delete[index] { visit( sess, module, index, - &mut discovered, - &mut finished, + &mut states[..], &mut has_recursion, - &func_to_index, + &mut postorder, + func_to_index, ); } } @@ -95,40 +125,43 @@ fn deny_recursion_in_module(sess: &Session, module: &Module) -> super::Result<() sess: &Session, module: &Module, current: usize, - discovered: &mut Vec, - finished: &mut Vec, + states: &mut [NodeState], has_recursion: &mut Option, + postorder: &mut Vec, func_to_index: &FxHashMap, ) { - discovered[current] = true; + states[current] = NodeState::Discovered; for next in calls(&module.functions[current], func_to_index) { - if discovered[next] { - let names = get_names(module); - let current_name = get_name(&names, module.functions[current].def_id().unwrap()); - let next_name = get_name(&names, module.functions[next].def_id().unwrap()); - *has_recursion = Some(sess.err(&format!( - "module has recursion, which is not allowed: `{}` calls `{}`", - current_name, next_name - ))); - break; - } - - if !finished[next] { - visit( - sess, - module, - next, - discovered, - finished, - has_recursion, - func_to_index, - ); + match states[next] { + NodeState::Discovered => { + let names = get_names(module); + let current_name = + get_name(&names, module.functions[current].def_id().unwrap()); + let next_name = get_name(&names, module.functions[next].def_id().unwrap()); + *has_recursion = Some(sess.err(&format!( + "module has recursion, which is not allowed: `{}` calls `{}`", + current_name, next_name + ))); + break; + } + NodeState::NotVisited => { + visit( + sess, + module, + next, + states, + has_recursion, + postorder, + func_to_index, + ); + } + NodeState::Finished => {} } } - discovered[current] = false; - finished[current] = true; + states[current] = NodeState::Finished; + postorder.push(current); } fn calls<'a>( @@ -146,7 +179,7 @@ fn deny_recursion_in_module(sess: &Session, module: &Module) -> super::Result<() match has_recursion { Some(err) => Err(err), - None => Ok(()), + None => Ok(postorder), } } @@ -222,20 +255,21 @@ fn should_inline( // This should be more general, but a very common problem is passing an OpAccessChain to an // OpFunctionCall (i.e. `f(&s.x)`, or more commonly, `s.x.f()` where `f` takes `&self`), so detect // that case and inline the call. -fn args_invalid(function: &Function, call: &Instruction) -> bool { - for inst in function.all_inst_iter() { +fn get_invalid_args(function: &Function) -> impl Iterator + '_ { + function.all_inst_iter().filter_map(|inst| { if inst.class.opcode == Op::AccessChain { - let inst_result = inst.result_id.unwrap(); - if call - .operands - .iter() - .any(|op| *op == Operand::IdRef(inst_result)) - { - return true; - } + inst.result_id + } else { + None } - } - false + }) +} + +fn args_invalid(invalid_args: &FxHashSet, call: &Instruction) -> bool { + call.operands.iter().skip(1).any(|op| { + op.id_ref_any() + .map_or(false, |arg| invalid_args.contains(&arg)) + }) } // Steps: @@ -248,10 +282,10 @@ struct Inliner<'m, 'map> { header: &'m mut ModuleHeader, types_global_values: &'m mut Vec, void: Word, - functions: &'map FunctionMap, - disallowed_argument_types: &'map FxHashSet, - disallowed_return_types: &'map FxHashSet, - // rewrite_rules: FxHashMap, + ptr_map: FxHashMap, + function_to_index: &'map FunctionMap, + needs_inline: &'map [bool], + invalid_args: FxHashSet, } impl Inliner<'_, '_> { @@ -262,14 +296,9 @@ impl Inliner<'_, '_> { } fn ptr_ty(&mut self, pointee: Word) -> Word { - // TODO: This is horribly slow, fix this - let existing = self.types_global_values.iter().find(|inst| { - inst.class.opcode == Op::TypePointer - && inst.operands[0].unwrap_storage_class() == StorageClass::Function - && inst.operands[1].unwrap_id_ref() == pointee - }); + let existing = self.ptr_map.get(&pointee); if let Some(existing) = existing { - return existing.result_id.unwrap(); + return *existing; } let inst_id = self.id(); self.types_global_values.push(Instruction::new( @@ -281,22 +310,26 @@ impl Inliner<'_, '_> { Operand::IdRef(pointee), ], )); + self.ptr_map.insert(pointee, inst_id); inst_id } - fn inline_fn(&mut self, function: &mut Function) { + fn inline_fn(&mut self, functions: &mut [Function], index: usize) { + let mut function = take(&mut functions[index]); let mut block_idx = 0; while block_idx < function.blocks.len() { - // If we successfully inlined a block, then repeat processing on the same block, in - // case the newly inlined block has more inlined calls. - // TODO: This is quadratic - if !self.inline_block(function, block_idx) { - block_idx += 1; - } + // If we successfully inlined a block, then continue processing on the next block or its tail. + // TODO: this is quadratic in cases where [`Op::AccessChain`]s cascade into inner arguments. + // For the common case of "we knew which functions to inline", it is linear. + self.inline_block(&mut function, functions, block_idx); + block_idx += 1; } + functions[index] = function; } - fn inline_block(&mut self, caller: &mut Function, block_idx: usize) -> bool { + /// Inlines one block. + /// After calling this, `blocks[block_idx]` is finished processing. + fn inline_block(&mut self, caller: &mut Function, functions: &[Function], block_idx: usize) { // Find the first inlined OpFunctionCall let call = caller.blocks[block_idx] .instructions @@ -307,22 +340,17 @@ impl Inliner<'_, '_> { ( index, inst, - self.functions - .get(&inst.operands[0].id_ref_any().unwrap()) - .unwrap(), + self.function_to_index[&inst.operands[0].id_ref_any().unwrap()], ) }) - .find(|(_, inst, f)| { - should_inline( - self.disallowed_argument_types, - self.disallowed_return_types, - f, - ) || args_invalid(caller, inst) + .find(|(_, inst, func_idx)| { + self.needs_inline[*func_idx] || args_invalid(&self.invalid_args, inst) }); - let (call_index, call_inst, callee) = match call { - None => return false, + let (call_index, call_inst, callee_idx) = match call { + None => return, Some(call) => call, }; + let callee = &functions[callee_idx]; let call_result_type = { let ty = call_inst.result_type.unwrap(); if ty == self.void { @@ -351,10 +379,20 @@ impl Inliner<'_, '_> { }; let return_jump = self.id(); // Rewrite OpReturns of the callee. - let mut inlined_blocks = get_inlined_blocks(callee, return_variable, return_jump); + let (mut inlined_blocks, return_values) = + get_inlined_blocks(callee, return_variable, return_jump); // Clone the IDs of the callee, because otherwise they'd be defined multiple times if the // fn is inlined multiple times. self.add_clone_id_rules(&mut rewrite_rules, &inlined_blocks); + // If any of the OpReturns were invalid, return will also be invalid. + for value in &return_values { + let value_rewritten = *rewrite_rules.get(value).unwrap_or(value); + // value_rewritten might be originally a function argument + if self.invalid_args.contains(value) || self.invalid_args.contains(&value_rewritten) { + self.invalid_args.insert(call_result_id); + self.invalid_args.insert(value_rewritten); + } + } apply_rewrite_rules(&rewrite_rules, &mut inlined_blocks); // Split the block containing the OpFunctionCall into two, around the call. @@ -375,17 +413,24 @@ impl Inliner<'_, '_> { ); } - // Fuse the first block of the callee into the block of the caller. This is okay because - // it's illegal to branch to the first BB in a function. - let mut callee_header = inlined_blocks.remove(0).instructions; + // Move the variables over from the inlined function to here. + let mut callee_header = take(&mut inlined_blocks[0]).instructions; // TODO: OpLine handling - let num_variables = callee_header - .iter() - .position(|inst| inst.class.opcode != Op::Variable) - .unwrap_or(callee_header.len()); - caller.blocks[block_idx] - .instructions - .append(&mut callee_header.split_off(num_variables)); + let num_variables = callee_header.partition_point(|inst| inst.class.opcode == Op::Variable); + // Rather than fuse the first block of the inline function to the current block, + // generate a new jump here. Branch fusing will take care of + // it, and we maintain the invariant that current block has finished processing. + let jump_to = self.id(); + inlined_blocks[0] = Block { + label: Some(Instruction::new(Op::Label, None, Some(jump_to), vec![])), + instructions: callee_header.split_off(num_variables), + }; + caller.blocks[block_idx].instructions.push(Instruction::new( + Op::Branch, + None, + None, + vec![Operand::IdRef(jump_to)], + )); // Move the OpVariables of the callee to the caller. insert_opvariables(&mut caller.blocks[0], callee_header); @@ -414,8 +459,6 @@ impl Inliner<'_, '_> { caller .blocks .splice((block_idx + 1)..(block_idx + 1), inlined_blocks); - - true } fn add_clone_id_rules(&mut self, rewrite_rules: &mut FxHashMap, blocks: &[Block]) { @@ -435,13 +478,15 @@ fn get_inlined_blocks( function: &Function, return_variable: Option, return_jump: Word, -) -> Vec { +) -> (Vec, Vec) { let mut blocks = function.blocks.clone(); + let mut values = Vec::new(); for block in &mut blocks { let last = block.instructions.last().unwrap(); if let Op::Return | Op::ReturnValue = last.class.opcode { if Op::ReturnValue == last.class.opcode { let return_value = last.operands[0].id_ref_any().unwrap(); + values.push(return_value); block.instructions.insert( block.instructions.len() - 1, Instruction::new( @@ -461,89 +506,90 @@ fn get_inlined_blocks( Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(return_jump)]); } } - blocks + (blocks, values) } fn insert_opvariable(block: &mut Block, ptr_ty: Word, result_id: Word) { let index = block .instructions - .iter() - .enumerate() - .find_map(|(index, inst)| { - if inst.class.opcode != Op::Variable { - Some(index) - } else { - None - } - }); + .partition_point(|inst| inst.class.opcode == Op::Variable); + let inst = Instruction::new( Op::Variable, Some(ptr_ty), Some(result_id), vec![Operand::StorageClass(StorageClass::Function)], ); - match index { - Some(index) => block.instructions.insert(index, inst), - None => block.instructions.push(inst), - } + block.instructions.insert(index, inst); } -fn insert_opvariables(block: &mut Block, mut insts: Vec) { +fn insert_opvariables(block: &mut Block, insts: Vec) { let index = block .instructions - .iter() - .enumerate() - .find_map(|(index, inst)| { - if inst.class.opcode != Op::Variable { - Some(index) - } else { - None - } - }); - match index { - Some(index) => { - block.instructions.splice(index..index, insts); - } - None => block.instructions.append(&mut insts), - } + .partition_point(|inst| inst.class.opcode == Op::Variable); + block.instructions.splice(index..index, insts); } fn fuse_trivial_branches(function: &mut Function) { - let all_preds = compute_preds(&function.blocks); - 'outer: for (dest_block, mut preds) in all_preds.iter().enumerate() { - // if there's two trivial branches in a row, the middle one might get inlined before the - // last one, so when processing the last one, skip through to the first one. - let pred = loop { - if preds.len() != 1 || preds[0] == dest_block { - continue 'outer; - } - let pred = preds[0]; - if !function.blocks[pred].instructions.is_empty() { - break pred; + let mut chain_list = compute_outgoing_1to1_branches(&function.blocks); + + for block_idx in 0..chain_list.len() { + let mut next = chain_list[block_idx].take(); + loop { + match next { + None => { + // end of the chain list + break; + } + Some(x) if x == block_idx => { + // loop detected + break; + } + Some(next_idx) => { + let mut dest_insts = take(&mut function.blocks[next_idx].instructions); + let self_insts = &mut function.blocks[block_idx].instructions; + self_insts.pop(); // pop the branch + self_insts.append(&mut dest_insts); + next = chain_list[next_idx].take(); + } } - preds = &all_preds[pred]; - }; - let pred_insts = &function.blocks[pred].instructions; - if pred_insts.last().unwrap().class.opcode == Op::Branch { - let mut dest_insts = take(&mut function.blocks[dest_block].instructions); - let pred_insts = &mut function.blocks[pred].instructions; - pred_insts.pop(); // pop the branch - pred_insts.append(&mut dest_insts); } } function.blocks.retain(|b| !b.instructions.is_empty()); } -fn compute_preds(blocks: &[Block]) -> Vec> { - let mut result = vec![vec![]; blocks.len()]; +fn compute_outgoing_1to1_branches(blocks: &[Block]) -> Vec> { + let block_id_to_idx: FxHashMap<_, _> = blocks + .iter() + .enumerate() + .map(|(idx, block)| (block.label_id().unwrap(), idx)) + .collect(); + #[derive(Clone)] + enum NumIncoming { + Zero, + One(usize), + TooMany, + } + let mut incoming = vec![NumIncoming::Zero; blocks.len()]; for (source_idx, source) in blocks.iter().enumerate() { for dest_id in outgoing_edges(source) { - let dest_idx = blocks - .iter() - .position(|b| b.label_id().unwrap() == dest_id) - .unwrap(); - result[dest_idx].push(source_idx); + let dest_idx = block_id_to_idx[&dest_id]; + incoming[dest_idx] = match incoming[dest_idx] { + NumIncoming::Zero => NumIncoming::One(source_idx), + _ => NumIncoming::TooMany, + } } } + + let mut result = vec![None; blocks.len()]; + + for (dest_idx, inc) in incoming.iter().enumerate() { + if let &NumIncoming::One(source_idx) = inc { + if blocks[source_idx].instructions.last().unwrap().class.opcode == Op::Branch { + result[source_idx] = Some(dest_idx); + } + } + } + result } diff --git a/crates/rustc_codegen_spirv/src/linker/test.rs b/crates/rustc_codegen_spirv/src/linker/test.rs index 9c281a2398..186c3b8c55 100644 --- a/crates/rustc_codegen_spirv/src/linker/test.rs +++ b/crates/rustc_codegen_spirv/src/linker/test.rs @@ -505,3 +505,117 @@ fn names_and_decorations() { without_header_eq(result, expect); } + +#[test] +fn cascade_inlining_of_ptr_args() { + let a = assemble_spirv( + r#"OpCapability Linkage + OpDecorate %1 LinkageAttributes "foo" Export + %2 = OpTypeInt 32 0 + %8 = OpConstant %2 0 + %3 = OpTypeStruct %2 %2 + %4 = OpTypePointer Function %2 + %5 = OpTypePointer Function %3 + %6 = OpTypeFunction %4 %5 + %1 = OpFunction %4 Const %6 + %7 = OpFunctionParameter %5 + %10 = OpLabel + %9 = OpAccessChain %4 %7 %8 + OpReturnValue %9 + OpFunctionEnd + "#, + ); + + let b = assemble_spirv( + r#"OpCapability Linkage + OpDecorate %1 LinkageAttributes "bar" Export + %2 = OpTypeInt 32 0 + %4 = OpTypePointer Function %2 + %6 = OpTypeFunction %2 %4 + %1 = OpFunction %2 None %6 + %7 = OpFunctionParameter %4 + %10 = OpLabel + %8 = OpLoad %2 %7 + OpReturnValue %8 + OpFunctionEnd + "#, + ); + + let c = assemble_spirv( + r#"OpCapability Linkage + OpDecorate %1 LinkageAttributes "baz" Export + %2 = OpTypeInt 32 0 + %4 = OpTypePointer Function %2 + %6 = OpTypeFunction %4 %4 + %1 = OpFunction %4 None %6 + %7 = OpFunctionParameter %4 + %10 = OpLabel + OpReturnValue %7 + OpFunctionEnd + "#, + ); + + // In here, inlining foo should mark its return result as a not-fit-for-function-consumption + // pointer and inline "baz" as well. That would lead to inlining "bar" too. + let d = assemble_spirv( + r#"OpCapability Linkage + OpDecorate %10 LinkageAttributes "foo" Import + OpDecorate %12 LinkageAttributes "bar" Import + OpDecorate %14 LinkageAttributes "baz" Import + OpName %1 "main" + %2 = OpTypeInt 32 0 + %3 = OpTypeStruct %2 %2 + %4 = OpTypePointer Function %2 + %5 = OpTypePointer Function %3 + %6 = OpTypeFunction %4 %5 + %7 = OpTypeFunction %2 %4 + %8 = OpTypeFunction %4 %4 + %10 = OpFunction %4 Const %6 + %11 = OpFunctionParameter %5 + OpFunctionEnd + %12 = OpFunction %2 None %7 + %13 = OpFunctionParameter %4 + OpFunctionEnd + %14 = OpFunction %4 None %8 + %15 = OpFunctionParameter %4 + OpFunctionEnd + %21 = OpTypeFunction %2 %5 + %1 = OpFunction %2 None %14 + %22 = OpFunctionParameter %5 + %23 = OpLabel + %24 = OpFunctionCall %4 %10 %22 + %25 = OpFunctionCall %4 %14 %24 + %26 = OpFunctionCall %2 %12 %25 + OpReturnValue %26 + OpFunctionEnd + "#, + ); + + let result = assemble_and_link(&[&a, &b, &c, &d]).unwrap(); + let expect = r#"OpName %1 "main" + %2 = OpTypeInt 32 0 + %3 = OpConstant %2 0 + %4 = OpTypeStruct %2 %2 + %5 = OpTypePointer Function %2 + %6 = OpTypePointer Function %4 + %7 = OpTypeFunction %5 %6 + %8 = OpTypeFunction %2 %5 + %9 = OpTypeFunction %5 %5 + %10 = OpTypeFunction %2 %6 + %11 = OpTypePointer Function %5 + %12 = OpFunction %2 None %8 + %13 = OpFunctionParameter %5 + %14 = OpLabel + %15 = OpLoad %2 %13 + OpReturnValue %15 + OpFunctionEnd + %1 = OpFunction %2 None %16 + %17 = OpFunctionParameter %6 + %18 = OpLabel + %19 = OpAccessChain %5 %17 %3 + %20 = OpLoad %2 %19 + OpReturnValue %20 + OpFunctionEnd"#; + + without_header_eq(result, expect); +}