55
55
using namespace Halide ;
56
56
57
57
namespace tc {
58
-
59
58
namespace polyhedral {
60
-
59
+ namespace {
61
60
using IteratorMapType = std::unordered_map<std::string, isl::ast_expr>;
62
61
using IteratorMapsType =
63
62
std::unordered_map<isl::id, IteratorMapType, isl::IslIdIslHash>;
64
63
65
64
using StmtSubscriptExprMapType =
66
65
std::unordered_map<isl::id, std::vector<isl::ast_expr>, isl::IslIdIslHash>;
67
66
68
- namespace {
67
+ struct IslCodegenRes {
68
+ IteratorMapsType iteratorMaps;
69
+ StmtSubscriptExprMapType stmtSubscripts;
70
+ isl::ast_node astNode;
71
+ };
72
+
73
+ isl::ast_node collectIteratorMaps (
74
+ isl::ast_node node,
75
+ isl::ast_build build,
76
+ IteratorMapsType& iteratorMaps,
77
+ const Scop& scop,
78
+ StmtSubscriptExprMapType& stmtSubscripts) {
79
+ auto user = node.as <isl::ast_node_user>();
80
+ TC_CHECK (user);
81
+ auto expr = user.get_expr ().as <isl::ast_expr_op>();
82
+ auto schedule = build.get_schedule ();
83
+ auto scheduleMap = isl::map::from_union_map (schedule);
84
+
85
+ auto stmtId = expr.get_arg (0 ).as <isl::ast_expr_id>().get_id ();
86
+ TC_CHECK_EQ (0u , iteratorMaps.count (stmtId)) << " entry exists: " << stmtId;
87
+ auto iteratorMap = isl::pw_multi_aff (scheduleMap.reverse ());
88
+ auto tuple = scop.halide .domains .at (stmtId).tuple ;
89
+ auto & stmtIteratorMap = iteratorMaps[stmtId];
90
+ for (int i = 0 ; i < tuple.size (); ++i) {
91
+ auto expr = build.expr_from (iteratorMap.get_pw_aff (i));
92
+ stmtIteratorMap.emplace (tuple.get_id (i).get_name (), expr);
93
+ }
94
+ auto & subscripts = stmtSubscripts[stmtId];
95
+ auto provide =
96
+ scop.halide .statements .at (stmtId).as <Halide::Internal::Provide>();
97
+ for (auto e : provide->args ) {
98
+ const auto & map = iteratorMap;
99
+ auto aff = scop.makeIslAffFromStmtExpr (stmtId, e);
100
+ auto pulled = isl::pw_aff (aff).pullback (map);
101
+ TC_CHECK_EQ (pulled.n_piece (), 1 );
102
+ subscripts.push_back (build.expr_from (pulled));
103
+ }
104
+ return node.set_annotation (stmtId);
105
+ }
106
+
107
+ static IslCodegenRes codegenISL (const Scop& scop) {
108
+ IteratorMapsType iteratorMaps;
109
+ StmtSubscriptExprMapType stmtSubscripts;
110
+ auto collect = [&iteratorMaps, &scop, &stmtSubscripts](
111
+ isl::ast_node n, isl::ast_build b) -> isl::ast_node {
112
+ auto & uv = iteratorMaps;
113
+ return collectIteratorMaps (n, b, uv, scop, stmtSubscripts);
114
+ };
115
+
116
+ auto schedule = detail::toIslSchedule (scop.scheduleRoot ());
117
+ auto astBuild = isl::ast_build (schedule.get_ctx ());
118
+ astBuild = astBuild.set_at_each_domain (collect);
119
+ auto root = scop.scheduleRoot ();
120
+ astBuild = astBuild.set_iterators (Codegen::makeLoopIterators (root));
121
+ auto astNode = astBuild.node_from (schedule);
122
+ return {
123
+ std::move (iteratorMaps), std::move (stmtSubscripts), std::move (astNode)};
124
+ }
69
125
70
126
thread_local llvm::LLVMContext llvmCtx;
71
127
@@ -324,6 +380,71 @@ Halide::Expr CodeGen_TC::makeHalideExpr(isl::ast_expr expr) {
324
380
}
325
381
326
382
class LLVMCodegen {
383
+ public:
384
+ LLVMCodegen (
385
+ const std::string& specializedName,
386
+ const Scop& scop,
387
+ const llvm::TargetMachine& targetMachine)
388
+ : scop_(scop),
389
+ islCg_ (codegenISL(scop_)),
390
+ iteratorMaps_(islCg_.iteratorMaps),
391
+ stmtSubscripts_(islCg_.stmtSubscripts),
392
+ targetMachine(targetMachine),
393
+ // we don't use Halide to tinker with llvm::Module optimization so we
394
+ // tthe Halide target can be whatever.
395
+ halide_cg(Halide::get_host_target()) {
396
+ halide_cg.set_context (llvmCtx);
397
+ halide_cg.init_module ();
398
+ halide_cg.get_module ()->setDataLayout (targetMachine.createDataLayout ());
399
+ halide_cg.get_module ()->setTargetTriple (
400
+ targetMachine.getTargetTriple ().str ());
401
+ auto entry = createSignature (
402
+ scop.halide .inputs ,
403
+ scop.halide .outputs ,
404
+ scop.halide .params ,
405
+ specializedName);
406
+ auto exit = emitAst (islCg_.astNode , entry);
407
+ halide_cg.get_builder ().SetInsertPoint (exit);
408
+ halide_cg.get_builder ().CreateRetVoid ();
409
+
410
+ TC_CHECK (!llvm::verifyModule (*halide_cg.get_module ()))
411
+ << " LLVM generated module is invalid." << str ().c_str ();
412
+
413
+ halide_cg.optimize_module (targetMachine);
414
+
415
+ if (FLAGS_llvm_dump_asm) {
416
+ std::string pat (" /tmp/tcXXXXXX" );
417
+ std::vector<char > ifn (pat.begin (), pat.end ());
418
+ TC_CHECK_GE (mkstemp (ifn.data ()), 0 ); // string.c_str is const char*
419
+ std::string fileName (ifn.begin (), ifn.end ());
420
+ std::string optFile = fileName + " -opt.ll" ;
421
+ std::string asmFile = fileName + " .s" ;
422
+ // cstdio's std::remove to delete files
423
+ tc::ScopeGuard sgi ([&]() {
424
+ std::remove (optFile.c_str ());
425
+ std::remove (asmFile.c_str ());
426
+ });
427
+ {
428
+ std::ofstream ostream (optFile, std::ios::binary);
429
+ ostream << str ();
430
+ }
431
+ utils::checkedSystemCall (
432
+ std::string (TC_STRINGIFY (TC_LLVM_BIN_DIR)) + " /llc" ,
433
+ {FLAGS_llvm_dump_asm_options,
434
+ utils::CPUID::llcFlags (),
435
+ optFile,
436
+ std::string (" -o " ) + asmFile});
437
+
438
+ std::ifstream is (asmFile);
439
+ std::string str (
440
+ (std::istreambuf_iterator<char >(is)),
441
+ std::istreambuf_iterator<char >());
442
+ LOG (INFO) << " Dumping asm for: " << utils::CPUID::llcFlags () << " \n "
443
+ << str;
444
+ }
445
+ }
446
+
447
+ private:
327
448
void collectTensor (const Halide::OutputImageParam& t) {
328
449
auto sizes = getTensorSizesWithoutLeadingDim (t, scop_.context ());
329
450
if (not sizes.empty ()) {
@@ -354,23 +475,16 @@ class LLVMCodegen {
354
475
}
355
476
}
356
477
357
- public:
358
- LLVMCodegen (
359
- const Scop& scop,
360
- const IteratorMapsType& iteratorMaps,
361
- const StmtSubscriptExprMapType& stmtSubscripts,
362
- const llvm::TargetMachine& targetMachine)
363
- : scop_(scop),
364
- iteratorMaps_ (iteratorMaps),
365
- stmtSubscripts_(stmtSubscripts),
366
- targetMachine(targetMachine),
367
- halide_cg(Halide::Target(
368
- Halide::Target::OSUnknown,
369
- Halide::Target::X86,
370
- 64 )) {
371
- halide_cg.set_context (llvmCtx);
372
-
373
- halide_cg.init_module ();
478
+ llvm::Type* makePtrToArrayType (
479
+ llvm::Type* baseTy,
480
+ const std::vector<int64_t >& sizes) {
481
+ TC_CHECK_GE (sizes.size (), 1u );
482
+ TC_CHECK (baseTy);
483
+ llvm::Type* arrTy = llvm::ArrayType::get (baseTy, sizes.back ());
484
+ for (auto s = sizes.rbegin () + 1 ; s != sizes.rend (); ++s) {
485
+ arrTy = llvm::ArrayType::get (arrTy, *s);
486
+ }
487
+ return arrTy->getPointerTo ();
374
488
}
375
489
376
490
// This creates a signature of the form:
@@ -451,19 +565,6 @@ class LLVMCodegen {
451
565
return nullptr ;
452
566
}
453
567
454
- private:
455
- llvm::Type* makePtrToArrayType (
456
- llvm::Type* baseTy,
457
- const std::vector<int64_t >& sizes) {
458
- TC_CHECK_GE (sizes.size (), 1u );
459
- TC_CHECK (baseTy);
460
- llvm::Type* arrTy = llvm::ArrayType::get (baseTy, sizes.back ());
461
- for (auto s = sizes.rbegin () + 1 ; s != sizes.rend (); ++s) {
462
- arrTy = llvm::ArrayType::get (arrTy, *s);
463
- }
464
- return arrTy->getPointerTo ();
465
- }
466
-
467
568
llvm::BasicBlock* emitIf (
468
569
isl::ast_node_if node,
469
570
llvm::BasicBlock* insertionPoint) {
@@ -582,6 +683,7 @@ class LLVMCodegen {
582
683
583
684
private:
584
685
const Scop& scop_;
686
+ const IslCodegenRes islCg_;
585
687
const IteratorMapsType& iteratorMaps_;
586
688
const StmtSubscriptExprMapType& stmtSubscripts_;
587
689
@@ -592,120 +694,13 @@ class LLVMCodegen {
592
694
const llvm::TargetMachine& targetMachine;
593
695
CodeGen_TC halide_cg;
594
696
};
595
-
596
- struct IslCodegenRes {
597
- IteratorMapsType iteratorMaps;
598
- StmtSubscriptExprMapType stmtSubscripts;
599
- isl::ast_node astNode;
600
- };
601
-
602
- isl::ast_node collectIteratorMaps (
603
- isl::ast_node node,
604
- isl::ast_build build,
605
- IteratorMapsType& iteratorMaps,
606
- const Scop& scop,
607
- StmtSubscriptExprMapType& stmtSubscripts) {
608
- auto user = node.as <isl::ast_node_user>();
609
- TC_CHECK (user);
610
- auto expr = user.get_expr ().as <isl::ast_expr_op>();
611
- auto schedule = build.get_schedule ();
612
- auto scheduleMap = isl::map::from_union_map (schedule);
613
-
614
- auto stmtId = expr.get_arg (0 ).as <isl::ast_expr_id>().get_id ();
615
- TC_CHECK_EQ (0u , iteratorMaps.count (stmtId)) << " entry exists: " << stmtId;
616
- auto iteratorMap = isl::pw_multi_aff (scheduleMap.reverse ());
617
- auto tuple = scop.halide .domains .at (stmtId).tuple ;
618
- auto & stmtIteratorMap = iteratorMaps[stmtId];
619
- for (int i = 0 ; i < tuple.size (); ++i) {
620
- auto expr = build.expr_from (iteratorMap.get_pw_aff (i));
621
- stmtIteratorMap.emplace (tuple.get_id (i).get_name (), expr);
622
- }
623
- auto & subscripts = stmtSubscripts[stmtId];
624
- auto provide =
625
- scop.halide .statements .at (stmtId).as <Halide::Internal::Provide>();
626
- for (auto e : provide->args ) {
627
- const auto & map = iteratorMap;
628
- auto aff = scop.makeIslAffFromStmtExpr (stmtId, e);
629
- auto pulled = isl::pw_aff (aff).pullback (map);
630
- TC_CHECK_EQ (pulled.n_piece (), 1 );
631
- subscripts.push_back (build.expr_from (pulled));
632
- }
633
- return node.set_annotation (stmtId);
634
- }
635
-
636
- static IslCodegenRes codegenISL (const Scop& scop) {
637
- IteratorMapsType iteratorMaps;
638
- StmtSubscriptExprMapType stmtSubscripts;
639
- auto collect = [&iteratorMaps, &scop, &stmtSubscripts](
640
- isl::ast_node n, isl::ast_build b) -> isl::ast_node {
641
- auto & uv = iteratorMaps;
642
- return collectIteratorMaps (n, b, uv, scop, stmtSubscripts);
643
- };
644
-
645
- auto schedule = detail::toIslSchedule (scop.scheduleRoot ());
646
- auto astBuild = isl::ast_build (schedule.get_ctx ());
647
- astBuild = astBuild.set_at_each_domain (collect);
648
- auto root = scop.scheduleRoot ();
649
- astBuild = astBuild.set_iterators (Codegen::makeLoopIterators (root));
650
- auto astNode = astBuild.node_from (schedule);
651
- return {
652
- std::move (iteratorMaps), std::move (stmtSubscripts), std::move (astNode)};
653
- }
654
-
655
697
} // namespace
656
698
657
699
std::unique_ptr<llvm::Module> emitLLVMKernel (
658
700
const std::string& specializedName,
659
701
const Scop& scop,
660
702
const llvm::TargetMachine& targetMachine) {
661
- auto islCg = codegenISL (scop);
662
- LLVMCodegen cg (scop, islCg.iteratorMaps , islCg.stmtSubscripts , targetMachine);
663
- cg.halide_cg .get_module ()->setDataLayout (targetMachine.createDataLayout ());
664
- cg.halide_cg .get_module ()->setTargetTriple (
665
- llvm::EngineBuilder ().selectTarget ()->getTargetTriple ().str ());
666
- auto entry = cg.createSignature (
667
- scop.halide .inputs ,
668
- scop.halide .outputs ,
669
- scop.halide .params ,
670
- specializedName);
671
- auto exit = cg.emitAst (islCg.astNode , entry);
672
- cg.halide_cg .get_builder ().SetInsertPoint (exit);
673
- cg.halide_cg .get_builder ().CreateRetVoid ();
674
-
675
- TC_CHECK (!llvm::verifyModule (*cg.halide_cg .get_module ()))
676
- << " LLVM generated module is invalid." << cg.str ().c_str ();
677
-
678
- cg.halide_cg .optimize_module (cg.targetMachine );
679
- if (FLAGS_llvm_dump_asm) {
680
- std::string pat (" /tmp/tcXXXXXX" );
681
- std::vector<char > ifn (pat.begin (), pat.end ());
682
- TC_CHECK_GE (mkstemp (ifn.data ()), 0 ); // string.c_str is const char*
683
- std::string fileName (ifn.begin (), ifn.end ());
684
- std::string optFile = fileName + " -opt.ll" ;
685
- std::string asmFile = fileName + " .s" ;
686
- // cstdio's std::remove to delete files
687
- tc::ScopeGuard sgi ([&]() {
688
- std::remove (optFile.c_str ());
689
- std::remove (asmFile.c_str ());
690
- });
691
- {
692
- std::ofstream ostream (optFile, std::ios::binary);
693
- ostream << cg.str ();
694
- }
695
- utils::checkedSystemCall (
696
- std::string (TC_STRINGIFY (TC_LLVM_BIN_DIR)) + " /llc" ,
697
- {FLAGS_llvm_dump_asm_options,
698
- utils::CPUID::llcFlags (),
699
- optFile,
700
- std::string (" -o " ) + asmFile});
701
- {
702
- std::ifstream is (asmFile);
703
- std::string str (
704
- (std::istreambuf_iterator<char >(is)),
705
- std::istreambuf_iterator<char >());
706
- LOG (INFO) << str;
707
- }
708
- }
703
+ LLVMCodegen cg (specializedName, scop, targetMachine);
709
704
return cg.halide_cg .move_module ();
710
705
}
711
706
0 commit comments