Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit d12ed79

Browse files
authored
Merge pull request #30 from facebookresearch/pcpu
Permit parallel CPU backend compilation (though currently disabled)
2 parents 1ef8af5 + 675a42d commit d12ed79

File tree

4 files changed

+144
-61
lines changed

4 files changed

+144
-61
lines changed

include/tc/core/polyhedral/llvm_jit.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,6 @@ class Jit {
3434
llvm::orc::RTDyldObjectLinkingLayer objectLayer_;
3535
llvm::orc::IRCompileLayer<decltype(objectLayer_), llvm::orc::SimpleCompiler>
3636
compileLayer_;
37-
using OptimizeFunction = std::function<std::shared_ptr<llvm::Module>(
38-
std::shared_ptr<llvm::Module>)>;
39-
40-
llvm::orc::IRTransformLayer<decltype(compileLayer_), OptimizeFunction>
41-
optimizeLayer_;
4237

4338
public:
4439
Jit();
@@ -47,17 +42,14 @@ class Jit {
4742
const std::string& specializedName,
4843
const polyhedral::Scop& scop);
4944

50-
using ModuleHandle = decltype(optimizeLayer_)::ModuleHandleT;
45+
using ModuleHandle = decltype(compileLayer_)::ModuleHandleT;
5146
ModuleHandle addModule(std::unique_ptr<llvm::Module> M);
5247
void removeModule(ModuleHandle H);
5348

5449
llvm::JITSymbol findSymbol(const std::string name);
5550
llvm::JITTargetAddress getSymbolAddress(const std::string name);
5651

5752
llvm::TargetMachine& getTargetMachine();
58-
59-
private:
60-
std::shared_ptr<llvm::Module> optimizeModule(std::shared_ptr<llvm::Module> M);
6153
};
6254

6355
} // namespace tc

src/core/polyhedral/codegen_llvm.cc

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,17 @@
2020

2121
#include "llvm/ADT/STLExtras.h"
2222
#include "llvm/ADT/SmallVector.h"
23+
#include "llvm/Analysis/TargetTransformInfo.h"
24+
#include "llvm/Config/llvm-config.h"
2325
#include "llvm/ExecutionEngine/ExecutionEngine.h"
2426
#include "llvm/IR/IRBuilder.h"
27+
#include "llvm/IR/LegacyPassManager.h"
2528
#include "llvm/IR/Verifier.h"
2629
#include "llvm/Support/TargetSelect.h"
2730
#include "llvm/Support/raw_ostream.h"
31+
#include "llvm/Transforms/IPO.h"
32+
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
33+
#include "llvm/Transforms/Tapir/CilkABI.h"
2834

2935
#include "Halide/Halide.h"
3036

@@ -37,6 +43,10 @@
3743
#include "tc/core/polyhedral/scop.h"
3844
#include "tc/core/scope_guard.h"
3945

46+
#ifndef LLVM_VERSION_MAJOR
47+
#error LLVM_VERSION_MAJOR not set
48+
#endif
49+
4050
using namespace Halide;
4151

4252
namespace tc {
@@ -195,14 +205,17 @@ class IslAstExprInterpeter {
195205
}
196206
};
197207

208+
DEFINE_bool(llvm_dump_before_opt, false, "Print IR before optimization");
209+
DEFINE_bool(llvm_dump_after_opt, false, "Print IR after optimization");
210+
static constexpr int kOptLevel = 3;
211+
198212
class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
199213
public:
200214
const isl::pw_multi_aff* iteratorMap_;
201215
CodeGen_TC(Target t) : CodeGen_X86(t) {}
202216

203217
using CodeGen_X86::codegen;
204218
using CodeGen_X86::llvm_type_of;
205-
using CodeGen_X86::optimize_module;
206219
using CodeGen_X86::sym_get;
207220
using CodeGen_X86::sym_pop;
208221
using CodeGen_X86::sym_push;
@@ -294,6 +307,52 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
294307

295308
value = sym_get(name);
296309
}
310+
311+
public:
312+
void optimize_module() {
313+
if (FLAGS_llvm_dump_before_opt) {
314+
module->print(llvm::dbgs(), nullptr, false, true);
315+
}
316+
317+
llvm::legacy::FunctionPassManager functionPassManager(module.get());
318+
llvm::legacy::PassManager modulePassManager;
319+
320+
std::unique_ptr<llvm::TargetMachine> targetMachine =
321+
Halide::Internal::make_target_machine(*module);
322+
modulePassManager.add(llvm::createTargetTransformInfoWrapperPass(
323+
targetMachine ? targetMachine->getTargetIRAnalysis()
324+
: llvm::TargetIRAnalysis()));
325+
functionPassManager.add(llvm::createTargetTransformInfoWrapperPass(
326+
targetMachine ? targetMachine->getTargetIRAnalysis()
327+
: llvm::TargetIRAnalysis()));
328+
329+
llvm::PassManagerBuilder b;
330+
b.OptLevel = kOptLevel;
331+
b.tapirTarget = new llvm::CilkABI();
332+
b.Inliner = llvm::createFunctionInliningPass(b.OptLevel, 0, false);
333+
b.LoopVectorize = true;
334+
b.SLPVectorize = true;
335+
336+
if (targetMachine) {
337+
targetMachine->adjustPassManager(b);
338+
}
339+
340+
b.populateFunctionPassManager(functionPassManager);
341+
b.populateModulePassManager(modulePassManager);
342+
343+
// Run optimization passes
344+
functionPassManager.doInitialization();
345+
for (llvm::Module::iterator i = module->begin(); i != module->end(); i++) {
346+
functionPassManager.run(*i);
347+
348+
functionPassManager.doFinalization();
349+
modulePassManager.run(*module);
350+
351+
if (FLAGS_llvm_dump_after_opt) {
352+
module->print(llvm::dbgs(), nullptr, false, true);
353+
}
354+
}
355+
}
297356
};
298357

