Skip to content

Commit b473f2a

Browse files
committed
linker/inline: use variables instead of OpPhis to unify branches.
This partially reverts commit 990425b.
1 parent c5d3abe commit b473f2a

File tree

1 file changed

+95
-48
lines changed
  • crates/rustc_codegen_spirv/src/linker

1 file changed

+95
-48
lines changed

crates/rustc_codegen_spirv/src/linker/inline.rs

Lines changed: 95 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,25 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
3737
.iter()
3838
.find(|inst| inst.class.opcode == Op::TypeVoid)
3939
.map_or(0, |inst| inst.result_id.unwrap());
40-
40+
let ptr_map: FxHashMap<_, _> = module
41+
.types_global_values
42+
.iter()
43+
.filter_map(|inst| {
44+
if inst.class.opcode == Op::TypePointer
45+
&& inst.operands[0].unwrap_storage_class() == StorageClass::Function
46+
{
47+
Some((inst.operands[1].unwrap_id_ref(), inst.result_id.unwrap()))
48+
} else {
49+
None
50+
}
51+
})
52+
.collect();
4153
let invalid_args = module.functions.iter().flat_map(get_invalid_args).collect();
42-
43-
// Drop all the functions we'll be inlining. (This also means we won't waste time processing
44-
// inlines in functions that will get inlined)
4554
let mut inliner = Inliner {
4655
header: module.header.as_mut().unwrap(),
56+
types_global_values: &mut module.types_global_values,
4757
void,
58+
ptr_map,
4859
functions: &functions,
4960
needs_inline: &to_delete,
5061
invalid_args,
@@ -270,7 +281,9 @@ fn args_invalid(invalid_args: &FxHashSet<Word>, call: &Instruction) -> bool {
270281

271282
struct Inliner<'m, 'map> {
272283
header: &'m mut ModuleHeader,
284+
types_global_values: &'m mut Vec<Instruction>,
273285
void: Word,
286+
ptr_map: FxHashMap<Word, Word>,
274287
functions: &'map FunctionMap,
275288
needs_inline: &'map [bool],
276289
invalid_args: FxHashSet<Word>,
@@ -283,6 +296,25 @@ impl Inliner<'_, '_> {
283296
result
284297
}
285298

299+
fn ptr_ty(&mut self, pointee: Word) -> Word {
300+
let existing = self.ptr_map.get(&pointee);
301+
if let Some(existing) = existing {
302+
return *existing;
303+
}
304+
let inst_id = self.id();
305+
self.types_global_values.push(Instruction::new(
306+
Op::TypePointer,
307+
None,
308+
Some(inst_id),
309+
vec![
310+
Operand::StorageClass(StorageClass::Function),
311+
Operand::IdRef(pointee),
312+
],
313+
));
314+
self.ptr_map.insert(pointee, inst_id);
315+
inst_id
316+
}
317+
286318
fn inline_fn(&mut self, functions: &mut [Function], index: usize) {
287319
let mut function = take(&mut functions[index]);
288320
let mut block_idx = 0;
@@ -346,23 +378,27 @@ impl Inliner<'_, '_> {
346378
});
347379
let mut rewrite_rules = callee_parameters.zip(call_arguments).collect();
348380

381+
let return_variable = if call_result_type.is_some() {
382+
Some(self.id())
383+
} else {
384+
None
385+
};
349386
let return_jump = self.id();
350387
// Rewrite OpReturns of the callee.
351-
let (mut inlined_blocks, phi_pairs) = get_inlined_blocks(callee, return_jump);
388+
let (mut inlined_blocks, return_values) =
389+
get_inlined_blocks(callee, return_variable, return_jump);
352390
// Clone the IDs of the callee, because otherwise they'd be defined multiple times if the
353391
// fn is inlined multiple times.
354392
self.add_clone_id_rules(&mut rewrite_rules, &inlined_blocks);
355393
// If any of the OpReturns were invalid, return will also be invalid.
356-
for (value, _) in &phi_pairs {
394+
for value in &return_values {
357395
if self.invalid_args.contains(value) {
358396
self.invalid_args.insert(call_result_id);
359397
self.invalid_args
360398
.insert(*rewrite_rules.get(value).unwrap_or(value));
361399
}
362400
}
363401
apply_rewrite_rules(&rewrite_rules, &mut inlined_blocks);
364-
// unnecessary: invalidate_more_args(&rewrite_rules, &mut self.invalid_args);
365-
// as no values from inside the inlined function ever make it directly out.
366402

367403
// Split the block containing the OpFunctionCall into two, around the call.
368404
let mut post_call_block_insts = caller.blocks[block_idx]
@@ -372,27 +408,32 @@ impl Inliner<'_, '_> {
372408
let call = caller.blocks[block_idx].instructions.pop().unwrap();
373409
assert!(call.class.opcode == Op::FunctionCall);
374410

411+
if let Some(call_result_type) = call_result_type {
412+
// Generate the storage space for the return value: Do this *after* the split above,
413+
// because if block_idx=0, inserting a variable here shifts call_index.
414+
insert_opvariable(
415+
&mut caller.blocks[0],
416+
self.ptr_ty(call_result_type),
417+
return_variable.unwrap(),
418+
);
419+
}
420+
375421
// Move the variables over from the inlined function to here.
376422
let mut callee_header = take(&mut inlined_blocks[0]).instructions;
377423
// TODO: OpLine handling
378424
let num_variables = callee_header.partition_point(|inst| inst.class.opcode == Op::Variable);
379425
// Rather than fuse blocks, generate a new jump here. Branch fusing will take care of
380426
// it, and we maintain the invariant that current block has finished processing.
381-
let first_block_id = self.id();
427+
let jump_to = self.id();
382428
inlined_blocks[0] = Block {
383-
label: Some(Instruction::new(
384-
Op::Label,
385-
None,
386-
Some(first_block_id),
387-
vec![],
388-
)),
429+
label: Some(Instruction::new(Op::Label, None, Some(jump_to), vec![])),
389430
instructions: callee_header.split_off(num_variables),
390431
};
391432
caller.blocks[block_idx].instructions.push(Instruction::new(
392433
Op::Branch,
393434
None,
394435
None,
395-
vec![Operand::IdRef(first_block_id)],
436+
vec![Operand::IdRef(jump_to)],
396437
));
397438
// Move the OpVariables of the callee to the caller.
398439
insert_opvariables(&mut caller.blocks[0], callee_header);
@@ -403,17 +444,10 @@ impl Inliner<'_, '_> {
403444
post_call_block_insts.insert(
404445
0,
405446
Instruction::new(
406-
Op::Phi,
447+
Op::Load,
407448
Some(call_result_type),
408449
Some(call_result_id),
409-
phi_pairs
410-
.into_iter()
411-
.flat_map(|(value, parent)| {
412-
use std::iter;
413-
iter::once(Operand::IdRef(*rewrite_rules.get(&value).unwrap_or(&value)))
414-
.chain(iter::once(Operand::IdRef(rewrite_rules[&parent])))
415-
})
416-
.collect(),
450+
vec![Operand::IdRef(return_variable.unwrap())],
417451
),
418452
);
419453
}
@@ -446,21 +480,53 @@ impl Inliner<'_, '_> {
446480
}
447481
}
448482

449-
fn get_inlined_blocks(function: &Function, return_jump: Word) -> (Vec<Block>, Vec<(Word, Word)>) {
483+
fn get_inlined_blocks(
484+
function: &Function,
485+
return_variable: Option<Word>,
486+
return_jump: Word,
487+
) -> (Vec<Block>, Vec<Word>) {
450488
let mut blocks = function.blocks.clone();
451-
let mut phipairs = Vec::new();
489+
let mut values = Vec::new();
452490
for block in &mut blocks {
453491
let last = block.instructions.last().unwrap();
454492
if let Op::Return | Op::ReturnValue = last.class.opcode {
455493
if Op::ReturnValue == last.class.opcode {
456494
let return_value = last.operands[0].id_ref_any().unwrap();
457-
phipairs.push((return_value, block.label_id().unwrap()));
495+
values.push(return_value);
496+
block.instructions.insert(
497+
block.instructions.len() - 1,
498+
Instruction::new(
499+
Op::Store,
500+
None,
501+
None,
502+
vec![
503+
Operand::IdRef(return_variable.unwrap()),
504+
Operand::IdRef(return_value),
505+
],
506+
),
507+
);
508+
} else {
509+
assert!(return_variable.is_none());
458510
}
459511
*block.instructions.last_mut().unwrap() =
460512
Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(return_jump)]);
461513
}
462514
}
463-
(blocks, phipairs)
515+
(blocks, values)
516+
}
517+
518+
fn insert_opvariable(block: &mut Block, ptr_ty: Word, result_id: Word) {
519+
let index = block
520+
.instructions
521+
.partition_point(|inst| inst.class.opcode == Op::Variable);
522+
523+
let inst = Instruction::new(
524+
Op::Variable,
525+
Some(ptr_ty),
526+
Some(result_id),
527+
vec![Operand::StorageClass(StorageClass::Function)],
528+
);
529+
block.instructions.insert(index, inst)
464530
}
465531

466532
fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
@@ -472,7 +538,6 @@ fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
472538

473539
fn fuse_trivial_branches(function: &mut Function) {
474540
let mut chain_list = compute_outgoing_1to1_branches(&function.blocks);
475-
let mut rewrite_rules = FxHashMap::default();
476541

477542
for block_idx in 0..chain_list.len() {
478543
let mut next = chain_list[block_idx].take();
@@ -488,16 +553,6 @@ fn fuse_trivial_branches(function: &mut Function) {
488553
}
489554
Some(next_idx) => {
490555
let mut dest_insts = take(&mut function.blocks[next_idx].instructions);
491-
dest_insts.retain(|inst| {
492-
if inst.class.opcode == Op::Phi {
493-
assert_eq!(inst.operands.len(), 2);
494-
rewrite_rules
495-
.insert(inst.result_id.unwrap(), inst.operands[0].unwrap_id_ref());
496-
false
497-
} else {
498-
true
499-
}
500-
});
501556
let self_insts = &mut function.blocks[block_idx].instructions;
502557
self_insts.pop(); // pop the branch
503558
self_insts.append(&mut dest_insts);
@@ -507,14 +562,6 @@ fn fuse_trivial_branches(function: &mut Function) {
507562
}
508563
}
509564
function.blocks.retain(|b| !b.instructions.is_empty());
510-
// Calculate a closure, as these rules can be transitive
511-
let mut rewrite_rules_new = rewrite_rules.clone();
512-
for value in rewrite_rules_new.values_mut() {
513-
while let Some(next) = rewrite_rules.get(value) {
514-
*value = *next;
515-
}
516-
}
517-
apply_rewrite_rules(&rewrite_rules_new, &mut function.blocks);
518565
}
519566

520567
fn compute_outgoing_1to1_branches(blocks: &[Block]) -> Vec<Option<usize>> {

0 commit comments

Comments
 (0)