Skip to content

Commit c10a1ca

Browse files
authored
Implement bool fusion pass (#776)
Fixes #677
1 parent f58c6f2 commit c10a1ca

File tree

4 files changed

+219
-1
lines changed

4 files changed

+219
-1
lines changed

crates/rustc_codegen_spirv/src/linker/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
266266
for func in &mut output.functions {
267267
peephole_opts::composite_construct(&types, func);
268268
peephole_opts::vector_ops(output.header.as_mut().unwrap(), &types, func);
269+
peephole_opts::bool_fusion(output.header.as_mut().unwrap(), &types, func);
269270
}
270271
}
271272

crates/rustc_codegen_spirv/src/linker/peephole_opts.rs

Lines changed: 167 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
use super::id;
12
use rspirv::dr::{Function, Instruction, Module, ModuleHeader, Operand};
23
use rspirv::spirv::{Op, Word};
3-
use rustc_data_structures::fx::FxHashMap;
4+
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
45
use rustc_middle::bug;
56

67
pub fn collect_types(module: &Module) -> FxHashMap<Word, Instruction> {
@@ -447,3 +448,168 @@ pub fn vector_ops(
447448
}
448449
}
449450
}
451+
452+
fn can_fuse_bool(
453+
types: &FxHashMap<Word, Instruction>,
454+
defs: &FxHashMap<Word, (usize, Instruction)>,
455+
inst: &Instruction,
456+
) -> bool {
457+
fn constant_value(types: &FxHashMap<Word, Instruction>, val: Word) -> Option<u32> {
458+
let inst = match types.get(&val) {
459+
None => return None,
460+
Some(inst) => inst,
461+
};
462+
if inst.class.opcode != Op::Constant {
463+
return None;
464+
}
465+
match inst.operands[0] {
466+
Operand::LiteralInt32(v) => Some(v),
467+
_ => None,
468+
}
469+
}
470+
471+
fn visit(
472+
types: &FxHashMap<Word, Instruction>,
473+
defs: &FxHashMap<Word, (usize, Instruction)>,
474+
visited: &mut FxHashSet<Word>,
475+
value: Word,
476+
) -> bool {
477+
if visited.insert(value) {
478+
let inst = match defs.get(&value) {
479+
Some((_, inst)) => inst,
480+
None => return false,
481+
};
482+
match inst.class.opcode {
483+
Op::Select => {
484+
constant_value(types, inst.operands[1].unwrap_id_ref()) == Some(1)
485+
&& constant_value(types, inst.operands[2].unwrap_id_ref()) == Some(0)
486+
}
487+
Op::Phi => inst
488+
.operands
489+
.iter()
490+
.step_by(2)
491+
.all(|op| visit(types, defs, visited, op.unwrap_id_ref())),
492+
_ => false,
493+
}
494+
} else {
495+
true
496+
}
497+
}
498+
499+
if inst.class.opcode != Op::INotEqual
500+
|| constant_value(types, inst.operands[1].unwrap_id_ref()) != Some(0)
501+
{
502+
return false;
503+
}
504+
let int_value = inst.operands[0].unwrap_id_ref();
505+
506+
visit(types, defs, &mut FxHashSet::default(), int_value)
507+
}
508+
509+
fn fuse_bool(
510+
header: &mut ModuleHeader,
511+
defs: &FxHashMap<Word, (usize, Instruction)>,
512+
phis_to_insert: &mut Vec<(usize, Instruction)>,
513+
already_mapped: &mut FxHashMap<Word, Word>,
514+
bool_ty: Word,
515+
int_value: Word,
516+
) -> Word {
517+
if let Some(&result) = already_mapped.get(&int_value) {
518+
return result;
519+
}
520+
let (block_of_inst, inst) = defs.get(&int_value).unwrap();
521+
match inst.class.opcode {
522+
Op::Select => inst.operands[0].unwrap_id_ref(),
523+
Op::Phi => {
524+
let result_id = id(header);
525+
already_mapped.insert(int_value, result_id);
526+
let new_phi_args = inst
527+
.operands
528+
.chunks(2)
529+
.flat_map(|arr| {
530+
let phi_value = &arr[0];
531+
let block = &arr[1];
532+
[
533+
Operand::IdRef(fuse_bool(
534+
header,
535+
defs,
536+
phis_to_insert,
537+
already_mapped,
538+
bool_ty,
539+
phi_value.unwrap_id_ref(),
540+
)),
541+
block.clone(),
542+
]
543+
})
544+
.collect::<Vec<_>>();
545+
let inst = Instruction::new(Op::Phi, Some(bool_ty), Some(result_id), new_phi_args);
546+
phis_to_insert.push((*block_of_inst, inst));
547+
result_id
548+
}
549+
_ => bug!("can_fuse_bool should have prevented this case"),
550+
}
551+
}
552+
553+
// The compiler generates a lot of code that looks like this:
554+
// %v_int = OpSelect %int %v %const_1 %const_0
555+
// %v2 = OpINotEqual %bool %v_int %const_0
556+
// (This is due to rustc/spirv not supporting bools in memory, and needing to convert to u8, but
557+
// then things get inlined/mem2reg'd)
558+
//
559+
// This pass fuses together those two instructions to strip out the intermediate integer variable.
560+
// The purpose is to make simple code that doesn't actually do memory-stuff with bools not require
561+
// the Int8 capability (and so we can't rely on spirv-opt to do this same pass).
562+
//
563+
// Unfortunately, things get complicated because of phis: the majority of actually useful cases to
564+
// do this pass need to track pseudo-bool ints through phi instructions.
565+
//
566+
// The logic goes like:
567+
// 1) Figure out what we *can* fuse. This means finding OpINotEqual instructions (converting back
568+
// from int->bool) and tracing the value back recursively through any phis, and making sure each
569+
// one terminates in either a loop back around to something we've already seen, or an OpSelect
570+
// (converting from bool->int).
571+
// 2) Do the fusion. Trace back through phis, generating a second bool-typed phi alongside the
572+
// original int-typed phi, and when hitting an OpSelect, taking the bool value directly.
573+
// 3) DCE the dead OpSelects/int-typed OpPhis (done in a later pass). We don't nuke them here,
574+
// since they might be used elsewhere, and don't want to accidentally leave a dangling
575+
// reference.
576+
pub fn bool_fusion(
577+
header: &mut ModuleHeader,
578+
types: &FxHashMap<Word, Instruction>,
579+
function: &mut Function,
580+
) {
581+
let defs: FxHashMap<Word, (usize, Instruction)> = function
582+
.blocks
583+
.iter()
584+
.enumerate()
585+
.flat_map(|(block_id, block)| {
586+
block
587+
.instructions
588+
.iter()
589+
.filter_map(move |inst| Some((inst.result_id?, (block_id, inst.clone()))))
590+
})
591+
.collect();
592+
let mut rewrite_rules = FxHashMap::default();
593+
let mut phis_to_insert = Default::default();
594+
let mut already_mapped = Default::default();
595+
for block in &mut function.blocks {
596+
for inst in &mut block.instructions {
597+
if can_fuse_bool(types, &defs, inst) {
598+
let rewrite_to = fuse_bool(
599+
header,
600+
&defs,
601+
&mut phis_to_insert,
602+
&mut already_mapped,
603+
inst.result_type.unwrap(),
604+
inst.operands[0].unwrap_id_ref(),
605+
);
606+
rewrite_rules.insert(inst.result_id.unwrap(), rewrite_to);
607+
*inst = Instruction::new(Op::Nop, None, None, Vec::new());
608+
}
609+
}
610+
}
611+
for (block, phi) in phis_to_insert {
612+
function.blocks[block].instructions.insert(0, phi);
613+
}
614+
super::apply_rewrite_rules(&rewrite_rules, &mut function.blocks);
615+
}

