Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 9647fac

Browse files
Asm emission for the proper CPU target
This commit passes proper llvm::TargetMachine information in llvm_jit and codegen_llvm by introducing a proper TargetMachine at the LLVMJit level and avoids introducing adhoc objects. The TargetMachine is constructed either from the `--mcpu` flag if passed or from the `cpuid` information. As a consequence of all this, one can now emit AVX2 and AVX512 code. Before this commit, the TargetMachine was essentially a default one and only AVX code would be generated. To test and see it one can run with: ``` cd build && \ make -j 16 test_mapper_llvm && \ ./test/test_mapper_llvm --logtostderr=1 --llvm_dump_asm=1 --llvm_dump_after_opt=1 --llvm_dump_before_opt=1 --gtest_filter="*Batch*" --mcpu=skylake ``` Of course if one forces a more fancy architecture than one has, illegal instructions will likely be generated but at least the asm will be printed properly.
1 parent ccd3662 commit 9647fac

File tree

3 files changed

+62
-28
lines changed

3 files changed

+62
-28
lines changed

tc/core/polyhedral/codegen_llvm.cc

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "llvm/IR/IRBuilder.h"
2828
#include "llvm/IR/LegacyPassManager.h"
2929
#include "llvm/IR/Verifier.h"
30+
#include "llvm/Support/TargetRegistry.h"
3031
#include "llvm/Support/TargetSelect.h"
3132
#include "llvm/Support/raw_ostream.h"
3233
#include "llvm/Transforms/IPO.h"
@@ -99,6 +100,32 @@ std::vector<int64_t> getTensorSizesWithoutLeadingDim(
99100
return sizes;
100101
}
101102

103+
// Set some options, grabbed from Halide + we force fast math atm
104+
static llvm::TargetOptions makeTargetOptions() {
105+
bool use_soft_float_abi = false;
106+
bool per_instruction_fast_math_flags = true;
107+
108+
llvm::TargetOptions options;
109+
options.AllowFPOpFusion = per_instruction_fast_math_flags
110+
? llvm::FPOpFusion::Strict
111+
: llvm::FPOpFusion::Fast;
112+
options.UnsafeFPMath = !per_instruction_fast_math_flags;
113+
options.NoInfsFPMath = !per_instruction_fast_math_flags;
114+
options.NoNaNsFPMath = !per_instruction_fast_math_flags;
115+
options.HonorSignDependentRoundingFPMathOption =
116+
!per_instruction_fast_math_flags;
117+
options.NoZerosInBSS = false;
118+
options.GuaranteedTailCallOpt = false;
119+
options.StackAlignmentOverride = 0;
120+
options.FunctionSections = true;
121+
options.UseInitArray = false;
122+
options.FloatABIType =
123+
use_soft_float_abi ? llvm::FloatABI::Soft : llvm::FloatABI::Hard;
124+
options.RelaxELFRelocations = false;
125+
126+
return options;
127+
}
128+
102129
static constexpr int kOptLevel = 3;
103130

104131
class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
@@ -116,6 +143,7 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
116143
const char* llvm_args[] = {"tc (LLVM argument parsing)", nullptr};
117144
llvm::cl::ParseCommandLineOptions(
118145
sizeof(llvm_args) / sizeof(*llvm_args) - 1, llvm_args);
146+
119147
init_context();
120148
module =
121149
llvm::make_unique<llvm::Module>("TensorComprehensionsModule", *context);
@@ -198,33 +226,35 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
198226
}
199227

