Skip to content

Commit e597077

Browse files
committed
inline globals!
1 parent 5ac500d commit e597077

File tree

4 files changed

+302
-3
lines changed

4 files changed

+302
-3
lines changed

crates/rustc_codegen_spirv/src/linker/inline.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,16 @@ fn should_inline(
209209
) -> bool {
210210
let def = function.def.as_ref().unwrap();
211211
let control = def.operands[0].unwrap_function_control();
212-
control.contains(FunctionControl::INLINE)
212+
let should = control.contains(FunctionControl::INLINE)
213213
|| function
214214
.parameters
215215
.iter()
216216
.any(|inst| disallowed_argument_types.contains(inst.result_type.as_ref().unwrap()))
217-
|| disallowed_return_types.contains(&function.def.as_ref().unwrap().result_type.unwrap())
217+
|| disallowed_return_types.contains(&function.def.as_ref().unwrap().result_type.unwrap());
218+
// if should && control.contains(FunctionControl::DONT_INLINE) {
219+
// println!("should not be inlined!");
220+
// }
221+
should
218222
}
219223

220224
// This should be more general, but a very common problem is passing an OpAccessChain to an
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
use rspirv::binary::Disassemble;
2+
use rspirv::dr::{Instruction, Module, Operand};
3+
use rspirv::spirv::Op;
4+
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
5+
use rustc_session::Session;
6+
7+
#[derive(Debug, Clone, PartialEq)]
8+
enum FunctionArg {
9+
Invalid,
10+
Insts(Vec<Instruction>),
11+
}
12+
13+
pub fn inline_global_varaibles(sess: &Session, module: &mut Module) -> super::Result<()> {
14+
let mut i = 0;
15+
let mut cont = true;
16+
std::fs::write("res0.txt", module.disassemble());
17+
while cont {
18+
cont = inline_global_varaibles_rec(sess, module)?;
19+
i += 1;
20+
std::fs::write(format!("res{}.txt", i), module.disassemble());
21+
}
22+
Ok(())
23+
}
24+
25+
fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Result<bool> {
26+
// first collect global stuff
27+
let mut variables: FxHashSet<u32> = FxHashSet::default();
28+
let mut function_types: FxHashMap<u32, Instruction> = FxHashMap::default();
29+
for global_inst in &module.types_global_values {
30+
let opcode = global_inst.class.opcode;
31+
if opcode == Op::Variable || opcode == Op::Constant {
32+
variables.insert(global_inst.result_id.unwrap());
33+
} else if opcode == Op::TypeFunction {
34+
function_types.insert(global_inst.result_id.unwrap(), global_inst.clone());
35+
}
36+
}
37+
// then we keep track of which function parameter are always called with the same expression that only uses global variables
38+
let mut function_args: FxHashMap<(u32, u32), FunctionArg> = FxHashMap::default();
39+
for caller in &module.functions {
40+
let mut insts: FxHashMap<u32, Instruction> = FxHashMap::default();
41+
for block in &caller.blocks {
42+
for inst in &block.instructions {
43+
if inst.result_id.is_some() {
44+
insts.insert(inst.result_id.unwrap(), inst.clone());
45+
}
46+
if inst.class.opcode == Op::FunctionCall {
47+
let function_id = match &inst.operands[0] {
48+
&Operand::IdRef(w) => w,
49+
_ => panic!(),
50+
};
51+
for i in 1..inst.operands.len() {
52+
let key = (function_id, i as u32 - 1);
53+
match &inst.operands[i] {
54+
&Operand::IdRef(w) => match &function_args.get(&key) {
55+
None => match get_const_arg_insts(&variables, &insts, w) {
56+
Some(insts) => {
57+
function_args.insert(key, FunctionArg::Insts(insts));
58+
}
59+
None => {
60+
function_args.insert(key, FunctionArg::Invalid);
61+
}
62+
},
63+
Some(FunctionArg::Insts(w2)) => {
64+
let new_insts = get_const_arg_insts(&variables, &insts, w);
65+
match new_insts {
66+
Some(new_insts) => {
67+
if new_insts != *w2 {
68+
function_args.insert(key, FunctionArg::Invalid);
69+
}
70+
}
71+
None => {
72+
function_args.insert(key, FunctionArg::Invalid);
73+
}
74+
}
75+
}
76+
_ => {
77+
function_args.insert(key, FunctionArg::Invalid);
78+
}
79+
},
80+
_ => {
81+
function_args.insert(key, FunctionArg::Invalid);
82+
}
83+
};
84+
}
85+
}
86+
}
87+
}
88+
}
89+
function_args.retain(|_, k| match k {
90+
FunctionArg::Invalid => false,
91+
FunctionArg::Insts(v) => !v.is_empty(),
92+
});
93+
if function_args.is_empty() {
94+
return Ok(false);
95+
}
96+
let mut bound = module.header.as_ref().unwrap().bound;
97+
for function in &mut module.functions {
98+
let def = function.def.as_mut().unwrap();
99+
let fid = def.result_id.unwrap();
100+
let mut insts: Vec<Instruction> = Vec::new();
101+
let mut j: u32 = 0;
102+
let mut i = 0;
103+
let mut removed_indexes: Vec<u32> = Vec::new();
104+
// callee side. remove parameters from function def
105+
while i < function.parameters.len() {
106+
let mut removed = false;
107+
match &function_args.get(&(fid, j)) {
108+
Some(FunctionArg::Insts(arg)) => {
109+
let parameter = function.parameters.remove(i);
110+
let mut arg = arg.clone();
111+
arg.reverse();
112+
insts_replacing_captured_ids(&mut arg, &mut bound);
113+
let index = arg.len() - 1;
114+
arg[index].result_id = parameter.result_id;
115+
insts.extend(arg);
116+
removed_indexes.push(j);
117+
removed = true;
118+
}
119+
_ => (),
120+
}
121+
if !removed {
122+
i += 1;
123+
}
124+
j += 1;
125+
}
126+
// callee side. and add a new function type in global section
127+
if removed_indexes.len() > 0 {
128+
if let Operand::IdRef(tid) = def.operands[1] {
129+
let mut function_type: Instruction = function_types.get(&tid).unwrap().clone();
130+
let tid: u32 = bound;
131+
bound += 1;
132+
for i in removed_indexes.iter().rev() {
133+
let i = *i as usize + 1;
134+
function_type.operands.remove(i);
135+
function_type.result_id = Some(tid);
136+
}
137+
def.operands[1] = Operand::IdRef(tid);
138+
module.types_global_values.push(function_type);
139+
}
140+
}
141+
// callee side. insert initialization instructions, which reuse the ids of the removed parameters
142+
if !function.blocks.is_empty() {
143+
let first_block = &mut function.blocks[0];
144+
// skip some instructions that must be at top of block
145+
let mut i = 0;
146+
loop {
147+
if i >= first_block.instructions.len() {
148+
break;
149+
}
150+
let inst = &first_block.instructions[i];
151+
if inst.class.opcode == Op::Label || inst.class.opcode == Op::Variable {
152+
} else {
153+
break;
154+
}
155+
i += 1;
156+
}
157+
first_block.instructions.splice(i..i, insts);
158+
}
159+
// caller side, remove parameters from function call
160+
for block in &mut function.blocks {
161+
for inst in &mut block.instructions {
162+
if inst.class.opcode == Op::FunctionCall {
163+
let function_id = match &inst.operands[0] {
164+
&Operand::IdRef(w) => w,
165+
_ => panic!(),
166+
};
167+
let mut removed_size = 0;
168+
for i in 0..inst.operands.len() - 1 {
169+
if function_args.contains_key(&(function_id, i as u32)) {
170+
inst.operands.remove(i - removed_size + 1);
171+
removed_size += 1;
172+
}
173+
}
174+
}
175+
}
176+
}
177+
}
178+
if let Some(header) = &mut module.header {
179+
header.bound = bound;
180+
}
181+
Ok(true)
182+
}
183+
184+
fn insts_replacing_captured_ids(arg: &mut Vec<Instruction>, bound: &mut u32) {
185+
let mut id_map: FxHashMap<u32, u32> = FxHashMap::default();
186+
for ins in arg {
187+
if let Some(id) = &mut ins.result_id {
188+
for op in &mut ins.operands {
189+
match op {
190+
Operand::IdRef(id) => match id_map.get(id) {
191+
Some(new_id) => {
192+
*id = *new_id;
193+
}
194+
_ => {}
195+
},
196+
_ => {}
197+
}
198+
}
199+
id_map.insert(*id, *bound);
200+
*id = *bound;
201+
*bound += 1;
202+
}
203+
}
204+
}
205+
206+
fn get_const_arg_operands(
207+
variables: &FxHashSet<u32>,
208+
insts: &FxHashMap<u32, Instruction>,
209+
operand: &Operand,
210+
) -> Option<Vec<Instruction>> {
211+
match operand {
212+
Operand::IdRef(id) => {
213+
let insts = get_const_arg_insts(variables, insts, *id)?;
214+
return Some(insts);
215+
}
216+
Operand::LiteralInt32(_) => {},
217+
Operand::LiteralInt64(_) => {},
218+
Operand::LiteralFloat32(_) => {},
219+
Operand::LiteralFloat64(_) => {},
220+
Operand::LiteralExtInstInteger(_) => {},
221+
Operand::LiteralSpecConstantOpInteger(_) => {},
222+
Operand::LiteralString(_) => {},
223+
_ => {
224+
// TOOD add more cases
225+
return None;
226+
}
227+
}
228+
return Some(Vec::new());
229+
}
230+
231+
fn get_const_arg_insts(
232+
variables: &FxHashSet<u32>,
233+
insts: &FxHashMap<u32, Instruction>,
234+
id: u32,
235+
) -> Option<Vec<Instruction>> {
236+
let mut result: Vec<Instruction> = Vec::new();
237+
if variables.contains(&id) {
238+
return Some(result);
239+
}
240+
let par: &Instruction = insts.get(&id)?;
241+
if par.class.opcode == Op::AccessChain {
242+
result.push(par.clone());
243+
for oprand in &par.operands {
244+
let insts = get_const_arg_operands(variables, insts, oprand)?;
245+
result.extend(insts);
246+
}
247+
} else if par.class.opcode == Op::FunctionCall {
248+
result.push(par.clone());
249+
// skip first, first is function id
250+
for oprand in &par.operands[1..] {
251+
let insts = get_const_arg_operands(variables, insts, oprand)?;
252+
result.extend(insts);
253+
}
254+
} else {
255+
// TOOD add more cases
256+
return None;
257+
}
258+
Some(result)
259+
}