299358
class LLVMCodegen {
@@ -451,6 +510,19 @@ class LLVMCodegen {
451510
llvm::BasicBlock::Create(llvmCtx, "loop_latch", function);
452511
auto* loopExitBB = llvm::BasicBlock::Create(llvmCtx, "loop_exit", function);
453512

513+
// TODO: integrate query ISL as to whether the relevant loop ought be
514+
// parallelized
515+
bool parallel = false;
516+
517+
llvm::Value* SyncRegion = nullptr;
518+
if (parallel) {
519+
SyncRegion = halide_cg.get_builder().CreateCall(
520+
llvm::Intrinsic::getDeclaration(
521+
function->getParent(), llvm::Intrinsic::syncregion_start),
522+
{},
523+
"syncreg");
524+
}
525+
454526
halide_cg.get_builder().CreateBr(headerBB);
455527

456528
llvm::PHINode* phi = nullptr;
@@ -498,9 +570,22 @@ class LLVMCodegen {
498570
// Create Body
499571
{
500572
halide_cg.get_builder().SetInsertPoint(loopBodyBB);
573+
574+
if (parallel) {
575+
auto* detachedBB =
576+
llvm::BasicBlock::Create(llvmCtx, "det.achd", function);
577+
halide_cg.get_builder().CreateDetach(
578+
detachedBB, loopLatchBB, SyncRegion);
579+
halide_cg.get_builder().SetInsertPoint(detachedBB);
580+
}
501581
auto* currentBB = emitAst(node.for_get_body());
502582
halide_cg.get_builder().SetInsertPoint(currentBB);
503-
halide_cg.get_builder().CreateBr(loopLatchBB);
583+
584+
if (parallel) {
585+
halide_cg.get_builder().CreateReattach(loopLatchBB, SyncRegion);
586+
} else {
587+
halide_cg.get_builder().CreateBr(loopLatchBB);
588+
}
504589
}
505590

506591
// Create Latch
@@ -516,6 +601,11 @@ class LLVMCodegen {
516601

517602
halide_cg.get_builder().SetInsertPoint(loopExitBB);
518603
halide_cg.sym_pop(node.for_get_iterator().get_id().get_name());
604+
if (parallel) {
605+
auto* syncBB = llvm::BasicBlock::Create(llvmCtx, "synced", function);
606+
halide_cg.get_builder().CreateSync(syncBB, SyncRegion);
607+
halide_cg.get_builder().SetInsertPoint(syncBB);
608+
}
519609
return halide_cg.get_builder().GetInsertBlock();
520610
}
521611

@@ -583,7 +673,8 @@ class LLVMCodegen {
583673
CodeGen_TC halide_cg;
584674
};
585675

586-
// Create a list of isl ids to be used as loop iterators when building the AST.
676+
// Create a list of isl ids to be used as loop iterators when building the
677+
// AST.
587678
//
588679
// Note that this function can be scrapped as ISL can generate some default
589680
// iterator names. However, it may come handy for associating extra info with

src/core/polyhedral/llvm_jit.cc

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
#include <stdexcept>
17+
1618
#include "tc/core/polyhedral/llvm_jit.h"
1719

1820
#include "llvm/ExecutionEngine/ExecutionEngine.h"
@@ -29,17 +31,55 @@
2931

3032
using namespace llvm;
3133

34+
// Parse through ldconfig to find the path of a particular
35+
// shared library. This is an unfortunate way to have to
36+
// find it, but I couldn't immediately find something in
37+
// imported libraries that would resolve this for us.
38+
std::string find_library_path(std::string library) {
39+
std::string command = "ldconfig -p | grep " + library;
40+
41+
FILE* fpipe = popen(command.c_str(), "r");
42+
43+
if (fpipe == nullptr) {
44+
throw std::runtime_error("Failed to popen()");
45+
}
46+
47+
std::string output;
48+
char buffer[512];
49+
50+
while (1) {
51+
int charactersRead = fread(buffer, 1, sizeof(buffer), fpipe);
52+
if (charactersRead == 0)
53+
break;
54+
output += std::string(buffer, charactersRead);
55+
}
56+
pclose(fpipe);
57+
58+
int idx = output.rfind("=> ");
59+
if (idx == std::string::npos) {
60+
throw std::runtime_error("Failed locate library: " + library);
61+
}
62+
output = output.substr(idx + 3);
63+
if (output.length() > 0 && output[output.length() - 1] == '\n') {
64+
output = output.substr(0, output.length() - 1);
65+
}
66+
return output;
67+
}
68+
3269
namespace tc {
3370

3471
Jit::Jit()
3572
: TM_(EngineBuilder().selectTarget()),
3673
DL_(TM_->createDataLayout()),
3774
objectLayer_([]() { return std::make_shared<SectionMemoryManager>(); }),
38-
compileLayer_(objectLayer_, orc::SimpleCompiler(*TM_)),
39-
optimizeLayer_(compileLayer_, [this](std::shared_ptr<Module> M) {
40-
return optimizeModule(std::move(M));
41-
}) {
42-
sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
75+
compileLayer_(objectLayer_, orc::SimpleCompiler(*TM_)) {
76+
std::string err;
77+
78+
auto path = find_library_path("libcilkrts.so");
79+
sys::DynamicLibrary::LoadLibraryPermanently(path.c_str(), &err);
80+
if (err != "") {
81+
throw std::runtime_error("Failed to find cilkrts: " + err);
82+
}
4383
}
4484

4585
void Jit::codegenScop(
@@ -57,7 +97,7 @@ Jit::ModuleHandle Jit::addModule(std::unique_ptr<Module> M) {
5797
M->setTargetTriple(TM_->getTargetTriple().str());
5898
auto Resolver = orc::createLambdaResolver(
5999
[&](const std::string& Name) {
60-
if (auto Sym = optimizeLayer_.findSymbol(Name, false))
100+
if (auto Sym = compileLayer_.findSymbol(Name, false))
61101
return Sym;
62102
return JITSymbol(nullptr);
63103
},
@@ -67,7 +107,7 @@ Jit::ModuleHandle Jit::addModule(std::unique_ptr<Module> M) {
67107
return JITSymbol(nullptr);
68108
});
69109

70-
auto res = optimizeLayer_.addModule(std::move(M), std::move(Resolver));
110+
auto res = compileLayer_.addModule(std::move(M), std::move(Resolver));
71111
CHECK(res) << "Failed to jit compile.";
72112
return *res;
73113
}
@@ -76,7 +116,7 @@ JITSymbol Jit::findSymbol(const std::string Name) {
76116
std::string MangledName;
77117
raw_string_ostream MangledNameStream(MangledName);
78118
Mangler::getNameWithPrefix(MangledNameStream, Name, DL_);
79-
return optimizeLayer_.findSymbol(MangledNameStream.str(), true);
119+
return compileLayer_.findSymbol(MangledNameStream.str(), true);
80120
}
81121

82122
JITTargetAddress Jit::getSymbolAddress(const std::string Name) {
@@ -85,44 +125,4 @@ JITTargetAddress Jit::getSymbolAddress(const std::string Name) {
85125
return *res;
86126
}
87127

88-
DEFINE_bool(llvm_no_opt, false, "Disable LLVM optimizations");
89-
DEFINE_bool(llvm_debug_passes, false, "Print pass debug info");
90-
DEFINE_bool(llvm_dump_optimized_ir, false, "Print optimized IR");
91-
92-
std::shared_ptr<Module> Jit::optimizeModule(std::shared_ptr<Module> M) {
93-
if (FLAGS_llvm_no_opt) {
94-
return M;
95-
}
96-
97-
PassBuilder PB(TM_.get());
98-
AAManager AA;
99-
CHECK(PB.parseAAPipeline(AA, "default"))
100-
<< "Unable to parse AA pipeline description.";
101-
LoopAnalysisManager LAM(FLAGS_llvm_debug_passes);
102-
FunctionAnalysisManager FAM(FLAGS_llvm_debug_passes);
103-
CGSCCAnalysisManager CGAM(FLAGS_llvm_debug_passes);
104-
ModuleAnalysisManager MAM(FLAGS_llvm_debug_passes);
105-
FAM.registerPass([&] { return std::move(AA); });
106-
PB.registerModuleAnalyses(MAM);
107-
PB.registerCGSCCAnalyses(CGAM);
108-
PB.registerFunctionAnalyses(FAM);
109-
PB.registerLoopAnalyses(LAM);
110-
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
111-
112-
ModulePassManager MPM(FLAGS_llvm_debug_passes);
113-
MPM.addPass(VerifierPass());
114-
CHECK(PB.parsePassPipeline(MPM, "default<O3>", true, FLAGS_llvm_debug_passes))
115-
<< "Unable to parse pass pipline description.";
116-
MPM.addPass(VerifierPass());
117-
118-
MPM.run(*M, MAM);
119-
120-
if (FLAGS_llvm_dump_optimized_ir) {
121-
// M->dump(); // does not link
122-
M->print(llvm::errs(), nullptr);
123-
}
124-
125-
return M;
126-
}
127-
128128
} // namespace tc

third-party/halide

Submodule halide updated 75 files

0 commit comments

Comments
 (0)