200228
public:
201-
void optimize_module() {
229+
void optimize_module(const llvm::TargetMachine& targetMachine) {
202230
LOG_IF(INFO, FLAGS_llvm_dump_before_opt)
203231
<< "[LLVM-IR] Before optimization:\n"
204232
<< toString(module.get());
205233

206-
llvm::legacy::FunctionPassManager functionPassManager(module.get());
207-
llvm::legacy::PassManager modulePassManager;
234+
std::unique_ptr<llvm::TargetMachine> targetMachineWithOptions(
235+
targetMachine.getTarget().createTargetMachine(
236+
targetMachine.getTargetTriple().str(),
237+
targetMachine.getTargetCPU(),
238+
targetMachine.getTargetFeatureString(),
239+
makeTargetOptions(),
240+
llvm::Reloc::PIC_,
241+
llvm::CodeModel::Small,
242+
llvm::CodeGenOpt::Aggressive));
208243

209-
std::unique_ptr<llvm::TargetMachine> targetMachine =
210-
Halide::Internal::make_target_machine(*module);
244+
llvm::legacy::PassManager modulePassManager;
211245
modulePassManager.add(llvm::createTargetTransformInfoWrapperPass(
212-
targetMachine ? targetMachine->getTargetIRAnalysis()
213-
: llvm::TargetIRAnalysis()));
246+
targetMachineWithOptions->getTargetIRAnalysis()));
247+
248+
llvm::legacy::FunctionPassManager functionPassManager(module.get());
214249
functionPassManager.add(llvm::createTargetTransformInfoWrapperPass(
215-
targetMachine ? targetMachine->getTargetIRAnalysis()
216-
: llvm::TargetIRAnalysis()));
250+
targetMachineWithOptions->getTargetIRAnalysis()));
217251

218252
llvm::PassManagerBuilder b;
219253
b.OptLevel = kOptLevel;
220254
b.Inliner = llvm::createFunctionInliningPass(b.OptLevel, 0, false);
221255
b.LoopVectorize = true;
222256
b.SLPVectorize = true;
223-
224-
if (targetMachine) {
225-
targetMachine->adjustPassManager(b);
226-
}
227-
257+
targetMachineWithOptions->adjustPassManager(b);
228258
b.populateFunctionPassManager(functionPassManager);
229259
b.populateModulePassManager(modulePassManager);
230260

@@ -233,7 +263,6 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
233263
for (llvm::Module::iterator i = module->begin(); i != module->end(); i++) {
234264
functionPassManager.run(*i);
235265
}
236-
237266
functionPassManager.doFinalization();
238267
modulePassManager.run(*module);
239268

