From e23623d85dfa101d762c0399e12592aaa8738831 Mon Sep 17 00:00:00 2001 From: Hiram Chirino Date: Wed, 16 Apr 2025 15:37:41 -0400 Subject: [PATCH] Add Test to verify wat2wasm aot/interpreter has the same memory access. Signed-off-by: Hiram Chirino --- .../chicory/testing/LockStepMemory.java | 473 ++++++++++++++++++ .../testing/LockStepMemoryAccessTest.java | 105 ++++ .../com/dylibso/chicory/wabt/Wat2Wasm.java | 75 ++- .../dylibso/chicory/wabt/Wat2WasmTest.java | 11 + 4 files changed, 653 insertions(+), 11 deletions(-) create mode 100644 compiler-tests/src/test/java/com/dylibso/chicory/testing/LockStepMemory.java create mode 100644 compiler-tests/src/test/java/com/dylibso/chicory/testing/LockStepMemoryAccessTest.java diff --git a/compiler-tests/src/test/java/com/dylibso/chicory/testing/LockStepMemory.java b/compiler-tests/src/test/java/com/dylibso/chicory/testing/LockStepMemory.java new file mode 100644 index 000000000..86572bf9f --- /dev/null +++ b/compiler-tests/src/test/java/com/dylibso/chicory/testing/LockStepMemory.java @@ -0,0 +1,473 @@ +package com.dylibso.chicory.testing; + +import com.dylibso.chicory.runtime.Instance; +import com.dylibso.chicory.runtime.Memory; +import com.dylibso.chicory.wasm.types.DataSegment; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.nio.ByteOrder; +import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.ArrayBlockingQueue; + +public final class LockStepMemory implements Memory { + + private final Memory memory; + private final String name; + private final ArrayBlockingQueue eventsOut = new ArrayBlockingQueue<>(1); + private ArrayBlockingQueue eventsIn; + private long eventCounter; + + private LockStepMemory(String name, Memory memory) { + this.name = name; + this.memory = memory; + } + + public static LockStepMemory[] create( + String name1, Memory memory1, String name2, Memory memory2) { + LockStepMemory lockStepMemory1 = new LockStepMemory(name1, memory1); + LockStepMemory lockStepMemory2 = new LockStepMemory(name2, memory2); + lockStepMemory1.eventsIn = lockStepMemory2.eventsOut; + lockStepMemory2.eventsIn = lockStepMemory1.eventsOut; + return new LockStepMemory[] {lockStepMemory1, lockStepMemory2}; + } + + public long eventCounter() { + return eventCounter; + } + + static final class Event { + final String name; + final long eventId; + final String method; + final List args; + final Object result; + + Event(String name, long eventId, String method, List args, Object result) { + this.name = name; + this.eventId = eventId; + this.method = method; + this.args = args; + this.result = result; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + Event event = (Event) o; + return eventId == event.eventId + && Objects.equals(method, event.method) + && Objects.equals(args, event.args) + && Objects.equals(result, event.result); + } + + @Override + public int hashCode() { + return Objects.hash(eventId, method, args, result); + } + } + + static final class ByteArrayWrapper { + private final byte[] data; + + ByteArrayWrapper(byte[] data) { + this.data = data; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + ByteArrayWrapper that = (ByteArrayWrapper) o; + return Arrays.equals(this.data, that.data); + } + + @Override + public int hashCode() { + return Arrays.hashCode(data); + } + } + + private void exchange(String method, Object result, Object... args) { + exchange(new Event(this.name, eventCounter++, method, List.of(args), result)); + } + + private void exchange(Event expected) { + try { + eventsOut.put(expected); + Event actual = eventsIn.take(); + if (expected.eventId % 1_000_000 == 0) { + System.out.println( + String.format( + "%s: %s %d %s %s", + name, + expected.method, + expected.eventId, + expected.args, + expected.result)); + } + + if (!expected.equals(actual)) { + + throw new IllegalStateException( + String.format( + "Events of sync: \n" + + " %s event - id: %d, method: %s, args: %s, result:" + + " %s\n" + + " %s event - id: %d, method: %s, args: %s, result:" + + " %s\n", + expected.name, + expected.eventId, + expected.method, + expected.args, + expected.result, + actual.name, + actual.eventId, + actual.method, + actual.args, + actual.result)); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + @Override + public void initPassiveSegment(int segmentId, int dest, int offset, int size) { + memory.initPassiveSegment(segmentId, dest, offset, size); + exchange("initPassiveSegment", null, segmentId, dest, offset, size); + } + + @Override + public int pages() { + int result = memory.pages(); + exchange("pages", result); + return result; + } + + @Override + public int grow(int size) { + int result = memory.grow(size); + exchange("grow", result, size); + return result; + } + + @Override + public void fill(byte value, int fromIndex, int toIndex) { + memory.fill(value, fromIndex, toIndex); + exchange("fill", null, fromIndex, toIndex); + } + + @Override + public void writeI32(int addr, int data) { + memory.writeI32(addr, data); + exchange("writeI32", null, addr, data); + } + + @Override + public void writeLong(int addr, long data) { + memory.writeLong(addr, data); + exchange("writeLong", null, addr, data); + } + + @Override + public long readF64(int addr) { + long result = memory.readF64(addr); + exchange("readF64", result, addr); + return result; + } + + @Override + public void writeByte(int addr, byte data) { + memory.writeByte(addr, data); + exchange("writeByte", null, addr, data); + } + + @Override + public int initialPages() { + int result = memory.initialPages(); + exchange("initialPages", result); + return result; + } + + @Override + public int maximumPages() { + int result = memory.maximumPages(); + exchange("maximumPages", result); + return result; + } + + @Override + public boolean shared() { + boolean shared = memory.shared(); + exchange("shared", shared); + return shared; + } + + @Override + public Object lock(int address) { + Object lock = memory.lock(address); + exchange("lock", address); + return lock; + } + + @Override + public int waitOn(int address, int expected, long timeout) { + int result = memory.waitOn(address, expected, timeout); + exchange("waitOn", result, address, expected, timeout); + return result; + } + + @Override + public int waitOn(int address, long expected, long timeout) { + int result = memory.waitOn(address, expected, timeout); + exchange("waitOn", result, address, expected, timeout); + return result; + } + + @Override + public int notify(int address, int maxThreads) { + int result = memory.notify(address, maxThreads); + exchange("notify", result, address, maxThreads); + return result; + } + + @Override + public void writeF32(int addr, float data) { + memory.writeF32(addr, data); + exchange("writeF32", null, addr, data); + } + + @Override + public void initialize(Instance instance, DataSegment[] dataSegments) { + memory.initialize(instance, dataSegments); + exchange("initialize", null); + } + + @Override + public void zero() { + memory.zero(); + exchange("zero", null); + } + + // private static final VarHandle SHORT_ARR_HANDLE = + // MethodHandles.byteArrayViewVarHandle(short[].class, ByteOrder.LITTLE_ENDIAN); + private static final VarHandle INT_ARR_HANDLE = + MethodHandles.byteArrayViewVarHandle(int[].class, ByteOrder.LITTLE_ENDIAN); + private static final VarHandle FLOAT_ARR_HANDLE = + MethodHandles.byteArrayViewVarHandle(float[].class, ByteOrder.LITTLE_ENDIAN); + + private float toFloat(long value) { + byte[] bytes = new byte[4]; + INT_ARR_HANDLE.set(bytes, 0, (int) value); + return (float) FLOAT_ARR_HANDLE.get(bytes, 0); + } + + // private static final VarHandle LONG_ARR_HANDLE = + // MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.LITTLE_ENDIAN); + // private static final VarHandle DOUBLE_ARR_HANDLE = + // MethodHandles.byteArrayViewVarHandle(double[].class, ByteOrder.LITTLE_ENDIAN); + // private float toDouble(long value) { + // byte[] bytes = new byte[8]; + // LONG_ARR_HANDLE.set(bytes, 0, value); + // return (float) DOUBLE_ARR_HANDLE.get(bytes, 0); + // } + + @Override + public long readF32(int addr) { + long result = memory.readF32(addr); + exchange("readFloat", toFloat(result), addr); // remapped. + return result; + } + + @Override + public float readFloat(int addr) { + float result = memory.readFloat(addr); + exchange("readFloat", result, addr); + return result; + } + + @Override + public long readU16(int addr) { + long result = memory.readU16(addr); + exchange("readShort", (short) result, addr); // remapped. + return result; + } + + @Override + public short readShort(int addr) { + short result = memory.readShort(addr); + exchange("readShort", result, addr); + return result; + } + + @Override + public void writeF64(int addr, double data) { + memory.writeF64(addr, data); + exchange("writeF64", null, addr, data); + } + + @Override + public byte[] readBytes(int addr, int len) { + byte[] result = memory.readBytes(addr, len); + exchange("readBytes", new ByteArrayWrapper(result), addr, len); + return result; + } + + @Override + public void writeShort(int addr, short data) { + memory.writeShort(addr, data); + exchange("writeShort", null, addr, data); + } + + @Override + public double readDouble(int addr) { + double result = memory.readDouble(addr); + exchange("readDouble", result, addr); + return result; + } + + @Override + public void drop(int segment) { + memory.drop(segment); + exchange("drop", null, segment); + } + + @Override + public long readU32(int addr) { + long result = memory.readU32(addr); + exchange("readInt", (int) result, addr); // remapped. + return result; + } + + @Override + public long readI32(int addr) { + long result = memory.readI32(addr); + exchange("readInt", (int) result, addr); // remapped. + return result; + } + + @Override + public int readInt(int addr) { + int result = memory.readInt(addr); + exchange("readInt", result, addr); + return result; + } + + @Override + public long readI64(int addr) { + long result = memory.readI64(addr); + exchange("readLong", result, addr); // remapped. + return result; + } + + @Override + public long readLong(int addr) { + long result = memory.readLong(addr); + exchange("readLong", result, addr); + return result; + } + + @Override + public long readU8(int addr) { + long result = memory.readU8(addr); + exchange("read", (byte) result, addr); // remapped. + return result; + } + + @Override + public long readI8(int addr) { + byte result = memory.read(addr); + exchange("read", result, addr); // remapped. + return result; + } + + @Override + public byte read(int addr) { + byte result = memory.read(addr); + exchange("read", result, addr); + return result; + } + + @Override + public void writeString(int offset, String data, Charset charSet) { + memory.writeString(offset, data, charSet); + exchange("writeString", null, offset, data, charSet); + } + + @Override + public void writeString(int offset, String data) { + memory.writeString(offset, data); + exchange("writeString", null, offset, data); + } + + @Override + public long readI16(int addr) { + long result = memory.readI16(addr); + exchange("readI16", result, addr); + return result; + } + + @Override + public void write(int addr, byte[] data, int offset, int size) { + memory.write(addr, data, offset, size); + exchange("write", null, addr, new ByteArrayWrapper(data), offset, size); + } + + @Override + public void write(int addr, byte[] data) { + memory.write(addr, data); + exchange("write", null, addr, new ByteArrayWrapper(data)); + } + + @Override + public void copy(int dest, int src, int size) { + memory.copy(dest, src, size); + exchange("copy", null, dest, src, size); + } + + @Override + public String readCString(int addr) { + String result = memory.readCString(addr); + exchange("readCString", result, addr); + return result; + } + + @Override + public String readCString(int addr, Charset charSet) { + String result = memory.readCString(addr, charSet); + exchange("readCString", result, addr, charSet); + return result; + } + + @Override + public String readString(int addr, int len) { + String result = memory.readString(addr, len); + exchange("readString", result, addr, len); + return result; + } + + @Override + public String readString(int addr, int len, Charset charSet) { + String result = memory.readString(addr, len, charSet); + exchange("readString", result, addr, len, charSet); + return result; + } + + @Override + public void writeCString(int offset, String str, Charset charSet) { + memory.writeCString(offset, str, charSet); + exchange("writeCString", null, offset, str, charSet); + } + + @Override + public void writeCString(int offset, String str) { + memory.writeCString(offset, str); + exchange("writeCString", null, offset, str); + } +} diff --git a/compiler-tests/src/test/java/com/dylibso/chicory/testing/LockStepMemoryAccessTest.java b/compiler-tests/src/test/java/com/dylibso/chicory/testing/LockStepMemoryAccessTest.java new file mode 100644 index 000000000..e457c3578 --- /dev/null +++ b/compiler-tests/src/test/java/com/dylibso/chicory/testing/LockStepMemoryAccessTest.java @@ -0,0 +1,105 @@ +package com.dylibso.chicory.testing; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.dylibso.chicory.compiler.MachineFactoryCompiler; +import com.dylibso.chicory.corpus.WatGenerator; +import com.dylibso.chicory.runtime.ByteArrayMemory; +import com.dylibso.chicory.runtime.Instance; +import com.dylibso.chicory.runtime.Memory; +import com.dylibso.chicory.wabt.Wat2Wasm; +import com.dylibso.chicory.wasm.Parser; +import com.dylibso.chicory.wasm.types.MemoryLimits; +import java.io.IOException; +import java.util.function.Function; +import org.junit.jupiter.api.Test; + +public class LockStepMemoryAccessTest { + + public interface Workload { + void run(Function memoryFactory, String name); + } + + public LockStepMemory[] assertHasSameMemoryAccess(String name1, String name2, Workload workload) + throws IOException { + var pair = + LockStepMemory.create( + name1, + new ByteArrayMemory(new MemoryLimits(1024, 65536)), + name2, + new ByteArrayMemory(new MemoryLimits(1024, 65536))); + new Thread( + () -> { + workload.run((x) -> pair[1], name2); + }) + .start(); + workload.run((x) -> pair[0], name1); + return pair; + } + + private void testMemoryWat(Function memoryFactory, String name) { + var module = Parser.parse(getClass().getResourceAsStream("/compiled/memory.wat.wasm")); + Instance.Builder builder = Instance.builder(module).withMemoryFactory(memoryFactory); + if (name.startsWith("aot")) { + builder.withMachineFactory(MachineFactoryCompiler::compile); + } + var instance = builder.build(); + var run = instance.export("run32"); + var results = run.apply(42); + var result = results[0]; + assertEquals(42L, result); + + result = run.apply(Integer.MAX_VALUE)[0]; + assertEquals(Integer.MAX_VALUE, (int) result); + + result = run.apply(Integer.MIN_VALUE)[0]; + assertEquals(Integer.MIN_VALUE, (int) result); + + run = instance.export("run64"); + result = run.apply(42L)[0]; + assertEquals(42L, result); + + run = instance.export("run64"); + result = run.apply(Long.MIN_VALUE)[0]; + assertEquals(Long.MIN_VALUE, result); + + run = instance.export("run64"); + result = run.apply(Long.MAX_VALUE)[0]; + assertEquals(Long.MAX_VALUE, result); + } + + @Test + public void twoInterpretersHaveSameMemoryAccess() throws IOException { + var pair = assertHasSameMemoryAccess("interpreter 1", "interpreter 2", this::testMemoryWat); + assertEquals(14, pair[0].eventCounter()); + } + + @Test + public void twoAotsHaveSameMemoryAccess() throws IOException { + var pair = assertHasSameMemoryAccess("aot 1", "aot 2", this::testMemoryWat); + assertEquals(14, pair[0].eventCounter()); + } + + @Test + public void aotAndInterpretersHaveSameMemoryAccess() throws IOException { + var pair = assertHasSameMemoryAccess("interpreter", "aot", this::testMemoryWat); + assertEquals(14, pair[0].eventCounter()); + } + + private void testWat2Wasm(Function memoryFactory, String name) { + var options = Wat2Wasm.options().withMemoryFactory(memoryFactory); + if (name.startsWith("aot")) { + options.withExecutionType(Wat2Wasm.ExecutionType.AOT); + } else { + options.withExecutionType(Wat2Wasm.ExecutionType.INTERPRETED); + } + var wat = WatGenerator.bigWat(10_000, 0); + Wat2Wasm.parse(wat, options); + } + + @Test + public void wat2wasmAotAndInterpretersHaveSameMemoryAccess() throws IOException { + var pair = assertHasSameMemoryAccess("aot", "interpreter", this::testWat2Wasm); + assertEquals(14, pair[0].eventCounter()); + } +} diff --git a/wabt/src/main/java/com/dylibso/chicory/wabt/Wat2Wasm.java b/wabt/src/main/java/com/dylibso/chicory/wabt/Wat2Wasm.java index c6c0f3558..3611101c5 100644 --- a/wabt/src/main/java/com/dylibso/chicory/wabt/Wat2Wasm.java +++ b/wabt/src/main/java/com/dylibso/chicory/wabt/Wat2Wasm.java @@ -6,10 +6,13 @@ import com.dylibso.chicory.log.SystemLogger; import com.dylibso.chicory.runtime.ImportValues; import com.dylibso.chicory.runtime.Instance; +import com.dylibso.chicory.runtime.Memory; import com.dylibso.chicory.wasi.WasiExitException; import com.dylibso.chicory.wasi.WasiOptions; import com.dylibso.chicory.wasi.WasiPreview1; +import com.dylibso.chicory.wasm.Parser; import com.dylibso.chicory.wasm.WasmModule; +import com.dylibso.chicory.wasm.types.MemoryLimits; import io.roastedroot.zerofs.Configuration; import io.roastedroot.zerofs.ZeroFs; import java.io.ByteArrayInputStream; @@ -24,6 +27,7 @@ import java.nio.file.Path; import java.nio.file.StandardCopyOption; import java.util.List; +import java.util.function.Function; public final class Wat2Wasm { private static final Logger logger = new SystemLogger(); @@ -31,27 +35,27 @@ public final class Wat2Wasm { private Wat2Wasm() {} - public static byte[] parse(InputStream is) { - return parse(is, "temp.wast"); + public static byte[] parse(InputStream is, Options... options) { + return parse(is, "temp.wast", options); } - public static byte[] parse(File file) { + public static byte[] parse(File file, Options... options) { try (InputStream is = new FileInputStream(file)) { - return parse(is, file.getName()); + return parse(is, file.getName(), options); } catch (IOException e) { throw new UncheckedIOException(e); } } - public static byte[] parse(String wat) { + public static byte[] parse(String wat, Options... options) { try (InputStream is = new ByteArrayInputStream(wat.getBytes(StandardCharsets.UTF_8))) { - return parse(is, "temp.wast"); + return parse(is, "temp.wast", options); } catch (IOException e) { throw new UncheckedIOException(e); } } - private static byte[] parse(InputStream is, String fileName) { + private static byte[] parse(InputStream is, String fileName, Options... options) { try (ByteArrayOutputStream stdoutStream = new ByteArrayOutputStream(); ByteArrayOutputStream stderrStream = new ByteArrayOutputStream()) { @@ -76,10 +80,8 @@ private static byte[] parse(InputStream is, String fileName) { WasiPreview1.builder().withLogger(logger).withOptions(wasiOpts).build()) { ImportValues imports = ImportValues.builder().addFunction(wasi.toHostFunctions()).build(); - Instance.builder(MODULE) - .withMachineFactory(Wat2WasmModule::create) - .withImportValues(imports) - .build(); + Options b = options.length > 0 ? options[0] : options(); + b.build(imports); } catch (WasiExitException e) { if (e.exitCode() != 0) { throw new WatParseException( @@ -95,4 +97,55 @@ private static byte[] parse(InputStream is, String fileName) { throw new UncheckedIOException(e); } } + + public static Options options() { + return new Options(); + } + + public enum ExecutionType { + INTERPRETED, + AOT, + } + + public static final class Options { + + ExecutionType type = ExecutionType.AOT; + private Function memoryFactory; + + private Options() {} + + public Options withExecutionType(ExecutionType type) { + this.type = type; + return this; + } + + public Options withMemoryFactory(Function memoryFactory) { + this.memoryFactory = memoryFactory; + return this; + } + + Instance build(ImportValues imports) { + Instance.Builder builder; + switch (type) { + case INTERPRETED: + { + var module = Parser.parse(Wat2Wasm.class.getResourceAsStream("/wat2wasm")); + builder = Instance.builder(module); + break; + } + case AOT: + { + builder = + Instance.builder(MODULE).withMachineFactory(Wat2WasmModule::create); + break; + } + default: + throw new IllegalArgumentException("Unknown execution type: " + type); + } + if (memoryFactory != null) { + builder.withMemoryFactory(memoryFactory); + } + return builder.withImportValues(imports).build(); + } + } } diff --git a/wabt/src/test/java/com/dylibso/chicory/wabt/Wat2WasmTest.java b/wabt/src/test/java/com/dylibso/chicory/wabt/Wat2WasmTest.java index be6b6f698..3f1271477 100644 --- a/wabt/src/test/java/com/dylibso/chicory/wabt/Wat2WasmTest.java +++ b/wabt/src/test/java/com/dylibso/chicory/wabt/Wat2WasmTest.java @@ -23,6 +23,17 @@ public void shouldRunWat2Wasm() throws Exception { assertTrue(new String(result, UTF_8).contains("iterFact")); } + @Test + public void shouldRunWat2WasmInterpreted() throws IOException { + Wat2Wasm.Options options = + Wat2Wasm.options().withExecutionType(Wat2Wasm.ExecutionType.INTERPRETED); + var result = + Wat2Wasm.parse( + new File("../wasm-corpus/src/main/resources/wat/iterfact.wat"), options); + assertTrue(result.length > 0); + assertTrue(new String(result, UTF_8).contains("iterFact")); + } + @Test public void shouldRunWat2WasmOnString() { var moduleInstance =