Skip to content

Integrating LLVM optimizations with wasm-opt #7634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ if(EMSCRIPTEN)
set(BUILD_LLVM_DWARF OFF)
endif()

# Add a new option to use system LLVM
option(USE_SYSTEM_LLVM "Use system LLVM instead of third_party LLVM" ON)

option(BUILD_STATIC_LIB "Build as a static library" OFF)
if(MSVC)
# We don't have dllexport declarations set up for Windows yet.
Expand Down Expand Up @@ -186,8 +189,23 @@ endfunction()

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)
include_directories(SYSTEM ${CMAKE_CURRENT_SOURCE_DIR}/third_party/FP16/include)
if(BUILD_LLVM_DWARF)
include_directories(SYSTEM ${CMAKE_CURRENT_SOURCE_DIR}/third_party/llvm-project/include)
if(USE_SYSTEM_LLVM)
# Set LLVM_DIR to the installed LLVM's CMake directory
# set(LLVM_DIR "/usr/lib/llvm-18/cmake" CACHE PATH "Path to LLVM CMake modules")

# Find the required LLVM components
find_package(LLVM REQUIRED CONFIG)

# Include system LLVM headers
include_directories(SYSTEM ${LLVM_INCLUDE_DIRS})

# Optional: Print LLVM version for debugging
message(STATUS "Using LLVM ${LLVM_PACKAGE_VERSION}")
else()
# Legacy setup: Use third_party LLVM for DWARF
if(BUILD_LLVM_DWARF)
include_directories(SYSTEM ${CMAKE_CURRENT_SOURCE_DIR}/third_party/llvm-project/include)
endif()
endif()