tests/ui/lang/core/unwrap_or.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// unwrap_or generates some memory-bools (as u8). Test to make sure they're fused away.
2+
// OpINotEqual, as well as %bool, should not appear in the output.
3+
4+
// build-pass
5+
// compile-flags: -C llvm-args=--disassemble-entry=main
6+
7+
use spirv_std as _;
8+
9+
#[spirv(fragment)]
10+
pub fn main(out: &mut u32) {
11+
*out = None.unwrap_or(15);
12+
}

tests/ui/lang/core/unwrap_or.stderr

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
%1 = OpFunction %2 None %3
2+
%4 = OpLabel
3+
OpLine %5 11 11
4+
%6 = OpCompositeInsert %7 %8 %9 0
5+
OpLine %5 11 11
6+
%10 = OpCompositeExtract %11 %6 1
7+
OpLine %12 767 14
8+
%13 = OpBitcast %14 %8
9+
OpLine %12 767 8
10+
OpSelectionMerge %15 None
11+
OpSwitch %13 %16 0 %17 1 %18
12+
%16 = OpLabel
13+
OpLine %12 767 14
14+
OpUnreachable
15+
%17 = OpLabel
16+
OpLine %12 769 20
17+
OpBranch %15
18+
%18 = OpLabel
19+
OpLine %12 771 4
20+
OpBranch %15
21+
%15 = OpLabel
22+
%19 = OpPhi %20 %21 %17 %22 %18
23+
%23 = OpPhi %11 %24 %17 %10 %18
24+
OpBranch %25
25+
%25 = OpLabel
26+
OpLine %12 771 4
27+
OpSelectionMerge %26 None
28+
OpBranchConditional %19 %27 %28
29+
%27 = OpLabel
30+
OpLine %12 771 4
31+
OpBranch %26
32+
%28 = OpLabel
33+
OpBranch %26
34+
%26 = OpLabel
35+
OpLine %5 11 4
36+
OpStore %29 %23
37+
OpLine %5 12 1
38+
OpReturn
39+
OpFunctionEnd

0 commit comments

Comments
 (0)