Skip to content

Commit 97f4a1f

Browse files
jpagesalexcrichton
andauthored
[wast] Branch hinting proposal implementation (#1394)
* wast: parsing and encoding branch hint annotations. * Implement parsing/printing function branch hints * Adding a test for branch hinting * Touch up some branch hint parsing * Remove unnecessary inclusion in `Custom` enum since this isn't parsed there. * Refactor parsing slightly. * Fix encoding to use function indices instead of function offsets. * Refactor Rust idioms in encoding and remove some unnecessary abstractions. --------- Co-authored-by: Alex Crichton <alex@alexcrichton.com>
1 parent 8dc75f0 commit 97f4a1f

File tree

17 files changed

+452
-22
lines changed

17 files changed

+452
-22
lines changed

crates/wasmparser/src/readers/core.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
mod branch_hinting;
12
mod code;
23
mod coredumps;
34
mod custom;
@@ -18,6 +19,7 @@ mod tables;
1819
mod tags;
1920
mod types;
2021

22+
pub use self::branch_hinting::*;
2123
pub use self::code::*;
2224
pub use self::coredumps::*;
2325
pub use self::custom::*;
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
use crate::{BinaryReader, FromReader, Result, SectionLimited};
2+
3+
/// A reader for the `metadata.code.branch_hint` custom section.
4+
pub type BranchHintSectionReader<'a> = SectionLimited<'a, BranchHintFunction<'a>>;
5+
6+
/// Branch hints for a single function.
7+
///
8+
/// Produced from [`BranchHintSectionReader`].
9+
#[derive(Debug, Clone)]
10+
pub struct BranchHintFunction<'a> {
11+
/// The function that these branch hints apply to.
12+
pub func: u32,
13+
/// The branch hints available for this function.
14+
pub hints: SectionLimited<'a, BranchHint>,
15+
}
16+
17+
impl<'a> FromReader<'a> for BranchHintFunction<'a> {
18+
fn from_reader(reader: &mut BinaryReader<'a>) -> Result<Self> {
19+
let func = reader.read_var_u32()?;
20+
// FIXME(#188) ideally wouldn't have to do skips here
21+
let hints = reader.skip(|reader| {
22+
let items_count = reader.read_var_u32()?;
23+
for _ in 0..items_count {
24+
reader.read::<BranchHint>()?;
25+
}
26+
Ok(())
27+
})?;
28+
Ok(BranchHintFunction {
29+
func,
30+
hints: SectionLimited::new(hints.remaining_buffer(), hints.original_position())?,
31+
})
32+
}
33+
}
34+
35+
/// A hint for a single branch.
36+
#[derive(Debug, Copy, Clone)]
37+
pub struct BranchHint {
38+
/// The byte offset, from the start of the function's body, of where the
39+
/// hinted instruction lives.
40+
pub func_offset: u32,
41+
/// Whether or not the branch is hinted to be taken or not.
42+
pub taken: bool,
43+
}
44+
45+
impl<'a> FromReader<'a> for BranchHint {
46+
fn from_reader(reader: &mut BinaryReader<'a>) -> Result<Self> {
47+
let func_offset = reader.read_var_u32()?;
48+
match reader.read_u8()? {
49+
1 => {}
50+
n => reader.invalid_leading_byte(n, "invalid branch hint byte")?,
51+
}
52+
let taken = match reader.read_u8()? {
53+
0 => false,
54+
1 => true,
55+
n => reader.invalid_leading_byte(n, "invalid branch hint taken byte")?,
56+
};
57+
Ok(BranchHint { func_offset, taken })
58+
}
59+
}

crates/wasmprinter/src/lib.rs

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use wasmparser::*;
1818
const MAX_LOCALS: u32 = 50000;
1919
const MAX_NESTING_TO_PRINT: u32 = 50;
2020
const MAX_WASM_FUNCTIONS: u32 = 1_000_000;
21+
const MAX_WASM_FUNCTION_SIZE: u32 = 128 * 1024;
2122

2223
mod operator;
2324

@@ -52,6 +53,7 @@ pub struct Printer {
5253
nesting: u32,
5354
line: usize,
5455
group_lines: Vec<usize>,
56+
code_section_hints: Vec<(u32, Vec<(usize, BranchHint)>)>,
5557
}
5658

5759
#[derive(Default)]
@@ -242,6 +244,15 @@ impl Printer {
242244
drop(self.register_component_names(state, reader));
243245
}
244246

