|
32 | 32 | #include "llvm/BinaryFormat/Wasm.h"
|
33 | 33 | #include "llvm/CodeGen/Analysis.h"
|
34 | 34 | #include "llvm/CodeGen/AsmPrinter.h"
|
| 35 | +#include "llvm/CodeGen/MachineBranchProbabilityInfo.h" |
35 | 36 | #include "llvm/CodeGen/MachineConstantPool.h"
|
36 | 37 | #include "llvm/CodeGen/MachineInstr.h"
|
37 | 38 | #include "llvm/CodeGen/MachineModuleInfoImpls.h"
|
@@ -683,6 +684,100 @@ void WebAssemblyAsmPrinter::emitFunctionBodyStart() {
|
683 | 684 | AsmPrinter::emitFunctionBodyStart();
|
684 | 685 | }
|
685 | 686 |
|
| 687 | +// Try to infer branch target for a BR_IF instruction after MBB targets were |
| 688 | +// stackified by `WebAssemblyCFGStackify` using simple heuristics to avoid |
| 689 | +// having to simulate block-stack. |
| 690 | +const MachineBasicBlock *inferBranchTarget(const MachineInstr *MI, |
| 691 | + const MachineBasicBlock *MBB) { |
| 692 | + // Since we need to guess branch targets based on MBB successor order, |
| 693 | + // we need to make sure that the BR_IF is the last terminator to exclude |
| 694 | + // complicated edge cases. |
| 695 | + if (const auto Terminators = reverse(MBB->terminators()); |
| 696 | + Terminators.begin() == Terminators.end() || &*Terminators.begin() != MI) |
| 697 | + return nullptr; |
| 698 | + |
| 699 | + // Parent mbb might have more than the two successors (true / false) from |
| 700 | + // br_if due to eh pads / unwinds. We skip those cases. |
| 701 | + if (MBB->succ_size() != 2) |
| 702 | + return nullptr; |
| 703 | + |
| 704 | + const MachineBasicBlock *Succ0 = *MBB->succ_begin(); |
| 705 | + const MachineBasicBlock *Succ1 = *std::next(MBB->succ_begin()); |
| 706 | + |
| 707 | + // Find fallthrough block that is right after MBB and is the target of the |
| 708 | + // false-edge of the br_if |
| 709 | + assert(std::next(MBB->getIterator()) != MBB->getParent()->end() && |
| 710 | + "MBB with br_if must have a basic block after it"); |
| 711 | + const MachineBasicBlock *Fallthrough = &*std::next(MBB->getIterator()); |
| 712 | + |
| 713 | + // In some corner cases concerning exceptions, earlier optimizations |
| 714 | + // (`WebAssemblyCFGStackify::removeUnnecessaryInstrs` in particular) obfuscate |
| 715 | + // fallthrough control flow: |
| 716 | + // |
| 717 | + // bb0: |
| 718 | + // ;; successor: if.true, cont |
| 719 | + // br_if $if.true |
| 720 | + // br $cont |
| 721 | + // |
| 722 | + // ehpad: ... |
| 723 | + // cont: ... <- Continuation BB |
| 724 | + // |
| 725 | + // `br $cont` may be optimized away, making the `ehpad` seem like the |
| 726 | + // fallthrough block instead of `cont`. Give up on that case. |
| 727 | + if (Fallthrough != Succ0 && Fallthrough != Succ1) |
| 728 | + return nullptr; |
| 729 | + // return the true-block (desired branch target) which is !Fallthrough |
| 730 | + return Fallthrough == Succ0 ? Succ1 : Succ0; |
| 731 | +} |
| 732 | + |
| 733 | +void WebAssemblyAsmPrinter::recordBranchHint(const MachineInstr *MI) { |
| 734 | + assert(MI->getOpcode() == WebAssembly::BR_IF); |
| 735 | + const MachineBasicBlock *MBB = MI->getParent(); |
| 736 | + const MachineFunction *MF = MBB->getParent(); |
| 737 | + |
| 738 | + if (!MF->getSubtarget<WebAssemblySubtarget>().hasBranchHinting() || |
| 739 | + !MBB->hasSuccessorProbabilities()) { |
| 740 | + return; |
| 741 | + } |
| 742 | + const MachineBranchProbabilityInfo *MBPI = |
| 743 | + &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI(); |
| 744 | + const MachineBasicBlock *TrueBlock = inferBranchTarget(MI, MBB); |
| 745 | + if (TrueBlock == nullptr) { |
| 746 | + LLVM_DEBUG(dbgs() << "Could not infer branch target for " << *MI << '\n'); |
| 747 | + return; |
| 748 | + } |
| 749 | + const BranchProbability Prob = MBPI->getEdgeProbability(MBB, TrueBlock); |
| 750 | + |
| 751 | + const float ThresholdProbLow = WasmLowBranchProb.getValue(); |
| 752 | + const float ThresholdProbHigh = WasmHighBranchProb.getValue(); |
| 753 | + assert(ThresholdProbLow >= 0.0f && ThresholdProbLow <= 1.0f && |
| 754 | + ThresholdProbHigh >= 0.0f && ThresholdProbHigh <= 1.0f && |
| 755 | + "Branch probability thresholds must be in range [0.0-1.0]"); |
| 756 | + |
| 757 | + MCSymbol *BrIfSym = OutContext.createTempSymbol(); |
| 758 | + OutStreamer->emitLabel(BrIfSym); |
| 759 | + constexpr uint8_t HintLikely = 0x01; |
| 760 | + constexpr uint8_t HintUnlikely = 0x00; |
| 761 | + const uint32_t D = BranchProbability::getOne().getDenominator(); |
| 762 | + uint8_t HintValue; |
| 763 | + if (Prob > BranchProbability::getRaw(ThresholdProbHigh * D)) |
| 764 | + HintValue = HintLikely; |
| 765 | + else if (Prob <= BranchProbability::getRaw(ThresholdProbLow * D)) |
| 766 | + HintValue = HintUnlikely; |
| 767 | + else |
| 768 | + return; // Don't emit branch hint between thresholds |
| 769 | + |
| 770 | + // we know that we only emit branch hints for internal functions, |
| 771 | + // therefore we can directly cast and don't need getMCSymbolForFunction |
| 772 | + MCSymbol *FuncSym = cast<MCSymbolWasm>(getSymbol(&MF->getFunction())); |
| 773 | + const uint32_t LocalFuncIdx = MF->getFunctionNumber(); |
| 774 | + if (BranchHints.size() <= LocalFuncIdx) { |
| 775 | + BranchHints.resize(LocalFuncIdx + 1); |
| 776 | + BranchHints[LocalFuncIdx].FuncSym = FuncSym; |
| 777 | + } |
| 778 | + BranchHints[LocalFuncIdx].Hints.emplace_back(BrIfSym, HintValue); |
| 779 | +} |
| 780 | + |
686 | 781 | void WebAssemblyAsmPrinter::emitInstruction(const MachineInstr *MI) {
|
687 | 782 | LLVM_DEBUG(dbgs() << "EmitInstruction: " << *MI << '\n');
|
688 | 783 | WebAssembly_MC::verifyInstructionPredicates(MI->getOpcode(),
|
@@ -742,42 +837,12 @@ void WebAssemblyAsmPrinter::emitInstruction(const MachineInstr *MI) {
|
742 | 837 | WebAssemblyMCInstLower MCInstLowering(OutContext, *this);
|
743 | 838 | MCInst TmpInst;
|
744 | 839 | MCInstLowering.lower(MI, TmpInst);
|
745 |
| - if (Subtarget->hasBranchHinting() && MFI) { |
746 |
| - if (const auto Prob = MFI->BranchProbabilities.find(MI); |
747 |
| - Prob != MFI->BranchProbabilities.end()) { |
748 |
| - const float ThresholdProbLow = WasmLowBranchProb.getValue(); |
749 |
| - const float ThresholdProbHigh = WasmHighBranchProb.getValue(); |
750 |
| - assert(ThresholdProbLow >= 0.0f && ThresholdProbLow <= 1.0f && |
751 |
| - ThresholdProbHigh >= 0.0f && ThresholdProbHigh <= 1.0f && |
752 |
| - "Branch probability thresholds must be in range [0.0-1.0]"); |
753 |
| - |
754 |
| - MCSymbol *BrIfSym = OutContext.createTempSymbol(); |
755 |
| - OutStreamer->emitLabel(BrIfSym); |
756 |
| - constexpr uint8_t HintLikely = 0x01; |
757 |
| - constexpr uint8_t HintUnlikely = 0x00; |
758 |
| - const uint32_t D = BranchProbability::getOne().getDenominator(); |
759 |
| - uint8_t HintValue; |
760 |
| - if (Prob->getSecond() > |
761 |
| - BranchProbability::getRaw(ThresholdProbHigh * D)) |
762 |
| - HintValue = HintLikely; |
763 |
| - else if (Prob->getSecond() <= |
764 |
| - BranchProbability::getRaw(ThresholdProbLow * D)) |
765 |
| - HintValue = HintUnlikely; |
766 |
| - else |
767 |
| - goto emit; // Don't emit branch hint between thresholds |
768 |
| - |
769 |
| - // we know that we only emit branch hints for internal functions, |
770 |
| - // therefore we can directly cast and don't need getMCSymbolForFunction |
771 |
| - MCSymbol *FuncSym = cast<MCSymbolWasm>(getSymbol(&MF->getFunction())); |
772 |
| - const uint32_t LocalFuncIdx = MF->getFunctionNumber(); |
773 |
| - if (BranchHints.size() <= LocalFuncIdx) { |
774 |
| - BranchHints.resize(LocalFuncIdx + 1); |
775 |
| - BranchHints[LocalFuncIdx].FuncSym = FuncSym; |
776 |
| - } |
777 |
| - BranchHints[LocalFuncIdx].Hints.emplace_back(BrIfSym, HintValue); |
778 |
| - } |
| 840 | + if (Subtarget->hasBranchHinting() && |
| 841 | + MI->getOpcode() == WebAssembly::BR_IF) { |
| 842 | + // since we need to emit a label to later recover the instruction's |
| 843 | + // offset, this has to called before the instruction is emitted |
| 844 | + recordBranchHint(MI); |
779 | 845 | }
|
780 |
| - emit: |
781 | 846 | EmitToStreamer(*OutStreamer, TmpInst);
|
782 | 847 | break;
|
783 | 848 | }
|
|
0 commit comments