# Add output directory to include path so config.h can be found
Expand Down
6 changes: 6 additions & 0 deletions src/passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ set(passes_SOURCES
JSPI.cpp
LegalizeJSInterface.cpp
LimitSegments.cpp
LLVMOpt.cpp
LLVMMemoryCopyFillLowering.cpp
LocalCSE.cpp
LocalSubtyping.cpp
Expand Down Expand Up @@ -146,4 +147,9 @@ if(EMSCRIPTEN)
list(REMOVE_ITEM passes_SOURCES "hash-stringify-walker.cpp")
list(REMOVE_ITEM passes_SOURCES "Outlining.cpp")
endif()
if(USE_SYSTEM_LLVM)
target_link_libraries(binaryen
LLVMCore
)
endif()
target_sources(binaryen PRIVATE ${passes_SOURCES})
319 changes: 319 additions & 0 deletions src/passes/LLVMOpt.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
/*
* Copyright 2024 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

//
// Run LLVM to optimize the wasm.
//

#include "pass.h"
#include "support/utilities.h"
#include "wasm-builder.h"
#include "wasm-traversal.h"
#include "wasm.h"

#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalAlias.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"

#include <cassert>
#include <iostream>
#include <llvm-18/llvm/IR/BasicBlock.h>
#include <llvm-18/llvm/IR/Constants.h>
#include <llvm-18/llvm/IR/GlobalVariable.h>
#include <llvm-18/llvm/IR/Instructions.h>
#include <llvm-18/llvm/Support/Alignment.h>
#include <llvm-18/llvm/Support/Casting.h>
#include <memory>
#include <string>

using namespace llvm;

namespace wasm {

llvm::Type* convertWasmTypeToLLVMType(Type* wasmType,
llvm::IRBuilder<>& builder);
llvm::Type* determineLoadType(Load* wasmLoad, llvm::IRBuilder<>& builder);

struct LLVMCodeGen : Visitor<LLVMCodeGen, Value*> {

std::unique_ptr<LLVMContext> llvmCtx;
std::unique_ptr<llvm::Module> llvmMod;
std::unique_ptr<IRBuilder<>> llvmBuilder;

std::unique_ptr<Builder> wasmBuilder;
Module* module;
Function* func;

int funCounter;

LLVMCodeGen() {
llvmCtx = std::make_unique<LLVMContext>();
llvmMod = std::make_unique<llvm::Module>("LLVMOpt", *llvmCtx);
llvmBuilder = std::make_unique<IRBuilder<>>(*llvmCtx);

llvmMod->setTargetTriple("wasm32-unknown-unknown");

funCounter = 0;
}

Value* visitModule(Module* module_) {
module = module_;
for (auto& mem : module->memories) {
// Handle memory
visitMemory(mem.get());
}
for (auto& func : module->functions) {
visitFunction(func.get());
// if (auto llvmFn = dyn_cast<llvm::Function>(visitFunction(func.get())))
// {
// auto* entryBB = llvm::BasicBlock::Create(*llvmCtx, "entry", llvmFn);
// llvmBuilder->SetInsertPoint(entryBB);
// }
}
std::cout << "Module generated successfully!" << std::endl;
llvmMod->dump();
return nullptr;
}

Value* visitMemory(Memory* mem) {
// Given: Memory* mem = ...;
auto pageSize = Memory::kPageSize; // 64KB
uint64_t totalSize = mem->initial * pageSize;

// Create global memory buffer
ArrayType* memType = ArrayType::get(llvmBuilder->getInt8Ty(), totalSize);
GlobalVariable* llvmMem = new GlobalVariable(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would this be rewritten back into, presumably, multi-memory WASM? Or should this pass only work on single-memory WASM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still relatively new to compilers and WebAssembly (currently studying both), so please forgive any naivety in my code. This was just a small attempt...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a newbie, I’m trying to learn compilers and sys like llvm and wasm. However, the docs are huge and not very beginner-friendly. Any advice on where and how to start? Thanks!

*llvmMod,
memType, // Type (e.g., [65536 x i8])
false, // isConstant = false (mutable)
GlobalValue::AvailableExternallyLinkage, // TODO: Check this
ConstantAggregateZero::get(memType),
"wasm_memory");

// For shared memory (if needed)
if (mem->shared) {
llvmMem->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
}
return nullptr;
}
Value* visitFunction(Function* func) {
std::cout << "[visitFunction]" << std::endl;

// Create function type
std::vector<llvm::Type*> llvmParamsTypes;
std::vector<llvm::Type*> llvmResultTypes;
for (auto param : func->getParams()) {
llvmParamsTypes.push_back(
convertWasmTypeToLLVMType(&param, *llvmBuilder));
}
for (auto result : func->getResults()) {
llvmResultTypes.push_back(
convertWasmTypeToLLVMType(&result, *llvmBuilder));
}
// Currently not support multiple value results.
assert(llvmResultTypes.size() == 1);
llvm::FunctionType* funcType =
llvm::FunctionType::get(llvmResultTypes[0], llvmParamsTypes, false);

llvm::Function* llvmFn = llvm::Function::Create(
funcType,
GlobalValue::ExternalLinkage, // TODO: be used outside, to prevents LLVM
// from assuming the function is unused
"func" + std::to_string(funCounter++),
llvmMod.get());

// Create a new basic block to start insertion into.
BasicBlock* BB = BasicBlock::Create(*llvmCtx, "entry", llvmFn);
llvmBuilder->SetInsertPoint(BB);

if (Value* ret = visit(func->body)) {
// Finish off the function.
llvmBuilder->CreateRet(ret);
return llvmFn;
}

// Error reading body, remove function
llvmFn->eraseFromParent();
assert(0 && "Now we just trap");
return nullptr;
}

// Visiting.
Value* visitBlock(Block* block) {
Value* ret = nullptr;
for (auto expr : block->list) {
ret = visit(expr);
}
return ret;
}

Value* visitConst(Const* c) {
assert(c->type.isBasic());
switch (c->type.getBasic()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note LLVM supports WASM's externref

case wasm::Type::i32:
return ConstantInt::get(llvmBuilder->getInt32Ty(), c->value.geti32());
case wasm::Type::f32:
return ConstantInt::get(llvmBuilder->getFloatTy(), c->value.getf32());
case wasm::Type::i64:
return ConstantInt::get(llvmBuilder->getInt64Ty(), c->value.geti64());
case wasm::Type::f64:
return ConstantInt::get(llvmBuilder->getDoubleTy(), c->value.getf64());
case wasm::Type::v128:
case wasm::Type::none:
case wasm::Type::unreachable:
break;
default:
WASM_UNREACHABLE("Unknown type");
}
return nullptr;
}

Value* visitStore(Store* store) {
// 1. Get the @wasm_memory global variable.
GlobalVariable* wasmMemory =
llvmMod->getGlobalVariable("wasm_memory", true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If multiple memories are supported, this would be unsound.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The initial though would be:

  1. for MVP, we just need map each instructions while carefully dealing with semantics gaps such as UB in LLVM, inconsistency of FP spec and other low-level differences between the source and target semantics.
  2. for non-MVP (GC), we could perform code slicing to collect non GC parts (LLVM-optimizable), and transpile & send them to LLVM optimizer, then retrieve optimized code back and "stitch" it back.

So, in my humble view, it's better to satrt with MVP first.

assert(wasmMemory && "wasm_memory global not found");

// 2. Cast the global to i8* (pointer to first byte of memory)
Value* basePtr = llvmBuilder->CreatePointerCast(
wasmMemory, llvmBuilder->getInt8Ty()->getPointerTo(), "mem_base");

// 3. Evaluate the pointer (offset from memory base)
auto* totalOffset = llvmBuilder->CreateAdd(
visit(store->ptr),
ConstantInt::get(llvmBuilder->getInt32Ty(), store->offset.addr),
"mem_offset");

// 4. Compute the final address: basePtr + totalOffSet
Value* addr = llvmBuilder->CreateGEP(
llvmBuilder->getInt8Ty(), basePtr, totalOffset, "mem_ptr");

// 5. Determine the type of the value to store
llvm::Type* storedType =
convertWasmTypeToLLVMType(&store->valueType, *llvmBuilder);
Value* typedPtr =
llvmBuilder->CreateBitCast(addr, storedType->getPointerTo(), "typed_ptr");

// 6. Create the store instruction
StoreInst* storeInst =
llvmBuilder->CreateStore(visit(store->value), typedPtr);
storeInst->setAlignment(llvm::Align(store->align.addr));

// Handle alignment (WebAssembly uses pow2 alignment)
return storeInst;
}

Value* visitLoad(Load* load) {
// As above visitStore
GlobalVariable* wasmMemory = llvmMod->getGlobalVariable("wasm_memory");
assert(wasmMemory && "wasm_memory global not found");
Value* basePtr = llvmBuilder->CreatePointerCast(
wasmMemory, llvmBuilder->getPtrTy(), "mem_base");
Value* totalOffset = llvmBuilder->CreateAdd(
visit(load->ptr),
ConstantInt::get(llvmBuilder->getInt32Ty(), load->offset.addr),
"mem_offset");
Value* finalAddr = llvmBuilder->CreateGEP(
llvmBuilder->getInt8Ty(), basePtr, totalOffset, "mem_ptr");
llvm::Type* loadedType = determineLoadType(load, *llvmBuilder);
Value* typedPtr = llvmBuilder->CreateBitCast(
finalAddr, loadedType->getPointerTo(), "typed_ptr");

// Handle alignment (WebAssembly uses pow2 alignment)
unsigned alignment = load->align != 0 ? (1 << load->align) : load->bytes;
LoadInst* loadInst =
llvmBuilder->CreateLoad(loadedType, typedPtr, load->memory.toString());

loadInst->setAlignment(llvm::Align(alignment));
loadInst->setVolatile(load->isAtomic); // Atomic implies volatile in LLVM
return loadInst;
}
};

struct LLVMOpt : public Pass {
std::unique_ptr<Pass> create() override {
return std::make_unique<LLVMOpt>();
}

void run(Module* module) override { LLVMCodeGen().visitModule(module); }
};

Pass* createLLVMOptPass() { return new LLVMOpt(); }

llvm::Type* convertWasmTypeToLLVMType(Type* wasmType,
llvm::IRBuilder<>& llvmBuilder) {
switch (wasmType->getBasic()) {
case Type::i32:
return llvmBuilder.getInt32Ty();
case Type::i64:
return llvmBuilder.getInt64Ty();
case Type::f32:
return llvmBuilder.getFloatTy();
case Type::f64:
return llvmBuilder.getDoubleTy();
case Type::v128: /* handle vector type */
break;
default:
// Handle other cases or throw error
WASM_UNREACHABLE("Unsupported WASM type");
}
return nullptr;
}