247+
Payload::CustomSection(c) if c.name() == "metadata.code.branch_hint" => {
248+
drop(
249+
self.register_branch_hint_section(BranchHintSectionReader::new(
250+
c.data(),
251+
c.data_offset(),
252+
)?),
253+
);
254+
}
255+
245256
Payload::End(_) => break,
246257
_ => {}
247258
}
@@ -1109,10 +1120,20 @@ impl Printer {
11091120
.print_core_functype_idx(state, ty, Some(func_idx))?
11101121
.unwrap_or(0);
11111122

1123+
// Hints are stored on `self` in reverse order of function index so
1124+
// check the last one and see if it matches this function.
1125+
let hints = match self.code_section_hints.last() {
1126+
Some((f, _)) if *f == func_idx => {
1127+
let (_, hints) = self.code_section_hints.pop().unwrap();
1128+
hints
1129+
}
1130+
_ => Vec::new(),
1131+
};
1132+
11121133
if self.print_skeleton {
11131134
self.result.push_str(" ...");
11141135
} else {
1115-
self.print_func_body(state, func_idx, params, &mut body)?;
1136+
self.print_func_body(state, func_idx, params, &mut body, &hints)?;
11161137
}
11171138

11181139
self.end_group();
@@ -1128,10 +1149,12 @@ impl Printer {
11281149
func_idx: u32,
11291150
params: u32,
11301151
body: &mut BinaryReader<'_>,
1152+
mut branch_hints: &[(usize, BranchHint)],
11311153
) -> Result<()> {
11321154
let mut first = true;
11331155
let mut local_idx = 0;
11341156
let mut locals = NamedLocalPrinter::new("local");
1157+
let func_start = body.original_position();
11351158
for _ in 0..body.read_var_u32()? {
11361159
let offset = body.original_position();
11371160
let cnt = body.read_var_u32()?;
@@ -1163,7 +1186,21 @@ impl Printer {
11631186
let mut buf = String::new();
11641187
let mut op_printer = operator::PrintOperator::new(self, state);
11651188
while !body.eof() {
1166-
// TODO
1189+
// Branch hints are stored in increasing order of their body offset
1190+
// so print them whenever their instruction comes up.
1191+
if let Some(((hint_offset, hint), rest)) = branch_hints.split_first() {
1192+
if hint.func_offset == (body.original_position() - func_start) as u32 {
1193+
branch_hints = rest;
1194+
op_printer.printer.newline(*hint_offset);
1195+
let desc = if hint.taken { "\"\\01\"" } else { "\"\\00\"" };
1196+
write!(
1197+
op_printer.printer.result,
1198+
"(@metadata.code.branch_hint {})",
1199+
desc
1200+
)?;
1201+
}
1202+
}
1203+
11671204
let offset = body.original_position();
11681205
mem::swap(&mut buf, &mut op_printer.printer.result);
11691206
let op_kind = body.visit_operator(&mut op_printer)??;
@@ -2688,6 +2725,26 @@ impl Printer {
26882725
}
26892726
Ok(())
26902727
}
2728+
2729+
fn register_branch_hint_section(&mut self, section: BranchHintSectionReader<'_>) -> Result<()> {
2730+
self.code_section_hints.clear();
2731+
for func in section {
2732+
let func = func?;
2733+
if self.code_section_hints.len() >= MAX_WASM_FUNCTIONS as usize {
2734+
bail!("found too many hints");
2735+
}
2736+
if func.hints.count() >= MAX_WASM_FUNCTION_SIZE {
2737+
bail!("found too many hints");
2738+
}
2739+
let hints = func
2740+
.hints
2741+
.into_iter_with_offsets()
2742+
.collect::<wasmparser::Result<Vec<_>>>()?;
2743+
self.code_section_hints.push((func.func, hints));
2744+
}
2745+
self.code_section_hints.reverse();
2746+
Ok(())
2747+
}
26912748
}
26922749

26932750
struct NamedLocalPrinter {

crates/wast/src/component/component.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ impl<'a> Parse<'a> for Component<'a> {
111111
let _r = parser.register_annotation("custom");
112112
let _r = parser.register_annotation("producers");
113113
let _r = parser.register_annotation("name");
114+
let _r = parser.register_annotation("metadata.code.branch_hint");
114115

115116
let span = parser.parse::<kw::component>()?.0;
116117
let id = parser.parse()?;

0 commit comments

Comments
 (0)