diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index 6fab09b1bc2..0e6c7320fc1 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -114,6 +114,7 @@ set(passes_SOURCES ReorderGlobals.cpp ReorderLocals.cpp ReReloop.cpp + TailCall.cpp TrapMode.cpp TypeGeneralizing.cpp TypeRefining.cpp diff --git a/src/passes/TailCall.cpp b/src/passes/TailCall.cpp new file mode 100644 index 00000000000..d0ac118af07 --- /dev/null +++ b/src/passes/TailCall.cpp @@ -0,0 +1,86 @@ + +#include "pass.h" +#include "wasm-traversal.h" +#include "wasm.h" +#include + +namespace wasm { + +namespace { + +struct Finder : TryDepthWalker { + std::vector tailCalls; + std::vector tailCallIndirects; + void visitFunction(Function* curr) { checkTailCall(curr->body); } + void visitReturn(Return* curr) { checkTailCall(curr->value); } + +private: + void checkTailCall(Expression* expr) { + if (expr == nullptr) { + return; + } + if (tryDepth > 0) { + // We are in a try block, so we cannot optimize tail calls. + return; + } + if (auto* call = expr->dynCast()) { + if (!call->isReturn && call->type == getFunction()->getResults()) { + tailCalls.push_back(call); + } + return; + } + if (auto* call = expr->dynCast()) { + if (!call->isReturn && call->type == getFunction()->getResults()) { + tailCallIndirects.push_back(call); + } + return; + } + if (auto* block = expr->dynCast()) { + return checkTailCall(block->list); + } + if (auto* ifElse = expr->dynCast()) { + checkTailCall(ifElse->ifTrue); + checkTailCall(ifElse->ifFalse); + return; + } + } + void checkTailCall(ExpressionList const& exprs) { + if (exprs.empty()) { + return; + } + checkTailCall(exprs.back()); + return; + } +}; + +} // namespace + +struct TailCallOptimizer : public Pass { + bool isFunctionParallel() override { return true; } + std::unique_ptr create() override { + return std::make_unique(); + } + void runOnFunction(Module* module, Function* function) override { + if (!module->features.hasTailCall()) { + return; + } + Finder finder{}; + finder.walkFunctionInModule(function, module); + for (Call* call : finder.tailCalls) { + if (!call->isReturn) { + call->isReturn = true; + call->finalize(); + } + } + for (CallIndirect* call : finder.tailCallIndirects) { + if (!call->isReturn) { + call->isReturn = true; + call->finalize(); + } + } + } +}; + +Pass* createTailCallPass() { return new TailCallOptimizer(); } + +} // namespace wasm diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 2042bc71d3a..ffc3ecf941f 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -552,6 +552,8 @@ void PassRegistry::registerPasses() { registerPass("strip-target-features", "strip the wasm target features section", createStripTargetFeaturesPass); + registerPass( + "tail-call", "transform call to return call", createTailCallPass); registerPass("translate-to-new-eh", "deprecated; same as translate-to-exnref", createTranslateToExnrefPass); diff --git a/src/passes/passes.h b/src/passes/passes.h index e051e466e72..fc58f3775d1 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -176,6 +176,7 @@ Pass* createStripEHPass(); Pass* createStubUnsupportedJSOpsPass(); Pass* createSSAifyPass(); Pass* createSSAifyNoMergePass(); +Pass* createTailCallPass(); Pass* createTable64LoweringPass(); Pass* createTranslateToExnrefPass(); Pass* createTrapModeClamp();