@@ -329,10 +358,12 @@ class LLVMCodegen {
329358
LLVMCodegen(
330359
const Scop& scop,
331360
const IteratorMapsType& iteratorMaps,
332-
const StmtSubscriptExprMapType& stmtSubscripts)
361+
const StmtSubscriptExprMapType& stmtSubscripts,
362+
const llvm::TargetMachine& targetMachine)
333363
: scop_(scop),
334364
iteratorMaps_(iteratorMaps),
335365
stmtSubscripts_(stmtSubscripts),
366+
targetMachine(targetMachine),
336367
halide_cg(Halide::Target(
337368
Halide::Target::OSUnknown,
338369
Halide::Target::X86,
@@ -558,6 +589,7 @@ class LLVMCodegen {
558589
std::vector<std::string> argNames_;
559590

560591
public:
592+
const llvm::TargetMachine& targetMachine;
561593
CodeGen_TC halide_cg;
562594
};
563595

@@ -601,7 +633,7 @@ isl::ast_node collectIteratorMaps(
601633
return node.set_annotation(stmtId);
602634
}
603635

604-
IslCodegenRes codegenISL(const Scop& scop) {
636+
static IslCodegenRes codegenISL(const Scop& scop) {
605637
IteratorMapsType iteratorMaps;
606638
StmtSubscriptExprMapType stmtSubscripts;
607639
auto collect = [&iteratorMaps, &scop, &stmtSubscripts](
@@ -625,10 +657,10 @@ IslCodegenRes codegenISL(const Scop& scop) {
625657
std::unique_ptr<llvm::Module> emitLLVMKernel(
626658
const std::string& specializedName,
627659
const Scop& scop,
628-
const llvm::DataLayout& dataLayout) {
660+
const llvm::TargetMachine& targetMachine) {
629661
auto islCg = codegenISL(scop);
630-
LLVMCodegen cg(scop, islCg.iteratorMaps, islCg.stmtSubscripts);
631-
cg.halide_cg.get_module()->setDataLayout(dataLayout);
662+
LLVMCodegen cg(scop, islCg.iteratorMaps, islCg.stmtSubscripts, targetMachine);
663+
cg.halide_cg.get_module()->setDataLayout(targetMachine.createDataLayout());
632664
cg.halide_cg.get_module()->setTargetTriple(
633665
llvm::EngineBuilder().selectTarget()->getTargetTriple().str());
634666
auto entry = cg.createSignature(
@@ -643,7 +675,7 @@ std::unique_ptr<llvm::Module> emitLLVMKernel(
643675
TC_CHECK(!llvm::verifyModule(*cg.halide_cg.get_module()))
644676
<< "LLVM generated module is invalid." << cg.str().c_str();
645677

646-
cg.halide_cg.optimize_module();
678+
cg.halide_cg.optimize_module(cg.targetMachine);
647679
if (FLAGS_llvm_dump_asm) {
648680
std::string pat("/tmp/tcXXXXXX");
649681
std::vector<char> ifn(pat.begin(), pat.end());
@@ -662,7 +694,10 @@ std::unique_ptr<llvm::Module> emitLLVMKernel(
662694
}
663695
utils::checkedSystemCall(
664696
std::string(TC_STRINGIFY(TC_LLVM_BIN_DIR)) + "/llc",
665-
{FLAGS_llvm_dump_asm_options, utils::CPUID::llcFlags(), optFile, std::string("-o ") + asmFile});
697+
{FLAGS_llvm_dump_asm_options,
698+
utils::CPUID::llcFlags(),
699+
optFile,
700+
std::string("-o ") + asmFile});
666701
{
667702
std::ifstream is(asmFile);
668703
std::string str(
@@ -671,7 +706,6 @@ std::unique_ptr<llvm::Module> emitLLVMKernel(
671706
LOG(INFO) << str;
672707
}
673708
}
674-
675709
return cg.halide_cg.move_module();
676710
}
677711

tc/core/polyhedral/codegen_llvm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ struct Scop;
5555
std::unique_ptr<llvm::Module> emitLLVMKernel(
5656
const std::string& specializedName,
5757
const Scop& scop,
58-
const llvm::DataLayout& dataLayout);
58+
const llvm::TargetMachine& targetMachine);
5959

6060
// TODO: I want to do something like the following, but compilation was unhappy
6161
// using initialize_llvm = Halide::Internal::CodeGen_LLVM::initialize_llvm;

tc/core/polyhedral/llvm_jit.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@
3131
#include "tc/core/check.h"
3232
#include "tc/core/flags.h"
3333
#include "tc/core/polyhedral/codegen_llvm.h"
34+
#include "tc/core/utils/cpu.h"
3435

3536
using namespace llvm;
3637

3738
namespace tc {
38-
3939
Jit::Jit()
4040
: ES(),
4141
Resolver(llvm::orc::createLegacyLookupResolver(
@@ -51,7 +51,7 @@ Jit::Jit()
5151
return nullptr;
5252
},
5353
[](Error err) { throw std::runtime_error("Lookup failed!"); })),
54-
TM_(EngineBuilder().selectTarget()),
54+
TM_(EngineBuilder().setMCPU(utils::CPUID::mcpu()).selectTarget()),
5555
DL_(TM_->createDataLayout()),
5656
objectLayer_(
5757
ES,
@@ -71,8 +71,8 @@ void Jit::addModule(std::shared_ptr<Module> M) {
7171
std::shared_ptr<Module> Jit::codegenScop(
7272
const std::string& specializedName,
7373
const polyhedral::Scop& scop) {
74-
std::shared_ptr<Module> mod = emitLLVMKernel(
75-
specializedName, scop, getTargetMachine().createDataLayout());
74+
std::shared_ptr<Module> mod =
75+
emitLLVMKernel(specializedName, scop, getTargetMachine());
7676
addModule(mod);
7777
return mod;
7878
}

0 commit comments

Comments
 (0)