Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 17dbbcc

Browse files
committed
Parallel CPU
1 parent 1ef8af5 commit 17dbbcc

File tree

3 files changed

+113
-21
lines changed

3 files changed

+113
-21
lines changed

include/tc/core/polyhedral/llvm_jit.h

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,30 +34,21 @@ class Jit {
3434
llvm::orc::RTDyldObjectLinkingLayer objectLayer_;
3535
llvm::orc::IRCompileLayer<decltype(objectLayer_), llvm::orc::SimpleCompiler>
3636
compileLayer_;
37-
using OptimizeFunction = std::function<std::shared_ptr<llvm::Module>(
38-
std::shared_ptr<llvm::Module>)>;
39-
40-
llvm::orc::IRTransformLayer<decltype(compileLayer_), OptimizeFunction>
41-
optimizeLayer_;
42-
4337
public:
4438
Jit();
4539

4640
void codegenScop(
4741
const std::string& specializedName,
4842
const polyhedral::Scop& scop);
4943

50-
using ModuleHandle = decltype(optimizeLayer_)::ModuleHandleT;
44+
using ModuleHandle = decltype(compileLayer_)::ModuleHandleT;
5145
ModuleHandle addModule(std::unique_ptr<llvm::Module> M);
5246
void removeModule(ModuleHandle H);
5347

5448
llvm::JITSymbol findSymbol(const std::string name);
5549
llvm::JITTargetAddress getSymbolAddress(const std::string name);
5650

5751
llvm::TargetMachine& getTargetMachine();
58-
59-
private:
60-
std::shared_ptr<llvm::Module> optimizeModule(std::shared_ptr<llvm::Module> M);
6152
};
6253

6354
} // namespace tc

src/core/polyhedral/codegen_llvm.cc

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,16 @@
2020

2121
#include "llvm/ADT/STLExtras.h"
2222
#include "llvm/ADT/SmallVector.h"
23+
#include "llvm/Analysis/TargetTransformInfo.h"
2324
#include "llvm/ExecutionEngine/ExecutionEngine.h"
2425
#include "llvm/IR/IRBuilder.h"
26+
#include "llvm/IR/LegacyPassManager.h"
2527
#include "llvm/IR/Verifier.h"
2628
#include "llvm/Support/TargetSelect.h"
2729
#include "llvm/Support/raw_ostream.h"
30+
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
31+
#include "llvm/Transforms/IPO.h"
32+
#include "llvm/Transforms/Tapir/CilkABI.h"
2833

2934
#include "Halide/Halide.h"
3035

@@ -202,7 +207,6 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
202207

203208
using CodeGen_X86::codegen;
204209
using CodeGen_X86::llvm_type_of;
205-
using CodeGen_X86::optimize_module;
206210
using CodeGen_X86::sym_get;
207211
using CodeGen_X86::sym_pop;
208212
using CodeGen_X86::sym_push;
@@ -294,6 +298,81 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
294298

295299
value = sym_get(name);
296300
}
301+
public:
302+
void optimize_module() {
303+
Halide::Internal::debug(3) << "Optimizing module\n";
304+
305+
if (Halide::Internal::debug::debug_level() >= 3) {
306+
#if LLVM_VERSION >= 50
307+
module->print(dbgs(), nullptr, false, true);
308+
#else
309+
module->dump();
310+
#endif
311+
}
312+
313+
// We override PassManager::add so that we have an opportunity to
314+
// blacklist problematic LLVM passes.
315+
class MyFunctionPassManager : public llvm::legacy::FunctionPassManager {
316+
public:
317+
MyFunctionPassManager(llvm::Module *m) : llvm::legacy::FunctionPassManager(m) {}
318+
virtual void add(llvm::Pass *p) override {
319+
Halide::Internal::debug(2) << "Adding function pass: " << p->getPassName().str() << "\n";
320+
llvm::legacy::FunctionPassManager::add(p);
321+
}
322+
};
323+
324+
class MyModulePassManager : public llvm::legacy::PassManager {
325+
public:
326+
virtual void add(llvm::Pass *p) override {
327+
Halide::Internal::debug(2) << "Adding module pass: " << p->getPassName().str() << "\n";
328+
llvm::legacy::PassManager::add(p);
329+
}
330+
};
331+
332+
MyFunctionPassManager function_pass_manager(module.get());
333+
MyModulePassManager module_pass_manager;
334+
335+
std::unique_ptr<llvm::TargetMachine> TM = Halide::Internal::make_target_machine(*module);
336+
module_pass_manager.add(llvm::createTargetTransformInfoWrapperPass(TM ? TM->getTargetIRAnalysis() : llvm::TargetIRAnalysis()));
337+
function_pass_manager.add(llvm::createTargetTransformInfoWrapperPass(TM ? TM->getTargetIRAnalysis() : llvm::TargetIRAnalysis()));
338+
339+
llvm::PassManagerBuilder b;
340+
b.OptLevel = 3;
341+
b.tapirTarget = new llvm::tapir::CilkABI();
342+
#if LLVM_VERSION >= 50
343+
b.Inliner = llvm::createFunctionInliningPass(b.OptLevel, 0, false);
344+
#else
345+
b.Inliner = llvm::createFunctionInliningPass(b.OptLevel, 0);
346+
#endif
347+
b.LoopVectorize = true;
348+
b.SLPVectorize = true;
349+
350+
#if LLVM_VERSION >= 50
351+
if (TM) {
352+
TM->adjustPassManager(b);
353+
}
354+
#endif
355+
356+
b.populateFunctionPassManager(function_pass_manager);
357+
b.populateModulePassManager(module_pass_manager);
358+
359+
// Run optimization passes
360+
function_pass_manager.doInitialization();
361+
for (llvm::Module::iterator i = module->begin(); i != module->end(); i++) {
362+
function_pass_manager.run(*i);
363+
}
364+
function_pass_manager.doFinalization();
365+
module_pass_manager.run(*module);
366+
367+
Halide::Internal::debug(3) << "After LLVM optimizations:\n";
368+
if (Halide::Internal::debug::debug_level() >= 2) {
369+
#if LLVM_VERSION >= 50
370+
module->print(dbgs(), nullptr, false, true);
371+
#else
372+
module->dump();
373+
#endif
374+
}
375+
}
297376
};
298377

