Skip to content

James.lawrence/deterministic randomized compilation #2394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 130 additions & 25 deletions internal/engine/wazevo/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@ package wazevo

import (
"context"
"crypto/md5"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"math/rand/v2"
"runtime"
"sort"
"sync"
"sync/atomic"
"unsafe"

"github.com/tetratelabs/wazero/api"
Expand Down Expand Up @@ -209,7 +213,6 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene

// Creates new compiler instances which are reused for each function.
ssaBuilder := ssa.NewBuilder()
fe := frontend.NewFrontendCompiler(module, ssaBuilder, &cm.offsets, ensureTermination, withListener, needSourceInfo)
machine := newMachine()
be := backend.NewCompiler(ctx, machine, ssaBuilder)

Expand All @@ -227,27 +230,86 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene
needCallTrampoline := callTrampolineIslandSize > 0
var callTrampolineIslandOffsets []int // Holds the offsets of trampoline islands.

for i := range module.CodeSection {
if wazevoapi.DeterministicCompilationVerifierEnabled {
i = wazevoapi.DeterministicCompilationVerifierGetRandomizedLocalFunctionIndex(ctx, i)
}
type CompiledLocalFuncResult struct {
Body []byte
RelsPerFunc []backend.RelocationInfo
IDX wasm.Index
SourceOffsetInfo []backend.SourceOffsetInfo
}

compiledFuncs := make([]CompiledLocalFuncResult, len(module.CodeSection))

workers := runtime.GOMAXPROCS(0)

wg := sync.WaitGroup{}
wg.Add(workers)

ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)

// constant seed for demonstration purposes.
seq, perm := sequence(0, len(module.CodeSection))
resultmutex := &sync.Mutex{}

for range workers {
go func() {
defer wg.Done()

ssaBuilder := ssa.NewBuilder()
machine := newMachine()
fe := frontend.NewFrontendCompiler(module, ssaBuilder, &cm.offsets, ensureTermination, withListener, needSourceInfo)

for ix := seq.Pop(); ix < len(module.CodeSection); ix = seq.Pop() {
if err := ctx.Err(); err != nil {
// Compilation canceled!
return
}
i := perm[ix]
if wazevoapi.DeterministicCompilationVerifierEnabled {
i = wazevoapi.DeterministicCompilationVerifierGetRandomizedLocalFunctionIndex(ctx, i)
}

fidx := wasm.Index(i + importedFns)

if wazevoapi.NeedFunctionNameInContext {
def := module.FunctionDefinition(fidx)
name := def.DebugName()
if len(def.ExportNames()) > 0 {
name = def.ExportNames()[0]
}
ctx = wazevoapi.SetCurrentFunctionName(ctx, i, fmt.Sprintf("[%d/%d]%s", i, len(module.CodeSection)-1, name))
}

be := backend.NewCompiler(ctx, machine, ssaBuilder)

needListener := len(listeners) > 0 && listeners[i] != nil

body, relsPerFunc, err := e.compileLocalWasmFunction(ctx, module, wasm.Index(i), fe, ssaBuilder, be, needListener)
if err != nil {
cancel(fmt.Errorf("compile function %d/%d: %v", i, len(module.CodeSection)-1, err))
return
}

resultmutex.Lock()
compiledFuncs[i] = CompiledLocalFuncResult{
Body: body,
RelsPerFunc: relsPerFunc,
IDX: fidx,
SourceOffsetInfo: be.SourceOffsetInfo(),
}
resultmutex.Unlock()
}
}()
}

fidx := wasm.Index(i + importedFns)
wg.Wait()

if wazevoapi.NeedFunctionNameInContext {
def := module.FunctionDefinition(fidx)
name := def.DebugName()
if len(def.ExportNames()) > 0 {
name = def.ExportNames()[0]
}
ctx = wazevoapi.SetCurrentFunctionName(ctx, i, fmt.Sprintf("[%d/%d]%s", i, len(module.CodeSection)-1, name))
}
if err := context.Cause(ctx); err != nil {
return nil, err
}

needListener := len(listeners) > 0 && listeners[i] != nil
body, relsPerFunc, err := e.compileLocalWasmFunction(ctx, module, wasm.Index(i), fe, ssaBuilder, be, needListener)
if err != nil {
return nil, fmt.Errorf("compile function %d/%d: %v", i, len(module.CodeSection)-1, err)
}
for i := range module.CodeSection {
fn := compiledFuncs[i]

// Align 16-bytes boundary.
totalSize = (totalSize + 15) &^ 15
Expand All @@ -259,26 +321,26 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene
cm.sourceMap.executableOffsets = append(cm.sourceMap.executableOffsets, uintptr(totalSize))
cm.sourceMap.wasmBinaryOffsets = append(cm.sourceMap.wasmBinaryOffsets, module.CodeSection[i].BodyOffsetInCodeSection)

for _, info := range be.SourceOffsetInfo() {
for _, info := range fn.SourceOffsetInfo {
cm.sourceMap.executableOffsets = append(cm.sourceMap.executableOffsets, uintptr(totalSize)+uintptr(info.ExecutableOffset))
cm.sourceMap.wasmBinaryOffsets = append(cm.sourceMap.wasmBinaryOffsets, uint64(info.SourceOffset))
}
}

fref := frontend.FunctionIndexToFuncRef(fidx)
fref := frontend.FunctionIndexToFuncRef(fn.IDX)
refToBinaryOffset[fref] = totalSize

// At this point, relocation offsets are relative to the start of the function body,
// so we adjust it to the start of the executable.
for _, r := range relsPerFunc {
for _, r := range fn.RelsPerFunc {
r.Offset += int64(totalSize)
rels = append(rels, r)
}

bodies[i] = body
totalSize += len(body)
bodies[i] = fn.Body
totalSize += len(fn.Body)
if wazevoapi.PrintMachineCodeHexPerFunction {
fmt.Printf("[[[machine code for %s]]]\n%s\n\n", wazevoapi.GetCurrentFunctionName(ctx), hex.EncodeToString(body))
fmt.Printf("[[[machine code for %s]]]\n%s\n\n", wazevoapi.GetCurrentFunctionName(ctx), hex.EncodeToString(fn.Body))
}

if needCallTrampoline {
Expand Down Expand Up @@ -841,3 +903,46 @@ func (cm *compiledModule) getSourceOffset(pc uintptr) uint64 {
}
return cm.sourceMap.wasmBinaryOffsets[index]
}

type seq struct {
current int64
}

func (t *seq) Pop() int {
return int(atomic.AddInt64(&t.current, 1))
}

func chaCha8[T ~[]byte | string](seed T) *rand.ChaCha8 {
var (
vector [32]byte
source = []byte(seed)
)

v1 := md5.Sum(source)
v2 := md5.Sum(append(v1[:], source...))
copy(vector[:15], v1[:])
copy(vector[16:], v2[:])

return rand.NewChaCha8(vector)
}

func sequence(seed uint64, length int) (*seq, []int) {
uint64Bytes := func(v uint64) []byte {
var b [8]byte
binary.NativeEndian.PutUint64(b[:], v)
return b[:]
}

perm := make([]int, 0, length)
for idx := range length {
perm = append(perm, idx)
}

prng := rand.New(chaCha8(uint64Bytes(seed)))

prng.Shuffle(len(perm), func(i, j int) {
perm[i], perm[j] = perm[j], perm[i]
})

return &seq{current: -1}, perm
}