Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit c6e1989

Browse files
Refactor LLVMCodegen
This commit avoids leaking all the guts of LLVMCodegen to the emitLLVMKernel function and restructures some code.
1 parent 9647fac commit c6e1989

File tree

1 file changed

+136
-141
lines changed

1 file changed

+136
-141
lines changed

tc/core/polyhedral/codegen_llvm.cc

Lines changed: 136 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,73 @@
5555
using namespace Halide;
5656

5757
namespace tc {
58-
5958
namespace polyhedral {
60-
59+
namespace {
6160
using IteratorMapType = std::unordered_map<std::string, isl::ast_expr>;
6261
using IteratorMapsType =
6362
std::unordered_map<isl::id, IteratorMapType, isl::IslIdIslHash>;
6463

6564
using StmtSubscriptExprMapType =
6665
std::unordered_map<isl::id, std::vector<isl::ast_expr>, isl::IslIdIslHash>;
6766

68-
namespace {
67+
struct IslCodegenRes {
68+
IteratorMapsType iteratorMaps;
69+
StmtSubscriptExprMapType stmtSubscripts;
70+
isl::ast_node astNode;
71+
};
72+
73+
isl::ast_node collectIteratorMaps(
74+
isl::ast_node node,
75+
isl::ast_build build,
76+
IteratorMapsType& iteratorMaps,
77+
const Scop& scop,
78+
StmtSubscriptExprMapType& stmtSubscripts) {
79+
auto user = node.as<isl::ast_node_user>();
80+
TC_CHECK(user);
81+
auto expr = user.get_expr().as<isl::ast_expr_op>();
82+
auto schedule = build.get_schedule();
83+
auto scheduleMap = isl::map::from_union_map(schedule);
84+
85+
auto stmtId = expr.get_arg(0).as<isl::ast_expr_id>().get_id();
86+
TC_CHECK_EQ(0u, iteratorMaps.count(stmtId)) << "entry exists: " << stmtId;
87+
auto iteratorMap = isl::pw_multi_aff(scheduleMap.reverse());
88+
auto tuple = scop.halide.domains.at(stmtId).tuple;
89+
auto& stmtIteratorMap = iteratorMaps[stmtId];
90+
for (int i = 0; i < tuple.size(); ++i) {
91+
auto expr = build.expr_from(iteratorMap.get_pw_aff(i));
92+
stmtIteratorMap.emplace(tuple.get_id(i).get_name(), expr);
93+
}
94+
auto& subscripts = stmtSubscripts[stmtId];
95+
auto provide =
96+
scop.halide.statements.at(stmtId).as<Halide::Internal::Provide>();
97+
for (auto e : provide->args) {
98+
const auto& map = iteratorMap;
99+
auto aff = scop.makeIslAffFromStmtExpr(stmtId, e);
100+
auto pulled = isl::pw_aff(aff).pullback(map);
101+
TC_CHECK_EQ(pulled.n_piece(), 1);
102+
subscripts.push_back(build.expr_from(pulled));
103+
}
104+
return node.set_annotation(stmtId);
105+
}
106+
107+
static IslCodegenRes codegenISL(const Scop& scop) {
108+
IteratorMapsType iteratorMaps;
109+
StmtSubscriptExprMapType stmtSubscripts;
110+
auto collect = [&iteratorMaps, &scop, &stmtSubscripts](
111+
isl::ast_node n, isl::ast_build b) -> isl::ast_node {
112+
auto& uv = iteratorMaps;
113+
return collectIteratorMaps(n, b, uv, scop, stmtSubscripts);
114+
};
115+
116+
auto schedule = detail::toIslSchedule(scop.scheduleRoot());
117+
auto astBuild = isl::ast_build(schedule.get_ctx());
118+
astBuild = astBuild.set_at_each_domain(collect);
119+
auto root = scop.scheduleRoot();
120+
astBuild = astBuild.set_iterators(Codegen::makeLoopIterators(root));
121+
auto astNode = astBuild.node_from(schedule);
122+
return {
123+
std::move(iteratorMaps), std::move(stmtSubscripts), std::move(astNode)};
124+
}
69125

70126
thread_local llvm::LLVMContext llvmCtx;
71127

@@ -324,6 +380,71 @@ Halide::Expr CodeGen_TC::makeHalideExpr(isl::ast_expr expr) {
324380
}
325381

326382
class LLVMCodegen {
383+
public:
384+
LLVMCodegen(
385+
const std::string& specializedName,
386+
const Scop& scop,
387+
const llvm::TargetMachine& targetMachine)
388+
: scop_(scop),
389+
islCg_(codegenISL(scop_)),
390+
iteratorMaps_(islCg_.iteratorMaps),
391+
stmtSubscripts_(islCg_.stmtSubscripts),
392+
targetMachine(targetMachine),
393+
// we don't use Halide to tinker with llvm::Module optimization so we
394+
// tthe Halide target can be whatever.
395+
halide_cg(Halide::get_host_target()) {
396+
halide_cg.set_context(llvmCtx);
397+
halide_cg.init_module();
398+
halide_cg.get_module()->setDataLayout(targetMachine.createDataLayout());
399+
halide_cg.get_module()->setTargetTriple(
400+
targetMachine.getTargetTriple().str());
401+
auto entry = createSignature(
402+
scop.halide.inputs,
403+
scop.halide.outputs,
404+
scop.halide.params,
405+
specializedName);
406+
auto exit = emitAst(islCg_.astNode, entry);
407+
halide_cg.get_builder().SetInsertPoint(exit);
408+
halide_cg.get_builder().CreateRetVoid();
409+
410+
TC_CHECK(!llvm::verifyModule(*halide_cg.get_module()))
411+
<< "LLVM generated module is invalid." << str().c_str();
412+
413+
halide_cg.optimize_module(targetMachine);
414+
415+
if (FLAGS_llvm_dump_asm) {
416+
std::string pat("/tmp/tcXXXXXX");
417+
std::vector<char> ifn(pat.begin(), pat.end());
418+
TC_CHECK_GE(mkstemp(ifn.data()), 0); // string.c_str is const char*
419+
std::string fileName(ifn.begin(), ifn.end());
420+
std::string optFile = fileName + "-opt.ll";
421+
std::string asmFile = fileName + ".s";
422+
// cstdio's std::remove to delete files
423+
tc::ScopeGuard sgi([&]() {
424+
std::remove(optFile.c_str());
425+
std::remove(asmFile.c_str());
426+
});
427+
{
428+
std::ofstream ostream(optFile, std::ios::binary);
429+
ostream << str();
430+
}
431+
utils::checkedSystemCall(
432+
std::string(TC_STRINGIFY(TC_LLVM_BIN_DIR)) + "/llc",
433+
{FLAGS_llvm_dump_asm_options,
434+
utils::CPUID::llcFlags(),
435+
optFile,
436+
std::string("-o ") + asmFile});
437+
438+
std::ifstream is(asmFile);
439+
std::string str(
440+
(std::istreambuf_iterator<char>(is)),
441+
std::istreambuf_iterator<char>());
442+
LOG(INFO) << "Dumping asm for: " << utils::CPUID::llcFlags() << "\n"
443+
<< str;
444+
}
445+
}
446+
447+
private:
327448
void collectTensor(const Halide::OutputImageParam& t) {
328449
auto sizes = getTensorSizesWithoutLeadingDim(t, scop_.context());
329450
if (not sizes.empty()) {
@@ -354,23 +475,16 @@ class LLVMCodegen {
354475
}
355476
}
356477

357-
public:
358-
LLVMCodegen(
359-
const Scop& scop,
360-
const IteratorMapsType& iteratorMaps,
361-
const StmtSubscriptExprMapType& stmtSubscripts,
362-
const llvm::TargetMachine& targetMachine)
363-
: scop_(scop),
364-
iteratorMaps_(iteratorMaps),
365-
stmtSubscripts_(stmtSubscripts),
366-
targetMachine(targetMachine),
367-
halide_cg(Halide::Target(
368-
Halide::Target::OSUnknown,
369-
Halide::Target::X86,
370-
64)) {
371-
halide_cg.set_context(llvmCtx);
372-
373-
halide_cg.init_module();
478+
llvm::Type* makePtrToArrayType(
479+
llvm::Type* baseTy,
480+
const std::vector<int64_t>& sizes) {
481+
TC_CHECK_GE(sizes.size(), 1u);
482+
TC_CHECK(baseTy);
483+
llvm::Type* arrTy = llvm::ArrayType::get(baseTy, sizes.back());
484+
for (auto s = sizes.rbegin() + 1; s != sizes.rend(); ++s) {
485+
arrTy = llvm::ArrayType::get(arrTy, *s);
486+
}
487+
return arrTy->getPointerTo();
374488
}
375489

376490
// This creates a signature of the form:
@@ -451,19 +565,6 @@ class LLVMCodegen {
451565
return nullptr;
452566
}
453567

454-
private:
455-
llvm::Type* makePtrToArrayType(
456-
llvm::Type* baseTy,
457-
const std::vector<int64_t>& sizes) {
458-
TC_CHECK_GE(sizes.size(), 1u);
459-
TC_CHECK(baseTy);
460-
llvm::Type* arrTy = llvm::ArrayType::get(baseTy, sizes.back());
461-
for (auto s = sizes.rbegin() + 1; s != sizes.rend(); ++s) {
462-
arrTy = llvm::ArrayType::get(arrTy, *s);
463-
}
464-
return arrTy->getPointerTo();
465-
}
466-
467568
llvm::BasicBlock* emitIf(
468569
isl::ast_node_if node,
469570
llvm::BasicBlock* insertionPoint) {
@@ -582,6 +683,7 @@ class LLVMCodegen {
582683

583684
private:
584685
const Scop& scop_;
686+
const IslCodegenRes islCg_;
585687
const IteratorMapsType& iteratorMaps_;
586688
const StmtSubscriptExprMapType& stmtSubscripts_;
587689

@@ -592,120 +694,13 @@ class LLVMCodegen {
592694
const llvm::TargetMachine& targetMachine;
593695
CodeGen_TC halide_cg;
594696
};
595-
596-
struct IslCodegenRes {
597-
IteratorMapsType iteratorMaps;
598-
StmtSubscriptExprMapType stmtSubscripts;
599-
isl::ast_node astNode;
600-
};
601-
602-
isl::ast_node collectIteratorMaps(
603-
isl::ast_node node,
604-
isl::ast_build build,
605-
IteratorMapsType& iteratorMaps,
606-
const Scop& scop,
607-
StmtSubscriptExprMapType& stmtSubscripts) {
608-
auto user = node.as<isl::ast_node_user>();
609-
TC_CHECK(user);
610-
auto expr = user.get_expr().as<isl::ast_expr_op>();
611-
auto schedule = build.get_schedule();
612-
auto scheduleMap = isl::map::from_union_map(schedule);
613-
614-
auto stmtId = expr.get_arg(0).as<isl::ast_expr_id>().get_id();
615-
TC_CHECK_EQ(0u, iteratorMaps.count(stmtId)) << "entry exists: " << stmtId;
616-
auto iteratorMap = isl::pw_multi_aff(scheduleMap.reverse());
617-
auto tuple = scop.halide.domains.at(stmtId).tuple;
618-
auto& stmtIteratorMap = iteratorMaps[stmtId];
619-
for (int i = 0; i < tuple.size(); ++i) {
620-
auto expr = build.expr_from(iteratorMap.get_pw_aff(i));
621-
stmtIteratorMap.emplace(tuple.get_id(i).get_name(), expr);
622-
}
623-
auto& subscripts = stmtSubscripts[stmtId];
624-
auto provide =
625-
scop.halide.statements.at(stmtId).as<Halide::Internal::Provide>();
626-
for (auto e : provide->args) {
627-
const auto& map = iteratorMap;
628-
auto aff = scop.makeIslAffFromStmtExpr(stmtId, e);
629-
auto pulled = isl::pw_aff(aff).pullback(map);
630-
TC_CHECK_EQ(pulled.n_piece(), 1);
631-
subscripts.push_back(build.expr_from(pulled));
632-
}
633-
return node.set_annotation(stmtId);
634-
}
635-
636-
static IslCodegenRes codegenISL(const Scop& scop) {
637-
IteratorMapsType iteratorMaps;
638-
StmtSubscriptExprMapType stmtSubscripts;
639-
auto collect = [&iteratorMaps, &scop, &stmtSubscripts](
640-
isl::ast_node n, isl::ast_build b) -> isl::ast_node {
641-
auto& uv = iteratorMaps;
642-
return collectIteratorMaps(n, b, uv, scop, stmtSubscripts);
643-
};
644-
645-
auto schedule = detail::toIslSchedule(scop.scheduleRoot());
646-
auto astBuild = isl::ast_build(schedule.get_ctx());
647-
astBuild = astBuild.set_at_each_domain(collect);
648-
auto root = scop.scheduleRoot();
649-
astBuild = astBuild.set_iterators(Codegen::makeLoopIterators(root));
650-
auto astNode = astBuild.node_from(schedule);
651-
return {
652-
std::move(iteratorMaps), std::move(stmtSubscripts), std::move(astNode)};
653-
}
654-
655697
} // namespace
656698

657699
std::unique_ptr<llvm::Module> emitLLVMKernel(
658700
const std::string& specializedName,
659701
const Scop& scop,
660702
const llvm::TargetMachine& targetMachine) {
661-
auto islCg = codegenISL(scop);
662-
LLVMCodegen cg(scop, islCg.iteratorMaps, islCg.stmtSubscripts, targetMachine);
663-
cg.halide_cg.get_module()->setDataLayout(targetMachine.createDataLayout());
664-
cg.halide_cg.get_module()->setTargetTriple(
665-
llvm::EngineBuilder().selectTarget()->getTargetTriple().str());
666-
auto entry = cg.createSignature(
667-
scop.halide.inputs,
668-
scop.halide.outputs,
669-
scop.halide.params,
670-
specializedName);
671-
auto exit = cg.emitAst(islCg.astNode, entry);
672-
cg.halide_cg.get_builder().SetInsertPoint(exit);
673-
cg.halide_cg.get_builder().CreateRetVoid();
674-
675-
TC_CHECK(!llvm::verifyModule(*cg.halide_cg.get_module()))
676-
<< "LLVM generated module is invalid." << cg.str().c_str();
677-
678-
cg.halide_cg.optimize_module(cg.targetMachine);
679-
if (FLAGS_llvm_dump_asm) {
680-
std::string pat("/tmp/tcXXXXXX");
681-
std::vector<char> ifn(pat.begin(), pat.end());
682-
TC_CHECK_GE(mkstemp(ifn.data()), 0); // string.c_str is const char*
683-
std::string fileName(ifn.begin(), ifn.end());
684-
std::string optFile = fileName + "-opt.ll";
685-
std::string asmFile = fileName + ".s";
686-
// cstdio's std::remove to delete files
687-
tc::ScopeGuard sgi([&]() {
688-
std::remove(optFile.c_str());
689-
std::remove(asmFile.c_str());
690-
});
691-
{
692-
std::ofstream ostream(optFile, std::ios::binary);
693-
ostream << cg.str();
694-
}
695-
utils::checkedSystemCall(
696-
std::string(TC_STRINGIFY(TC_LLVM_BIN_DIR)) + "/llc",
697-
{FLAGS_llvm_dump_asm_options,
698-
utils::CPUID::llcFlags(),
699-
optFile,
700-
std::string("-o ") + asmFile});
701-
{
702-
std::ifstream is(asmFile);
703-
std::string str(
704-
(std::istreambuf_iterator<char>(is)),
705-
std::istreambuf_iterator<char>());
706-
LOG(INFO) << str;
707-
}
708-
}
703+
LLVMCodegen cg(specializedName, scop, targetMachine);
709704
return cg.halide_cg.move_module();
710705
}
711706

0 commit comments

Comments
 (0)