299378
class LLVMCodegen {
@@ -451,6 +530,17 @@ class LLVMCodegen {
451530
llvm::BasicBlock::Create(llvmCtx, "loop_latch", function);
452531
auto* loopExitBB = llvm::BasicBlock::Create(llvmCtx, "loop_exit", function);
453532

533+
bool parallel = true;
534+
535+
llvm::Value* SyncRegion = nullptr;
536+
if (parallel) {
537+
SyncRegion = halide_cg.get_builder().CreateCall(
538+
llvm::Intrinsic::getDeclaration(function->getParent(), llvm::Intrinsic::syncregion_start),
539+
{},
540+
"syncreg"
541+
);
542+
}
543+
454544
halide_cg.get_builder().CreateBr(headerBB);
455545

456546
llvm::PHINode* phi = nullptr;
@@ -498,9 +588,20 @@ class LLVMCodegen {
498588
// Create Body
499589
{
500590
halide_cg.get_builder().SetInsertPoint(loopBodyBB);
591+
592+
if (parallel) {
593+
auto* detachedBB = llvm::BasicBlock::Create(llvmCtx, "det.achd", function);
594+
halide_cg.get_builder().CreateDetach(detachedBB, loopLatchBB, SyncRegion);
595+
halide_cg.get_builder().SetInsertPoint(detachedBB);
596+
}
501597
auto* currentBB = emitAst(node.for_get_body());
502598
halide_cg.get_builder().SetInsertPoint(currentBB);
503-
halide_cg.get_builder().CreateBr(loopLatchBB);
599+
600+
if (parallel) {
601+
halide_cg.get_builder().CreateReattach(loopLatchBB, SyncRegion);
602+
} else {
603+
halide_cg.get_builder().CreateBr(loopLatchBB);
604+
}
504605
}
505606

506607
// Create Latch
@@ -516,6 +617,11 @@ class LLVMCodegen {
516617

517618
halide_cg.get_builder().SetInsertPoint(loopExitBB);
518619
halide_cg.sym_pop(node.for_get_iterator().get_id().get_name());
620+
if (parallel) {
621+
auto* syncBB = llvm::BasicBlock::Create(llvmCtx, "synced", function);
622+
halide_cg.get_builder().CreateSync(syncBB, SyncRegion);
623+
halide_cg.get_builder().SetInsertPoint(syncBB);
624+
}
519625
return halide_cg.get_builder().GetInsertBlock();
520626
}
521627

src/core/polyhedral/llvm_jit.cc

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,7 @@ Jit::Jit()
3535
: TM_(EngineBuilder().selectTarget()),
3636
DL_(TM_->createDataLayout()),
3737
objectLayer_([]() { return std::make_shared<SectionMemoryManager>(); }),
38-
compileLayer_(objectLayer_, orc::SimpleCompiler(*TM_)),
39-
optimizeLayer_(compileLayer_, [this](std::shared_ptr<Module> M) {
40-
return optimizeModule(std::move(M));
41-
}) {
38+
compileLayer_(objectLayer_, orc::SimpleCompiler(*TM_)) {
4239
sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
4340
}
4441

@@ -57,7 +54,7 @@ Jit::ModuleHandle Jit::addModule(std::unique_ptr<Module> M) {
5754
M->setTargetTriple(TM_->getTargetTriple().str());
5855
auto Resolver = orc::createLambdaResolver(
5956
[&](const std::string& Name) {
60-
if (auto Sym = optimizeLayer_.findSymbol(Name, false))
57+
if (auto Sym = compileLayer_.findSymbol(Name, false))
6158
return Sym;
6259
return JITSymbol(nullptr);
6360
},
@@ -67,7 +64,7 @@ Jit::ModuleHandle Jit::addModule(std::unique_ptr<Module> M) {
6764
return JITSymbol(nullptr);
6865
});
6966

70-
auto res = optimizeLayer_.addModule(std::move(M), std::move(Resolver));
67+
auto res = compileLayer_.addModule(std::move(M), std::move(Resolver));
7168
CHECK(res) << "Failed to jit compile.";
7269
return *res;
7370
}
@@ -76,7 +73,7 @@ JITSymbol Jit::findSymbol(const std::string Name) {
7673
std::string MangledName;
7774
raw_string_ostream MangledNameStream(MangledName);
7875
Mangler::getNameWithPrefix(MangledNameStream, Name, DL_);
79-
return optimizeLayer_.findSymbol(MangledNameStream.str(), true);
76+
return compileLayer_.findSymbol(MangledNameStream.str(), true);
8077
}
8178

8279
JITTargetAddress Jit::getSymbolAddress(const std::string Name) {
@@ -124,5 +121,3 @@ std::shared_ptr<Module> Jit::optimizeModule(std::shared_ptr<Module> M) {
124121

125122
return M;
126123
}
127-
128-
} // namespace tc

0 commit comments

Comments
 (0)