From 26c0252f6cd46334d0ad792cc0032ef0029b3f0d Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Thu, 19 Dec 2024 19:09:53 -0800 Subject: [PATCH 1/2] Update GlobalStructInference to handle atomics GlobalStructInference optimizes gets of immutable fields of structs that are only ever instantiated to initialize immutable globals. Due to all the immutability, it's not possible for the optimized reads to synchronize with any writes via the accessed memory, so we just need to be careful to replace removed seqcst gets with seqcst fences. As a drive-by, fix some stale comments in gsi.wast. --- src/passes/GlobalStructInference.cpp | 35 +++- test/lit/passes/gsi.wast | 228 ++++++++++++++++++++++++++- 2 files changed, 251 insertions(+), 12 deletions(-) diff --git a/src/passes/GlobalStructInference.cpp b/src/passes/GlobalStructInference.cpp index 4158db05104..321193d9ece 100644 --- a/src/passes/GlobalStructInference.cpp +++ b/src/passes/GlobalStructInference.cpp @@ -359,6 +359,11 @@ struct GlobalStructInference : public Pass { // refined, which could change the struct.get's type. refinalize = true; } + // No need to worry about atomic gets here. We will still read from + // the same memory location as before and preserve all side effects + // (including synchronization) that were previously present. The + // memory location is immutable anyway, so there cannot be any writes + // to synchronize with in the first place. curr->ref = builder.makeSequence( builder.makeDrop(builder.makeRefAs(RefAsNonNull, curr->ref)), builder.makeGlobalGet(global, globalType)); @@ -457,10 +462,18 @@ struct GlobalStructInference : public Pass { // the early return above) so that only leaves 1 and 2. if (values.size() == 1) { // The case of 1 value is simple: trap if the ref is null, and - // otherwise return the value. - replaceCurrent(builder.makeSequence( - builder.makeDrop(builder.makeRefAs(RefAsNonNull, curr->ref)), - getReadValue(values[0]))); + // otherwise return the value. We must also fence if the get was + // seqcst. No additional work is necessary for a acquire get because + // there cannot have been any writes to this immutable field that it + // would synchronize with. + Expression* replacement = + builder.makeDrop(builder.makeRefAs(RefAsNonNull, curr->ref)); + if (curr->order == MemoryOrder::SeqCst) { + replacement = + builder.blockify(replacement, builder.makeAtomicFence()); + } + replaceCurrent( + builder.blockify(replacement, getReadValue(values[0]))); return; } assert(values.size() == 2); @@ -486,11 +499,19 @@ struct GlobalStructInference : public Pass { // of their execution matters (they may note globals for un-nesting). auto* left = getReadValue(values[0]); auto* right = getReadValue(values[1]); - // Note that we must trap on null, so add a ref.as_non_null here. + // Note that we must trap on null, so add a ref.as_non_null here. We + // must also add a fence if this get is seqcst. As before, no extra work + // is necessary for an acquire get because there cannot be a write is + // synchronizes with. + Expression* getGlobal = + builder.makeGlobalGet(checkGlobal, wasm.getGlobal(checkGlobal)->type); + if (curr->order == MemoryOrder::SeqCst) { + getGlobal = + builder.makeSequence(builder.makeAtomicFence(), getGlobal); + } replaceCurrent(builder.makeSelect( builder.makeRefEq(builder.makeRefAs(RefAsNonNull, curr->ref), - builder.makeGlobalGet( - checkGlobal, wasm.getGlobal(checkGlobal)->type)), + getGlobal), left, right)); } diff --git a/test/lit/passes/gsi.wast b/test/lit/passes/gsi.wast index 299b19e00f2..7bc83efd40e 100644 --- a/test/lit/passes/gsi.wast +++ b/test/lit/passes/gsi.wast @@ -144,16 +144,16 @@ ;; CHECK-NEXT: ) ;; CHECK-NEXT: ) (func $test1 (param $struct1 (ref null $struct1)) (param $struct2 (ref null $struct2)) - ;; We can infer that this get must reference $global1 and make the reference - ;; point to that. Note that we do not infer the value of 42 here, but leave - ;; it for other passes to do. + ;; Even though the value here is not known at compile time - it reads an + ;; imported global - we can still infer that we are reading from $global1. (drop (struct.get $struct1 0 (local.get $struct1) ) ) - ;; Even though the value here is not known at compile time - it reads an - ;; imported global - we can still infer that we are reading from $global2. + ;; We can infer that this get must reference $global2 and make the reference + ;; point to that. Note that we do not infer the value of 42 here, but leave + ;; it for other passes to do. (drop (struct.get $struct2 0 (local.get $struct2) @@ -1944,3 +1944,221 @@ ) ) ) + +;; Test atomic gets. +(module + (rec + ;; CHECK: (rec + ;; CHECK-NEXT: (type $one (shared (struct (field i32)))) + (type $one (shared (struct (field i32)))) + ;; CHECK: (type $two (shared (struct (field i32)))) + (type $two (shared (struct (field i32)))) + ;; CHECK: (type $two-same (shared (struct (field i32)))) + (type $two-same (shared (struct (field i32)))) + ) + + ;; CHECK: (type $3 (func (param (ref $one)))) + + ;; CHECK: (type $4 (func (param (ref $two)))) + + ;; CHECK: (type $5 (func (param (ref $two-same)))) + + ;; CHECK: (global $one (ref $one) (struct.new $one + ;; CHECK-NEXT: (i32.const 42) + ;; CHECK-NEXT: )) + (global $one (ref $one) (struct.new $one (i32.const 42))) + + ;; CHECK: (global $two-a (ref $two) (struct.new $two + ;; CHECK-NEXT: (i32.const 42) + ;; CHECK-NEXT: )) + (global $two-a (ref $two) (struct.new $two (i32.const 42))) + + ;; CHECK: (global $two-b (ref $two) (struct.new $two + ;; CHECK-NEXT: (i32.const 1337) + ;; CHECK-NEXT: )) + (global $two-b (ref $two) (struct.new $two (i32.const 1337))) + + ;; CHECK: (global $two-same-a (ref $two-same) (struct.new $two-same + ;; CHECK-NEXT: (i32.const 42) + ;; CHECK-NEXT: )) + (global $two-same-a (ref $two-same) (struct.new $two-same (i32.const 42))) + + ;; CHECK: (global $two-same-b (ref $two-same) (struct.new $two-same + ;; CHECK-NEXT: (i32.const 42) + ;; CHECK-NEXT: )) + (global $two-same-b (ref $two-same) (struct.new $two-same (i32.const 42))) + + ;; CHECK: (func $one (type $3) (param $0 (ref $one)) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (struct.get $one 0 + ;; CHECK-NEXT: (block (result (ref $one)) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (global.get $one) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (struct.atomic.get acqrel $one 0 + ;; CHECK-NEXT: (block (result (ref $one)) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (global.get $one) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (struct.atomic.get $one 0 + ;; CHECK-NEXT: (block (result (ref $one)) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (global.get $one) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $one (param (ref $one)) + (drop + (struct.get $one 0 + (local.get 0) + ) + ) + (drop + (struct.atomic.get acqrel $one 0 + (local.get 0) + ) + ) + (drop + (struct.atomic.get $one 0 + (local.get 0) + ) + ) + ) + + ;; CHECK: (func $two (type $4) (param $0 (ref $two)) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (select + ;; CHECK-NEXT: (i32.const 42) + ;; CHECK-NEXT: (i32.const 1337) + ;; CHECK-NEXT: (ref.eq + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (global.get $two-a) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (select + ;; CHECK-NEXT: (i32.const 42) + ;; CHECK-NEXT: (i32.const 1337) + ;; CHECK-NEXT: (ref.eq + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (global.get $two-a) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (select + ;; CHECK-NEXT: (i32.const 42) + ;; CHECK-NEXT: (i32.const 1337) + ;; CHECK-NEXT: (ref.eq + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (block (result (ref $two)) + ;; CHECK-NEXT: (atomic.fence) + ;; CHECK-NEXT: (global.get $two-a) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $two (param (ref $two)) + (drop + (struct.get $two 0 + (local.get 0) + ) + ) + (drop + ;; This is optimized normally because there cannot be any writes it + ;; synchronizes with. + (struct.atomic.get acqrel $two 0 + (local.get 0) + ) + ) + (drop + ;; This requires a fence to maintain its effect on the global order of + ;; seqcst operations. + (struct.atomic.get $two 0 + (local.get 0) + ) + ) + ) + + ;; CHECK: (func $two-same (type $5) (param $0 (ref $two-same)) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (block (result i32) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (i32.const 42) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (block (result i32) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (i32.const 42) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (block (result i32) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (atomic.fence) + ;; CHECK-NEXT: (i32.const 42) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $two-same (param (ref $two-same)) + (drop + (struct.get $two-same 0 + (local.get 0) + ) + ) + (drop + ;; This is optimized normally because there cannot be any writes it + ;; synchronizes with. + (struct.atomic.get acqrel $two-same 0 + (local.get 0) + ) + ) + (drop + ;; This requires a fence to maintain its effect on the global order of + ;; seqcst operations. + (struct.atomic.get $two-same 0 + (local.get 0) + ) + ) + ) +) From 5c2bbb76aebed1af1c75ae0c04ec32836e2efd8b Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Thu, 19 Dec 2024 22:24:57 -0800 Subject: [PATCH 2/2] Super-minimal POC for bounded translation validation Given a source and target module, validate that the functions in the target module are refinements of the corresponding functions in the source module by translating the function bodies to SMT expressions and having Z3 either prove that they are equal for all inputs or find a counterexample showing that they are not. This minimal proof-of-concept can already prove that this source module: ``` (func $test (param $x i32) (result i32) (i32.mul (local.get 0) (i32.const 2) ) ) ``` Is refined by this target module: ``` (func $test (param $x i32) (result i32) (i32.shl (local.get 0) (i32.const 1) ) ) ``` But not by this target module: ``` (func $test (param $x i32) (result i32) (i32.shl (local.get 0) (i32.const 2) ) ) ``` --- CMakeLists.txt | 3 + src/tools/CMakeLists.txt | 9 ++ src/tools/wasm-validate-refinement.cpp | 185 +++++++++++++++++++++++++ 3 files changed, 197 insertions(+) create mode 100644 src/tools/wasm-validate-refinement.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index d02096fa06e..08e458f5c5a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,6 +61,8 @@ option(EMSCRIPTEN_ENABLE_PTHREADS "Enable pthreads in emscripten build" OFF) # This is useful for debugging, performance analysis, and other testing. option(EMSCRIPTEN_ENABLE_SINGLE_FILE "Enable SINGLE_FILE mode in emscripten build" ON) +option(BUILD_WASM_VALIDATE_REFINEMENT "Build the wasm-validate-refinement tool" OFF) + # For git users, attempt to generate a more useful version string if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/.git) find_package(Git QUIET REQUIRED) @@ -460,6 +462,7 @@ else() message(STATUS "Building libbinaryen as shared library.") add_library(binaryen SHARED ${binaryen_SOURCES} ${binaryen_objs}) endif() + target_link_libraries(binaryen ${CMAKE_THREAD_LIBS_INIT}) if(INSTALL_LIBS OR NOT BUILD_STATIC_LIB) install(TARGETS binaryen diff --git a/src/tools/CMakeLists.txt b/src/tools/CMakeLists.txt index c73797f54c2..105ad3d011c 100644 --- a/src/tools/CMakeLists.txt +++ b/src/tools/CMakeLists.txt @@ -21,5 +21,14 @@ if(NOT BUILD_EMSCRIPTEN_TOOLS_ONLY) binaryen_add_executable(wasm-fuzz-types "${fuzzing_SOURCES};wasm-fuzz-types.cpp") binaryen_add_executable(wasm-fuzz-lattices "${fuzzing_SOURCES};wasm-fuzz-lattices.cpp") endif() +if(BUILD_WASM_VALIDATE_REFINEMENT) + binaryen_add_executable(wasm-validate-refinement wasm-validate-refinement.cpp) + add_library(z3 SHARED IMPORTED) + if(NOT LIBZ3_LOCATION) + find_library(LIBZ3_LOCATION z3) + endif() + set_property(TARGET z3 PROPERTY IMPORTED_LOCATION ${LIBZ3_LOCATION}) + target_link_libraries(wasm-validate-refinement z3) +endif() add_subdirectory(wasm-split) diff --git a/src/tools/wasm-validate-refinement.cpp b/src/tools/wasm-validate-refinement.cpp new file mode 100644 index 00000000000..07d1507048e --- /dev/null +++ b/src/tools/wasm-validate-refinement.cpp @@ -0,0 +1,185 @@ +/* + * Copyright 2024 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "support/command-line.h" +#include "wasm-io.h" + +#include +#include + +using namespace wasm; + +struct ToSMT : UnifiedExpressionVisitor { + z3::context& ctx; + Function* func; + std::vector params; + + ToSMT(z3::context& ctx, Function* func) : ctx(ctx), func(func) { + initParams(func); + } + + void initParams(Function* func) { + for (Index i = 0; i < func->getNumParams(); ++i) { + auto type = func->getLocalType(i); + auto name = func->getLocalNameOrGeneric(i).str.data(); + if (type.isBasic()) { + switch (type.getBasic()) { + case Type::none: + case Type::unreachable: + case Type::f32: + case Type::f64: + break; + case Type::i32: + params.push_back(ctx.bv_const(name, 32)); + continue; + case Type::i64: + params.push_back(ctx.bv_const(name, 64)); + continue; + case Type::v128: + params.push_back(ctx.bv_const(name, 128)); + continue; + } + } + WASM_UNREACHABLE("unimplemented param type"); + } + } + + z3::expr visitExpression(Expression* curr) { + WASM_UNREACHABLE("unimplemented expression"); + } + + z3::expr visitLocalGet(LocalGet* curr) { + assert(curr->index < func->getNumParams() && "TODO"); + return params[curr->index]; + } + + z3::expr visitConst(Const* curr) { + assert(curr->type.isBasic()); + switch (curr->type.getBasic()) { + case Type::none: + case Type::unreachable: + break; + case Type::f32: + case Type::f64: + WASM_UNREACHABLE("TODO: fp const"); + case Type::i32: + return ctx.bv_val(curr->value.geti32(), 32); + case Type::i64: + return ctx.bv_val(curr->value.geti64(), 64); + case Type::v128: + WASM_UNREACHABLE("TODO: v128.const"); + } + WASM_UNREACHABLE("unexpected type"); + } + + z3::expr visitBinary(Binary* curr) { + auto lhs = visit(curr->left); + auto rhs = visit(curr->right); + switch (curr->op) { + case MulInt32: + return lhs * rhs; + case ShlInt32: + return z3::shl(lhs, rhs); + default: + break; + } + WASM_UNREACHABLE("unimplemented binary op"); + } +}; + +z3::expr funcToSMT(z3::context& ctx, Function* func) { + return ToSMT(ctx, func).visit(func->body); +} + +z3::expr refinedBy(const z3::expr& src, const z3::expr& tgt) { + // TODO: Something more complicated! + return tgt == src; +} + +void prove(const z3::expr& conjecture) { + z3::context& ctx = conjecture.ctx(); + z3::solver solver(ctx); + solver.add(!conjecture); + std::cout << "Proving conjecture:\n" << conjecture << "\n"; + if (solver.check() == z3::unsat) { + std::cout << "proved!\n"; + } else { + std::cout << "counterexample:\n" << solver.get_model() << "\n"; + } +} + +void checkRefinement(Function* src, Function* tgt) { + z3::context ctx; + auto srcSMT = funcToSMT(ctx, src); + auto tgtSMT = funcToSMT(ctx, tgt); + prove(refinedBy(srcSMT, tgtSMT)); +} + +struct ValidateRefinementOptions : Options { + std::string source; + std::string target; + ValidateRefinementOptions(const std::string& command, const std::string& desc) + : Options(command, desc) { + add("--source", + "-s", + "The original module", + "", + Arguments::One, + [&](Options*, const std::string& val) { source = val; }); + add("--target", + "-t", + "The transformed module", + "", + Arguments::One, + [&](Options*, const std::string& val) { target = val; }); + } +}; + +int main(int argc, const char* argv[]) { + ValidateRefinementOptions options( + "wasm-validate-refinement", + "Bounded translation validation for WebAssembly"); + + options.parse(argc, argv); + + if (options.source.empty()) { + std::cerr << "Source module must be provided (--source)\n"; + return 1; + } + + if (options.target.empty()) { + std::cerr << "Target module must be provided (--target)\n"; + return 1; + } + + Module src, tgt; + + ModuleReader().read(options.source, src); + ModuleReader().read(options.target, tgt); + + // TODO: Verify that src and tgt have matching global structures, including + // function signatures. + + for (size_t i = 0; i < src.functions.size(); ++i) { + if (src.functions[i]->imported()) { + continue; + } + + assert(i < tgt.functions.size() && !tgt.functions[i]->imported()); + + checkRefinement(src.functions[i].get(), tgt.functions[i].get()); + } +}