Skip to content

Commit a71cfc0

Browse files
committed
fix: support load store indirection
1 parent e597077 commit a71cfc0

File tree

2 files changed

+163
-65
lines changed

2 files changed

+163
-65
lines changed

crates/rustc_codegen_spirv/src/linker/inline_globals.rs

Lines changed: 154 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,80 @@
11
use rspirv::binary::Disassemble;
22
use rspirv::dr::{Instruction, Module, Operand};
3-
use rspirv::spirv::Op;
3+
use rspirv::spirv::{Op, StorageClass};
44
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
55
use rustc_session::Session;
66

7+
// bool is if this needs stored
8+
#[derive(Debug, Clone, PartialEq)]
9+
struct NormalizedInstructions {
10+
vars: Vec<Instruction>,
11+
insts: Vec<Instruction>,
12+
root: u32,
13+
}
14+
15+
impl NormalizedInstructions {
16+
fn new(id: u32) -> Self {
17+
NormalizedInstructions {
18+
vars: Vec::new(),
19+
insts: Vec::new(),
20+
root: id,
21+
}
22+
}
23+
24+
fn extend(&mut self, o: NormalizedInstructions) {
25+
self.vars.extend(o.vars);
26+
self.insts.extend(o.insts);
27+
}
28+
29+
fn is_empty(&self) -> bool {
30+
self.insts.is_empty() && self.vars.is_empty()
31+
}
32+
33+
fn fix_ids(&mut self, bound: &mut u32, new_root: u32) {
34+
let mut id_map: FxHashMap<u32, u32> = FxHashMap::default();
35+
id_map.insert(self.root, new_root);
36+
for inst in &mut self.vars {
37+
Self::fix_instruction(self.root, inst, &mut id_map, bound, new_root);
38+
}
39+
for inst in &mut self.insts {
40+
Self::fix_instruction(self.root, inst, &mut id_map, bound, new_root);
41+
}
42+
}
43+
44+
fn fix_instruction(
45+
root: u32,
46+
inst: &mut Instruction,
47+
id_map: &mut FxHashMap<u32, u32>,
48+
bound: &mut u32,
49+
new_root: u32,
50+
) {
51+
for op in &mut inst.operands {
52+
match op {
53+
Operand::IdRef(id) => match id_map.get(id) {
54+
Some(new_id) => {
55+
*id = *new_id;
56+
}
57+
_ => {}
58+
},
59+
_ => {}
60+
}
61+
}
62+
if let Some(id) = &mut inst.result_id {
63+
if *id != root {
64+
id_map.insert(*id, *bound);
65+
*id = *bound;
66+
*bound += 1;
67+
} else {
68+
*id = new_root;
69+
}
70+
}
71+
}
72+
}
73+
774
#[derive(Debug, Clone, PartialEq)]
875
enum FunctionArg {
976
Invalid,
10-
Insts(Vec<Instruction>),
77+
Insts(NormalizedInstructions),
1178
}
1279

