20
20
21
21
#include " llvm/ADT/STLExtras.h"
22
22
#include " llvm/ADT/SmallVector.h"
23
+ #include " llvm/Analysis/TargetTransformInfo.h"
23
24
#include " llvm/ExecutionEngine/ExecutionEngine.h"
24
25
#include " llvm/IR/IRBuilder.h"
26
+ #include " llvm/IR/LegacyPassManager.h"
25
27
#include " llvm/IR/Verifier.h"
26
28
#include " llvm/Support/TargetSelect.h"
27
29
#include " llvm/Support/raw_ostream.h"
30
+ #include " llvm/Transforms/IPO/PassManagerBuilder.h"
31
+ #include " llvm/Transforms/IPO.h"
32
+ #include " llvm/Transforms/Tapir/CilkABI.h"
28
33
29
34
#include " Halide/Halide.h"
30
35
@@ -202,7 +207,6 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
202
207
203
208
using CodeGen_X86::codegen;
204
209
using CodeGen_X86::llvm_type_of;
205
- using CodeGen_X86::optimize_module;
206
210
using CodeGen_X86::sym_get;
207
211
using CodeGen_X86::sym_pop;
208
212
using CodeGen_X86::sym_push;
@@ -294,6 +298,81 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
294
298
295
299
value = sym_get (name);
296
300
}
301
+ public:
302
+ void optimize_module () {
303
+ Halide::Internal::debug (3 ) << " Optimizing module\n " ;
304
+
305
+ if (Halide::Internal::debug::debug_level () >= 3 ) {
306
+ #if LLVM_VERSION >= 50
307
+ module ->print (dbgs (), nullptr , false , true );
308
+ #else
309
+ module ->dump ();
310
+ #endif
311
+ }
312
+
313
+ // We override PassManager::add so that we have an opportunity to
314
+ // blacklist problematic LLVM passes.
315
+ class MyFunctionPassManager : public llvm ::legacy::FunctionPassManager {
316
+ public:
317
+ MyFunctionPassManager (llvm::Module *m) : llvm::legacy::FunctionPassManager(m) {}
318
+ virtual void add (llvm::Pass *p) override {
319
+ Halide::Internal::debug (2 ) << " Adding function pass: " << p->getPassName ().str () << " \n " ;
320
+ llvm::legacy::FunctionPassManager::add (p);
321
+ }
322
+ };
323
+
324
+ class MyModulePassManager : public llvm ::legacy::PassManager {
325
+ public:
326
+ virtual void add (llvm::Pass *p) override {
327
+ Halide::Internal::debug (2 ) << " Adding module pass: " << p->getPassName ().str () << " \n " ;
328
+ llvm::legacy::PassManager::add (p);
329
+ }
330
+ };
331
+
332
+ MyFunctionPassManager function_pass_manager (module .get ());
333
+ MyModulePassManager module_pass_manager;
334
+
335
+ std::unique_ptr<llvm::TargetMachine> TM = Halide::Internal::make_target_machine (*module );
336
+ module_pass_manager.add (llvm::createTargetTransformInfoWrapperPass (TM ? TM->getTargetIRAnalysis () : llvm::TargetIRAnalysis ()));
337
+ function_pass_manager.add (llvm::createTargetTransformInfoWrapperPass (TM ? TM->getTargetIRAnalysis () : llvm::TargetIRAnalysis ()));
338
+
339
+ llvm::PassManagerBuilder b;
340
+ b.OptLevel = 3 ;
341
+ b.tapirTarget = new llvm::tapir::CilkABI ();
342
+ #if LLVM_VERSION >= 50
343
+ b.Inliner = llvm::createFunctionInliningPass (b.OptLevel , 0 , false );
344
+ #else
345
+ b.Inliner = llvm::createFunctionInliningPass (b.OptLevel , 0 );
346
+ #endif
347
+ b.LoopVectorize = true ;
348
+ b.SLPVectorize = true ;
349
+
350
+ #if LLVM_VERSION >= 50
351
+ if (TM) {
352
+ TM->adjustPassManager (b);
353
+ }
354
+ #endif
355
+
356
+ b.populateFunctionPassManager (function_pass_manager);
357
+ b.populateModulePassManager (module_pass_manager);
358
+
359
+ // Run optimization passes
360
+ function_pass_manager.doInitialization ();
361
+ for (llvm::Module::iterator i = module ->begin (); i != module ->end (); i++) {
362
+ function_pass_manager.run (*i);
363
+ }
364
+ function_pass_manager.doFinalization ();
365
+ module_pass_manager.run (*module );
366
+
367
+ Halide::Internal::debug (3 ) << " After LLVM optimizations:\n " ;
368
+ if (Halide::Internal::debug::debug_level () >= 2 ) {
369
+ #if LLVM_VERSION >= 50
370
+ module ->print (dbgs (), nullptr , false , true );
371
+ #else
372
+ module ->dump ();
373
+ #endif
374
+ }
375
+ }
297
376
};
298
377
299
378
class LLVMCodegen {
@@ -451,6 +530,17 @@ class LLVMCodegen {
451
530
llvm::BasicBlock::Create (llvmCtx, " loop_latch" , function);
452
531
auto * loopExitBB = llvm::BasicBlock::Create (llvmCtx, " loop_exit" , function);
453
532
533
+ bool parallel = true ;
534
+
535
+ llvm::Value* SyncRegion = nullptr ;
536
+ if (parallel) {
537
+ SyncRegion = halide_cg.get_builder ().CreateCall (
538
+ llvm::Intrinsic::getDeclaration (function->getParent (), llvm::Intrinsic::syncregion_start),
539
+ {},
540
+ " syncreg"
541
+ );
542
+ }
543
+
454
544
halide_cg.get_builder ().CreateBr (headerBB);
455
545
456
546
llvm::PHINode* phi = nullptr ;
@@ -498,9 +588,20 @@ class LLVMCodegen {
498
588
// Create Body
499
589
{
500
590
halide_cg.get_builder ().SetInsertPoint (loopBodyBB);
591
+
592
+ if (parallel) {
593
+ auto * detachedBB = llvm::BasicBlock::Create (llvmCtx, " det.achd" , function);
594
+ halide_cg.get_builder ().CreateDetach (detachedBB, loopLatchBB, SyncRegion);
595
+ halide_cg.get_builder ().SetInsertPoint (detachedBB);
596
+ }
501
597
auto * currentBB = emitAst (node.for_get_body ());
502
598
halide_cg.get_builder ().SetInsertPoint (currentBB);
503
- halide_cg.get_builder ().CreateBr (loopLatchBB);
599
+
600
+ if (parallel) {
601
+ halide_cg.get_builder ().CreateReattach (loopLatchBB, SyncRegion);
602
+ } else {
603
+ halide_cg.get_builder ().CreateBr (loopLatchBB);
604
+ }
504
605
}
505
606
506
607
// Create Latch
@@ -516,6 +617,11 @@ class LLVMCodegen {
516
617
517
618
halide_cg.get_builder ().SetInsertPoint (loopExitBB);
518
619
halide_cg.sym_pop (node.for_get_iterator ().get_id ().get_name ());
620
+ if (parallel) {
621
+ auto * syncBB = llvm::BasicBlock::Create (llvmCtx, " synced" , function);
622
+ halide_cg.get_builder ().CreateSync (syncBB, SyncRegion);
623
+ halide_cg.get_builder ().SetInsertPoint (syncBB);
624
+ }
519
625
return halide_cg.get_builder ().GetInsertBlock ();
520
626
}
521
627
0 commit comments