Skip to content

Commit 41a6089

Browse files
committed
linker/inline: pre-cache calculation of invalid args
Since during inlining, the only escaping value is the return value, we can calculate and update whether it has an invalid-to-call-with value as well. (Note that this is, strictly speaking, more rigor than get_invalid_values() applies, because it doesn't look behind OpPhis) As a nice bonus, we got rid of OpLoad/OpStore in favor of OpPhi, which means no type mucking and no work created for mem2reg.
1 parent 99f25a6 commit 41a6089

File tree

1 file changed

+63
-105
lines changed
  • crates/rustc_codegen_spirv/src/linker

1 file changed

+63
-105
lines changed

crates/rustc_codegen_spirv/src/linker/inline.rs

Lines changed: 63 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -37,28 +37,17 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
3737
.iter()
3838
.find(|inst| inst.class.opcode == Op::TypeVoid)
3939
.map_or(0, |inst| inst.result_id.unwrap());
40-
let ptr_map: FxHashMap<_, _> = module
41-
.types_global_values
42-
.iter()
43-
.filter_map(|inst| {
44-
if inst.class.opcode == Op::TypePointer
45-
&& inst.operands[0].unwrap_storage_class() == StorageClass::Function
46-
{
47-
Some((inst.operands[1].unwrap_id_ref(), inst.result_id.unwrap()))
48-
} else {
49-
None
50-
}
51-
})
52-
.collect();
40+
41+
let invalid_args = module.functions.iter().flat_map(get_invalid_args).collect();
42+
5343
// Drop all the functions we'll be inlining. (This also means we won't waste time processing
5444
// inlines in functions that will get inlined)
5545
let mut inliner = Inliner {
5646
header: module.header.as_mut().unwrap(),
57-
types_global_values: &mut module.types_global_values,
5847
void,
59-
ptr_map,
6048
functions: &functions,
6149
needs_inline: &to_delete,
50+
invalid_args,
6251
};
6352
for index in postorder {
6453
inliner.inline_fn(&mut module.functions, index);
@@ -268,20 +257,21 @@ fn should_inline(
268257
// This should be more general, but a very common problem is passing an OpAccessChain to an
269258
// OpFunctionCall (i.e. `f(&s.x)`, or more commonly, `s.x.f()` where `f` takes `&self`), so detect
270259
// that case and inline the call.
271-
fn args_invalid(function: &Function, call: &Instruction) -> bool {
272-
for inst in function.all_inst_iter() {
260+
fn get_invalid_args<'a>(function: &'a Function) -> impl Iterator<Item = Word> + 'a {
261+
function.all_inst_iter().filter_map(|inst| {
273262
if inst.class.opcode == Op::AccessChain {
274-
let inst_result = inst.result_id.unwrap();
275-
if call
276-
.operands
277-
.iter()
278-
.any(|op| *op == Operand::IdRef(inst_result))
279-
{
280-
return true;
281-
}
263+
inst.result_id
264+
} else {
265+
None
282266
}
283-
}
284-
false
267+
})
268+
}
269+
270+
fn args_invalid(invalid_args: &FxHashSet<Word>, call: &Instruction) -> bool {
271+
call.operands.iter().skip(1).any(|op| {
272+
op.id_ref_any()
273+
.map_or(false, |arg| invalid_args.contains(&arg))
274+
})
285275
}
286276

287277
// Steps:
@@ -292,11 +282,10 @@ fn args_invalid(function: &Function, call: &Instruction) -> bool {
292282

293283
struct Inliner<'m, 'map> {
294284
header: &'m mut ModuleHeader,
295-
types_global_values: &'m mut Vec<Instruction>,
296285
void: Word,
297-
ptr_map: FxHashMap<Word, Word>,
298286
functions: &'map FunctionMap,
299287
needs_inline: &'map [bool],
288+
invalid_args: FxHashSet<Word>,
300289
}
301290

302291
impl Inliner<'_, '_> {
@@ -306,25 +295,6 @@ impl Inliner<'_, '_> {
306295
result
307296
}
308297

309-
fn ptr_ty(&mut self, pointee: Word) -> Word {
310-
let existing = self.ptr_map.get(&pointee);
311-
if let Some(existing) = existing {
312-
return *existing;
313-
}
314-
let inst_id = self.id();
315-
self.types_global_values.push(Instruction::new(
316-
Op::TypePointer,
317-
None,
318-
Some(inst_id),
319-
vec![
320-
Operand::StorageClass(StorageClass::Function),
321-
Operand::IdRef(pointee),
322-
],
323-
));
324-
self.ptr_map.insert(pointee, inst_id);
325-
inst_id
326-
}
327-
328298
fn inline_fn(&mut self, functions: &mut [Function], index: usize) {
329299
let mut function = take(&mut functions[index]);
330300
let mut block_idx = 0;
@@ -359,8 +329,8 @@ impl Inliner<'_, '_> {
359329
self.functions[&inst.operands[0].id_ref_any().unwrap()],
360330
)
361331
})
362-
.find(|(index, inst, func_idx)| {
363-
self.needs_inline[*func_idx] || args_invalid(caller, inst)
332+
.find(|(_, inst, func_idx)| {
333+
self.needs_inline[*func_idx] || args_invalid(&self.invalid_args, inst)
364334
});
365335
let (call_index, call_inst, callee_idx) = match call {
366336
None => return false,
@@ -388,18 +358,23 @@ impl Inliner<'_, '_> {
388358
});
389359
let mut rewrite_rules = callee_parameters.zip(call_arguments).collect();
390360

391-
let return_variable = if call_result_type.is_some() {
392-
Some(self.id())
393-
} else {
394-
None
395-
};
396361
let return_jump = self.id();
397362
// Rewrite OpReturns of the callee.
398-
let mut inlined_blocks = get_inlined_blocks(callee, return_variable, return_jump);
363+
let (mut inlined_blocks, phi_pairs) = get_inlined_blocks(callee, return_jump);
399364
// Clone the IDs of the callee, because otherwise they'd be defined multiple times if the
400365
// fn is inlined multiple times.
401366
self.add_clone_id_rules(&mut rewrite_rules, &inlined_blocks);
367+
// If any of the OpReturns were invalid, return will also be invalid.
368+
for (value, _) in &phi_pairs {
369+
if self.invalid_args.contains(value) {
370+
self.invalid_args.insert(call_result_id);
371+
self.invalid_args
372+
.insert(*rewrite_rules.get(value).unwrap_or(value));
373+
}
374+
}
402375
apply_rewrite_rules(&rewrite_rules, &mut inlined_blocks);
376+
// unnecessary: invalidate_more_args(&rewrite_rules, &mut self.invalid_args);
377+
// as no values from inside the inlined function ever make it directly out.
403378

404379
// Split the block containing the OpFunctionCall into two, around the call.
405380
let mut post_call_block_insts = caller.blocks[block_idx]
@@ -409,32 +384,27 @@ impl Inliner<'_, '_> {
409384
let call = caller.blocks[block_idx].instructions.pop().unwrap();
410385
assert!(call.class.opcode == Op::FunctionCall);
411386

412-
if let Some(call_result_type) = call_result_type {
413-
// Generate the storage space for the return value: Do this *after* the split above,
414-
// because if block_idx=0, inserting a variable here shifts call_index.
415-
insert_opvariable(
416-
&mut caller.blocks[0],
417-
self.ptr_ty(call_result_type),
418-
return_variable.unwrap(),
419-
);
420-
}
421-
422387
// Move the variables over from the inlined function to here.
423388
let mut callee_header = take(&mut inlined_blocks[0]).instructions;
424389
// TODO: OpLine handling
425390
let num_variables = callee_header.partition_point(|inst| inst.class.opcode == Op::Variable);
426391
// Rather than fuse blocks, generate a new jump here. Branch fusing will take care of
427392
// it, and we maintain the invariant that current block has finished processing.
428-
let jump_to = self.id();
393+
let first_block_id = self.id();
429394
inlined_blocks[0] = Block {
430-
label: Some(Instruction::new(Op::Label, None, Some(jump_to), vec![])),
395+
label: Some(Instruction::new(
396+
Op::Label,
397+
None,
398+
Some(first_block_id),
399+
vec![],
400+
)),
431401
instructions: callee_header.split_off(num_variables),
432402
};
433403
caller.blocks[block_idx].instructions.push(Instruction::new(
434404
Op::Branch,
435405
None,
436406
None,
437-
vec![Operand::IdRef(jump_to)],
407+
vec![Operand::IdRef(first_block_id)],
438408
));
439409
// Move the OpVariables of the callee to the caller.
440410
insert_opvariables(&mut caller.blocks[0], callee_header);
@@ -445,10 +415,17 @@ impl Inliner<'_, '_> {
445415
post_call_block_insts.insert(
446416
0,
447417
Instruction::new(
448-
Op::Load,
418+
Op::Phi,
449419
Some(call_result_type),
450420
Some(call_result_id),
451-
vec![Operand::IdRef(return_variable.unwrap())],
421+
phi_pairs
422+
.into_iter()
423+
.flat_map(|(value, parent)| {
424+
use std::iter;
425+
iter::once(Operand::IdRef(*rewrite_rules.get(&value).unwrap_or(&value)))
426+
.chain(iter::once(Operand::IdRef(rewrite_rules[&parent])))
427+
})
428+
.collect(),
452429
),
453430
);
454431
}
@@ -481,51 +458,21 @@ impl Inliner<'_, '_> {
481458
}
482459
}
483460

484-
fn get_inlined_blocks(
485-
function: &Function,
486-
return_variable: Option<Word>,
487-
return_jump: Word,
488-
) -> Vec<Block> {
461+
fn get_inlined_blocks(function: &Function, return_jump: Word) -> (Vec<Block>, Vec<(Word, Word)>) {
489462
let mut blocks = function.blocks.clone();
463+
let mut phipairs = Vec::new();
490464
for block in &mut blocks {
491465
let last = block.instructions.last().unwrap();
492466
if let Op::Return | Op::ReturnValue = last.class.opcode {
493467
if Op::ReturnValue == last.class.opcode {
494468
let return_value = last.operands[0].id_ref_any().unwrap();
495-
block.instructions.insert(
496-
block.instructions.len() - 1,
497-
Instruction::new(
498-
Op::Store,
499-
None,
500-
None,
501-
vec![
502-
Operand::IdRef(return_variable.unwrap()),
503-
Operand::IdRef(return_value),
504-
],
505-
),
506-
);
507-
} else {
508-
assert!(return_variable.is_none());
469+
phipairs.push((return_value, block.label_id().unwrap()))
509470
}
510471
*block.instructions.last_mut().unwrap() =
511472
Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(return_jump)]);
512473
}
513474
}
514-
blocks
515-
}
516-
517-
fn insert_opvariable(block: &mut Block, ptr_ty: Word, result_id: Word) {
518-
let index = block
519-
.instructions
520-
.partition_point(|inst| inst.class.opcode == Op::Variable);
521-
522-
let inst = Instruction::new(
523-
Op::Variable,
524-
Some(ptr_ty),
525-
Some(result_id),
526-
vec![Operand::StorageClass(StorageClass::Function)],
527-
);
528-
block.instructions.insert(index, inst)
475+
(blocks, phipairs)
529476
}
530477

531478
fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
@@ -537,6 +484,7 @@ fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
537484

538485
fn fuse_trivial_branches(function: &mut Function) {
539486
let all_preds = compute_preds(&function.blocks);
487+
let mut rewrite_rules = FxHashMap::default();
540488
'outer: for (dest_block, mut preds) in all_preds.iter().enumerate() {
541489
// if there's two trivial branches in a row, the middle one might get inlined before the
542490
// last one, so when processing the last one, skip through to the first one.
@@ -553,12 +501,22 @@ fn fuse_trivial_branches(function: &mut Function) {
553501
let pred_insts = &function.blocks[pred].instructions;
554502
if pred_insts.last().unwrap().class.opcode == Op::Branch {
555503
let mut dest_insts = take(&mut function.blocks[dest_block].instructions);
504+
dest_insts.retain(|inst| {
505+
if inst.class.opcode == Op::Phi {
506+
assert_eq!(inst.operands.len(), 2);
507+
rewrite_rules.insert(inst.result_id.unwrap(), inst.operands[0].unwrap_id_ref());
508+
false
509+
} else {
510+
true
511+
}
512+
});
556513
let pred_insts = &mut function.blocks[pred].instructions;
557514
pred_insts.pop(); // pop the branch
558515
pred_insts.append(&mut dest_insts);
559516
}
560517
}
561518
function.blocks.retain(|b| !b.instructions.is_empty());
519+
apply_rewrite_rules(&rewrite_rules, &mut function.blocks);
562520
}
563521

564522
fn compute_preds(blocks: &[Block]) -> Vec<Vec<usize>> {

0 commit comments

Comments
 (0)