1380
pub fn inline_global_varaibles(sess: &Session, module: &mut Module) -> super::Result<()> {
@@ -36,14 +103,30 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
36103
}
37104
// then we keep track of which function parameter are always called with the same expression that only uses global variables
38105
let mut function_args: FxHashMap<(u32, u32), FunctionArg> = FxHashMap::default();
106+
let mut bound = module.header.as_ref().unwrap().bound;
39107
for caller in &module.functions {
40108
let mut insts: FxHashMap<u32, Instruction> = FxHashMap::default();
109+
// for variables that only stored once and it's stored as a ref
110+
let mut ref_stores: FxHashMap<u32, Option<u32>> = FxHashMap::default();
41111
for block in &caller.blocks {
42112
for inst in &block.instructions {
43113
if inst.result_id.is_some() {
44114
insts.insert(inst.result_id.unwrap(), inst.clone());
45115
}
46-
if inst.class.opcode == Op::FunctionCall {
116+
if inst.class.opcode == Op::Store {
117+
if let Operand::IdRef(to) = inst.operands[0] {
118+
if let Operand::IdRef(from) = inst.operands[1] {
119+
match ref_stores.get(&to) {
120+
None => {
121+
ref_stores.insert(to, Some(from));
122+
}
123+
Some(_) => {
124+
ref_stores.insert(to, None);
125+
}
126+
}
127+
}
128+
}
129+
} else if inst.class.opcode == Op::FunctionCall {
47130
let function_id = match &inst.operands[0] {
48131
&Operand::IdRef(w) => w,
49132
_ => panic!(),
@@ -52,16 +135,19 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
52135
let key = (function_id, i as u32 - 1);
53136
match &inst.operands[i] {
54137
&Operand::IdRef(w) => match &function_args.get(&key) {
55-
None => match get_const_arg_insts(&variables, &insts, w) {
56-
Some(insts) => {
57-
function_args.insert(key, FunctionArg::Insts(insts));
58-
}
59-
None => {
60-
function_args.insert(key, FunctionArg::Invalid);
138+
None => {
139+
match get_const_arg_insts(bound, &variables, &insts, &ref_stores, w) {
140+
Some(insts) => {
141+
function_args.insert(key, FunctionArg::Insts(insts));
142+
}
143+
None => {
144+
function_args.insert(key, FunctionArg::Invalid);
145+
}
61146
}
62-
},
147+
}
63148
Some(FunctionArg::Insts(w2)) => {
64-
let new_insts = get_const_arg_insts(&variables, &insts, w);
149+
let new_insts =
150+
get_const_arg_insts(bound, &variables, &insts, &ref_stores, w);
65151
match new_insts {
66152
Some(new_insts) => {
67153
if new_insts != *w2 {
@@ -93,11 +179,10 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
93179
if function_args.is_empty() {
94180
return Ok(false);
95181
}
96-
let mut bound = module.header.as_ref().unwrap().bound;
97182
for function in &mut module.functions {
98183
let def = function.def.as_mut().unwrap();
99184
let fid = def.result_id.unwrap();
100-
let mut insts: Vec<Instruction> = Vec::new();
185+
let mut insts = NormalizedInstructions::new(0);
101186
let mut j: u32 = 0;
102187
let mut i = 0;
103188
let mut removed_indexes: Vec<u32> = Vec::new();
@@ -108,10 +193,7 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
108193
Some(FunctionArg::Insts(arg)) => {
109194
let parameter = function.parameters.remove(i);
110195
let mut arg = arg.clone();
111-
arg.reverse();
112-
insts_replacing_captured_ids(&mut arg, &mut bound);
113-
let index = arg.len() - 1;
114-
arg[index].result_id = parameter.result_id;
196+
arg.fix_ids(&mut bound, parameter.result_id.unwrap());
115197
insts.extend(arg);
116198
removed_indexes.push(j);
117199
removed = true;
@@ -132,15 +214,16 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
132214
for i in removed_indexes.iter().rev() {
133215
let i = *i as usize + 1;
134216
function_type.operands.remove(i);
135-
function_type.result_id = Some(tid);
136217
}
218+
function_type.result_id = Some(tid);
137219
def.operands[1] = Operand::IdRef(tid);
138220
module.types_global_values.push(function_type);
139221
}
140222
}
141223
// callee side. insert initialization instructions, which reuse the ids of the removed parameters
142224
if !function.blocks.is_empty() {
143225
let first_block = &mut function.blocks[0];
226+
first_block.instructions.splice(0..0, insts.vars);
144227
// skip some instructions that must be at top of block
145228
let mut i = 0;
146229
loop {
@@ -154,7 +237,7 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
154237
}
155238
i += 1;
156239
}
157-
first_block.instructions.splice(i..i, insts);
240+
first_block.instructions.splice(i..i, insts.insts);
158241
}
159242
// caller side, remove parameters from function call
160243
for block in &mut function.blocks {
@@ -181,76 +264,90 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
181264
Ok(true)
182265
}
183266

184-
fn insts_replacing_captured_ids(arg: &mut Vec<Instruction>, bound: &mut u32) {
185-
let mut id_map: FxHashMap<u32, u32> = FxHashMap::default();
186-
for ins in arg {
187-
if let Some(id) = &mut ins.result_id {
188-
for op in &mut ins.operands {
189-
match op {
190-
Operand::IdRef(id) => match id_map.get(id) {
191-
Some(new_id) => {
192-
*id = *new_id;
193-
}
194-
_ => {}
195-
},
196-
_ => {}
197-
}
198-
}
199-
id_map.insert(*id, *bound);
200-
*id = *bound;
201-
*bound += 1;
202-
}
203-
}
204-
}
205-
206267
fn get_const_arg_operands(
207268
variables: &FxHashSet<u32>,
208269
insts: &FxHashMap<u32, Instruction>,
270+
ref_stores: &FxHashMap<u32, Option<u32>>,
209271
operand: &Operand,
210-
) -> Option<Vec<Instruction>> {
272+
) -> Option<NormalizedInstructions> {
211273
match operand {
212274
Operand::IdRef(id) => {
213-
let insts = get_const_arg_insts(variables, insts, *id)?;
275+
let insts = get_const_arg_insts_rec(variables, insts, ref_stores, *id)?;
214276
return Some(insts);
215277
}
216-
Operand::LiteralInt32(_) => {},
217-
Operand::LiteralInt64(_) => {},
218-
Operand::LiteralFloat32(_) => {},
219-
Operand::LiteralFloat64(_) => {},
220-
Operand::LiteralExtInstInteger(_) => {},
221-
Operand::LiteralSpecConstantOpInteger(_) => {},
222-
Operand::LiteralString(_) => {},
278+
Operand::LiteralInt32(_) => {}
279+
Operand::LiteralInt64(_) => {}
280+
Operand::LiteralFloat32(_) => {}
281+
Operand::LiteralFloat64(_) => {}
282+
Operand::LiteralExtInstInteger(_) => {}
283+
Operand::LiteralSpecConstantOpInteger(_) => {}
284+
Operand::LiteralString(_) => {}
223285
_ => {
224286
// TOOD add more cases
225287
return None;
226288
}
227289
}
228-
return Some(Vec::new());
290+
return Some(NormalizedInstructions::new(0));
229291
}
230292

231293
fn get_const_arg_insts(
294+
mut bound: u32,
295+
variables: &FxHashSet<u32>,
296+
insts: &FxHashMap<u32, Instruction>,
297+
ref_stores: &FxHashMap<u32, Option<u32>>,
298+
id: u32,
299+
) -> Option<NormalizedInstructions> {
300+
let mut res = get_const_arg_insts_rec(variables, insts, ref_stores, id)?;
301+
res.insts.reverse();
302+
// the bound passed in is always the same
303+
// we need to normalize the ids, so they are the same when compared
304+
let fake_root = bound;
305+
bound += 1;
306+
res.fix_ids(&mut bound, fake_root);
307+
res.root = fake_root;
308+
Some(res)
309+
}
310+
311+
fn get_const_arg_insts_rec(
232312
variables: &FxHashSet<u32>,
233313
insts: &FxHashMap<u32, Instruction>,
314+
ref_stores: &FxHashMap<u32, Option<u32>>,
234315
id: u32,
235-
) -> Option<Vec<Instruction>> {
236-
let mut result: Vec<Instruction> = Vec::new();
316+
) -> Option<NormalizedInstructions> {
317+
let mut result = NormalizedInstructions::new(id);
237318
if variables.contains(&id) {
238319
return Some(result);
239320
}
240321
let par: &Instruction = insts.get(&id)?;
241322
if par.class.opcode == Op::AccessChain {
242-
result.push(par.clone());
323+
result.insts.push(par.clone());
243324
for oprand in &par.operands {
244-
let insts = get_const_arg_operands(variables, insts, oprand)?;
325+
let insts = get_const_arg_operands(variables, insts, ref_stores, oprand)?;
245326
result.extend(insts);
246327
}
247328
} else if par.class.opcode == Op::FunctionCall {
248-
result.push(par.clone());
329+
result.insts.push(par.clone());
249330
// skip first, first is function id
250331
for oprand in &par.operands[1..] {
251-
let insts = get_const_arg_operands(variables, insts, oprand)?;
332+
let insts = get_const_arg_operands(variables, insts, ref_stores, oprand)?;
252333
result.extend(insts);
253334
}
335+
} else if par.class.opcode == Op::Variable {
336+
result.vars.push(par.clone());
337+
let stored = ref_stores.get(&id)?;
338+
let stored = (*stored)?;
339+
result.insts.push(Instruction::new(
340+
Op::Store,
341+
None,
342+
None,
343+
vec![Operand::IdRef(id), Operand::IdRef(stored)],
344+
));
345+
let new_insts = get_const_arg_insts_rec(variables, insts, ref_stores, stored)?;
346+
result.extend(new_insts);
347+
} else if par.class.opcode == Op::ArrayLength {
348+
result.insts.push(par.clone());
349+
let insts = get_const_arg_operands(variables, insts, ref_stores, &par.operands[0])?;
350+
result.extend(insts);
254351
} else {
255352
// TOOD add more cases
256353
return None;

crates/rustc_codegen_spirv/src/linker/mod.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -220,19 +220,20 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
220220
}
221221

222222
// this is needed so we can inline more global variables
223-
{
224-
let _timer = sess.timer("simpl_op_store_var");
225-
simpl_op_store_var::simpl_op_store_var(sess, &mut output)?;
226-
}
223+
// {
224+
// let _timer = sess.timer("simpl_op_store_var");
225+
// simpl_op_store_var::simpl_op_store_var(sess, &mut output)?;
226+
// }
227227

228228
{
229229
let _timer = sess.timer("link_inline_global");
230230
inline_globals::inline_global_varaibles(sess, &mut output)?;
231231
}
232232

233+
// needed because inline global create duplicate types...
233234
{
234-
let _timer = sess.timer("link_inline_global");
235-
inline_globals::inline_global_varaibles(sess, &mut output)?;
235+
let _timer = sess.timer("link_remove_duplicate_types_round_2");
236+
duplicates::remove_duplicate_types(&mut output);
236237
}
237238

238239
{
@@ -320,8 +321,8 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
320321
let _timer = sess.timer("link_sort_globals");
321322
simple_passes::sort_globals(&mut output);
322323
}
323-
324-
std::fs::write("res.txt", output.disassemble());
324+
325+
// std::fs::write("res.txt", output.disassemble());
325326

326327
let mut output = if opts.emit_multiple_modules {
327328
let modules = output

0 commit comments

Comments
 (0)