Skip to content

Commit 910a149

Browse files
committed
linker/inline: use variables instead of OpPhis to unify branches.
This partially reverts commit 990425b.
1 parent b3afbb9 commit 910a149

File tree

1 file changed

+95
-48
lines changed
  • crates/rustc_codegen_spirv/src/linker

1 file changed

+95
-48
lines changed

crates/rustc_codegen_spirv/src/linker/inline.rs

Lines changed: 95 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,25 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
3737
.find(|inst| inst.class.opcode == Op::TypeVoid)
3838
.map(|inst| inst.result_id.unwrap())
3939
.unwrap_or(0);
40-
40+
let ptr_map: FxHashMap<_, _> = module
41+
.types_global_values
42+
.iter()
43+
.filter_map(|inst| {
44+
if inst.class.opcode == Op::TypePointer
45+
&& inst.operands[0].unwrap_storage_class() == StorageClass::Function
46+
{
47+
Some((inst.operands[1].unwrap_id_ref(), inst.result_id.unwrap()))
48+
} else {
49+
None
50+
}
51+
})
52+
.collect();
4153
let invalid_args = module.functions.iter().flat_map(get_invalid_args).collect();
42-
43-
// Drop all the functions we'll be inlining. (This also means we won't waste time processing
44-
// inlines in functions that will get inlined)
4554
let mut inliner = Inliner {
4655
header: module.header.as_mut().unwrap(),
56+
types_global_values: &mut module.types_global_values,
4757
void,
58+
ptr_map,
4859
functions: &functions,
4960
needs_inline: &to_delete,
5061
invalid_args,
@@ -272,7 +283,9 @@ fn args_invalid(invalid_args: &FxHashSet<Word>, call: &Instruction) -> bool {
272283

273284
struct Inliner<'m, 'map> {
274285
header: &'m mut ModuleHeader,
286+
types_global_values: &'m mut Vec<Instruction>,
275287
void: Word,
288+
ptr_map: FxHashMap<Word, Word>,
276289
functions: &'map FunctionMap,
277290
needs_inline: &'map [bool],
278291
invalid_args: FxHashSet<Word>,
@@ -285,6 +298,25 @@ impl Inliner<'_, '_> {
285298
result
286299
}
287300

301+
fn ptr_ty(&mut self, pointee: Word) -> Word {
302+
let existing = self.ptr_map.get(&pointee);
303+
if let Some(existing) = existing {
304+
return *existing;
305+
}
306+
let inst_id = self.id();
307+
self.types_global_values.push(Instruction::new(
308+
Op::TypePointer,
309+
None,
310+
Some(inst_id),
311+
vec![
312+
Operand::StorageClass(StorageClass::Function),
313+
Operand::IdRef(pointee),
314+
],
315+
));
316+
self.ptr_map.insert(pointee, inst_id);
317+
inst_id
318+
}
319+
288320
fn inline_fn(&mut self, functions: &mut [Function], index: usize) {
289321
let mut function = take(&mut functions[index]);
290322
let mut block_idx = 0;
@@ -348,23 +380,27 @@ impl Inliner<'_, '_> {
348380
});
349381
let mut rewrite_rules = callee_parameters.zip(call_arguments).collect();
350382

383+
let return_variable = if call_result_type.is_some() {
384+
Some(self.id())
385+
} else {
386+
None
387+
};
351388
let return_jump = self.id();
352389
// Rewrite OpReturns of the callee.
353-
let (mut inlined_blocks, phi_pairs) = get_inlined_blocks(callee, return_jump);
390+
let (mut inlined_blocks, return_values) =
391+
get_inlined_blocks(callee, return_variable, return_jump);
354392
// Clone the IDs of the callee, because otherwise they'd be defined multiple times if the
355393
// fn is inlined multiple times.
356394
self.add_clone_id_rules(&mut rewrite_rules, &inlined_blocks);
357395
// If any of the OpReturns were invalid, return will also be invalid.
358-
for (value, _) in &phi_pairs {
396+
for value in &return_values {
359397
if self.invalid_args.contains(value) {
360398
self.invalid_args.insert(call_result_id);
361399
self.invalid_args
362400
.insert(*rewrite_rules.get(value).unwrap_or(value));
363401
}
364402
}
365403
apply_rewrite_rules(&rewrite_rules, &mut inlined_blocks);
366-
// unnecessary: invalidate_more_args(&rewrite_rules, &mut self.invalid_args);
367-
// as no values from inside the inlined function ever make it directly out.
368404

369405
// Split the block containing the OpFunctionCall into two, around the call.
370406
let mut post_call_block_insts = caller.blocks[block_idx]
@@ -374,27 +410,32 @@ impl Inliner<'_, '_> {
374410
let call = caller.blocks[block_idx].instructions.pop().unwrap();
375411
assert!(call.class.opcode == Op::FunctionCall);
376412

413+
if let Some(call_result_type) = call_result_type {
414+
// Generate the storage space for the return value: Do this *after* the split above,
415+
// because if block_idx=0, inserting a variable here shifts call_index.
416+
insert_opvariable(
417+
&mut caller.blocks[0],
418+
self.ptr_ty(call_result_type),
419+
return_variable.unwrap(),
420+
);
421+
}
422+
377423
// Move the variables over from the inlined function to here.
378424
let mut callee_header = take(&mut inlined_blocks[0]).instructions;
379425
// TODO: OpLine handling
380426
let num_variables = callee_header.partition_point(|inst| inst.class.opcode == Op::Variable);
381427
// Rather than fuse blocks, generate a new jump here. Branch fusing will take care of
382428
// it, and we maintain the invariant that current block has finished processing.
383-
let first_block_id = self.id();
429+
let jump_to = self.id();
384430
inlined_blocks[0] = Block {
385-
label: Some(Instruction::new(
386-
Op::Label,
387-
None,
388-
Some(first_block_id),
389-
vec![],
390-
)),
431+
label: Some(Instruction::new(Op::Label, None, Some(jump_to), vec![])),
391432
instructions: callee_header.split_off(num_variables),
392433
};
393434
caller.blocks[block_idx].instructions.push(Instruction::new(
394435
Op::Branch,
395436
None,
396437
None,
397-
vec![Operand::IdRef(first_block_id)],
438+
vec![Operand::IdRef(jump_to)],
398439
));
399440
// Move the OpVariables of the callee to the caller.
400441
insert_opvariables(&mut caller.blocks[0], callee_header);
@@ -405,17 +446,10 @@ impl Inliner<'_, '_> {
405446
post_call_block_insts.insert(
406447
0,
407448
Instruction::new(
408-
Op::Phi,
449+
Op::Load,
409450
Some(call_result_type),
410451
Some(call_result_id),
411-
phi_pairs
412-
.into_iter()
413-
.flat_map(|(value, parent)| {
414-
use std::iter;
415-
iter::once(Operand::IdRef(*rewrite_rules.get(&value).unwrap_or(&value)))
416-
.chain(iter::once(Operand::IdRef(rewrite_rules[&parent])))
417-
})
418-
.collect(),
452+
vec![Operand::IdRef(return_variable.unwrap())],
419453
),
420454
);
421455
}
@@ -448,21 +482,53 @@ impl Inliner<'_, '_> {
448482
}
449483
}
450484

451-
fn get_inlined_blocks(function: &Function, return_jump: Word) -> (Vec<Block>, Vec<(Word, Word)>) {
485+
fn get_inlined_blocks(
486+
function: &Function,
487+
return_variable: Option<Word>,
488+
return_jump: Word,
489+
) -> (Vec<Block>, Vec<Word>) {
452490
let mut blocks = function.blocks.clone();
453-
let mut phipairs = Vec::new();
491+
let mut values = Vec::new();
454492
for block in &mut blocks {
455493
let last = block.instructions.last().unwrap();
456494
if let Op::Return | Op::ReturnValue = last.class.opcode {
457495
if Op::ReturnValue == last.class.opcode {
458496
let return_value = last.operands[0].id_ref_any().unwrap();
459-
phipairs.push((return_value, block.label_id().unwrap()));
497+
values.push(return_value);
498+
block.instructions.insert(
499+
block.instructions.len() - 1,
500+
Instruction::new(
501+
Op::Store,
502+
None,
503+
None,
504+
vec![
505+
Operand::IdRef(return_variable.unwrap()),
506+
Operand::IdRef(return_value),
507+
],
508+
),
509+
);
510+
} else {
511+
assert!(return_variable.is_none());
460512
}
461513
*block.instructions.last_mut().unwrap() =
462514
Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(return_jump)]);
463515
}
464516
}
465-
(blocks, phipairs)
517+
(blocks, values)
518+
}
519+
520+
fn insert_opvariable(block: &mut Block, ptr_ty: Word, result_id: Word) {
521+
let index = block
522+
.instructions
523+
.partition_point(|inst| inst.class.opcode == Op::Variable);
524+
525+
let inst = Instruction::new(
526+
Op::Variable,
527+
Some(ptr_ty),
528+
Some(result_id),
529+
vec![Operand::StorageClass(StorageClass::Function)],
530+
);
531+
block.instructions.insert(index, inst)
466532
}
467533

468534
fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
@@ -474,7 +540,6 @@ fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
474540

475541
fn fuse_trivial_branches(function: &mut Function) {
476542
let mut chain_list = compute_outgoing_1to1_branches(&function.blocks);
477-
let mut rewrite_rules = FxHashMap::default();
478543

479544
for block_idx in 0..chain_list.len() {
480545
let mut next = chain_list[block_idx].take();
@@ -490,16 +555,6 @@ fn fuse_trivial_branches(function: &mut Function) {
490555
}
491556
Some(next_idx) => {
492557
let mut dest_insts = take(&mut function.blocks[next_idx].instructions);
493-
dest_insts.retain(|inst| {
494-
if inst.class.opcode == Op::Phi {
495-
assert_eq!(inst.operands.len(), 2);
496-
rewrite_rules
497-
.insert(inst.result_id.unwrap(), inst.operands[0].unwrap_id_ref());
498-
false
499-
} else {
500-
true
501-
}
502-
});
503558
let self_insts = &mut function.blocks[block_idx].instructions;
504559
self_insts.pop(); // pop the branch
505560
self_insts.append(&mut dest_insts);
@@ -509,14 +564,6 @@ fn fuse_trivial_branches(function: &mut Function) {
509564
}
510565
}
511566
function.blocks.retain(|b| !b.instructions.is_empty());
512-
// Calculate a closure, as these rules can be transitive
513-
let mut rewrite_rules_new = rewrite_rules.clone();
514-
for value in rewrite_rules_new.values_mut() {
515-
while let Some(next) = rewrite_rules.get(value) {
516-
*value = *next;
517-
}
518-
}
519-
apply_rewrite_rules(&rewrite_rules_new, &mut function.blocks);
520567
}
521568

522569
fn compute_outgoing_1to1_branches(blocks: &[Block]) -> Vec<Option<usize>> {

0 commit comments

Comments
 (0)