Skip to content

Commit a630d51

Browse files
committed
[LLVM][WebAssembly] Revert changes to WebAssemblyCFGStackify.cpp and move branch hint collection to WebAssemblyAsmPrinter
This change is largely based by prior contributions by @kripken!
1 parent e4b0ccd commit a630d51

File tree

4 files changed

+101
-56
lines changed

4 files changed

+101
-56
lines changed

llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp

Lines changed: 100 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "llvm/BinaryFormat/Wasm.h"
3333
#include "llvm/CodeGen/Analysis.h"
3434
#include "llvm/CodeGen/AsmPrinter.h"
35+
#include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
3536
#include "llvm/CodeGen/MachineConstantPool.h"
3637
#include "llvm/CodeGen/MachineInstr.h"
3738
#include "llvm/CodeGen/MachineModuleInfoImpls.h"
@@ -683,6 +684,100 @@ void WebAssemblyAsmPrinter::emitFunctionBodyStart() {
683684
AsmPrinter::emitFunctionBodyStart();
684685
}
685686

687+
// Try to infer branch target for a BR_IF instruction after MBB targets were
688+
// stackified by `WebAssemblyCFGStackify` using simple heuristics to avoid
689+
// having to simulate block-stack.
690+
const MachineBasicBlock *inferBranchTarget(const MachineInstr *MI,
691+
const MachineBasicBlock *MBB) {
692+
// Since we need to guess branch targets based on MBB successor order,
693+
// we need to make sure that the BR_IF is the last terminator to exclude
694+
// complicated edge cases.
695+
if (const auto Terminators = reverse(MBB->terminators());
696+
Terminators.begin() == Terminators.end() || &*Terminators.begin() != MI)
697+
return nullptr;
698+
699+
// Parent mbb might have more than the two successors (true / false) from
700+
// br_if due to eh pads / unwinds. We skip those cases.
701+
if (MBB->succ_size() != 2)
702+
return nullptr;
703+
704+
const MachineBasicBlock *Succ0 = *MBB->succ_begin();
705+
const MachineBasicBlock *Succ1 = *std::next(MBB->succ_begin());
706+
707+
// Find fallthrough block that is right after MBB and is the target of the
708+
// false-edge of the br_if
709+
assert(std::next(MBB->getIterator()) != MBB->getParent()->end() &&
710+
"MBB with br_if must have a basic block after it");
711+
const MachineBasicBlock *Fallthrough = &*std::next(MBB->getIterator());
712+
713+
// In some corner cases concerning exceptions, earlier optimizations
714+
// (`WebAssemblyCFGStackify::removeUnnecessaryInstrs` in particular) obfuscate
715+
// fallthrough control flow:
716+
//
717+
// bb0:
718+
// ;; successor: if.true, cont
719+
// br_if $if.true
720+
// br $cont
721+
//
722+
// ehpad: ...
723+
// cont: ... <- Continuation BB
724+
//
725+
// `br $cont` may be optimized away, making the `ehpad` seem like the
726+
// fallthrough block instead of `cont`. Give up on that case.
727+
if (Fallthrough != Succ0 && Fallthrough != Succ1)
728+
return nullptr;
729+
// return the true-block (desired branch target) which is !Fallthrough
730+
return Fallthrough == Succ0 ? Succ1 : Succ0;
731+
}
732+
733+
void WebAssemblyAsmPrinter::recordBranchHint(const MachineInstr *MI) {
734+
assert(MI->getOpcode() == WebAssembly::BR_IF);
735+
const MachineBasicBlock *MBB = MI->getParent();
736+
const MachineFunction *MF = MBB->getParent();
737+
738+
if (!MF->getSubtarget<WebAssemblySubtarget>().hasBranchHinting() ||
739+
!MBB->hasSuccessorProbabilities()) {
740+
return;
741+
}
742+
const MachineBranchProbabilityInfo *MBPI =
743+
&getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
744+
const MachineBasicBlock *TrueBlock = inferBranchTarget(MI, MBB);
745+
if (TrueBlock == nullptr) {
746+
LLVM_DEBUG(dbgs() << "Could not infer branch target for " << *MI << '\n');
747+
return;
748+
}
749+
const BranchProbability Prob = MBPI->getEdgeProbability(MBB, TrueBlock);
750+
751+
const float ThresholdProbLow = WasmLowBranchProb.getValue();
752+
const float ThresholdProbHigh = WasmHighBranchProb.getValue();
753+
assert(ThresholdProbLow >= 0.0f && ThresholdProbLow <= 1.0f &&
754+
ThresholdProbHigh >= 0.0f && ThresholdProbHigh <= 1.0f &&
755+
"Branch probability thresholds must be in range [0.0-1.0]");
756+
757+
MCSymbol *BrIfSym = OutContext.createTempSymbol();
758+
OutStreamer->emitLabel(BrIfSym);
759+
constexpr uint8_t HintLikely = 0x01;
760+
constexpr uint8_t HintUnlikely = 0x00;
761+
const uint32_t D = BranchProbability::getOne().getDenominator();
762+
uint8_t HintValue;
763+
if (Prob > BranchProbability::getRaw(ThresholdProbHigh * D))
764+
HintValue = HintLikely;
765+
else if (Prob <= BranchProbability::getRaw(ThresholdProbLow * D))
766+
HintValue = HintUnlikely;
767+
else
768+
return; // Don't emit branch hint between thresholds
769+
770+
// we know that we only emit branch hints for internal functions,
771+
// therefore we can directly cast and don't need getMCSymbolForFunction
772+
MCSymbol *FuncSym = cast<MCSymbolWasm>(getSymbol(&MF->getFunction()));
773+
const uint32_t LocalFuncIdx = MF->getFunctionNumber();
774+
if (BranchHints.size() <= LocalFuncIdx) {
775+
BranchHints.resize(LocalFuncIdx + 1);
776+
BranchHints[LocalFuncIdx].FuncSym = FuncSym;
777+
}
778+
BranchHints[LocalFuncIdx].Hints.emplace_back(BrIfSym, HintValue);
779+
}
780+
686781
void WebAssemblyAsmPrinter::emitInstruction(const MachineInstr *MI) {
687782
LLVM_DEBUG(dbgs() << "EmitInstruction: " << *MI << '\n');
688783
WebAssembly_MC::verifyInstructionPredicates(MI->getOpcode(),
@@ -742,42 +837,12 @@ void WebAssemblyAsmPrinter::emitInstruction(const MachineInstr *MI) {
742837
WebAssemblyMCInstLower MCInstLowering(OutContext, *this);
743838
MCInst TmpInst;
744839
MCInstLowering.lower(MI, TmpInst);
745-
if (Subtarget->hasBranchHinting() && MFI) {
746-
if (const auto Prob = MFI->BranchProbabilities.find(MI);
747-
Prob != MFI->BranchProbabilities.end()) {
748-
const float ThresholdProbLow = WasmLowBranchProb.getValue();
749-
const float ThresholdProbHigh = WasmHighBranchProb.getValue();
750-
assert(ThresholdProbLow >= 0.0f && ThresholdProbLow <= 1.0f &&
751-
ThresholdProbHigh >= 0.0f && ThresholdProbHigh <= 1.0f &&
752-
"Branch probability thresholds must be in range [0.0-1.0]");
753-
754-
MCSymbol *BrIfSym = OutContext.createTempSymbol();
755-
OutStreamer->emitLabel(BrIfSym);
756-
constexpr uint8_t HintLikely = 0x01;
757-
constexpr uint8_t HintUnlikely = 0x00;
758-
const uint32_t D = BranchProbability::getOne().getDenominator();
759-
uint8_t HintValue;
760-
if (Prob->getSecond() >
761-
BranchProbability::getRaw(ThresholdProbHigh * D))
762-
HintValue = HintLikely;
763-
else if (Prob->getSecond() <=
764-
BranchProbability::getRaw(ThresholdProbLow * D))
765-
HintValue = HintUnlikely;
766-
else
767-
goto emit; // Don't emit branch hint between thresholds
768-
769-
// we know that we only emit branch hints for internal functions,
770-
// therefore we can directly cast and don't need getMCSymbolForFunction
771-
MCSymbol *FuncSym = cast<MCSymbolWasm>(getSymbol(&MF->getFunction()));
772-
const uint32_t LocalFuncIdx = MF->getFunctionNumber();
773-
if (BranchHints.size() <= LocalFuncIdx) {
774-
BranchHints.resize(LocalFuncIdx + 1);
775-
BranchHints[LocalFuncIdx].FuncSym = FuncSym;
776-
}
777-
BranchHints[LocalFuncIdx].Hints.emplace_back(BrIfSym, HintValue);
778-
}
840+
if (Subtarget->hasBranchHinting() &&
841+
MI->getOpcode() == WebAssembly::BR_IF) {
842+
// since we need to emit a label to later recover the instruction's
843+
// offset, this has to called before the instruction is emitted
844+
recordBranchHint(MI);
779845
}
780-
emit:
781846
EmitToStreamer(*OutStreamer, TmpInst);
782847
break;
783848
}

llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyAsmPrinter final : public AsmPrinter {
9595
bool &InvokeDetected);
9696
MCSymbol *getOrCreateWasmSymbol(StringRef Name);
9797
void emitDecls(const Module &M);
98+
void recordBranchHint(const MachineInstr *MI);
9899
};
99100

100101
} // end namespace llvm

llvm/lib/Target/WebAssembly/WebAssemblyCFGStackify.cpp

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
#include "WebAssemblyUtilities.h"
3232
#include "llvm/ADT/Statistic.h"
3333
#include "llvm/BinaryFormat/Wasm.h"
34-
#include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
3534
#include "llvm/CodeGen/MachineDominators.h"
3635
#include "llvm/CodeGen/MachineInstrBuilder.h"
3736
#include "llvm/CodeGen/MachineLoopInfo.h"
@@ -49,15 +48,13 @@ STATISTIC(NumCatchUnwindMismatches, "Number of catch unwind mismatches found");
4948
namespace {
5049
class WebAssemblyCFGStackify final : public MachineFunctionPass {
5150
MachineDominatorTree *MDT;
52-
MachineBranchProbabilityInfo *MBPI;
5351

5452
StringRef getPassName() const override { return "WebAssembly CFG Stackify"; }
5553

5654
void getAnalysisUsage(AnalysisUsage &AU) const override {
5755
AU.addRequired<MachineDominatorTreeWrapperPass>();
5856
AU.addRequired<MachineLoopInfoWrapperPass>();
5957
AU.addRequired<WebAssemblyExceptionInfo>();
60-
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
6158
MachineFunctionPass::getAnalysisUsage(AU);
6259
}
6360

@@ -2611,21 +2608,6 @@ void WebAssemblyCFGStackify::rewriteDepthImmediates(MachineFunction &MF) {
26112608
Stack.push_back(std::make_pair(&MBB, &MI));
26122609
break;
26132610

2614-
case WebAssembly::BR_IF: {
2615-
// this is the last place where we can easily calculate the branch
2616-
// probabilities. We do not emit if-blocks, meaning only br_ifs have
2617-
// to be annotated with branch probabilities.
2618-
if (MF.getSubtarget<WebAssemblySubtarget>().hasBranchHinting() &&
2619-
MI.getParent()->hasSuccessorProbabilities()) {
2620-
const auto Prob = MBPI->getEdgeProbability(
2621-
MI.getParent(), MI.operands().begin()->getMBB());
2622-
WebAssemblyFunctionInfo *MFI = MF.getInfo<WebAssemblyFunctionInfo>();
2623-
assert(!MFI->BranchProbabilities.contains(&MI));
2624-
MFI->BranchProbabilities[&MI] = Prob;
2625-
}
2626-
RewriteOperands(MI);
2627-
break;
2628-
}
26292611
default:
26302612
if (MI.isTerminator())
26312613
RewriteOperands(MI);
@@ -2657,7 +2639,6 @@ bool WebAssemblyCFGStackify::runOnMachineFunction(MachineFunction &MF) {
26572639
<< MF.getName() << '\n');
26582640
const MCAsmInfo *MCAI = MF.getTarget().getMCAsmInfo();
26592641
MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
2660-
MBPI = &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
26612642

26622643
releaseMemory();
26632644

llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,6 @@ class WebAssemblyFunctionInfo final : public MachineFunctionInfo {
153153

154154
bool isCFGStackified() const { return CFGStackified; }
155155
void setCFGStackified(bool Value = true) { CFGStackified = Value; }
156-
157-
DenseMap<const MachineInstr *, BranchProbability> BranchProbabilities;
158156
};
159157

160158
void computeLegalValueVTs(const WebAssemblyTargetLowering &TLI,

0 commit comments

Comments
 (0)