llvm::Type* determineLoadType(Load* wasmLoad, llvm::IRBuilder<>& llvmBuilder) {
// First check the explicit type if available
if (wasmLoad->type != Type::none) {
switch (wasmLoad->type.getBasic()) {
case Type::i32:
return llvmBuilder.getInt32Ty();
case Type::i64:
return llvmBuilder.getInt64Ty();
case Type::f32:
return llvmBuilder.getFloatTy();
case Type::f64:
return llvmBuilder.getDoubleTy();
case Type::v128: /* SIMD */
return llvm::FixedVectorType::get(llvmBuilder.getInt32Ty(), 4);
default:
break;
}
}

// Fall back to byte size and signedness
switch (wasmLoad->bytes) {
case 1:
return wasmLoad->signed_ ? llvmBuilder.getInt8Ty()
: llvmBuilder.getInt8Ty();
case 2:
return wasmLoad->signed_ ? llvmBuilder.getInt16Ty()
: llvmBuilder.getInt16Ty();
case 4:
return wasmLoad->signed_ ? llvmBuilder.getInt32Ty()
: llvmBuilder.getInt32Ty();
case 8:
return llvmBuilder.getInt64Ty();
default:
assert(false && "Invalid load byte size");
return llvmBuilder.getInt32Ty(); // fallback
}
}

} // namespace wasm
4 changes: 4 additions & 0 deletions src/passes/pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ void PassRegistry::registerPasses() {
registerPass("table64-lowering",
"alias for memory64-lowering",
createMemory64LoweringPass);
registerPass("llvm-opt",
"run llvm to optimize the wasm",
createLLVMOptPass);

registerPass("llvm-memory-copy-fill-lowering",
"Lower memory.copy and memory.fill to wasm mvp and disable "
"the bulk-memory feature.",
Expand Down
1 change: 1 addition & 0 deletions src/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ Pass* createIntrinsicLoweringPass();
Pass* createTraceCallsPass();
Pass* createInstrumentLocalsPass();
Pass* createInstrumentMemoryPass();
Pass* createLLVMOptPass();
Pass* createLLVMMemoryCopyFillLoweringPass();
Pass* createLoopInvariantCodeMotionPass();
Pass* createMemory64LoweringPass();
Expand Down
Loading