diff --git a/internal/engine/wazevo/engine.go b/internal/engine/wazevo/engine.go index a6df3e7e79..8a533dc381 100644 --- a/internal/engine/wazevo/engine.go +++ b/internal/engine/wazevo/engine.go @@ -2,12 +2,16 @@ package wazevo import ( "context" + "crypto/md5" + "encoding/binary" "encoding/hex" "errors" "fmt" + "math/rand/v2" "runtime" "sort" "sync" + "sync/atomic" "unsafe" "github.com/tetratelabs/wazero/api" @@ -209,7 +213,6 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene // Creates new compiler instances which are reused for each function. ssaBuilder := ssa.NewBuilder() - fe := frontend.NewFrontendCompiler(module, ssaBuilder, &cm.offsets, ensureTermination, withListener, needSourceInfo) machine := newMachine() be := backend.NewCompiler(ctx, machine, ssaBuilder) @@ -227,27 +230,86 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene needCallTrampoline := callTrampolineIslandSize > 0 var callTrampolineIslandOffsets []int // Holds the offsets of trampoline islands. - for i := range module.CodeSection { - if wazevoapi.DeterministicCompilationVerifierEnabled { - i = wazevoapi.DeterministicCompilationVerifierGetRandomizedLocalFunctionIndex(ctx, i) - } + type CompiledLocalFuncResult struct { + Body []byte + RelsPerFunc []backend.RelocationInfo + IDX wasm.Index + SourceOffsetInfo []backend.SourceOffsetInfo + } + + compiledFuncs := make([]CompiledLocalFuncResult, len(module.CodeSection)) + + workers := runtime.GOMAXPROCS(0) + + wg := sync.WaitGroup{} + wg.Add(workers) + + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + + // constant seed for demonstration purposes. + seq, perm := sequence(0, len(module.CodeSection)) + resultmutex := &sync.Mutex{} + + for range workers { + go func() { + defer wg.Done() + + ssaBuilder := ssa.NewBuilder() + machine := newMachine() + fe := frontend.NewFrontendCompiler(module, ssaBuilder, &cm.offsets, ensureTermination, withListener, needSourceInfo) + + for ix := seq.Pop(); ix < len(module.CodeSection); ix = seq.Pop() { + if err := ctx.Err(); err != nil { + // Compilation canceled! + return + } + i := perm[ix] + if wazevoapi.DeterministicCompilationVerifierEnabled { + i = wazevoapi.DeterministicCompilationVerifierGetRandomizedLocalFunctionIndex(ctx, i) + } + + fidx := wasm.Index(i + importedFns) + + if wazevoapi.NeedFunctionNameInContext { + def := module.FunctionDefinition(fidx) + name := def.DebugName() + if len(def.ExportNames()) > 0 { + name = def.ExportNames()[0] + } + ctx = wazevoapi.SetCurrentFunctionName(ctx, i, fmt.Sprintf("[%d/%d]%s", i, len(module.CodeSection)-1, name)) + } + + be := backend.NewCompiler(ctx, machine, ssaBuilder) + + needListener := len(listeners) > 0 && listeners[i] != nil + + body, relsPerFunc, err := e.compileLocalWasmFunction(ctx, module, wasm.Index(i), fe, ssaBuilder, be, needListener) + if err != nil { + cancel(fmt.Errorf("compile function %d/%d: %v", i, len(module.CodeSection)-1, err)) + return + } + + resultmutex.Lock() + compiledFuncs[i] = CompiledLocalFuncResult{ + Body: body, + RelsPerFunc: relsPerFunc, + IDX: fidx, + SourceOffsetInfo: be.SourceOffsetInfo(), + } + resultmutex.Unlock() + } + }() + } - fidx := wasm.Index(i + importedFns) + wg.Wait() - if wazevoapi.NeedFunctionNameInContext { - def := module.FunctionDefinition(fidx) - name := def.DebugName() - if len(def.ExportNames()) > 0 { - name = def.ExportNames()[0] - } - ctx = wazevoapi.SetCurrentFunctionName(ctx, i, fmt.Sprintf("[%d/%d]%s", i, len(module.CodeSection)-1, name)) - } + if err := context.Cause(ctx); err != nil { + return nil, err + } - needListener := len(listeners) > 0 && listeners[i] != nil - body, relsPerFunc, err := e.compileLocalWasmFunction(ctx, module, wasm.Index(i), fe, ssaBuilder, be, needListener) - if err != nil { - return nil, fmt.Errorf("compile function %d/%d: %v", i, len(module.CodeSection)-1, err) - } + for i := range module.CodeSection { + fn := compiledFuncs[i] // Align 16-bytes boundary. totalSize = (totalSize + 15) &^ 15 @@ -259,26 +321,26 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene cm.sourceMap.executableOffsets = append(cm.sourceMap.executableOffsets, uintptr(totalSize)) cm.sourceMap.wasmBinaryOffsets = append(cm.sourceMap.wasmBinaryOffsets, module.CodeSection[i].BodyOffsetInCodeSection) - for _, info := range be.SourceOffsetInfo() { + for _, info := range fn.SourceOffsetInfo { cm.sourceMap.executableOffsets = append(cm.sourceMap.executableOffsets, uintptr(totalSize)+uintptr(info.ExecutableOffset)) cm.sourceMap.wasmBinaryOffsets = append(cm.sourceMap.wasmBinaryOffsets, uint64(info.SourceOffset)) } } - fref := frontend.FunctionIndexToFuncRef(fidx) + fref := frontend.FunctionIndexToFuncRef(fn.IDX) refToBinaryOffset[fref] = totalSize // At this point, relocation offsets are relative to the start of the function body, // so we adjust it to the start of the executable. - for _, r := range relsPerFunc { + for _, r := range fn.RelsPerFunc { r.Offset += int64(totalSize) rels = append(rels, r) } - bodies[i] = body - totalSize += len(body) + bodies[i] = fn.Body + totalSize += len(fn.Body) if wazevoapi.PrintMachineCodeHexPerFunction { - fmt.Printf("[[[machine code for %s]]]\n%s\n\n", wazevoapi.GetCurrentFunctionName(ctx), hex.EncodeToString(body)) + fmt.Printf("[[[machine code for %s]]]\n%s\n\n", wazevoapi.GetCurrentFunctionName(ctx), hex.EncodeToString(fn.Body)) } if needCallTrampoline { @@ -841,3 +903,46 @@ func (cm *compiledModule) getSourceOffset(pc uintptr) uint64 { } return cm.sourceMap.wasmBinaryOffsets[index] } + +type seq struct { + current int64 +} + +func (t *seq) Pop() int { + return int(atomic.AddInt64(&t.current, 1)) +} + +func chaCha8[T ~[]byte | string](seed T) *rand.ChaCha8 { + var ( + vector [32]byte + source = []byte(seed) + ) + + v1 := md5.Sum(source) + v2 := md5.Sum(append(v1[:], source...)) + copy(vector[:15], v1[:]) + copy(vector[16:], v2[:]) + + return rand.NewChaCha8(vector) +} + +func sequence(seed uint64, length int) (*seq, []int) { + uint64Bytes := func(v uint64) []byte { + var b [8]byte + binary.NativeEndian.PutUint64(b[:], v) + return b[:] + } + + perm := make([]int, 0, length) + for idx := range length { + perm = append(perm, idx) + } + + prng := rand.New(chaCha8(uint64Bytes(seed))) + + prng.Shuffle(len(perm), func(i, j int) { + perm[i], perm[j] = perm[j], perm[i] + }) + + return &seq{current: -1}, perm +}