Skip to content

Commit c3f7e2e

Browse files
authored
[Branch Hinting] Add branch hint handling in RemoveUnusedBrs (#7706)
Add some utilities for easily updating/copying/clearing branch hints, then use those in the pass. As a drive-by, move flip() from wasm-builder.h, which everyone was including but only this one cpp file was using, and update it in the new location (this avoids including the new branch hints header in a central place).
1 parent d5e7f18 commit c3f7e2e

File tree

5 files changed

+1129
-11
lines changed

5 files changed

+1129
-11
lines changed

src/ir/branch-hints.h

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Copyright 2025 WebAssembly Community Group participants
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#ifndef wasm_ir_branch_hint_h
18+
#define wasm_ir_branch_hint_h
19+
20+
#include "wasm.h"
21+
22+
//
23+
// Branch hint utilities to get them, set, flip, etc.
24+
//
25+
26+
namespace wasm::BranchHints {
27+
28+
// Get the branch hint for an expression.
29+
inline std::optional<bool> get(Expression* expr, Function* func) {
30+
auto iter = func->codeAnnotations.find(expr);
31+
if (iter == func->codeAnnotations.end()) {
32+
// No annotations at all.
33+
return {};
34+
}
35+
return iter->second.branchLikely;
36+
}
37+
38+
// Set the branch hint for an expression, trampling anything existing before.
39+
inline void set(Expression* expr, std::optional<bool> likely, Function* func) {
40+
// When we are writing an empty hint, do not create an empty annotation if one
41+
// did not exist.
42+
if (!likely && !func->codeAnnotations.count(expr)) {
43+
return;
44+
}
45+
func->codeAnnotations[expr].branchLikely = likely;
46+
}
47+
48+
// Clear the branch hint for an expression.
49+
inline void clear(Expression* expr, Function* func) {
50+
// Do not create an empty annotation if one did not exist.
51+
auto iter = func->codeAnnotations.find(expr);
52+
if (iter == func->codeAnnotations.end()) {
53+
return;
54+
}
55+
iter->second.branchLikely = {};
56+
}
57+
58+
// Copy the branch hint for an expression to another, trampling anything
59+
// existing before for the latter.
60+
inline void copyTo(Expression* from, Expression* to, Function* func) {
61+
auto fromLikely = get(from, func);
62+
set(to, fromLikely, func);
63+
}
64+
65+
// Flip the branch hint for an expression (if it exists).
66+
inline void flip(Expression* expr, Function* func) {
67+
if (auto likely = get(expr, func)) {
68+
set(expr, !*likely, func);
69+
}
70+
}
71+
72+
// Copy the branch hint for an expression to another, flipping it while we do
73+
// so.
74+
inline void copyFlippedTo(Expression* from, Expression* to, Function* func) {
75+
copyTo(from, to, func);
76+
flip(to, func);
77+
}
78+
79+
// Given two expressions to read from, apply the AND hint to a target. That is,
80+
// the target will be true when both inputs are true. |to| may be equal to
81+
// |from1| or |from2|. The hint of |to| is trampled.
82+
inline void applyAndTo(Expression* from1,
83+
Expression* from2,
84+
Expression* to,
85+
Function* func) {
86+
// If from1 and from2 are both likely, then from1 && from2 is slightly less
87+
// likely, but we assume our hints are nearly certain, so we apply it. And,
88+
// converse, if from1 and from2 and both unlikely, then from1 && from2 is even
89+
// less likely, so we can once more apply a hint. If the hints differ, then
90+
// one is unlikely or unknown, and we can't say anything about from1 && from2.
91+
auto from1Hint = BranchHints::get(from1, func);
92+
auto from2Hint = BranchHints::get(from2, func);
93+
if (from1Hint == from2Hint) {
94+
set(to, from1Hint, func);
95+
} else {
96+
// The hints do not even match.
97+
BranchHints::clear(to, func);
98+
}
99+
}
100+
101+
// As |applyAndTo|, but now the condition on |to| the OR of |from1| and |from2|.
102+
inline void applyOrTo(Expression* from1,
103+
Expression* from2,
104+
Expression* to,
105+
Function* func) {
106+
// If one is likely then so is the from1 || from2. If both are unlikely then
107+
// from1 || from2 is slightly more likely, but we assume our hints are nearly
108+
// certain, so we apply it.
109+
auto from1Hint = BranchHints::get(from1, func);
110+
auto from2Hint = BranchHints::get(from2, func);
111+
if ((from1Hint && *from1Hint) || (from2Hint && *from2Hint)) {
112+
set(to, true, func);
113+
} else if (from1Hint && from2Hint) {
114+
// We ruled out that either one is present and true, so if both are present,
115+
// both must be false.
116+
assert(!*from1Hint && !*from2Hint);
117+
set(to, false, func);
118+
} else {
119+
// We don't know.
120+
BranchHints::clear(to, func);
121+
}
122+
}
123+
124+
} // namespace wasm::BranchHints
125+
126+
#endif // wasm_ir_branch_hint_h

src/passes/RemoveUnusedBrs.cpp

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
// Removes branches for which we go to where they go anyhow
1919
//
2020

21+
#include "ir/branch-hints.h"
2122
#include "ir/branch-utils.h"
2223
#include "ir/cost.h"
2324
#include "ir/drop.h"
@@ -396,6 +397,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
396397
curr->condition, br->value, getPassOptions(), *getModule())) {
397398
if (!br->condition) {
398399
br->condition = curr->condition;
400+
BranchHints::copyTo(curr, br, getFunction());
399401
} else {
400402
// In this case we can replace
401403
// if (condition1) br_if (condition2)
@@ -427,6 +429,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
427429
// That keeps the order of the two conditions as it was originally.
428430
br->condition =
429431
builder.makeSelect(br->condition, curr->condition, zero);
432+
BranchHints::applyAndTo(curr, br, br, getFunction());
430433
}
431434
br->finalize();
432435
replaceCurrent(Builder(*getModule()).dropIfConcretelyTyped(br));
@@ -459,6 +462,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
459462
Builder builder(*getModule());
460463
curr->condition = builder.makeSelect(
461464
child->condition, curr->condition, builder.makeConst(int32_t(0)));
465+
BranchHints::applyAndTo(curr, child, curr, getFunction());
462466
curr->ifTrue = child->ifTrue;
463467
}
464468
}
@@ -689,6 +693,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
689693
brIf->condition = builder.makeUnary(EqZInt32, brIf->condition);
690694
last->name = brIf->name;
691695
brIf->name = loop->name;
696+
BranchHints::flip(brIf, getFunction());
692697
return true;
693698
} else {
694699
// there are elements in the middle,
@@ -709,6 +714,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
709714
builder.makeIf(brIf->condition,
710715
builder.makeBreak(brIf->name),
711716
stealSlice(builder, block, i + 1, list.size()));
717+
BranchHints::copyTo(brIf, list[i], getFunction());
712718
block->finalize();
713719
return true;
714720
}
@@ -1210,6 +1216,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
12101216
// we are an if-else where the ifTrue is a break without a
12111217
// condition, so we can do this
12121218
ifTrueBreak->condition = iff->condition;
1219+
BranchHints::copyTo(iff, ifTrueBreak, getFunction());
12131220
ifTrueBreak->finalize();
12141221
list[i] = Builder(*getModule()).dropIfConcretelyTyped(ifTrueBreak);
12151222
ExpressionManipulator::spliceIntoBlock(curr, i + 1, iff->ifFalse);
@@ -1224,6 +1231,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
12241231
*getModule())) {
12251232
ifFalseBreak->condition =
12261233
Builder(*getModule()).makeUnary(EqZInt32, iff->condition);
1234+
BranchHints::copyFlippedTo(iff, ifFalseBreak, getFunction());
12271235
ifFalseBreak->finalize();
12281236
list[i] = Builder(*getModule()).dropIfConcretelyTyped(ifFalseBreak);
12291237
ExpressionManipulator::spliceIntoBlock(curr, i + 1, iff->ifTrue);
@@ -1256,7 +1264,9 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
12561264
Builder builder(*getModule());
12571265
br1->condition =
12581266
builder.makeBinary(OrInt32, br1->condition, br2->condition);
1267+
BranchHints::applyOrTo(br1, br2, br1, getFunction());
12591268
ExpressionManipulator::nop(br2);
1269+
BranchHints::clear(br2, getFunction());
12601270
}
12611271
}
12621272
} else {
@@ -1396,9 +1406,12 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
13961406
// no other breaks to that name, so we can do this
13971407
if (!drop) {
13981408
assert(!br->value);
1399-
replaceCurrent(builder.makeIf(
1400-
builder.makeUnary(EqZInt32, br->condition), curr));
1409+
auto* iff = builder.makeIf(
1410+
builder.makeUnary(EqZInt32, br->condition), curr);
1411+
replaceCurrent(iff);
1412+
BranchHints::copyFlippedTo(br, iff, getFunction());
14011413
ExpressionManipulator::nop(br);
1414+
BranchHints::clear(br, getFunction());
14021415
curr->finalize(curr->type);
14031416
} else {
14041417
// To use an if, the value must have no side effects, as in the
@@ -1409,8 +1422,9 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
14091422
if (EffectAnalyzer::canReorder(
14101423
passOptions, *getModule(), br->condition, br->value)) {
14111424
ExpressionManipulator::nop(list[0]);
1412-
replaceCurrent(
1413-
builder.makeIf(br->condition, br->value, curr));
1425+
auto* iff = builder.makeIf(br->condition, br->value, curr);
1426+
BranchHints::copyTo(br, iff, getFunction());
1427+
replaceCurrent(iff);
14141428
}
14151429
} else {
14161430
// The value has side effects, so it must always execute. We
@@ -1529,6 +1543,14 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
15291543
optimizeSetIf(getCurrentPointer());
15301544
}
15311545

1546+
// Flip an if's condition with an eqz, and flip its arms.
1547+
void flip(If* iff) {
1548+
std::swap(iff->ifTrue, iff->ifFalse);
1549+
iff->condition =
1550+
Builder(*getModule()).makeUnary(EqZInt32, iff->condition);
1551+
BranchHints::flip(iff, getFunction());
1552+
}
1553+
15321554
void optimizeSetIf(Expression** currp) {
15331555
if (optimizeSetIfWithBrArm(currp)) {
15341556
return;
@@ -1570,9 +1592,10 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
15701592
// Wonderful, do it!
15711593
Builder builder(*getModule());
15721594
if (flipCondition) {
1573-
builder.flip(iff);
1595+
flip(iff);
15741596
}
15751597
br->condition = iff->condition;
1598+
BranchHints::copyTo(iff, br, getFunction());
15761599
br->finalize();
15771600
set->value = two;
15781601
auto* block = builder.makeSequence(br, set);
@@ -1640,7 +1663,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
16401663
Builder builder(*getModule());
16411664
LocalGet* get = iff->ifTrue->dynCast<LocalGet>();
16421665
if (get && get->index == set->index) {
1643-
builder.flip(iff);
1666+
flip(iff);
16441667
} else {
16451668
get = iff->ifFalse->dynCast<LocalGet>();
16461669
if (get && get->index != set->index) {
@@ -1901,6 +1924,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
19011924
curr->type = Type::unreachable;
19021925
block->list.push_back(curr);
19031926
block->finalize();
1927+
BranchHints::clear(curr, getFunction());
19041928
// The type changed, so refinalize.
19051929
refinalize = true;
19061930
} else {

src/wasm-builder.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,11 +1481,6 @@ class Builder {
14811481
return makeDrop(curr);
14821482
}
14831483

1484-
void flip(If* iff) {
1485-
std::swap(iff->ifTrue, iff->ifFalse);
1486-
iff->condition = makeUnary(EqZInt32, iff->condition);
1487-
}
1488-
14891484
// Returns a replacement with the precise same type, and with minimal contents
14901485
// as best we can. As a replacement, this may reuse the input node.
14911486
template<typename T> Expression* replaceWithIdenticalType(T* curr) {

0 commit comments

Comments
 (0)