crates/rustc_codegen_spirv/src/linker/mod.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ mod destructure_composites;
66
mod duplicates;
77
mod entry_interface;
88
mod import_export_link;
9+
mod simpl_op_store_var;
10+
mod inline_globals;
911
mod inline;
1012
mod ipo;
1113
mod mem2reg;
@@ -20,7 +22,7 @@ use std::borrow::Cow;
2022

2123
use crate::codegen_cx::SpirvMetadata;
2224
use crate::decorations::{CustomDecoration, UnrollLoopsDecoration};
23-
use rspirv::binary::{Assemble, Consumer};
25+
use rspirv::binary::{Assemble, Consumer, Disassemble};
2426
use rspirv::dr::{Block, Instruction, Loader, Module, ModuleHeader, Operand};
2527
use rspirv::spirv::{Op, StorageClass, Word};
2628
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
@@ -152,6 +154,7 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
152154
std::fs::write(path, spirv_tools::binary::from_binary(&output.assemble())).unwrap();
153155
}
154156

157+
155158
// remove duplicates (https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp)
156159
{
157160
let _timer = sess.timer("link_remove_duplicates");
@@ -216,6 +219,22 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
216219
);
217220
}
218221

222+
// this is needed so we can inline more global variables
223+
{
224+
let _timer = sess.timer("simpl_op_store_var");
225+
simpl_op_store_var::simpl_op_store_var(sess, &mut output)?;
226+
}
227+
228+
{
229+
let _timer = sess.timer("link_inline_global");
230+
inline_globals::inline_global_varaibles(sess, &mut output)?;
231+
}
232+
233+
{
234+
let _timer = sess.timer("link_inline_global");
235+
inline_globals::inline_global_varaibles(sess, &mut output)?;
236+
}
237+
219238
{
220239
let _timer = sess.timer("link_inline");
221240
inline::inline(sess, &mut output)?;
@@ -302,6 +321,8 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
302321
simple_passes::sort_globals(&mut output);
303322
}
304323

324+
std::fs::write("res.txt", output.disassemble());
325+
305326
let mut output = if opts.emit_multiple_modules {
306327
let modules = output
307328
.entry_points
@@ -348,5 +369,6 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
348369
};
349370
}
350371

372+
351373
Ok(output)
352374
}

tests/ui/arch/inline_global.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// build-pass
2+
use spirv_std as _;
3+
4+
#[inline(never)]
5+
fn sdf(con: &mut [u32]) -> f32{
6+
(con[1] - con[0]) as f32
7+
}
8+
9+
#[spirv(fragment)]
10+
pub fn main(
11+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] runtime_array: &mut [u32],
12+
) {
13+
sdf(runtime_array);
14+
}

0 commit comments

Comments
 (0)