27
27
#include " llvm/IR/IRBuilder.h"
28
28
#include " llvm/IR/LegacyPassManager.h"
29
29
#include " llvm/IR/Verifier.h"
30
+ #include " llvm/Support/TargetRegistry.h"
30
31
#include " llvm/Support/TargetSelect.h"
31
32
#include " llvm/Support/raw_ostream.h"
32
33
#include " llvm/Transforms/IPO.h"
@@ -99,6 +100,32 @@ std::vector<int64_t> getTensorSizesWithoutLeadingDim(
99
100
return sizes;
100
101
}
101
102
103
+ // Set some options, grabbed from Halide + we force fast math atm
104
+ static llvm::TargetOptions makeTargetOptions () {
105
+ bool use_soft_float_abi = false ;
106
+ bool per_instruction_fast_math_flags = true ;
107
+
108
+ llvm::TargetOptions options;
109
+ options.AllowFPOpFusion = per_instruction_fast_math_flags
110
+ ? llvm::FPOpFusion::Strict
111
+ : llvm::FPOpFusion::Fast;
112
+ options.UnsafeFPMath = !per_instruction_fast_math_flags;
113
+ options.NoInfsFPMath = !per_instruction_fast_math_flags;
114
+ options.NoNaNsFPMath = !per_instruction_fast_math_flags;
115
+ options.HonorSignDependentRoundingFPMathOption =
116
+ !per_instruction_fast_math_flags;
117
+ options.NoZerosInBSS = false ;
118
+ options.GuaranteedTailCallOpt = false ;
119
+ options.StackAlignmentOverride = 0 ;
120
+ options.FunctionSections = true ;
121
+ options.UseInitArray = false ;
122
+ options.FloatABIType =
123
+ use_soft_float_abi ? llvm::FloatABI::Soft : llvm::FloatABI::Hard;
124
+ options.RelaxELFRelocations = false ;
125
+
126
+ return options;
127
+ }
128
+
102
129
static constexpr int kOptLevel = 3 ;
103
130
104
131
class CodeGen_TC : public Halide ::Internal::CodeGen_X86 {
@@ -116,6 +143,7 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
116
143
const char * llvm_args[] = {" tc (LLVM argument parsing)" , nullptr };
117
144
llvm::cl::ParseCommandLineOptions (
118
145
sizeof (llvm_args) / sizeof (*llvm_args) - 1 , llvm_args);
146
+
119
147
init_context ();
120
148
module =
121
149
llvm::make_unique<llvm::Module>(" TensorComprehensionsModule" , *context);
@@ -198,33 +226,35 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
198
226
}
199
227
200
228
public:
201
- void optimize_module () {
229
+ void optimize_module (const llvm::TargetMachine& targetMachine ) {
202
230
LOG_IF (INFO, FLAGS_llvm_dump_before_opt)
203
231
<< " [LLVM-IR] Before optimization:\n "
204
232
<< toString (module .get ());
205
233
206
- llvm::legacy::FunctionPassManager functionPassManager (module .get ());
207
- llvm::legacy::PassManager modulePassManager;
234
+ std::unique_ptr<llvm::TargetMachine> targetMachineWithOptions (
235
+ targetMachine.getTarget ().createTargetMachine (
236
+ targetMachine.getTargetTriple ().str (),
237
+ targetMachine.getTargetCPU (),
238
+ targetMachine.getTargetFeatureString (),
239
+ makeTargetOptions (),
240
+ llvm::Reloc::PIC_,
241
+ llvm::CodeModel::Small,
242
+ llvm::CodeGenOpt::Aggressive));
208
243
209
- std::unique_ptr<llvm::TargetMachine> targetMachine =
210
- Halide::Internal::make_target_machine (*module );
244
+ llvm::legacy::PassManager modulePassManager;
211
245
modulePassManager.add (llvm::createTargetTransformInfoWrapperPass (
212
- targetMachine ? targetMachine->getTargetIRAnalysis ()
213
- : llvm::TargetIRAnalysis ()));
246
+ targetMachineWithOptions->getTargetIRAnalysis ()));
247
+
248
+ llvm::legacy::FunctionPassManager functionPassManager (module .get ());
214
249
functionPassManager.add (llvm::createTargetTransformInfoWrapperPass (
215
- targetMachine ? targetMachine->getTargetIRAnalysis ()
216
- : llvm::TargetIRAnalysis ()));
250
+ targetMachineWithOptions->getTargetIRAnalysis ()));
217
251
218
252
llvm::PassManagerBuilder b;
219
253
b.OptLevel = kOptLevel ;
220
254
b.Inliner = llvm::createFunctionInliningPass (b.OptLevel , 0 , false );
221
255
b.LoopVectorize = true ;
222
256
b.SLPVectorize = true ;
223
-
224
- if (targetMachine) {
225
- targetMachine->adjustPassManager (b);
226
- }
227
-
257
+ targetMachineWithOptions->adjustPassManager (b);
228
258
b.populateFunctionPassManager (functionPassManager);
229
259
b.populateModulePassManager (modulePassManager);
230
260
@@ -233,7 +263,6 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
233
263
for (llvm::Module::iterator i = module ->begin (); i != module ->end (); i++) {
234
264
functionPassManager.run (*i);
235
265
}
236
-
237
266
functionPassManager.doFinalization ();
238
267
modulePassManager.run (*module );
239
268
@@ -329,10 +358,12 @@ class LLVMCodegen {
329
358
LLVMCodegen (
330
359
const Scop& scop,
331
360
const IteratorMapsType& iteratorMaps,
332
- const StmtSubscriptExprMapType& stmtSubscripts)
361
+ const StmtSubscriptExprMapType& stmtSubscripts,
362
+ const llvm::TargetMachine& targetMachine)
333
363
: scop_(scop),
334
364
iteratorMaps_ (iteratorMaps),
335
365
stmtSubscripts_(stmtSubscripts),
366
+ targetMachine(targetMachine),
336
367
halide_cg(Halide::Target(
337
368
Halide::Target::OSUnknown,
338
369
Halide::Target::X86,
@@ -558,6 +589,7 @@ class LLVMCodegen {
558
589
std::vector<std::string> argNames_;
559
590
560
591
public:
592
+ const llvm::TargetMachine& targetMachine;
561
593
CodeGen_TC halide_cg;
562
594
};
563
595
@@ -601,7 +633,7 @@ isl::ast_node collectIteratorMaps(
601
633
return node.set_annotation (stmtId);
602
634
}
603
635
604
- IslCodegenRes codegenISL (const Scop& scop) {
636
+ static IslCodegenRes codegenISL (const Scop& scop) {
605
637
IteratorMapsType iteratorMaps;
606
638
StmtSubscriptExprMapType stmtSubscripts;
607
639
auto collect = [&iteratorMaps, &scop, &stmtSubscripts](
@@ -625,10 +657,10 @@ IslCodegenRes codegenISL(const Scop& scop) {
625
657
std::unique_ptr<llvm::Module> emitLLVMKernel (
626
658
const std::string& specializedName,
627
659
const Scop& scop,
628
- const llvm::DataLayout& dataLayout ) {
660
+ const llvm::TargetMachine& targetMachine ) {
629
661
auto islCg = codegenISL (scop);
630
- LLVMCodegen cg (scop, islCg.iteratorMaps , islCg.stmtSubscripts );
631
- cg.halide_cg .get_module ()->setDataLayout (dataLayout );
662
+ LLVMCodegen cg (scop, islCg.iteratorMaps , islCg.stmtSubscripts , targetMachine );
663
+ cg.halide_cg .get_module ()->setDataLayout (targetMachine. createDataLayout () );
632
664
cg.halide_cg .get_module ()->setTargetTriple (
633
665
llvm::EngineBuilder ().selectTarget ()->getTargetTriple ().str ());
634
666
auto entry = cg.createSignature (
@@ -643,7 +675,7 @@ std::unique_ptr<llvm::Module> emitLLVMKernel(
643
675
TC_CHECK (!llvm::verifyModule (*cg.halide_cg .get_module ()))
644
676
<< " LLVM generated module is invalid." << cg.str ().c_str ();
645
677
646
- cg.halide_cg .optimize_module ();
678
+ cg.halide_cg .optimize_module (cg. targetMachine );
647
679
if (FLAGS_llvm_dump_asm) {
648
680
std::string pat (" /tmp/tcXXXXXX" );
649
681
std::vector<char > ifn (pat.begin (), pat.end ());
@@ -662,7 +694,10 @@ std::unique_ptr<llvm::Module> emitLLVMKernel(
662
694
}
663
695
utils::checkedSystemCall (
664
696
std::string (TC_STRINGIFY (TC_LLVM_BIN_DIR)) + " /llc" ,
665
- {FLAGS_llvm_dump_asm_options, utils::CPUID::llcFlags (), optFile, std::string (" -o " ) + asmFile});
697
+ {FLAGS_llvm_dump_asm_options,
698
+ utils::CPUID::llcFlags (),
699
+ optFile,
700
+ std::string (" -o " ) + asmFile});
666
701
{
667
702
std::ifstream is (asmFile);
668
703
std::string str (
@@ -671,7 +706,6 @@ std::unique_ptr<llvm::Module> emitLLVMKernel(
671
706
LOG (INFO) << str;
672
707
}
673
708
}
674
-
675
709
return cg.halide_cg .move_module ();
676
710
}
677
711
0 commit comments