From 71e1a8b613ab6017686f3bbee5371cdd1b4c15fa Mon Sep 17 00:00:00 2001 From: David Desmarais-Michaud Date: Fri, 21 Feb 2025 22:56:12 -0500 Subject: [PATCH 1/3] wazevo: concurrent local wasm function compilation Signed-off-by: David Desmarais-Michaud --- internal/engine/wazevo/engine.go | 112 ++++++++++++++++++++++++------- 1 file changed, 88 insertions(+), 24 deletions(-) diff --git a/internal/engine/wazevo/engine.go b/internal/engine/wazevo/engine.go index a6df3e7e79..a6ffbeb96f 100644 --- a/internal/engine/wazevo/engine.go +++ b/internal/engine/wazevo/engine.go @@ -209,7 +209,6 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene // Creates new compiler instances which are reused for each function. ssaBuilder := ssa.NewBuilder() - fe := frontend.NewFrontendCompiler(module, ssaBuilder, &cm.offsets, ensureTermination, withListener, needSourceInfo) machine := newMachine() be := backend.NewCompiler(ctx, machine, ssaBuilder) @@ -227,27 +226,81 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene needCallTrampoline := callTrampolineIslandSize > 0 var callTrampolineIslandOffsets []int // Holds the offsets of trampoline islands. - for i := range module.CodeSection { - if wazevoapi.DeterministicCompilationVerifierEnabled { - i = wazevoapi.DeterministicCompilationVerifierGetRandomizedLocalFunctionIndex(ctx, i) - } + type CompiledLocalFuncResult struct { + Body []byte + RelsPerFunc []backend.RelocationInfo + IDX wasm.Index + SourceOffsetInfo []backend.SourceOffsetInfo + } + + compiledFuncs := make([]CompiledLocalFuncResult, len(module.CodeSection)) + + workers := runtime.GOMAXPROCS(0) + + wg := sync.WaitGroup{} + wg.Add(workers) + + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + + sections := sequence(len(module.CodeSection)) + + for range workers { + go func() { + defer wg.Done() + + ssaBuilder := ssa.NewBuilder() + machine := newMachine() + fe := frontend.NewFrontendCompiler(module, ssaBuilder, &cm.offsets, ensureTermination, withListener, needSourceInfo) + + for i := range sections { + if err := ctx.Err(); err != nil { + // Compilation canceled! + return + } + + if wazevoapi.DeterministicCompilationVerifierEnabled { + i = wazevoapi.DeterministicCompilationVerifierGetRandomizedLocalFunctionIndex(ctx, i) + } + + fidx := wasm.Index(i + importedFns) + + if wazevoapi.NeedFunctionNameInContext { + def := module.FunctionDefinition(fidx) + name := def.DebugName() + if len(def.ExportNames()) > 0 { + name = def.ExportNames()[0] + } + ctx = wazevoapi.SetCurrentFunctionName(ctx, i, fmt.Sprintf("[%d/%d]%s", i, len(module.CodeSection)-1, name)) + } - fidx := wasm.Index(i + importedFns) + be := backend.NewCompiler(ctx, machine, ssaBuilder) - if wazevoapi.NeedFunctionNameInContext { - def := module.FunctionDefinition(fidx) - name := def.DebugName() - if len(def.ExportNames()) > 0 { - name = def.ExportNames()[0] + needListener := len(listeners) > 0 && listeners[i] != nil + + body, relsPerFunc, err := e.compileLocalWasmFunction(ctx, module, wasm.Index(i), fe, ssaBuilder, be, needListener) + if err != nil { + cancel(fmt.Errorf("compile function %d/%d: %v", i, len(module.CodeSection)-1, err)) + return + } + compiledFuncs[i] = CompiledLocalFuncResult{ + Body: body, + RelsPerFunc: relsPerFunc, + IDX: fidx, + SourceOffsetInfo: be.SourceOffsetInfo(), + } } - ctx = wazevoapi.SetCurrentFunctionName(ctx, i, fmt.Sprintf("[%d/%d]%s", i, len(module.CodeSection)-1, name)) - } + }() + } - needListener := len(listeners) > 0 && listeners[i] != nil - body, relsPerFunc, err := e.compileLocalWasmFunction(ctx, module, wasm.Index(i), fe, ssaBuilder, be, needListener) - if err != nil { - return nil, fmt.Errorf("compile function %d/%d: %v", i, len(module.CodeSection)-1, err) - } + wg.Wait() + + if err := context.Cause(ctx); err != nil { + return nil, err + } + + for i := range module.CodeSection { + fn := compiledFuncs[i] // Align 16-bytes boundary. totalSize = (totalSize + 15) &^ 15 @@ -259,26 +312,26 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene cm.sourceMap.executableOffsets = append(cm.sourceMap.executableOffsets, uintptr(totalSize)) cm.sourceMap.wasmBinaryOffsets = append(cm.sourceMap.wasmBinaryOffsets, module.CodeSection[i].BodyOffsetInCodeSection) - for _, info := range be.SourceOffsetInfo() { + for _, info := range fn.SourceOffsetInfo { cm.sourceMap.executableOffsets = append(cm.sourceMap.executableOffsets, uintptr(totalSize)+uintptr(info.ExecutableOffset)) cm.sourceMap.wasmBinaryOffsets = append(cm.sourceMap.wasmBinaryOffsets, uint64(info.SourceOffset)) } } - fref := frontend.FunctionIndexToFuncRef(fidx) + fref := frontend.FunctionIndexToFuncRef(fn.IDX) refToBinaryOffset[fref] = totalSize // At this point, relocation offsets are relative to the start of the function body, // so we adjust it to the start of the executable. - for _, r := range relsPerFunc { + for _, r := range fn.RelsPerFunc { r.Offset += int64(totalSize) rels = append(rels, r) } - bodies[i] = body - totalSize += len(body) + bodies[i] = fn.Body + totalSize += len(fn.Body) if wazevoapi.PrintMachineCodeHexPerFunction { - fmt.Printf("[[[machine code for %s]]]\n%s\n\n", wazevoapi.GetCurrentFunctionName(ctx), hex.EncodeToString(body)) + fmt.Printf("[[[machine code for %s]]]\n%s\n\n", wazevoapi.GetCurrentFunctionName(ctx), hex.EncodeToString(fn.Body)) } if needCallTrampoline { @@ -841,3 +894,14 @@ func (cm *compiledModule) getSourceOffset(pc uintptr) uint64 { } return cm.sourceMap.wasmBinaryOffsets[index] } + +func sequence(size int) <-chan int { + result := make(chan int) + go func() { + for i := range size { + result <- i + } + close(result) + }() + return result +} From 1c3b73f039c85e73ba87009951d0364b6d136df3 Mon Sep 17 00:00:00 2001 From: James Lawrence Date: Thu, 3 Apr 2025 11:34:35 -0400 Subject: [PATCH 2/3] deterministically shuffle the functions to be compiled based on a seed. --- internal/engine/wazevo/engine.go | 57 ++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 11 deletions(-) diff --git a/internal/engine/wazevo/engine.go b/internal/engine/wazevo/engine.go index a6ffbeb96f..90766ae766 100644 --- a/internal/engine/wazevo/engine.go +++ b/internal/engine/wazevo/engine.go @@ -2,12 +2,16 @@ package wazevo import ( "context" + "crypto/md5" + "encoding/binary" "encoding/hex" "errors" "fmt" + "math/rand/v2" "runtime" "sort" "sync" + "sync/atomic" "unsafe" "github.com/tetratelabs/wazero/api" @@ -243,7 +247,9 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene ctx, cancel := context.WithCancelCause(ctx) defer cancel(nil) - sections := sequence(len(module.CodeSection)) + // constant seed for demonstration purposes. + seq := sequence(0, module.CodeSection) + resultmutex := &sync.Mutex{} for range workers { go func() { @@ -253,7 +259,7 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene machine := newMachine() fe := frontend.NewFrontendCompiler(module, ssaBuilder, &cm.offsets, ensureTermination, withListener, needSourceInfo) - for i := range sections { + for i := seq.Pop(); i < len(module.CodeSection); i = seq.Pop() { if err := ctx.Err(); err != nil { // Compilation canceled! return @@ -283,12 +289,15 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene cancel(fmt.Errorf("compile function %d/%d: %v", i, len(module.CodeSection)-1, err)) return } + + resultmutex.Lock() compiledFuncs[i] = CompiledLocalFuncResult{ Body: body, RelsPerFunc: relsPerFunc, IDX: fidx, SourceOffsetInfo: be.SourceOffsetInfo(), } + resultmutex.Unlock() } }() } @@ -895,13 +904,39 @@ func (cm *compiledModule) getSourceOffset(pc uintptr) uint64 { return cm.sourceMap.wasmBinaryOffsets[index] } -func sequence(size int) <-chan int { - result := make(chan int) - go func() { - for i := range size { - result <- i - } - close(result) - }() - return result +type seq struct { + current int64 +} + +func (t *seq) Pop() int { + return int(atomic.AddInt64(&t.current, 1)) +} + +func chaCha8[T ~[]byte | string](seed T) *rand.ChaCha8 { + var ( + vector [32]byte + source = []byte(seed) + ) + + v1 := md5.Sum(source) + v2 := md5.Sum(append(v1[:], source...)) + copy(vector[:15], v1[:]) + copy(vector[16:], v2[:]) + + return rand.NewChaCha8(vector) +} + +func sequence(seed uint64, src []wasm.Code) *seq { + uint64Bytes := func(v uint64) []byte { + var b [8]byte + binary.NativeEndian.PutUint64(b[:], v) + return b[:] + } + prng := rand.New(chaCha8(uint64Bytes(seed))) + + prng.Shuffle(len(src), func(i, j int) { + src[i], src[j] = src[j], src[i] + }) + + return &seq{current: -1} } From 44337ac1cee344096bf591626627221ae4173e11 Mon Sep 17 00:00:00 2001 From: James Lawrence Date: Thu, 3 Apr 2025 12:44:46 -0400 Subject: [PATCH 3/3] generate permutation order without changing actual order of the code sections --- internal/engine/wazevo/engine.go | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/internal/engine/wazevo/engine.go b/internal/engine/wazevo/engine.go index 90766ae766..8a533dc381 100644 --- a/internal/engine/wazevo/engine.go +++ b/internal/engine/wazevo/engine.go @@ -248,7 +248,7 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene defer cancel(nil) // constant seed for demonstration purposes. - seq := sequence(0, module.CodeSection) + seq, perm := sequence(0, len(module.CodeSection)) resultmutex := &sync.Mutex{} for range workers { @@ -259,12 +259,12 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene machine := newMachine() fe := frontend.NewFrontendCompiler(module, ssaBuilder, &cm.offsets, ensureTermination, withListener, needSourceInfo) - for i := seq.Pop(); i < len(module.CodeSection); i = seq.Pop() { + for ix := seq.Pop(); ix < len(module.CodeSection); ix = seq.Pop() { if err := ctx.Err(); err != nil { // Compilation canceled! return } - + i := perm[ix] if wazevoapi.DeterministicCompilationVerifierEnabled { i = wazevoapi.DeterministicCompilationVerifierGetRandomizedLocalFunctionIndex(ctx, i) } @@ -926,17 +926,23 @@ func chaCha8[T ~[]byte | string](seed T) *rand.ChaCha8 { return rand.NewChaCha8(vector) } -func sequence(seed uint64, src []wasm.Code) *seq { +func sequence(seed uint64, length int) (*seq, []int) { uint64Bytes := func(v uint64) []byte { var b [8]byte binary.NativeEndian.PutUint64(b[:], v) return b[:] } + + perm := make([]int, 0, length) + for idx := range length { + perm = append(perm, idx) + } + prng := rand.New(chaCha8(uint64Bytes(seed))) - prng.Shuffle(len(src), func(i, j int) { - src[i], src[j] = src[j], src[i] + prng.Shuffle(len(perm), func(i, j int) { + perm[i], perm[j] = perm[j], perm[i] }) - return &seq{current: -1} + return &seq{current: -1}, perm }