Skip to content

Commit 0bbea9e

Browse files
committed
Introduce remove_unused_params pass, to run just before zombie removal.
1 parent 5e23a73 commit 0bbea9e

File tree

3 files changed

+133
-2
lines changed

3 files changed

+133
-2
lines changed

crates/rustc_codegen_spirv/src/linker/ipo.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use rustc_data_structures::fx::FxHashMap;
1111
type FuncIdx = usize;
1212

1313
pub struct CallGraph {
14-
entry_points: IndexSet<FuncIdx>,
14+
pub entry_points: IndexSet<FuncIdx>,
1515

1616
/// `callees[i].contains(j)` implies `functions[i]` calls `functions[j]`.
1717
callees: Vec<IndexSet<FuncIdx>>,
@@ -39,7 +39,15 @@ impl CallGraph {
3939
.map(|func| {
4040
func.all_inst_iter()
4141
.filter(|inst| inst.class.opcode == Op::FunctionCall)
42-
.map(|inst| func_id_to_idx[&inst.operands[0].unwrap_id_ref()])
42+
.filter_map(|inst| {
43+
// FIXME(eddyb) `func_id_to_idx` should always have an
44+
// entry for a callee ID, but when ran early enough
45+
// (before zombie removal), the callee ID might not
46+
// point to an `OpFunction` (unsure what, `OpUndef`?).
47+
func_id_to_idx
48+
.get(&inst.operands[0].unwrap_id_ref())
49+
.copied()
50+
})
4351
.collect()
4452
})
4553
.collect();

crates/rustc_codegen_spirv/src/linker/mod.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod import_export_link;
88
mod inline;
99
mod ipo;
1010
mod mem2reg;
11+
mod param_weakening;
1112
mod peephole_opts;
1213
mod simple_passes;
1314
mod specializer;
@@ -130,6 +131,14 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
130131
import_export_link::run(sess, &mut output)?;
131132
}
132133

134+
// HACK(eddyb) this has to run before the `remove_zombies` pass, so that any
135+
// zombies that are passed as call arguments, but eventually unused, won't
136+
// be (incorrectly) considered used.
137+
{
138+
let _timer = sess.timer("link_remove_unused_params");
139+
output = param_weakening::remove_unused_params(output);
140+
}
141+
133142
{
134143
let _timer = sess.timer("link_remove_zombies");
135144
zombies::remove_zombies(sess, &mut output);
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
//! Interprocedural optimizations that "weaken" function parameters, i.e. they
2+
//! replace parameter types with "simpler" ones, or outright remove parameters,
3+
//! based on how those parameters are used in the function and/or what arguments
4+
//! get passed from callers.
5+
//!
6+
use crate::linker::ipo::CallGraph;
7+
use indexmap::IndexMap;
8+
use rspirv::dr::{Builder, Module, Operand};
9+
use rspirv::spirv::{Op, Word};
10+
use rustc_data_structures::fx::FxHashMap;
11+
use rustc_index::bit_set::BitSet;
12+
use std::mem;
13+
14+
pub fn remove_unused_params(module: Module) -> Module {
15+
let call_graph = CallGraph::collect(&module);
16+
17+
// Gather all of the unused parameters for each function, transitively.
18+
// (i.e. parameters which are passed, as call arguments, to functions that
19+
// won't use them, are also considered unused, through any number of calls)
20+
let mut unused_params_per_func_id: IndexMap<Word, BitSet<usize>> = IndexMap::new();
21+
for func_idx in call_graph.post_order() {
22+
// Skip entry points, as they're the only "exported" functions, at least
23+
// at link-time (likely only relevant to `Kernel`s, but not `Shader`s).
24+
if call_graph.entry_points.contains(&func_idx) {
25+
continue;
26+
}
27+
28+
let func = &module.functions[func_idx];
29+
30+
let params_id_to_idx: FxHashMap<Word, usize> = func
31+
.parameters
32+
.iter()
33+
.enumerate()
34+
.map(|(i, p)| (p.result_id.unwrap(), i))
35+
.collect();
36+
let mut unused_params = BitSet::new_filled(func.parameters.len());
37+
for inst in func.all_inst_iter() {
38+
// If this is a call, we can ignore the arguments passed to the
39+
// callee parameters we already determined to be unused, because
40+
// those parameters (and matching arguments) will get removed later.
41+
let (operands, ignore_operands) = if inst.class.opcode == Op::FunctionCall {
42+
(
43+
&inst.operands[1..],
44+
unused_params_per_func_id.get(&inst.operands[0].unwrap_id_ref()),
45+
)
46+
} else {
47+
(&inst.operands[..], None)
48+
};
49+
50+
for (i, operand) in operands.iter().enumerate() {
51+
if let Some(ignore_operands) = ignore_operands {
52+
if ignore_operands.contains(i) {
53+
continue;
54+
}
55+
}
56+
57+
if let Operand::IdRef(id) = operand {
58+
if let Some(&param_idx) = params_id_to_idx.get(id) {
59+
unused_params.remove(param_idx);
60+
}
61+
}
62+
}
63+
}
64+
65+
if !unused_params.is_empty() {
66+
unused_params_per_func_id.insert(func.def_id().unwrap(), unused_params);
67+
}
68+
}
69+
70+
// Remove unused parameters and call arguments for unused parameters.
71+
let mut builder = Builder::new_from_module(module);
72+
for func_idx in 0..builder.module_ref().functions.len() {
73+
let func = &mut builder.module_mut().functions[func_idx];
74+
let unused_params = unused_params_per_func_id.get(&func.def_id().unwrap());
75+
if let Some(unused_params) = unused_params {
76+
func.parameters = mem::take(&mut func.parameters)
77+
.into_iter()
78+
.enumerate()
79+
.filter(|&(i, _)| !unused_params.contains(i))
80+
.map(|(_, p)| p)
81+
.collect();
82+
}
83+
84+
for inst in func.all_inst_iter_mut() {
85+
if inst.class.opcode == Op::FunctionCall {
86+
if let Some(unused_callee_params) =
87+
unused_params_per_func_id.get(&inst.operands[0].unwrap_id_ref())
88+
{
89+
inst.operands = mem::take(&mut inst.operands)
90+
.into_iter()
91+
.enumerate()
92+
.filter(|&(i, _)| i == 0 || !unused_callee_params.contains(i - 1))
93+
.map(|(_, o)| o)
94+
.collect();
95+
}
96+
}
97+
}
98+
99+
// Regenerate the function type from remaining parameters, if necessary.
100+
if unused_params.is_some() {
101+
let return_type = func.def.as_mut().unwrap().result_type.unwrap();
102+
let new_param_types: Vec<_> = func
103+
.parameters
104+
.iter()
105+
.map(|inst| inst.result_type.unwrap())
106+
.collect();
107+
let new_func_type = builder.type_function(return_type, new_param_types);
108+
let func = &mut builder.module_mut().functions[func_idx];
109+
func.def.as_mut().unwrap().operands[1] = Operand::IdRef(new_func_type);
110+
}
111+
}
112+
113+
builder.module()
114+
}

0 commit comments

Comments
 (0)