36
36
#include " isl/ast.h"
37
37
38
38
#include " tc/core/constants.h"
39
- // #include "tc/core/polyhedral/isl_mu_wrappers.h"
40
39
#include " tc/core/flags.h"
41
40
#include " tc/core/polyhedral/codegen.h"
42
41
#include " tc/core/polyhedral/schedule_isl_conversion.h"
43
42
#include " tc/core/polyhedral/scop.h"
44
43
#include " tc/core/scope_guard.h"
44
+ #include " tc/external/isl.h"
45
45
46
46
#ifndef LLVM_VERSION_MAJOR
47
47
#error LLVM_VERSION_MAJOR not set
@@ -76,10 +76,9 @@ namespace {
76
76
thread_local llvm::LLVMContext llvmCtx;
77
77
78
78
int64_t toSInt (isl::val v) {
79
- auto n = v.get_num_si ();
80
- auto d = v.get_den_si ();
81
- CHECK_EQ (n % d, 0 );
82
- return n / d;
79
+ CHECK (v.is_int ());
80
+ static_assert (sizeof (long ) <= 8 , " long is assumed to fit into 64bits" );
81
+ return v.get_num_si ();
83
82
}
84
83
85
84
llvm::Value* getLLVMConstantSignedInt64 (int64_t v) {
@@ -88,25 +87,16 @@ llvm::Value* getLLVMConstantSignedInt64(int64_t v) {
88
87
89
88
int64_t IslExprToSInt (isl::ast_expr e) {
90
89
CHECK (isl_ast_expr_get_type (e.get ()) == isl_ast_expr_type::isl_ast_expr_int);
91
- assert (sizeof (long ) <= 8 ); // long is assumed to fit to 64bits
92
90
return toSInt (isl::manage (isl_ast_expr_get_val (e.get ())));
93
91
}
94
92
95
93
int64_t islIdToInt (isl::ast_expr e, isl::set context) {
96
94
CHECK (isl_ast_expr_get_type (e.get ()) == isl_ast_expr_type::isl_ast_expr_id);
97
- CHECK_NE (-1 , context.find_dim_by_id (isl::dim_type::param, e.get_id ()));
98
- while (context.dim (isl::dim_type::param) > 1 ) {
99
- for (unsigned int d = 0 ; d < context.dim (isl::dim_type::param); ++d) {
100
- if (d == context.find_dim_by_id (isl::dim_type::param, e.get_id ())) {
101
- continue ;
102
- }
103
- context = context.remove_dims (isl::dim_type::param, d, 1 );
104
- }
105
- }
95
+ auto space = context.get_space ();
96
+ isl::aff param (isl::aff::param_on_domain_space (space, e.get_id ()));
106
97
auto p = context.sample_point ();
107
-
108
- auto val = toSInt (p.get_coordinate_val (isl::dim_type::param, 0 ));
109
- return val;
98
+ CHECK (context.is_equal (p));
99
+ return toSInt (param.eval (p));
110
100
}
111
101
112
102
int64_t getTensorSize (isl::set context, const Halide::Expr& e) {
@@ -319,8 +309,7 @@ llvm::Value* CodeGen_TC::getValue(isl::ast_expr expr) {
319
309
return sym_get (expr.get_id ().get_name ());
320
310
case isl_ast_expr_type::isl_ast_expr_int: {
321
311
auto val = isl::manage (isl_ast_expr_get_val (expr.get ()));
322
- CHECK (val.is_int ());
323
- return getLLVMConstantSignedInt64 (val.get_num_si ());
312
+ return getLLVMConstantSignedInt64 (toSInt (val));
324
313
}
325
314
default :
326
315
LOG (FATAL) << " NYI" ;
@@ -497,16 +486,15 @@ class LLVMCodegen {
497
486
halide_cg.get_builder ().CreateBr (headerBB);
498
487
499
488
llvm::PHINode* phi = nullptr ;
489
+ auto iterator = node.get_iterator ().get_id ();
500
490
501
491
// Loop Header
502
492
{
503
493
auto initVal = IslExprToSInt (node.get_init ());
504
494
halide_cg.get_builder ().SetInsertPoint (headerBB);
505
495
phi = halide_cg.get_builder ().CreatePHI (
506
- llvm::Type::getInt64Ty (llvmCtx),
507
- 2 ,
508
- node.get_iterator ().get_id ().get_name ());
509
- halide_cg.sym_push (node.get_iterator ().get_id ().get_name (), phi);
496
+ llvm::Type::getInt64Ty (llvmCtx), 2 , iterator.get_name ());
497
+ halide_cg.sym_push (iterator.get_name (), phi);
510
498
phi->addIncoming (getLLVMConstantSignedInt64 (initVal), incoming);
511
499
512
500
auto cond_expr = node.get_cond ();
@@ -518,7 +506,7 @@ class LLVMCodegen {
518
506
CHECK (
519
507
isl_ast_expr_get_type (condLHS.get ()) ==
520
508
isl_ast_expr_type::isl_ast_expr_id);
521
- CHECK_EQ (condLHS.get_id (), node. get_iterator (). get_id () );
509
+ CHECK_EQ (condLHS.get_id (), iterator );
522
510
523
511
IslAstExprInterpeter i (scop_.globalParameterContext );
524
512
auto condRHSVal = i.interpret (cond_expr.get_op_arg (1 ));
@@ -575,7 +563,7 @@ class LLVMCodegen {
575
563
}
576
564
577
565
halide_cg.get_builder ().SetInsertPoint (loopExitBB);
578
- halide_cg.sym_pop (node. get_iterator (). get_id () .get_name ());
566
+ halide_cg.sym_pop (iterator .get_name ());
579
567
#ifdef TAPIR_VERSION_MAJOR
580
568
if (parallel) {
581
569
auto * syncBB = llvm::BasicBlock::Create (llvmCtx, " synced" , function);
@@ -652,9 +640,6 @@ IslCodegenRes codegenISL(const Scop& scop) {
652
640
auto scheduleMap = isl::map::from_union_map (schedule);
653
641
654
642
auto stmtId = expr.get_op_arg (0 ).get_id ();
655
- // auto nodeId = isl::id(
656
- // node.get_ctx(),
657
- // std::string(kAstNodeIdPrefix) + std::to_string(nAstNodes()++));
658
643
CHECK_EQ (0 , iteratorMaps.count (stmtId)) << " entry exists: " << stmtId;
659
644
auto iteratorMap = isl::pw_multi_aff (scheduleMap.reverse ());
660
645
auto iterators = scop.halide .iterators .at (stmtId);
0 commit comments