|
| 1 | +//! Interprocedural optimizations that "weaken" function parameters, i.e. they |
| 2 | +//! replace parameter types with "simpler" ones, or outright remove parameters, |
| 3 | +//! based on how those parameters are used in the function and/or what arguments |
| 4 | +//! get passed from callers. |
| 5 | +//! |
| 6 | +use crate::linker::ipo::CallGraph; |
| 7 | +use indexmap::IndexMap; |
| 8 | +use rspirv::dr::{Builder, Module, Operand}; |
| 9 | +use rspirv::spirv::{Op, Word}; |
| 10 | +use rustc_data_structures::fx::FxHashMap; |
| 11 | +use rustc_index::bit_set::BitSet; |
| 12 | +use std::mem; |
| 13 | + |
| 14 | +pub fn remove_unused_params(module: Module) -> Module { |
| 15 | + let call_graph = CallGraph::collect(&module); |
| 16 | + |
| 17 | + // Gather all of the unused parameters for each function, transitively. |
| 18 | + // (i.e. parameters which are passed, as call arguments, to functions that |
| 19 | + // won't use them, are also considered unused, through any number of calls) |
| 20 | + let mut unused_params_per_func_id: IndexMap<Word, BitSet<usize>> = IndexMap::new(); |
| 21 | + for func_idx in call_graph.post_order() { |
| 22 | + // Skip entry points, as they're the only "exported" functions, at least |
| 23 | + // at link-time (likely only relevant to `Kernel`s, but not `Shader`s). |
| 24 | + if call_graph.entry_points.contains(&func_idx) { |
| 25 | + continue; |
| 26 | + } |
| 27 | + |
| 28 | + let func = &module.functions[func_idx]; |
| 29 | + |
| 30 | + let params_id_to_idx: FxHashMap<Word, usize> = func |
| 31 | + .parameters |
| 32 | + .iter() |
| 33 | + .enumerate() |
| 34 | + .map(|(i, p)| (p.result_id.unwrap(), i)) |
| 35 | + .collect(); |
| 36 | + let mut unused_params = BitSet::new_filled(func.parameters.len()); |
| 37 | + for inst in func.all_inst_iter() { |
| 38 | + // If this is a call, we can ignore the arguments passed to the |
| 39 | + // callee parameters we already determined to be unused, because |
| 40 | + // those parameters (and matching arguments) will get removed later. |
| 41 | + let (operands, ignore_operands) = if inst.class.opcode == Op::FunctionCall { |
| 42 | + ( |
| 43 | + &inst.operands[1..], |
| 44 | + unused_params_per_func_id.get(&inst.operands[0].unwrap_id_ref()), |
| 45 | + ) |
| 46 | + } else { |
| 47 | + (&inst.operands[..], None) |
| 48 | + }; |
| 49 | + |
| 50 | + for (i, operand) in operands.iter().enumerate() { |
| 51 | + if let Some(ignore_operands) = ignore_operands { |
| 52 | + if ignore_operands.contains(i) { |
| 53 | + continue; |
| 54 | + } |
| 55 | + } |
| 56 | + |
| 57 | + if let Operand::IdRef(id) = operand { |
| 58 | + if let Some(¶m_idx) = params_id_to_idx.get(id) { |
| 59 | + unused_params.remove(param_idx); |
| 60 | + } |
| 61 | + } |
| 62 | + } |
| 63 | + } |
| 64 | + |
| 65 | + if !unused_params.is_empty() { |
| 66 | + unused_params_per_func_id.insert(func.def_id().unwrap(), unused_params); |
| 67 | + } |
| 68 | + } |
| 69 | + |
| 70 | + // Remove unused parameters and call arguments for unused parameters. |
| 71 | + let mut builder = Builder::new_from_module(module); |
| 72 | + for func_idx in 0..builder.module_ref().functions.len() { |
| 73 | + let func = &mut builder.module_mut().functions[func_idx]; |
| 74 | + let unused_params = unused_params_per_func_id.get(&func.def_id().unwrap()); |
| 75 | + if let Some(unused_params) = unused_params { |
| 76 | + func.parameters = mem::take(&mut func.parameters) |
| 77 | + .into_iter() |
| 78 | + .enumerate() |
| 79 | + .filter(|&(i, _)| !unused_params.contains(i)) |
| 80 | + .map(|(_, p)| p) |
| 81 | + .collect(); |
| 82 | + } |
| 83 | + |
| 84 | + for inst in func.all_inst_iter_mut() { |
| 85 | + if inst.class.opcode == Op::FunctionCall { |
| 86 | + if let Some(unused_callee_params) = |
| 87 | + unused_params_per_func_id.get(&inst.operands[0].unwrap_id_ref()) |
| 88 | + { |
| 89 | + inst.operands = mem::take(&mut inst.operands) |
| 90 | + .into_iter() |
| 91 | + .enumerate() |
| 92 | + .filter(|&(i, _)| i == 0 || !unused_callee_params.contains(i - 1)) |
| 93 | + .map(|(_, o)| o) |
| 94 | + .collect(); |
| 95 | + } |
| 96 | + } |
| 97 | + } |
| 98 | + |
| 99 | + // Regenerate the function type from remaining parameters, if necessary. |
| 100 | + if unused_params.is_some() { |
| 101 | + let return_type = func.def.as_mut().unwrap().result_type.unwrap(); |
| 102 | + let new_param_types: Vec<_> = func |
| 103 | + .parameters |
| 104 | + .iter() |
| 105 | + .map(|inst| inst.result_type.unwrap()) |
| 106 | + .collect(); |
| 107 | + let new_func_type = builder.type_function(return_type, new_param_types); |
| 108 | + let func = &mut builder.module_mut().functions[func_idx]; |
| 109 | + func.def.as_mut().unwrap().operands[1] = Operand::IdRef(new_func_type); |
| 110 | + } |
| 111 | + } |
| 112 | + |
| 113 | + builder.module() |
| 114 | +} |
0 commit comments