|
| 1 | +use super::id; |
1 | 2 | use rspirv::dr::{Function, Instruction, Module, ModuleHeader, Operand};
|
2 | 3 | use rspirv::spirv::{Op, Word};
|
3 |
| -use rustc_data_structures::fx::FxHashMap; |
| 4 | +use rustc_data_structures::fx::{FxHashMap, FxHashSet}; |
4 | 5 | use rustc_middle::bug;
|
5 | 6 |
|
6 | 7 | pub fn collect_types(module: &Module) -> FxHashMap<Word, Instruction> {
|
@@ -447,3 +448,168 @@ pub fn vector_ops(
|
447 | 448 | }
|
448 | 449 | }
|
449 | 450 | }
|
| 451 | + |
| 452 | +fn can_fuse_bool( |
| 453 | + types: &FxHashMap<Word, Instruction>, |
| 454 | + defs: &FxHashMap<Word, (usize, Instruction)>, |
| 455 | + inst: &Instruction, |
| 456 | +) -> bool { |
| 457 | + fn constant_value(types: &FxHashMap<Word, Instruction>, val: Word) -> Option<u32> { |
| 458 | + let inst = match types.get(&val) { |
| 459 | + None => return None, |
| 460 | + Some(inst) => inst, |
| 461 | + }; |
| 462 | + if inst.class.opcode != Op::Constant { |
| 463 | + return None; |
| 464 | + } |
| 465 | + match inst.operands[0] { |
| 466 | + Operand::LiteralInt32(v) => Some(v), |
| 467 | + _ => None, |
| 468 | + } |
| 469 | + } |
| 470 | + |
| 471 | + fn visit( |
| 472 | + types: &FxHashMap<Word, Instruction>, |
| 473 | + defs: &FxHashMap<Word, (usize, Instruction)>, |
| 474 | + visited: &mut FxHashSet<Word>, |
| 475 | + value: Word, |
| 476 | + ) -> bool { |
| 477 | + if visited.insert(value) { |
| 478 | + let inst = match defs.get(&value) { |
| 479 | + Some((_, inst)) => inst, |
| 480 | + None => return false, |
| 481 | + }; |
| 482 | + match inst.class.opcode { |
| 483 | + Op::Select => { |
| 484 | + constant_value(types, inst.operands[1].unwrap_id_ref()) == Some(1) |
| 485 | + && constant_value(types, inst.operands[2].unwrap_id_ref()) == Some(0) |
| 486 | + } |
| 487 | + Op::Phi => inst |
| 488 | + .operands |
| 489 | + .iter() |
| 490 | + .step_by(2) |
| 491 | + .all(|op| visit(types, defs, visited, op.unwrap_id_ref())), |
| 492 | + _ => false, |
| 493 | + } |
| 494 | + } else { |
| 495 | + true |
| 496 | + } |
| 497 | + } |
| 498 | + |
| 499 | + if inst.class.opcode != Op::INotEqual |
| 500 | + || constant_value(types, inst.operands[1].unwrap_id_ref()) != Some(0) |
| 501 | + { |
| 502 | + return false; |
| 503 | + } |
| 504 | + let int_value = inst.operands[0].unwrap_id_ref(); |
| 505 | + |
| 506 | + visit(types, defs, &mut FxHashSet::default(), int_value) |
| 507 | +} |
| 508 | + |
| 509 | +fn fuse_bool( |
| 510 | + header: &mut ModuleHeader, |
| 511 | + defs: &FxHashMap<Word, (usize, Instruction)>, |
| 512 | + phis_to_insert: &mut Vec<(usize, Instruction)>, |
| 513 | + already_mapped: &mut FxHashMap<Word, Word>, |
| 514 | + bool_ty: Word, |
| 515 | + int_value: Word, |
| 516 | +) -> Word { |
| 517 | + if let Some(&result) = already_mapped.get(&int_value) { |
| 518 | + return result; |
| 519 | + } |
| 520 | + let (block_of_inst, inst) = defs.get(&int_value).unwrap(); |
| 521 | + match inst.class.opcode { |
| 522 | + Op::Select => inst.operands[0].unwrap_id_ref(), |
| 523 | + Op::Phi => { |
| 524 | + let result_id = id(header); |
| 525 | + already_mapped.insert(int_value, result_id); |
| 526 | + let new_phi_args = inst |
| 527 | + .operands |
| 528 | + .chunks(2) |
| 529 | + .flat_map(|arr| { |
| 530 | + let phi_value = &arr[0]; |
| 531 | + let block = &arr[1]; |
| 532 | + [ |
| 533 | + Operand::IdRef(fuse_bool( |
| 534 | + header, |
| 535 | + defs, |
| 536 | + phis_to_insert, |
| 537 | + already_mapped, |
| 538 | + bool_ty, |
| 539 | + phi_value.unwrap_id_ref(), |
| 540 | + )), |
| 541 | + block.clone(), |
| 542 | + ] |
| 543 | + }) |
| 544 | + .collect::<Vec<_>>(); |
| 545 | + let inst = Instruction::new(Op::Phi, Some(bool_ty), Some(result_id), new_phi_args); |
| 546 | + phis_to_insert.push((*block_of_inst, inst)); |
| 547 | + result_id |
| 548 | + } |
| 549 | + _ => bug!("can_fuse_bool should have prevented this case"), |
| 550 | + } |
| 551 | +} |
| 552 | + |
| 553 | +// The compiler generates a lot of code that looks like this: |
| 554 | +// %v_int = OpSelect %int %v %const_1 %const_0 |
| 555 | +// %v2 = OpINotEqual %bool %v_int %const_0 |
| 556 | +// (This is due to rustc/spirv not supporting bools in memory, and needing to convert to u8, but |
| 557 | +// then things get inlined/mem2reg'd) |
| 558 | +// |
| 559 | +// This pass fuses together those two instructions to strip out the intermediate integer variable. |
| 560 | +// The purpose is to make simple code that doesn't actually do memory-stuff with bools not require |
| 561 | +// the Int8 capability (and so we can't rely on spirv-opt to do this same pass). |
| 562 | +// |
| 563 | +// Unfortunately, things get complicated because of phis: the majority of actually useful cases to |
| 564 | +// do this pass need to track pseudo-bool ints through phi instructions. |
| 565 | +// |
| 566 | +// The logic goes like: |
| 567 | +// 1) Figure out what we *can* fuse. This means finding OpINotEqual instructions (converting back |
| 568 | +// from int->bool) and tracing the value back recursively through any phis, and making sure each |
| 569 | +// one terminates in either a loop back around to something we've already seen, or an OpSelect |
| 570 | +// (converting from bool->int). |
| 571 | +// 2) Do the fusion. Trace back through phis, generating a second bool-typed phi alongside the |
| 572 | +// original int-typed phi, and when hitting an OpSelect, taking the bool value directly. |
| 573 | +// 3) DCE the dead OpSelects/int-typed OpPhis (done in a later pass). We don't nuke them here, |
| 574 | +// since they might be used elsewhere, and don't want to accidentally leave a dangling |
| 575 | +// reference. |
| 576 | +pub fn bool_fusion( |
| 577 | + header: &mut ModuleHeader, |
| 578 | + types: &FxHashMap<Word, Instruction>, |
| 579 | + function: &mut Function, |
| 580 | +) { |
| 581 | + let defs: FxHashMap<Word, (usize, Instruction)> = function |
| 582 | + .blocks |
| 583 | + .iter() |
| 584 | + .enumerate() |
| 585 | + .flat_map(|(block_id, block)| { |
| 586 | + block |
| 587 | + .instructions |
| 588 | + .iter() |
| 589 | + .filter_map(move |inst| Some((inst.result_id?, (block_id, inst.clone())))) |
| 590 | + }) |
| 591 | + .collect(); |
| 592 | + let mut rewrite_rules = FxHashMap::default(); |
| 593 | + let mut phis_to_insert = Default::default(); |
| 594 | + let mut already_mapped = Default::default(); |
| 595 | + for block in &mut function.blocks { |
| 596 | + for inst in &mut block.instructions { |
| 597 | + if can_fuse_bool(types, &defs, inst) { |
| 598 | + let rewrite_to = fuse_bool( |
| 599 | + header, |
| 600 | + &defs, |
| 601 | + &mut phis_to_insert, |
| 602 | + &mut already_mapped, |
| 603 | + inst.result_type.unwrap(), |
| 604 | + inst.operands[0].unwrap_id_ref(), |
| 605 | + ); |
| 606 | + rewrite_rules.insert(inst.result_id.unwrap(), rewrite_to); |
| 607 | + *inst = Instruction::new(Op::Nop, None, None, Vec::new()); |
| 608 | + } |
| 609 | + } |
| 610 | + } |
| 611 | + for (block, phi) in phis_to_insert { |
| 612 | + function.blocks[block].instructions.insert(0, phi); |
| 613 | + } |
| 614 | + super::apply_rewrite_rules(&rewrite_rules, &mut function.blocks); |
| 615 | +} |
0 commit comments