Skip to content

Commit 990425b

Browse files
committed
linker/inline: pre-cache calculation of invalid args
Since during inlining, the only escaping value is the return value, we can calculate and update whether it has an invalid-to-call-with value as well. (Note that this is, strictly speaking, more rigor than get_invalid_values() applies, because it doesn't look behind OpPhis) As a nice bonus, we got rid of OpLoad/OpStore in favor of OpPhi, which means no type mucking and no work created for mem2reg.
1 parent 58deeab commit 990425b

File tree

1 file changed

+63
-105
lines changed
  • crates/rustc_codegen_spirv/src/linker

1 file changed

+63
-105
lines changed

crates/rustc_codegen_spirv/src/linker/inline.rs

Lines changed: 63 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -37,28 +37,17 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
3737
.find(|inst| inst.class.opcode == Op::TypeVoid)
3838
.map(|inst| inst.result_id.unwrap())
3939
.unwrap_or(0);
40-
let ptr_map: FxHashMap<_, _> = module
41-
.types_global_values
42-
.iter()
43-
.filter_map(|inst| {
44-
if inst.class.opcode == Op::TypePointer
45-
&& inst.operands[0].unwrap_storage_class() == StorageClass::Function
46-
{
47-
Some((inst.operands[1].unwrap_id_ref(), inst.result_id.unwrap()))
48-
} else {
49-
None
50-
}
51-
})
52-
.collect();
40+
41+
let invalid_args = module.functions.iter().flat_map(get_invalid_args).collect();
42+
5343
// Drop all the functions we'll be inlining. (This also means we won't waste time processing
5444
// inlines in functions that will get inlined)
5545
let mut inliner = Inliner {
5646
header: module.header.as_mut().unwrap(),
57-
types_global_values: &mut module.types_global_values,
5847
void,
59-
ptr_map,
6048
functions: &functions,
6149
needs_inline: &to_delete,
50+
invalid_args,
6251
};
6352
for index in postorder {
6453
inliner.inline_fn(&mut module.functions, index);
@@ -270,20 +259,21 @@ fn should_inline(
270259
// This should be more general, but a very common problem is passing an OpAccessChain to an
271260
// OpFunctionCall (i.e. `f(&s.x)`, or more commonly, `s.x.f()` where `f` takes `&self`), so detect
272261
// that case and inline the call.
273-
fn args_invalid(function: &Function, call: &Instruction) -> bool {
274-
for inst in function.all_inst_iter() {
262+
fn get_invalid_args<'a>(function: &'a Function) -> impl Iterator<Item = Word> + 'a {
263+
function.all_inst_iter().filter_map(|inst| {
275264
if inst.class.opcode == Op::AccessChain {
276-
let inst_result = inst.result_id.unwrap();
277-
if call
278-
.operands
279-
.iter()
280-
.any(|op| *op == Operand::IdRef(inst_result))
281-
{
282-
return true;
283-
}
265+
inst.result_id
266+
} else {
267+
None
284268
}
285-
}
286-
false
269+
})
270+
}
271+
272+
fn args_invalid(invalid_args: &FxHashSet<Word>, call: &Instruction) -> bool {
273+
call.operands.iter().skip(1).any(|op| {
274+
op.id_ref_any()
275+
.map_or(false, |arg| invalid_args.contains(&arg))
276+
})
287277
}
288278

289279
// Steps:
@@ -294,11 +284,10 @@ fn args_invalid(function: &Function, call: &Instruction) -> bool {
294284

295285
struct Inliner<'m, 'map> {
296286
header: &'m mut ModuleHeader,
297-
types_global_values: &'m mut Vec<Instruction>,
298287
void: Word,
299-
ptr_map: FxHashMap<Word, Word>,
300288
functions: &'map FunctionMap,
301289
needs_inline: &'map [bool],
290+
invalid_args: FxHashSet<Word>,
302291
}
303292

304293
impl Inliner<'_, '_> {
@@ -308,25 +297,6 @@ impl Inliner<'_, '_> {
308297
result
309298
}
310299

311-
fn ptr_ty(&mut self, pointee: Word) -> Word {
312-
let existing = self.ptr_map.get(&pointee);
313-
if let Some(existing) = existing {
314-
return *existing;
315-
}
316-
let inst_id = self.id();
317-
self.types_global_values.push(Instruction::new(
318-
Op::TypePointer,
319-
None,
320-
Some(inst_id),
321-
vec![
322-
Operand::StorageClass(StorageClass::Function),
323-
Operand::IdRef(pointee),
324-
],
325-
));
326-
self.ptr_map.insert(pointee, inst_id);
327-
inst_id
328-
}
329-
330300
fn inline_fn(&mut self, functions: &mut [Function], index: usize) {
331301
let mut function = take(&mut functions[index]);
332302
let mut block_idx = 0;
@@ -361,8 +331,8 @@ impl Inliner<'_, '_> {
361331
self.functions[&inst.operands[0].id_ref_any().unwrap()],
362332
)
363333
})
364-
.find(|(index, inst, func_idx)| {
365-
self.needs_inline[*func_idx] || args_invalid(caller, inst)
334+
.find(|(_, inst, func_idx)| {
335+
self.needs_inline[*func_idx] || args_invalid(&self.invalid_args, inst)
366336
});
367337
let (call_index, call_inst, callee_idx) = match call {
368338
None => return false,
@@ -390,18 +360,23 @@ impl Inliner<'_, '_> {
390360
});
391361
let mut rewrite_rules = callee_parameters.zip(call_arguments).collect();
392362

393-
let return_variable = if call_result_type.is_some() {
394-
Some(self.id())
395-
} else {
396-
None
397-
};
398363
let return_jump = self.id();
399364
// Rewrite OpReturns of the callee.
400-
let mut inlined_blocks = get_inlined_blocks(callee, return_variable, return_jump);
365+
let (mut inlined_blocks, phi_pairs) = get_inlined_blocks(callee, return_jump);
401366
// Clone the IDs of the callee, because otherwise they'd be defined multiple times if the
402367
// fn is inlined multiple times.
403368
self.add_clone_id_rules(&mut rewrite_rules, &inlined_blocks);
369+
// If any of the OpReturns were invalid, return will also be invalid.
370+
for (value, _) in &phi_pairs {
371+
if self.invalid_args.contains(value) {
372+
self.invalid_args.insert(call_result_id);
373+
self.invalid_args
374+
.insert(*rewrite_rules.get(value).unwrap_or(value));
375+
}
376+
}
404377
apply_rewrite_rules(&rewrite_rules, &mut inlined_blocks);
378+
// unnecessary: invalidate_more_args(&rewrite_rules, &mut self.invalid_args);
379+
// as no values from inside the inlined function ever make it directly out.
405380

406381
// Split the block containing the OpFunctionCall into two, around the call.
407382
let mut post_call_block_insts = caller.blocks[block_idx]
@@ -411,32 +386,27 @@ impl Inliner<'_, '_> {
411386
let call = caller.blocks[block_idx].instructions.pop().unwrap();
412387
assert!(call.class.opcode == Op::FunctionCall);
413388

414-
if let Some(call_result_type) = call_result_type {
415-
// Generate the storage space for the return value: Do this *after* the split above,
416-
// because if block_idx=0, inserting a variable here shifts call_index.
417-
insert_opvariable(
418-
&mut caller.blocks[0],
419-
self.ptr_ty(call_result_type),
420-
return_variable.unwrap(),
421-
);
422-
}
423-
424389
// Move the variables over from the inlined function to here.
425390
let mut callee_header = take(&mut inlined_blocks[0]).instructions;
426391
// TODO: OpLine handling
427392
let num_variables = callee_header.partition_point(|inst| inst.class.opcode == Op::Variable);
428393
// Rather than fuse blocks, generate a new jump here. Branch fusing will take care of
429394
// it, and we maintain the invariant that current block has finished processing.
430-
let jump_to = self.id();
395+
let first_block_id = self.id();
431396
inlined_blocks[0] = Block {
432-
label: Some(Instruction::new(Op::Label, None, Some(jump_to), vec![])),
397+
label: Some(Instruction::new(
398+
Op::Label,
399+
None,
400+
Some(first_block_id),
401+
vec![],
402+
)),
433403
instructions: callee_header.split_off(num_variables),
434404
};
435405
caller.blocks[block_idx].instructions.push(Instruction::new(
436406
Op::Branch,
437407
None,
438408
None,
439-
vec![Operand::IdRef(jump_to)],
409+
vec![Operand::IdRef(first_block_id)],
440410
));
441411
// Move the OpVariables of the callee to the caller.
442412
insert_opvariables(&mut caller.blocks[0], callee_header);
@@ -447,10 +417,17 @@ impl Inliner<'_, '_> {
447417
post_call_block_insts.insert(
448418
0,
449419
Instruction::new(
450-
Op::Load,
420+
Op::Phi,
451421
Some(call_result_type),
452422
Some(call_result_id),
453-
vec![Operand::IdRef(return_variable.unwrap())],
423+
phi_pairs
424+
.into_iter()
425+
.flat_map(|(value, parent)| {
426+
use std::iter;
427+
iter::once(Operand::IdRef(*rewrite_rules.get(&value).unwrap_or(&value)))
428+
.chain(iter::once(Operand::IdRef(rewrite_rules[&parent])))
429+
})
430+
.collect(),
454431
),
455432
);
456433
}
@@ -483,51 +460,21 @@ impl Inliner<'_, '_> {
483460
}
484461
}
485462

486-
fn get_inlined_blocks(
487-
function: &Function,
488-
return_variable: Option<Word>,
489-
return_jump: Word,
490-
) -> Vec<Block> {
463+
fn get_inlined_blocks(function: &Function, return_jump: Word) -> (Vec<Block>, Vec<(Word, Word)>) {
491464
let mut blocks = function.blocks.clone();
465+
let mut phipairs = Vec::new();
492466
for block in &mut blocks {
493467
let last = block.instructions.last().unwrap();
494468
if let Op::Return | Op::ReturnValue = last.class.opcode {
495469
if Op::ReturnValue == last.class.opcode {
496470
let return_value = last.operands[0].id_ref_any().unwrap();
497-
block.instructions.insert(
498-
block.instructions.len() - 1,
499-
Instruction::new(
500-
Op::Store,
501-
None,
502-
None,
503-
vec![
504-
Operand::IdRef(return_variable.unwrap()),
505-
Operand::IdRef(return_value),
506-
],
507-
),
508-
);
509-
} else {
510-
assert!(return_variable.is_none());
471+
phipairs.push((return_value, block.label_id().unwrap()))
511472
}
512473
*block.instructions.last_mut().unwrap() =
513474
Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(return_jump)]);
514475
}
515476
}
516-
blocks
517-
}
518-
519-
fn insert_opvariable(block: &mut Block, ptr_ty: Word, result_id: Word) {
520-
let index = block
521-
.instructions
522-
.partition_point(|inst| inst.class.opcode == Op::Variable);
523-
524-
let inst = Instruction::new(
525-
Op::Variable,
526-
Some(ptr_ty),
527-
Some(result_id),
528-
vec![Operand::StorageClass(StorageClass::Function)],
529-
);
530-
block.instructions.insert(index, inst)
477+
(blocks, phipairs)
531478
}
532479

533480
fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
@@ -539,6 +486,7 @@ fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
539486

540487
fn fuse_trivial_branches(function: &mut Function) {
541488
let all_preds = compute_preds(&function.blocks);
489+
let mut rewrite_rules = FxHashMap::default();
542490
'outer: for (dest_block, mut preds) in all_preds.iter().enumerate() {
543491
// if there's two trivial branches in a row, the middle one might get inlined before the
544492
// last one, so when processing the last one, skip through to the first one.
@@ -555,12 +503,22 @@ fn fuse_trivial_branches(function: &mut Function) {
555503
let pred_insts = &function.blocks[pred].instructions;
556504
if pred_insts.last().unwrap().class.opcode == Op::Branch {
557505
let mut dest_insts = take(&mut function.blocks[dest_block].instructions);
506+
dest_insts.retain(|inst| {
507+
if inst.class.opcode == Op::Phi {
508+
assert_eq!(inst.operands.len(), 2);
509+
rewrite_rules.insert(inst.result_id.unwrap(), inst.operands[0].unwrap_id_ref());
510+
false
511+
} else {
512+
true
513+
}
514+
});
558515
let pred_insts = &mut function.blocks[pred].instructions;
559516
pred_insts.pop(); // pop the branch
560517
pred_insts.append(&mut dest_insts);
561518
}
562519
}
563520
function.blocks.retain(|b| !b.instructions.is_empty());
521+
apply_rewrite_rules(&rewrite_rules, &mut function.blocks);
564522
}
565523

566524
fn compute_preds(blocks: &[Block]) -> Vec<Vec<usize>> {

0 commit comments

Comments
 (0)