From 3e1dbabc7922cdd74045511c3e7038ee306fa0bf Mon Sep 17 00:00:00 2001 From: Alex S Date: Mon, 29 Nov 2021 14:25:25 +0300 Subject: [PATCH 01/12] linker/inline: upgrade cycle counting to topo sort & resolve quadratic inlining. By inlining in callee -> caller order, we avoid the need to continue inlining the code we just inlined. A simple reachability test from one of the entry points helps avoid unnecessary work as well. The algorithm however remains quadratic in case where OpAccessChains repeatedly find their way into function parameters. There are two ways out: either a more complex control flow analysis, or conservatively inlining all function calls which reference FunctionParameters as arguments. I don't think either case is very worth it. --- .../rustc_codegen_spirv/src/linker/inline.rs | 250 ++++++++++-------- 1 file changed, 136 insertions(+), 114 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index c8ec854504..4ce53af9fa 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -14,19 +14,24 @@ use rustc_errors::ErrorGuaranteed; use rustc_session::Session; use std::mem::take; -type FunctionMap = FxHashMap; +type FunctionMap = FxHashMap; pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { + let (disallowed_argument_types, disallowed_return_types) = + compute_disallowed_argument_and_return_types(module); + let mut to_delete: Vec<_> = module + .functions + .iter() + .map(|f| should_inline(&disallowed_argument_types, &disallowed_return_types, f)) + .collect(); // This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion - deny_recursion_in_module(sess, module)?; - + let postorder = compute_function_postorder(sess, module, &mut to_delete)?; let functions = module .functions .iter() - .map(|f| (f.def_id().unwrap(), f.clone())) + .enumerate() + .map(|(idx, f)| (f.def_id().unwrap(), idx)) .collect(); - let (disallowed_argument_types, disallowed_return_types) = - compute_disallowed_argument_and_return_types(module); let void = module .types_global_values .iter() @@ -34,23 +39,6 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { .map_or(0, |inst| inst.result_id.unwrap()); // Drop all the functions we'll be inlining. (This also means we won't waste time processing // inlines in functions that will get inlined) - let mut dropped_ids = FxHashSet::default(); - module.functions.retain(|f| { - if should_inline(&disallowed_argument_types, &disallowed_return_types, f) { - // TODO: We should insert all defined IDs in this function. - dropped_ids.insert(f.def_id().unwrap()); - false - } else { - true - } - }); - // Drop OpName etc. for inlined functions - module.debug_names.retain(|inst| { - !inst.operands.iter().any(|op| { - op.id_ref_any() - .map_or(false, |id| dropped_ids.contains(&id)) - }) - }); let mut inliner = Inliner { header: module.header.as_mut().unwrap(), types_global_values: &mut module.types_global_values, @@ -59,76 +47,121 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { disallowed_argument_types: &disallowed_argument_types, disallowed_return_types: &disallowed_return_types, }; - for function in &mut module.functions { - inliner.inline_fn(function); - fuse_trivial_branches(function); + for index in postorder { + inliner.inline_fn(&mut module.functions, index); + fuse_trivial_branches(&mut module.functions[index]); + } + let mut dropped_ids = FxHashSet::default(); + for i in (0..module.functions.len()).rev() { + if to_delete[i] { + dropped_ids.insert(module.functions.remove(i).def_id().unwrap()); + } } + // Drop OpName etc. for inlined functions + module.debug_names.retain(|inst| { + !inst.operands.iter().any(|op| { + op.id_ref_any() + .map_or(false, |id| dropped_ids.contains(&id)) + }) + }); Ok(()) } -// https://stackoverflow.com/a/53995651 -fn deny_recursion_in_module(sess: &Session, module: &Module) -> super::Result<()> { +/// Topological sorting algorithm due to T. Cormen +/// Starts from module's entry points, so only reachable functions will be returned +/// in post-traversal order of DFS. For all unvisited functions `module.functions[i]`, +/// `to_delete[i]` is set to true. +fn compute_function_postorder( + sess: &Session, + module: &Module, + to_delete: &mut [bool], +) -> super::Result> { let func_to_index: FxHashMap = module .functions .iter() .enumerate() .map(|(index, func)| (func.def_id().unwrap(), index)) .collect(); - let mut discovered = vec![false; module.functions.len()]; - let mut finished = vec![false; module.functions.len()]; + /// Possible node states for cycle-discovering DFS. + #[derive(Clone, PartialEq)] + enum NodeState { + /// Normal, not visited. + NotVisited, + /// Currently being visited. + Discovered, + /// DFS returned. + Finished, + /// Not visited, entry point. + Entry, + } + let mut states = vec![NodeState::NotVisited; module.functions.len()]; + for opep in module.entry_points.iter() { + let func_id = opep.operands[1].unwrap_id_ref(); + states[func_to_index[&func_id]] = NodeState::Entry; + } let mut has_recursion = None; + let mut postorder = vec![]; for index in 0..module.functions.len() { - if !discovered[index] && !finished[index] { + if NodeState::Entry == states[index] { visit( sess, module, index, - &mut discovered, - &mut finished, + &mut states[..], &mut has_recursion, + &mut postorder, &func_to_index, ); } } + for index in 0..module.functions.len() { + if NodeState::NotVisited == states[index] { + to_delete[index] = true; + } + } + fn visit( sess: &Session, module: &Module, current: usize, - discovered: &mut Vec, - finished: &mut Vec, + states: &mut [NodeState], has_recursion: &mut Option, + postorder: &mut Vec, func_to_index: &FxHashMap, ) { - discovered[current] = true; + states[current] = NodeState::Discovered; for next in calls(&module.functions[current], func_to_index) { - if discovered[next] { - let names = get_names(module); - let current_name = get_name(&names, module.functions[current].def_id().unwrap()); - let next_name = get_name(&names, module.functions[next].def_id().unwrap()); - *has_recursion = Some(sess.err(&format!( - "module has recursion, which is not allowed: `{}` calls `{}`", - current_name, next_name - ))); - break; - } - - if !finished[next] { - visit( - sess, - module, - next, - discovered, - finished, - has_recursion, - func_to_index, - ); + match states[next] { + NodeState::Discovered => { + let names = get_names(module); + let current_name = + get_name(&names, module.functions[current].def_id().unwrap()); + let next_name = get_name(&names, module.functions[next].def_id().unwrap()); + *has_recursion = Some(sess.err(&format!( + "module has recursion, which is not allowed: `{}` calls `{}`", + current_name, next_name + ))); + break; + } + NodeState::NotVisited | NodeState::Entry => { + visit( + sess, + module, + next, + states, + has_recursion, + postorder, + func_to_index, + ); + } + NodeState::Finished => {} } } - discovered[current] = false; - finished[current] = true; + states[current] = NodeState::Finished; + postorder.push(current) } fn calls<'a>( @@ -146,7 +179,7 @@ fn deny_recursion_in_module(sess: &Session, module: &Module) -> super::Result<() match has_recursion { Some(err) => Err(err), - None => Ok(()), + None => Ok(postorder), } } @@ -284,19 +317,27 @@ impl Inliner<'_, '_> { inst_id } - fn inline_fn(&mut self, function: &mut Function) { + fn inline_fn(&mut self, functions: &mut [Function], index: usize) { + let mut function = take(&mut functions[index]); let mut block_idx = 0; while block_idx < function.blocks.len() { - // If we successfully inlined a block, then repeat processing on the same block, in - // case the newly inlined block has more inlined calls. - // TODO: This is quadratic - if !self.inline_block(function, block_idx) { - block_idx += 1; - } + // If we successfully inlined a block, then continue processing on the next block or its tail. + // TODO: this is quadratic in cases where [`Op::AccessChain`]s cascade into inner arguments. + // For the common case of "we knew which functions to inline", it is linear. + self.inline_block(&mut function, &functions, block_idx); + block_idx += 1; } + functions[index] = function; } - fn inline_block(&mut self, caller: &mut Function, block_idx: usize) -> bool { + /// Inlines one block and returns whether inlining actually occurred. + /// After calling this, blocks[block_idx] is finished processing. + fn inline_block( + &mut self, + caller: &mut Function, + functions: &[Function], + block_idx: usize, + ) -> bool { // Find the first inlined OpFunctionCall let call = caller.blocks[block_idx] .instructions @@ -304,13 +345,11 @@ impl Inliner<'_, '_> { .enumerate() .filter(|(_, inst)| inst.class.opcode == Op::FunctionCall) .map(|(index, inst)| { - ( - index, - inst, - self.functions - .get(&inst.operands[0].id_ref_any().unwrap()) - .unwrap(), - ) + let idx = self + .functions + .get(&inst.operands[0].id_ref_any().unwrap()) + .unwrap(); + (index, inst, &functions[*idx]) }) .find(|(_, inst, f)| { should_inline( @@ -375,17 +414,23 @@ impl Inliner<'_, '_> { ); } - // Fuse the first block of the callee into the block of the caller. This is okay because - // it's illegal to branch to the first BB in a function. - let mut callee_header = inlined_blocks.remove(0).instructions; + // Move the variables over from the inlined function to here. + let mut callee_header = take(&mut inlined_blocks[0]).instructions; // TODO: OpLine handling - let num_variables = callee_header - .iter() - .position(|inst| inst.class.opcode != Op::Variable) - .unwrap_or(callee_header.len()); - caller.blocks[block_idx] - .instructions - .append(&mut callee_header.split_off(num_variables)); + let num_variables = callee_header.partition_point(|inst| inst.class.opcode == Op::Variable); + // Rather than fuse blocks, generate a new jump here. Branch fusing will take care of + // it, and we maintain the invariant that current block has finished processing. + let jump_to = self.id(); + inlined_blocks[0] = Block { + label: Some(Instruction::new(Op::Label, None, Some(jump_to), vec![])), + instructions: callee_header.split_off(num_variables), + }; + caller.blocks[block_idx].instructions.push(Instruction::new( + Op::Branch, + None, + None, + vec![Operand::IdRef(jump_to)], + )); // Move the OpVariables of the callee to the caller. insert_opvariables(&mut caller.blocks[0], callee_header); @@ -467,45 +512,22 @@ fn get_inlined_blocks( fn insert_opvariable(block: &mut Block, ptr_ty: Word, result_id: Word) { let index = block .instructions - .iter() - .enumerate() - .find_map(|(index, inst)| { - if inst.class.opcode != Op::Variable { - Some(index) - } else { - None - } - }); + .partition_point(|inst| inst.class.opcode == Op::Variable); + let inst = Instruction::new( Op::Variable, Some(ptr_ty), Some(result_id), vec![Operand::StorageClass(StorageClass::Function)], ); - match index { - Some(index) => block.instructions.insert(index, inst), - None => block.instructions.push(inst), - } + block.instructions.insert(index, inst) } -fn insert_opvariables(block: &mut Block, mut insts: Vec) { +fn insert_opvariables(block: &mut Block, insts: Vec) { let index = block .instructions - .iter() - .enumerate() - .find_map(|(index, inst)| { - if inst.class.opcode != Op::Variable { - Some(index) - } else { - None - } - }); - match index { - Some(index) => { - block.instructions.splice(index..index, insts); - } - None => block.instructions.append(&mut insts), - } + .partition_point(|inst| inst.class.opcode == Op::Variable); + block.instructions.splice(index..index, insts); } fn fuse_trivial_branches(function: &mut Function) { From 37e659de2564c172bdc375ed9d9e76ff549d0e95 Mon Sep 17 00:00:00 2001 From: Alex S Date: Mon, 29 Nov 2021 14:37:20 +0300 Subject: [PATCH 02/12] linker/inline: add pointer type caching We need pointer types, and re-checking all the types to see if we already have one is rather slow, it's better to keep track. --- .../rustc_codegen_spirv/src/linker/inline.rs | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index 4ce53af9fa..087a19f667 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -37,12 +37,26 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { .iter() .find(|inst| inst.class.opcode == Op::TypeVoid) .map_or(0, |inst| inst.result_id.unwrap()); + let ptr_map: FxHashMap<_, _> = module + .types_global_values + .iter() + .filter_map(|inst| { + if inst.class.opcode == Op::TypePointer + && inst.operands[0].unwrap_storage_class() == StorageClass::Function + { + Some((inst.operands[1].unwrap_id_ref(), inst.result_id.unwrap())) + } else { + None + } + }) + .collect(); // Drop all the functions we'll be inlining. (This also means we won't waste time processing // inlines in functions that will get inlined) let mut inliner = Inliner { header: module.header.as_mut().unwrap(), types_global_values: &mut module.types_global_values, void, + ptr_map, functions: &functions, disallowed_argument_types: &disallowed_argument_types, disallowed_return_types: &disallowed_return_types, @@ -281,6 +295,7 @@ struct Inliner<'m, 'map> { header: &'m mut ModuleHeader, types_global_values: &'m mut Vec, void: Word, + ptr_map: FxHashMap, functions: &'map FunctionMap, disallowed_argument_types: &'map FxHashSet, disallowed_return_types: &'map FxHashSet, @@ -295,14 +310,9 @@ impl Inliner<'_, '_> { } fn ptr_ty(&mut self, pointee: Word) -> Word { - // TODO: This is horribly slow, fix this - let existing = self.types_global_values.iter().find(|inst| { - inst.class.opcode == Op::TypePointer - && inst.operands[0].unwrap_storage_class() == StorageClass::Function - && inst.operands[1].unwrap_id_ref() == pointee - }); + let existing = self.ptr_map.get(&pointee); if let Some(existing) = existing { - return existing.result_id.unwrap(); + return *existing; } let inst_id = self.id(); self.types_global_values.push(Instruction::new( @@ -314,6 +324,7 @@ impl Inliner<'_, '_> { Operand::IdRef(pointee), ], )); + self.ptr_map.insert(pointee, inst_id); inst_id } From 99f25a603a4e0282250bec34274c97446c13223b Mon Sep 17 00:00:00 2001 From: Alex S Date: Mon, 29 Nov 2021 15:13:08 +0300 Subject: [PATCH 03/12] linker/inline: reuse information about which functions should be inlined. The functions we are going to delete definitely either need to be inlined, or are never called (so we don't care what to decide about them). --- .../rustc_codegen_spirv/src/linker/inline.rs | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index 087a19f667..aea1a3978e 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -58,8 +58,7 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { void, ptr_map, functions: &functions, - disallowed_argument_types: &disallowed_argument_types, - disallowed_return_types: &disallowed_return_types, + needs_inline: &to_delete, }; for index in postorder { inliner.inline_fn(&mut module.functions, index); @@ -297,9 +296,7 @@ struct Inliner<'m, 'map> { void: Word, ptr_map: FxHashMap, functions: &'map FunctionMap, - disallowed_argument_types: &'map FxHashSet, - disallowed_return_types: &'map FxHashSet, - // rewrite_rules: FxHashMap, + needs_inline: &'map [bool], } impl Inliner<'_, '_> { @@ -356,23 +353,20 @@ impl Inliner<'_, '_> { .enumerate() .filter(|(_, inst)| inst.class.opcode == Op::FunctionCall) .map(|(index, inst)| { - let idx = self - .functions - .get(&inst.operands[0].id_ref_any().unwrap()) - .unwrap(); - (index, inst, &functions[*idx]) + ( + index, + inst, + self.functions[&inst.operands[0].id_ref_any().unwrap()], + ) }) - .find(|(_, inst, f)| { - should_inline( - self.disallowed_argument_types, - self.disallowed_return_types, - f, - ) || args_invalid(caller, inst) + .find(|(index, inst, func_idx)| { + self.needs_inline[*func_idx] || args_invalid(caller, inst) }); - let (call_index, call_inst, callee) = match call { + let (call_index, call_inst, callee_idx) = match call { None => return false, Some(call) => call, }; + let callee = &functions[callee_idx]; let call_result_type = { let ty = call_inst.result_type.unwrap(); if ty == self.void { From 41a6089ccc4bcdc078d3fe52e5f01a8104c441d6 Mon Sep 17 00:00:00 2001 From: Alex S Date: Mon, 29 Nov 2021 17:10:49 +0300 Subject: [PATCH 04/12] linker/inline: pre-cache calculation of invalid args Since during inlining, the only escaping value is the return value, we can calculate and update whether it has an invalid-to-call-with value as well. (Note that this is, strictly speaking, more rigor than get_invalid_values() applies, because it doesn't look behind OpPhis) As a nice bonus, we got rid of OpLoad/OpStore in favor of OpPhi, which means no type mucking and no work created for mem2reg. --- .../rustc_codegen_spirv/src/linker/inline.rs | 168 +++++++----------- 1 file changed, 63 insertions(+), 105 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index aea1a3978e..170adac77b 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -37,28 +37,17 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { .iter() .find(|inst| inst.class.opcode == Op::TypeVoid) .map_or(0, |inst| inst.result_id.unwrap()); - let ptr_map: FxHashMap<_, _> = module - .types_global_values - .iter() - .filter_map(|inst| { - if inst.class.opcode == Op::TypePointer - && inst.operands[0].unwrap_storage_class() == StorageClass::Function - { - Some((inst.operands[1].unwrap_id_ref(), inst.result_id.unwrap())) - } else { - None - } - }) - .collect(); + + let invalid_args = module.functions.iter().flat_map(get_invalid_args).collect(); + // Drop all the functions we'll be inlining. (This also means we won't waste time processing // inlines in functions that will get inlined) let mut inliner = Inliner { header: module.header.as_mut().unwrap(), - types_global_values: &mut module.types_global_values, void, - ptr_map, functions: &functions, needs_inline: &to_delete, + invalid_args, }; for index in postorder { inliner.inline_fn(&mut module.functions, index); @@ -268,20 +257,21 @@ fn should_inline( // This should be more general, but a very common problem is passing an OpAccessChain to an // OpFunctionCall (i.e. `f(&s.x)`, or more commonly, `s.x.f()` where `f` takes `&self`), so detect // that case and inline the call. -fn args_invalid(function: &Function, call: &Instruction) -> bool { - for inst in function.all_inst_iter() { +fn get_invalid_args<'a>(function: &'a Function) -> impl Iterator + 'a { + function.all_inst_iter().filter_map(|inst| { if inst.class.opcode == Op::AccessChain { - let inst_result = inst.result_id.unwrap(); - if call - .operands - .iter() - .any(|op| *op == Operand::IdRef(inst_result)) - { - return true; - } + inst.result_id + } else { + None } - } - false + }) +} + +fn args_invalid(invalid_args: &FxHashSet, call: &Instruction) -> bool { + call.operands.iter().skip(1).any(|op| { + op.id_ref_any() + .map_or(false, |arg| invalid_args.contains(&arg)) + }) } // Steps: @@ -292,11 +282,10 @@ fn args_invalid(function: &Function, call: &Instruction) -> bool { struct Inliner<'m, 'map> { header: &'m mut ModuleHeader, - types_global_values: &'m mut Vec, void: Word, - ptr_map: FxHashMap, functions: &'map FunctionMap, needs_inline: &'map [bool], + invalid_args: FxHashSet, } impl Inliner<'_, '_> { @@ -306,25 +295,6 @@ impl Inliner<'_, '_> { result } - fn ptr_ty(&mut self, pointee: Word) -> Word { - let existing = self.ptr_map.get(&pointee); - if let Some(existing) = existing { - return *existing; - } - let inst_id = self.id(); - self.types_global_values.push(Instruction::new( - Op::TypePointer, - None, - Some(inst_id), - vec![ - Operand::StorageClass(StorageClass::Function), - Operand::IdRef(pointee), - ], - )); - self.ptr_map.insert(pointee, inst_id); - inst_id - } - fn inline_fn(&mut self, functions: &mut [Function], index: usize) { let mut function = take(&mut functions[index]); let mut block_idx = 0; @@ -359,8 +329,8 @@ impl Inliner<'_, '_> { self.functions[&inst.operands[0].id_ref_any().unwrap()], ) }) - .find(|(index, inst, func_idx)| { - self.needs_inline[*func_idx] || args_invalid(caller, inst) + .find(|(_, inst, func_idx)| { + self.needs_inline[*func_idx] || args_invalid(&self.invalid_args, inst) }); let (call_index, call_inst, callee_idx) = match call { None => return false, @@ -388,18 +358,23 @@ impl Inliner<'_, '_> { }); let mut rewrite_rules = callee_parameters.zip(call_arguments).collect(); - let return_variable = if call_result_type.is_some() { - Some(self.id()) - } else { - None - }; let return_jump = self.id(); // Rewrite OpReturns of the callee. - let mut inlined_blocks = get_inlined_blocks(callee, return_variable, return_jump); + let (mut inlined_blocks, phi_pairs) = get_inlined_blocks(callee, return_jump); // Clone the IDs of the callee, because otherwise they'd be defined multiple times if the // fn is inlined multiple times. self.add_clone_id_rules(&mut rewrite_rules, &inlined_blocks); + // If any of the OpReturns were invalid, return will also be invalid. + for (value, _) in &phi_pairs { + if self.invalid_args.contains(value) { + self.invalid_args.insert(call_result_id); + self.invalid_args + .insert(*rewrite_rules.get(value).unwrap_or(value)); + } + } apply_rewrite_rules(&rewrite_rules, &mut inlined_blocks); + // unnecessary: invalidate_more_args(&rewrite_rules, &mut self.invalid_args); + // as no values from inside the inlined function ever make it directly out. // Split the block containing the OpFunctionCall into two, around the call. let mut post_call_block_insts = caller.blocks[block_idx] @@ -409,32 +384,27 @@ impl Inliner<'_, '_> { let call = caller.blocks[block_idx].instructions.pop().unwrap(); assert!(call.class.opcode == Op::FunctionCall); - if let Some(call_result_type) = call_result_type { - // Generate the storage space for the return value: Do this *after* the split above, - // because if block_idx=0, inserting a variable here shifts call_index. - insert_opvariable( - &mut caller.blocks[0], - self.ptr_ty(call_result_type), - return_variable.unwrap(), - ); - } - // Move the variables over from the inlined function to here. let mut callee_header = take(&mut inlined_blocks[0]).instructions; // TODO: OpLine handling let num_variables = callee_header.partition_point(|inst| inst.class.opcode == Op::Variable); // Rather than fuse blocks, generate a new jump here. Branch fusing will take care of // it, and we maintain the invariant that current block has finished processing. - let jump_to = self.id(); + let first_block_id = self.id(); inlined_blocks[0] = Block { - label: Some(Instruction::new(Op::Label, None, Some(jump_to), vec![])), + label: Some(Instruction::new( + Op::Label, + None, + Some(first_block_id), + vec![], + )), instructions: callee_header.split_off(num_variables), }; caller.blocks[block_idx].instructions.push(Instruction::new( Op::Branch, None, None, - vec![Operand::IdRef(jump_to)], + vec![Operand::IdRef(first_block_id)], )); // Move the OpVariables of the callee to the caller. insert_opvariables(&mut caller.blocks[0], callee_header); @@ -445,10 +415,17 @@ impl Inliner<'_, '_> { post_call_block_insts.insert( 0, Instruction::new( - Op::Load, + Op::Phi, Some(call_result_type), Some(call_result_id), - vec![Operand::IdRef(return_variable.unwrap())], + phi_pairs + .into_iter() + .flat_map(|(value, parent)| { + use std::iter; + iter::once(Operand::IdRef(*rewrite_rules.get(&value).unwrap_or(&value))) + .chain(iter::once(Operand::IdRef(rewrite_rules[&parent]))) + }) + .collect(), ), ); } @@ -481,51 +458,21 @@ impl Inliner<'_, '_> { } } -fn get_inlined_blocks( - function: &Function, - return_variable: Option, - return_jump: Word, -) -> Vec { +fn get_inlined_blocks(function: &Function, return_jump: Word) -> (Vec, Vec<(Word, Word)>) { let mut blocks = function.blocks.clone(); + let mut phipairs = Vec::new(); for block in &mut blocks { let last = block.instructions.last().unwrap(); if let Op::Return | Op::ReturnValue = last.class.opcode { if Op::ReturnValue == last.class.opcode { let return_value = last.operands[0].id_ref_any().unwrap(); - block.instructions.insert( - block.instructions.len() - 1, - Instruction::new( - Op::Store, - None, - None, - vec![ - Operand::IdRef(return_variable.unwrap()), - Operand::IdRef(return_value), - ], - ), - ); - } else { - assert!(return_variable.is_none()); + phipairs.push((return_value, block.label_id().unwrap())) } *block.instructions.last_mut().unwrap() = Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(return_jump)]); } } - blocks -} - -fn insert_opvariable(block: &mut Block, ptr_ty: Word, result_id: Word) { - let index = block - .instructions - .partition_point(|inst| inst.class.opcode == Op::Variable); - - let inst = Instruction::new( - Op::Variable, - Some(ptr_ty), - Some(result_id), - vec![Operand::StorageClass(StorageClass::Function)], - ); - block.instructions.insert(index, inst) + (blocks, phipairs) } fn insert_opvariables(block: &mut Block, insts: Vec) { @@ -537,6 +484,7 @@ fn insert_opvariables(block: &mut Block, insts: Vec) { fn fuse_trivial_branches(function: &mut Function) { let all_preds = compute_preds(&function.blocks); + let mut rewrite_rules = FxHashMap::default(); 'outer: for (dest_block, mut preds) in all_preds.iter().enumerate() { // if there's two trivial branches in a row, the middle one might get inlined before the // last one, so when processing the last one, skip through to the first one. @@ -553,12 +501,22 @@ fn fuse_trivial_branches(function: &mut Function) { let pred_insts = &function.blocks[pred].instructions; if pred_insts.last().unwrap().class.opcode == Op::Branch { let mut dest_insts = take(&mut function.blocks[dest_block].instructions); + dest_insts.retain(|inst| { + if inst.class.opcode == Op::Phi { + assert_eq!(inst.operands.len(), 2); + rewrite_rules.insert(inst.result_id.unwrap(), inst.operands[0].unwrap_id_ref()); + false + } else { + true + } + }); let pred_insts = &mut function.blocks[pred].instructions; pred_insts.pop(); // pop the branch pred_insts.append(&mut dest_insts); } } function.blocks.retain(|b| !b.instructions.is_empty()); + apply_rewrite_rules(&rewrite_rules, &mut function.blocks); } fn compute_preds(blocks: &[Block]) -> Vec> { From c048909a62b2a97987b05081cb501dc72aa4b5af Mon Sep 17 00:00:00 2001 From: Alex S Date: Mon, 29 Nov 2021 17:18:01 +0300 Subject: [PATCH 05/12] linker/inline: fuse trivial branches in less time Originally, this algorithm walked a linked list by the back-edges, copying and skipping. It is easier to just go with front-edges and gobble up a series of potential blocks at once. The predecessor finding algorithm really just wanted to find 1-to-1 edges (it was split between `compute_all_preds` and `fuse_trivial_branches`), so made it that. --- .../rustc_codegen_spirv/src/linker/inline.rs | 95 ++++++++++++------- 1 file changed, 60 insertions(+), 35 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index 170adac77b..4d66674029 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -483,52 +483,77 @@ fn insert_opvariables(block: &mut Block, insts: Vec) { } fn fuse_trivial_branches(function: &mut Function) { - let all_preds = compute_preds(&function.blocks); + let mut chain_list = compute_outgoing_1to1_branches(&function.blocks); let mut rewrite_rules = FxHashMap::default(); - 'outer: for (dest_block, mut preds) in all_preds.iter().enumerate() { - // if there's two trivial branches in a row, the middle one might get inlined before the - // last one, so when processing the last one, skip through to the first one. - let pred = loop { - if preds.len() != 1 || preds[0] == dest_block { - continue 'outer; - } - let pred = preds[0]; - if !function.blocks[pred].instructions.is_empty() { - break pred; - } - preds = &all_preds[pred]; - }; - let pred_insts = &function.blocks[pred].instructions; - if pred_insts.last().unwrap().class.opcode == Op::Branch { - let mut dest_insts = take(&mut function.blocks[dest_block].instructions); - dest_insts.retain(|inst| { - if inst.class.opcode == Op::Phi { - assert_eq!(inst.operands.len(), 2); - rewrite_rules.insert(inst.result_id.unwrap(), inst.operands[0].unwrap_id_ref()); - false - } else { - true + + for block_idx in 0..chain_list.len() { + let mut next = chain_list[block_idx].take(); + loop { + match next { + None => { + // end of the chain list + break; } - }); - let pred_insts = &mut function.blocks[pred].instructions; - pred_insts.pop(); // pop the branch - pred_insts.append(&mut dest_insts); + Some(x) if x == block_idx => { + // loop detected + break; + } + Some(next_idx) => { + let mut dest_insts = take(&mut function.blocks[next_idx].instructions); + dest_insts.retain(|inst| { + if inst.class.opcode == Op::Phi { + assert_eq!(inst.operands.len(), 2); + rewrite_rules + .insert(inst.result_id.unwrap(), inst.operands[0].unwrap_id_ref()); + false + } else { + true + } + }); + let self_insts = &mut function.blocks[block_idx].instructions; + self_insts.pop(); // pop the branch + self_insts.append(&mut dest_insts); + next = chain_list[next_idx].take(); + } + } } } function.blocks.retain(|b| !b.instructions.is_empty()); apply_rewrite_rules(&rewrite_rules, &mut function.blocks); } -fn compute_preds(blocks: &[Block]) -> Vec> { - let mut result = vec![vec![]; blocks.len()]; +fn compute_outgoing_1to1_branches(blocks: &[Block]) -> Vec> { + let block_id_to_idx: FxHashMap<_, _> = blocks + .iter() + .enumerate() + .map(|(idx, block)| (block.label_id().unwrap(), idx)) + .collect(); + #[derive(Clone)] + enum NumIncoming { + Zero, + One(usize), + TooMany, + } + let mut incoming = vec![NumIncoming::Zero; blocks.len()]; for (source_idx, source) in blocks.iter().enumerate() { for dest_id in outgoing_edges(source) { - let dest_idx = blocks - .iter() - .position(|b| b.label_id().unwrap() == dest_id) - .unwrap(); - result[dest_idx].push(source_idx); + let dest_idx = block_id_to_idx[&dest_id]; + incoming[dest_idx] = match incoming[dest_idx] { + NumIncoming::Zero => NumIncoming::One(source_idx), + _ => NumIncoming::TooMany, + } + } + } + + let mut result = vec![None; blocks.len()]; + + for (dest_idx, inc) in incoming.iter().enumerate() { + if let &NumIncoming::One(source_idx) = inc { + if blocks[source_idx].instructions.last().unwrap().class.opcode == Op::Branch { + result[source_idx] = Some(dest_idx); + } } } + result } From a7cc8dbeaec42a5e0d827522663362bb63b71fee Mon Sep 17 00:00:00 2001 From: Alex S Date: Tue, 30 Nov 2021 15:26:02 +0300 Subject: [PATCH 06/12] linker/inline: make a proper closure in fuse_trivial_branches --- crates/rustc_codegen_spirv/src/linker/inline.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index 4d66674029..9bddfe4bb9 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -519,7 +519,14 @@ fn fuse_trivial_branches(function: &mut Function) { } } function.blocks.retain(|b| !b.instructions.is_empty()); - apply_rewrite_rules(&rewrite_rules, &mut function.blocks); + // Calculate a closure, as these rules can be transitive + let mut rewrite_rules_new = rewrite_rules.clone(); + for value in rewrite_rules_new.values_mut() { + while let Some(next) = rewrite_rules.get(value) { + *value = *next; + } + } + apply_rewrite_rules(&rewrite_rules_new, &mut function.blocks); } fn compute_outgoing_1to1_branches(blocks: &[Block]) -> Vec> { From b360321ff120eb4225558e9dbbb858a68c5f7a8b Mon Sep 17 00:00:00 2001 From: Alex S Date: Tue, 30 Nov 2021 15:36:58 +0300 Subject: [PATCH 07/12] linker/inline: fix test regression Just inlining entry points deletes functions from tests and makes everyone sad. --- .../rustc_codegen_spirv/src/linker/inline.rs | 22 +++++-------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index 9bddfe4bb9..5d11562685 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -19,13 +19,13 @@ type FunctionMap = FxHashMap; pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { let (disallowed_argument_types, disallowed_return_types) = compute_disallowed_argument_and_return_types(module); - let mut to_delete: Vec<_> = module + let to_delete: Vec<_> = module .functions .iter() .map(|f| should_inline(&disallowed_argument_types, &disallowed_return_types, f)) .collect(); // This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion - let postorder = compute_function_postorder(sess, module, &mut to_delete)?; + let postorder = compute_function_postorder(sess, module, &to_delete)?; let functions = module .functions .iter() @@ -76,7 +76,7 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { fn compute_function_postorder( sess: &Session, module: &Module, - to_delete: &mut [bool], + to_delete: &[bool], ) -> super::Result> { let func_to_index: FxHashMap = module .functions @@ -93,18 +93,12 @@ fn compute_function_postorder( Discovered, /// DFS returned. Finished, - /// Not visited, entry point. - Entry, } let mut states = vec![NodeState::NotVisited; module.functions.len()]; - for opep in module.entry_points.iter() { - let func_id = opep.operands[1].unwrap_id_ref(); - states[func_to_index[&func_id]] = NodeState::Entry; - } let mut has_recursion = None; let mut postorder = vec![]; for index in 0..module.functions.len() { - if NodeState::Entry == states[index] { + if NodeState::NotVisited == states[index] && !to_delete[index] { visit( sess, module, @@ -117,12 +111,6 @@ fn compute_function_postorder( } } - for index in 0..module.functions.len() { - if NodeState::NotVisited == states[index] { - to_delete[index] = true; - } - } - fn visit( sess: &Session, module: &Module, @@ -147,7 +135,7 @@ fn compute_function_postorder( ))); break; } - NodeState::NotVisited | NodeState::Entry => { + NodeState::NotVisited => { visit( sess, module, From c5d3abe0718cdb9afbbae01350e01d5e4229f91c Mon Sep 17 00:00:00 2001 From: Alex S Date: Tue, 30 Nov 2021 15:41:25 +0300 Subject: [PATCH 08/12] linker/inline: make Clippy happy --- crates/rustc_codegen_spirv/src/linker/inline.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index 5d11562685..f4ec290e62 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -151,7 +151,7 @@ fn compute_function_postorder( } states[current] = NodeState::Finished; - postorder.push(current) + postorder.push(current); } fn calls<'a>( @@ -245,7 +245,7 @@ fn should_inline( // This should be more general, but a very common problem is passing an OpAccessChain to an // OpFunctionCall (i.e. `f(&s.x)`, or more commonly, `s.x.f()` where `f` takes `&self`), so detect // that case and inline the call. -fn get_invalid_args<'a>(function: &'a Function) -> impl Iterator + 'a { +fn get_invalid_args(function: &Function) -> impl Iterator + '_ { function.all_inst_iter().filter_map(|inst| { if inst.class.opcode == Op::AccessChain { inst.result_id @@ -290,14 +290,14 @@ impl Inliner<'_, '_> { // If we successfully inlined a block, then continue processing on the next block or its tail. // TODO: this is quadratic in cases where [`Op::AccessChain`]s cascade into inner arguments. // For the common case of "we knew which functions to inline", it is linear. - self.inline_block(&mut function, &functions, block_idx); + self.inline_block(&mut function, functions, block_idx); block_idx += 1; } functions[index] = function; } /// Inlines one block and returns whether inlining actually occurred. - /// After calling this, blocks[block_idx] is finished processing. + /// After calling this, `blocks[block_idx]` is finished processing. fn inline_block( &mut self, caller: &mut Function, @@ -454,7 +454,7 @@ fn get_inlined_blocks(function: &Function, return_jump: Word) -> (Vec, Ve if let Op::Return | Op::ReturnValue = last.class.opcode { if Op::ReturnValue == last.class.opcode { let return_value = last.operands[0].id_ref_any().unwrap(); - phipairs.push((return_value, block.label_id().unwrap())) + phipairs.push((return_value, block.label_id().unwrap())); } *block.instructions.last_mut().unwrap() = Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(return_jump)]); From b473f2abd0f06499fbc975a74360427686d609b4 Mon Sep 17 00:00:00 2001 From: Alex S Date: Tue, 30 Nov 2021 15:54:53 +0300 Subject: [PATCH 09/12] linker/inline: use variables instead of OpPhis to unify branches. This partially reverts commit 990425b7d6b9e23c1bf2ee35da67f77b5a18024d. --- .../rustc_codegen_spirv/src/linker/inline.rs | 143 ++++++++++++------ 1 file changed, 95 insertions(+), 48 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index f4ec290e62..3c825c87cd 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -37,14 +37,25 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { .iter() .find(|inst| inst.class.opcode == Op::TypeVoid) .map_or(0, |inst| inst.result_id.unwrap()); - + let ptr_map: FxHashMap<_, _> = module + .types_global_values + .iter() + .filter_map(|inst| { + if inst.class.opcode == Op::TypePointer + && inst.operands[0].unwrap_storage_class() == StorageClass::Function + { + Some((inst.operands[1].unwrap_id_ref(), inst.result_id.unwrap())) + } else { + None + } + }) + .collect(); let invalid_args = module.functions.iter().flat_map(get_invalid_args).collect(); - - // Drop all the functions we'll be inlining. (This also means we won't waste time processing - // inlines in functions that will get inlined) let mut inliner = Inliner { header: module.header.as_mut().unwrap(), + types_global_values: &mut module.types_global_values, void, + ptr_map, functions: &functions, needs_inline: &to_delete, invalid_args, @@ -270,7 +281,9 @@ fn args_invalid(invalid_args: &FxHashSet, call: &Instruction) -> bool { struct Inliner<'m, 'map> { header: &'m mut ModuleHeader, + types_global_values: &'m mut Vec, void: Word, + ptr_map: FxHashMap, functions: &'map FunctionMap, needs_inline: &'map [bool], invalid_args: FxHashSet, @@ -283,6 +296,25 @@ impl Inliner<'_, '_> { result } + fn ptr_ty(&mut self, pointee: Word) -> Word { + let existing = self.ptr_map.get(&pointee); + if let Some(existing) = existing { + return *existing; + } + let inst_id = self.id(); + self.types_global_values.push(Instruction::new( + Op::TypePointer, + None, + Some(inst_id), + vec![ + Operand::StorageClass(StorageClass::Function), + Operand::IdRef(pointee), + ], + )); + self.ptr_map.insert(pointee, inst_id); + inst_id + } + fn inline_fn(&mut self, functions: &mut [Function], index: usize) { let mut function = take(&mut functions[index]); let mut block_idx = 0; @@ -346,14 +378,20 @@ impl Inliner<'_, '_> { }); let mut rewrite_rules = callee_parameters.zip(call_arguments).collect(); + let return_variable = if call_result_type.is_some() { + Some(self.id()) + } else { + None + }; let return_jump = self.id(); // Rewrite OpReturns of the callee. - let (mut inlined_blocks, phi_pairs) = get_inlined_blocks(callee, return_jump); + let (mut inlined_blocks, return_values) = + get_inlined_blocks(callee, return_variable, return_jump); // Clone the IDs of the callee, because otherwise they'd be defined multiple times if the // fn is inlined multiple times. self.add_clone_id_rules(&mut rewrite_rules, &inlined_blocks); // If any of the OpReturns were invalid, return will also be invalid. - for (value, _) in &phi_pairs { + for value in &return_values { if self.invalid_args.contains(value) { self.invalid_args.insert(call_result_id); self.invalid_args @@ -361,8 +399,6 @@ impl Inliner<'_, '_> { } } apply_rewrite_rules(&rewrite_rules, &mut inlined_blocks); - // unnecessary: invalidate_more_args(&rewrite_rules, &mut self.invalid_args); - // as no values from inside the inlined function ever make it directly out. // Split the block containing the OpFunctionCall into two, around the call. let mut post_call_block_insts = caller.blocks[block_idx] @@ -372,27 +408,32 @@ impl Inliner<'_, '_> { let call = caller.blocks[block_idx].instructions.pop().unwrap(); assert!(call.class.opcode == Op::FunctionCall); + if let Some(call_result_type) = call_result_type { + // Generate the storage space for the return value: Do this *after* the split above, + // because if block_idx=0, inserting a variable here shifts call_index. + insert_opvariable( + &mut caller.blocks[0], + self.ptr_ty(call_result_type), + return_variable.unwrap(), + ); + } + // Move the variables over from the inlined function to here. let mut callee_header = take(&mut inlined_blocks[0]).instructions; // TODO: OpLine handling let num_variables = callee_header.partition_point(|inst| inst.class.opcode == Op::Variable); // Rather than fuse blocks, generate a new jump here. Branch fusing will take care of // it, and we maintain the invariant that current block has finished processing. - let first_block_id = self.id(); + let jump_to = self.id(); inlined_blocks[0] = Block { - label: Some(Instruction::new( - Op::Label, - None, - Some(first_block_id), - vec![], - )), + label: Some(Instruction::new(Op::Label, None, Some(jump_to), vec![])), instructions: callee_header.split_off(num_variables), }; caller.blocks[block_idx].instructions.push(Instruction::new( Op::Branch, None, None, - vec![Operand::IdRef(first_block_id)], + vec![Operand::IdRef(jump_to)], )); // Move the OpVariables of the callee to the caller. insert_opvariables(&mut caller.blocks[0], callee_header); @@ -403,17 +444,10 @@ impl Inliner<'_, '_> { post_call_block_insts.insert( 0, Instruction::new( - Op::Phi, + Op::Load, Some(call_result_type), Some(call_result_id), - phi_pairs - .into_iter() - .flat_map(|(value, parent)| { - use std::iter; - iter::once(Operand::IdRef(*rewrite_rules.get(&value).unwrap_or(&value))) - .chain(iter::once(Operand::IdRef(rewrite_rules[&parent]))) - }) - .collect(), + vec![Operand::IdRef(return_variable.unwrap())], ), ); } @@ -446,21 +480,53 @@ impl Inliner<'_, '_> { } } -fn get_inlined_blocks(function: &Function, return_jump: Word) -> (Vec, Vec<(Word, Word)>) { +fn get_inlined_blocks( + function: &Function, + return_variable: Option, + return_jump: Word, +) -> (Vec, Vec) { let mut blocks = function.blocks.clone(); - let mut phipairs = Vec::new(); + let mut values = Vec::new(); for block in &mut blocks { let last = block.instructions.last().unwrap(); if let Op::Return | Op::ReturnValue = last.class.opcode { if Op::ReturnValue == last.class.opcode { let return_value = last.operands[0].id_ref_any().unwrap(); - phipairs.push((return_value, block.label_id().unwrap())); + values.push(return_value); + block.instructions.insert( + block.instructions.len() - 1, + Instruction::new( + Op::Store, + None, + None, + vec![ + Operand::IdRef(return_variable.unwrap()), + Operand::IdRef(return_value), + ], + ), + ); + } else { + assert!(return_variable.is_none()); } *block.instructions.last_mut().unwrap() = Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(return_jump)]); } } - (blocks, phipairs) + (blocks, values) +} + +fn insert_opvariable(block: &mut Block, ptr_ty: Word, result_id: Word) { + let index = block + .instructions + .partition_point(|inst| inst.class.opcode == Op::Variable); + + let inst = Instruction::new( + Op::Variable, + Some(ptr_ty), + Some(result_id), + vec![Operand::StorageClass(StorageClass::Function)], + ); + block.instructions.insert(index, inst) } fn insert_opvariables(block: &mut Block, insts: Vec) { @@ -472,7 +538,6 @@ fn insert_opvariables(block: &mut Block, insts: Vec) { fn fuse_trivial_branches(function: &mut Function) { let mut chain_list = compute_outgoing_1to1_branches(&function.blocks); - let mut rewrite_rules = FxHashMap::default(); for block_idx in 0..chain_list.len() { let mut next = chain_list[block_idx].take(); @@ -488,16 +553,6 @@ fn fuse_trivial_branches(function: &mut Function) { } Some(next_idx) => { let mut dest_insts = take(&mut function.blocks[next_idx].instructions); - dest_insts.retain(|inst| { - if inst.class.opcode == Op::Phi { - assert_eq!(inst.operands.len(), 2); - rewrite_rules - .insert(inst.result_id.unwrap(), inst.operands[0].unwrap_id_ref()); - false - } else { - true - } - }); let self_insts = &mut function.blocks[block_idx].instructions; self_insts.pop(); // pop the branch self_insts.append(&mut dest_insts); @@ -507,14 +562,6 @@ fn fuse_trivial_branches(function: &mut Function) { } } function.blocks.retain(|b| !b.instructions.is_empty()); - // Calculate a closure, as these rules can be transitive - let mut rewrite_rules_new = rewrite_rules.clone(); - for value in rewrite_rules_new.values_mut() { - while let Some(next) = rewrite_rules.get(value) { - *value = *next; - } - } - apply_rewrite_rules(&rewrite_rules_new, &mut function.blocks); } fn compute_outgoing_1to1_branches(blocks: &[Block]) -> Vec> { From 6e4d191071ac797ec23551f059d5c6460f94a8e3 Mon Sep 17 00:00:00 2001 From: Alex S Date: Tue, 30 Nov 2021 16:03:13 +0300 Subject: [PATCH 10/12] linker/inline: code review https://github.com/EmbarkStudios/rust-gpu/pull/811#pullrequestreview-818929958 --- .../rustc_codegen_spirv/src/linker/inline.rs | 39 ++++++++----------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index 3c825c87cd..3b9ed5fa71 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -24,14 +24,15 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { .iter() .map(|f| should_inline(&disallowed_argument_types, &disallowed_return_types, f)) .collect(); - // This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion - let postorder = compute_function_postorder(sess, module, &to_delete)?; - let functions = module + // This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion, + // so we exit with an error if [`compute_function_postorder`] finds it. + let function_to_index = module .functions .iter() .enumerate() .map(|(idx, f)| (f.def_id().unwrap(), idx)) .collect(); + let postorder = compute_function_postorder(sess, module, &function_to_index, &to_delete)?; let void = module .types_global_values .iter() @@ -56,10 +57,13 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { types_global_values: &mut module.types_global_values, void, ptr_map, - functions: &functions, + function_to_index: &function_to_index, needs_inline: &to_delete, invalid_args, }; + // Processing functions in post-order of call tree we ensure that + // inlined functions already have all of the inner functions inlined, so we don't do + // the same work multiple times. for index in postorder { inliner.inline_fn(&mut module.functions, index); fuse_trivial_branches(&mut module.functions[index]); @@ -87,14 +91,9 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { fn compute_function_postorder( sess: &Session, module: &Module, + func_to_index: &FxHashMap, to_delete: &[bool], ) -> super::Result> { - let func_to_index: FxHashMap = module - .functions - .iter() - .enumerate() - .map(|(index, func)| (func.def_id().unwrap(), index)) - .collect(); /// Possible node states for cycle-discovering DFS. #[derive(Clone, PartialEq)] enum NodeState { @@ -284,7 +283,7 @@ struct Inliner<'m, 'map> { types_global_values: &'m mut Vec, void: Word, ptr_map: FxHashMap, - functions: &'map FunctionMap, + function_to_index: &'map FunctionMap, needs_inline: &'map [bool], invalid_args: FxHashSet, } @@ -328,14 +327,9 @@ impl Inliner<'_, '_> { functions[index] = function; } - /// Inlines one block and returns whether inlining actually occurred. + /// Inlines one block. /// After calling this, `blocks[block_idx]` is finished processing. - fn inline_block( - &mut self, - caller: &mut Function, - functions: &[Function], - block_idx: usize, - ) -> bool { + fn inline_block(&mut self, caller: &mut Function, functions: &[Function], block_idx: usize) { // Find the first inlined OpFunctionCall let call = caller.blocks[block_idx] .instructions @@ -346,14 +340,14 @@ impl Inliner<'_, '_> { ( index, inst, - self.functions[&inst.operands[0].id_ref_any().unwrap()], + self.function_to_index[&inst.operands[0].id_ref_any().unwrap()], ) }) .find(|(_, inst, func_idx)| { self.needs_inline[*func_idx] || args_invalid(&self.invalid_args, inst) }); let (call_index, call_inst, callee_idx) = match call { - None => return false, + None => return, Some(call) => call, }; let callee = &functions[callee_idx]; @@ -422,7 +416,8 @@ impl Inliner<'_, '_> { let mut callee_header = take(&mut inlined_blocks[0]).instructions; // TODO: OpLine handling let num_variables = callee_header.partition_point(|inst| inst.class.opcode == Op::Variable); - // Rather than fuse blocks, generate a new jump here. Branch fusing will take care of + // Rather than fuse the first block of the inline function to the current block, + // generate a new jump here. Branch fusing will take care of // it, and we maintain the invariant that current block has finished processing. let jump_to = self.id(); inlined_blocks[0] = Block { @@ -463,8 +458,6 @@ impl Inliner<'_, '_> { caller .blocks .splice((block_idx + 1)..(block_idx + 1), inlined_blocks); - - true } fn add_clone_id_rules(&mut self, rewrite_rules: &mut FxHashMap, blocks: &[Block]) { From dcd35122bd22188a5f1666234809a21756fc9d95 Mon Sep 17 00:00:00 2001 From: Alex S Date: Tue, 30 Nov 2021 17:01:10 +0300 Subject: [PATCH 11/12] linker/inline: added test for cascade inlining. --- .../rustc_codegen_spirv/src/linker/inline.rs | 7 +- crates/rustc_codegen_spirv/src/linker/test.rs | 114 ++++++++++++++++++ 2 files changed, 118 insertions(+), 3 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index 3b9ed5fa71..cafee55d20 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -386,10 +386,11 @@ impl Inliner<'_, '_> { self.add_clone_id_rules(&mut rewrite_rules, &inlined_blocks); // If any of the OpReturns were invalid, return will also be invalid. for value in &return_values { - if self.invalid_args.contains(value) { + let value_rewritten = *rewrite_rules.get(value).unwrap_or(value); + // value_rewritten might be originally a function argument + if self.invalid_args.contains(value) || self.invalid_args.contains(&value_rewritten) { self.invalid_args.insert(call_result_id); - self.invalid_args - .insert(*rewrite_rules.get(value).unwrap_or(value)); + self.invalid_args.insert(value_rewritten); } } apply_rewrite_rules(&rewrite_rules, &mut inlined_blocks); diff --git a/crates/rustc_codegen_spirv/src/linker/test.rs b/crates/rustc_codegen_spirv/src/linker/test.rs index 9c281a2398..186c3b8c55 100644 --- a/crates/rustc_codegen_spirv/src/linker/test.rs +++ b/crates/rustc_codegen_spirv/src/linker/test.rs @@ -505,3 +505,117 @@ fn names_and_decorations() { without_header_eq(result, expect); } + +#[test] +fn cascade_inlining_of_ptr_args() { + let a = assemble_spirv( + r#"OpCapability Linkage + OpDecorate %1 LinkageAttributes "foo" Export + %2 = OpTypeInt 32 0 + %8 = OpConstant %2 0 + %3 = OpTypeStruct %2 %2 + %4 = OpTypePointer Function %2 + %5 = OpTypePointer Function %3 + %6 = OpTypeFunction %4 %5 + %1 = OpFunction %4 Const %6 + %7 = OpFunctionParameter %5 + %10 = OpLabel + %9 = OpAccessChain %4 %7 %8 + OpReturnValue %9 + OpFunctionEnd + "#, + ); + + let b = assemble_spirv( + r#"OpCapability Linkage + OpDecorate %1 LinkageAttributes "bar" Export + %2 = OpTypeInt 32 0 + %4 = OpTypePointer Function %2 + %6 = OpTypeFunction %2 %4 + %1 = OpFunction %2 None %6 + %7 = OpFunctionParameter %4 + %10 = OpLabel + %8 = OpLoad %2 %7 + OpReturnValue %8 + OpFunctionEnd + "#, + ); + + let c = assemble_spirv( + r#"OpCapability Linkage + OpDecorate %1 LinkageAttributes "baz" Export + %2 = OpTypeInt 32 0 + %4 = OpTypePointer Function %2 + %6 = OpTypeFunction %4 %4 + %1 = OpFunction %4 None %6 + %7 = OpFunctionParameter %4 + %10 = OpLabel + OpReturnValue %7 + OpFunctionEnd + "#, + ); + + // In here, inlining foo should mark its return result as a not-fit-for-function-consumption + // pointer and inline "baz" as well. That would lead to inlining "bar" too. + let d = assemble_spirv( + r#"OpCapability Linkage + OpDecorate %10 LinkageAttributes "foo" Import + OpDecorate %12 LinkageAttributes "bar" Import + OpDecorate %14 LinkageAttributes "baz" Import + OpName %1 "main" + %2 = OpTypeInt 32 0 + %3 = OpTypeStruct %2 %2 + %4 = OpTypePointer Function %2 + %5 = OpTypePointer Function %3 + %6 = OpTypeFunction %4 %5 + %7 = OpTypeFunction %2 %4 + %8 = OpTypeFunction %4 %4 + %10 = OpFunction %4 Const %6 + %11 = OpFunctionParameter %5 + OpFunctionEnd + %12 = OpFunction %2 None %7 + %13 = OpFunctionParameter %4 + OpFunctionEnd + %14 = OpFunction %4 None %8 + %15 = OpFunctionParameter %4 + OpFunctionEnd + %21 = OpTypeFunction %2 %5 + %1 = OpFunction %2 None %14 + %22 = OpFunctionParameter %5 + %23 = OpLabel + %24 = OpFunctionCall %4 %10 %22 + %25 = OpFunctionCall %4 %14 %24 + %26 = OpFunctionCall %2 %12 %25 + OpReturnValue %26 + OpFunctionEnd + "#, + ); + + let result = assemble_and_link(&[&a, &b, &c, &d]).unwrap(); + let expect = r#"OpName %1 "main" + %2 = OpTypeInt 32 0 + %3 = OpConstant %2 0 + %4 = OpTypeStruct %2 %2 + %5 = OpTypePointer Function %2 + %6 = OpTypePointer Function %4 + %7 = OpTypeFunction %5 %6 + %8 = OpTypeFunction %2 %5 + %9 = OpTypeFunction %5 %5 + %10 = OpTypeFunction %2 %6 + %11 = OpTypePointer Function %5 + %12 = OpFunction %2 None %8 + %13 = OpFunctionParameter %5 + %14 = OpLabel + %15 = OpLoad %2 %13 + OpReturnValue %15 + OpFunctionEnd + %1 = OpFunction %2 None %16 + %17 = OpFunctionParameter %6 + %18 = OpLabel + %19 = OpAccessChain %5 %17 %3 + %20 = OpLoad %2 %19 + OpReturnValue %20 + OpFunctionEnd"#; + + without_header_eq(result, expect); +} From 3722fc63206f1c39dc9773fc888de9a9baf866f8 Mon Sep 17 00:00:00 2001 From: Alex S Date: Tue, 30 Nov 2021 17:04:19 +0300 Subject: [PATCH 12/12] Make clippy happy --- crates/rustc_codegen_spirv/src/linker/inline.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index cafee55d20..163ed73e42 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -116,7 +116,7 @@ fn compute_function_postorder( &mut states[..], &mut has_recursion, &mut postorder, - &func_to_index, + func_to_index, ); } } @@ -520,7 +520,7 @@ fn insert_opvariable(block: &mut Block, ptr_ty: Word, result_id: Word) { Some(result_id), vec![Operand::StorageClass(StorageClass::Function)], ); - block.instructions.insert(index, inst) + block.instructions.insert(index, inst); } fn insert_opvariables(block: &mut Block, insts: Vec) {