33
33
34
34
#include " Halide.h"
35
35
36
- #include " isl/ast.h"
37
-
38
36
#include " tc/core/constants.h"
39
37
#include " tc/core/flags.h"
40
38
#include " tc/core/halide2isl.h"
@@ -83,12 +81,12 @@ llvm::Value* getLLVMConstantSignedInt64(int64_t v) {
83
81
}
84
82
85
83
int64_t IslExprToSInt (isl::ast_expr e) {
86
- CHECK (isl_ast_expr_get_type (e.get ()) == isl_ast_expr_type::isl_ast_expr_int);
87
- return toSInt (isl::manage (isl_ast_expr_get_val (e.get ())));
84
+ auto intExpr = e.as <isl::ast_expr_int>();
85
+ CHECK (intExpr);
86
+ return toSInt (intExpr.get_val ());
88
87
}
89
88
90
- int64_t islIdToInt (isl::ast_expr e, isl::set context) {
91
- CHECK (isl_ast_expr_get_type (e.get ()) == isl_ast_expr_type::isl_ast_expr_id);
89
+ int64_t islIdToInt (isl::ast_expr_id e, isl::set context) {
92
90
auto space = context.get_space ();
93
91
isl::aff param (isl::aff::param_on_domain_space (space, e.get_id ()));
94
92
auto p = context.sample_point ();
@@ -127,22 +125,21 @@ class IslAstExprInterpeter {
127
125
IslAstExprInterpeter (isl::set context) : context_(context){};
128
126
129
127
int64_t interpret (isl::ast_expr e) {
130
- switch (isl_ast_expr_get_type (e.get ())) {
131
- case isl_ast_expr_type::isl_ast_expr_int:
132
- return IslExprToSInt (e);
133
- case isl_ast_expr_type::isl_ast_expr_id:
134
- return islIdToInt (e, context_);
135
- case isl_ast_expr_type::isl_ast_expr_op:
136
- return interpretOp (e);
137
- default :
138
- CHECK (false ) << " NYI" ;
139
- return 0 ; // avoid warning
128
+ if (auto intExpr = e.as <isl::ast_expr_int>()) {
129
+ return IslExprToSInt (intExpr);
130
+ } else if (auto idExpr = e.as <isl::ast_expr_id>()) {
131
+ return islIdToInt (idExpr, context_);
132
+ } else if (auto opExpr = e.as <isl::ast_expr_op>()) {
133
+ return interpretOp (opExpr);
134
+ } else {
135
+ CHECK (false ) << " NYI" ;
136
+ return 0 ; // avoid warning
140
137
}
141
138
};
142
139
143
140
private:
144
- int64_t interpretOp (isl::ast_expr e) {
145
- switch (e.get_op_n_arg ()) {
141
+ int64_t interpretOp (isl::ast_expr_op e) {
142
+ switch (e.get_n_arg ()) {
146
143
case 1 :
147
144
return interpretUnaryOp (e);
148
145
case 2 :
@@ -153,28 +150,26 @@ class IslAstExprInterpeter {
153
150
}
154
151
}
155
152
156
- int64_t interpretBinaryOp (isl::ast_expr e) {
157
- auto left = interpret (e.get_op_arg (0 ));
158
- auto right = interpret (e.get_op_arg (1 ));
159
- switch (e.get_op_type ()) {
160
- case isl::ast_op_type::add:
161
- return left + right;
162
- case isl::ast_op_type::sub:
163
- return left - right;
164
- default :
165
- CHECK (false ) << " NYI: " << e;
166
- return 0 ; // avoid warning
153
+ int64_t interpretBinaryOp (isl::ast_expr_op e) {
154
+ auto left = interpret (e.get_arg (0 ));
155
+ auto right = interpret (e.get_arg (1 ));
156
+ if (e.as <isl::ast_op_add>()) {
157
+ return left + right;
158
+ } else if (e.as <isl::ast_op_sub>()) {
159
+ return left - right;
160
+ } else {
161
+ CHECK (false ) << " NYI: " << e;
162
+ return 0 ; // avoid warning
167
163
}
168
164
}
169
165
170
- int64_t interpretUnaryOp (isl::ast_expr e) {
171
- auto val = interpret (e.get_op_arg (0 ));
172
- switch (e.get_op_type ()) {
173
- case isl::ast_op_type::minus:
174
- return -val;
175
- default :
176
- CHECK (false ) << " NYI" ;
177
- return 0 ; // avoid warning
166
+ int64_t interpretUnaryOp (isl::ast_expr_op e) {
167
+ auto val = interpret (e.get_arg (0 ));
168
+ if (e.as <isl::ast_op_minus>()) {
169
+ return -val;
170
+ } else {
171
+ CHECK (false ) << " NYI" ;
172
+ return 0 ; // avoid warning
178
173
}
179
174
}
180
175
};
@@ -301,16 +296,13 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
301
296
};
302
297
303
298
llvm::Value* CodeGen_TC::getValue (isl::ast_expr expr) {
304
- switch (isl_ast_expr_get_type (expr.get ())) {
305
- case isl_ast_expr_type::isl_ast_expr_id:
306
- return sym_get (expr.get_id ().get_name ());
307
- case isl_ast_expr_type::isl_ast_expr_int: {
308
- auto val = isl::manage (isl_ast_expr_get_val (expr.get ()));
309
- return getLLVMConstantSignedInt64 (toSInt (val));
310
- }
311
- default :
312
- LOG (FATAL) << " NYI" ;
313
- return nullptr ;
299
+ if (auto idExpr = expr.as <isl::ast_expr_id>()) {
300
+ return sym_get (idExpr.get_id ().get_name ());
301
+ } else if (auto intExpr = expr.as <isl::ast_expr_int>()) {
302
+ return getLLVMConstantSignedInt64 (toSInt (intExpr.get_val ()));
303
+ } else {
304
+ LOG (FATAL) << " NYI" ;
305
+ return nullptr ;
314
306
}
315
307
}
316
308
@@ -483,7 +475,7 @@ class LLVMCodegen {
483
475
halide_cg.get_builder ().CreateBr (headerBB);
484
476
485
477
llvm::PHINode* phi = nullptr ;
486
- auto iterator = node.get_iterator ().get_id ();
478
+ auto iterator = node.get_iterator ().as <isl::ast_expr_id>(). get_id ();
487
479
488
480
// Loop Header
489
481
{
@@ -494,30 +486,25 @@ class LLVMCodegen {
494
486
halide_cg.sym_push (iterator.get_name (), phi);
495
487
phi->addIncoming (getLLVMConstantSignedInt64 (initVal), incoming);
496
488
497
- auto cond_expr = node.get_cond ();
498
- CHECK (
499
- cond_expr.get_op_type () == isl::ast_op_type::lt or
500
- cond_expr.get_op_type () == isl::ast_op_type::le)
489
+ auto cond_expr = node.get_cond ().as <isl::ast_expr_op>();
490
+ CHECK (cond_expr.as <isl::ast_op_lt>() or cond_expr.as <isl::ast_op_le>())
501
491
<< " I only know how to codegen lt and le" ;
502
- auto condLHS = cond_expr.get_op_arg (0 );
503
- CHECK (
504
- isl_ast_expr_get_type (condLHS.get ()) ==
505
- isl_ast_expr_type::isl_ast_expr_id);
492
+ auto condLHS = cond_expr.get_arg (0 ).as <isl::ast_expr_id>();
493
+ CHECK (condLHS);
506
494
CHECK_EQ (condLHS.get_id (), iterator);
507
495
508
496
IslAstExprInterpeter i (scop_.globalParameterContext );
509
- auto condRHSVal = i.interpret (cond_expr.get_op_arg (1 ));
497
+ auto condRHSVal = i.interpret (cond_expr.get_arg (1 ));
510
498
511
499
auto cond = [&]() {
512
500
auto constant = getLLVMConstantSignedInt64 (condRHSVal);
513
- switch (cond_expr.get_op_type ()) {
514
- case isl::ast_op_type::lt:
515
- return halide_cg.get_builder ().CreateICmpSLT (phi, constant);
516
- case isl::ast_op_type::le:
517
- return halide_cg.get_builder ().CreateICmpSLE (phi, constant);
518
- default :
519
- CHECK (false ) << " NYI" ;
520
- return static_cast <llvm::Value*>(nullptr ); // avoid warning
501
+ if (cond_expr.as <isl::ast_op_lt>()) {
502
+ return halide_cg.get_builder ().CreateICmpSLT (phi, constant);
503
+ } else if (cond_expr.as <isl::ast_op_le>()) {
504
+ return halide_cg.get_builder ().CreateICmpSLE (phi, constant);
505
+ } else {
506
+ CHECK (false ) << " NYI" ;
507
+ return static_cast <llvm::Value*>(nullptr ); // avoid warning
521
508
}
522
509
}();
523
510
halide_cg.get_builder ().CreateCondBr (cond, loopBodyBB, loopExitBB);
@@ -572,8 +559,8 @@ class LLVMCodegen {
572
559
}
573
560
574
561
llvm::BasicBlock* emitStmt (isl::ast_node_user node) {
575
- isl::ast_expr usrExp = node.get_expr ();
576
- auto id = usrExp.get_op_arg ( 0 ).get_id ();
562
+ isl::ast_expr_op usrExp = node.get_expr (). as <isl::ast_expr_op> ();
563
+ auto id = usrExp.get_arg ( 0 ). as <isl::ast_expr_id>( ).get_id ();
577
564
auto provide = scop_.halide .statements .at (id);
578
565
auto op = provide.as <Halide::Internal::Provide>();
579
566
CHECK (op) << " Expected a Provide node: " << provide << ' \n ' ;
@@ -632,11 +619,11 @@ IslCodegenRes codegenISL(const Scop& scop) {
632
619
StmtSubscriptExprMapType& stmtSubscripts) -> isl::ast_node {
633
620
auto user = node.as <isl::ast_node_user>();
634
621
CHECK (user);
635
- auto expr = user.get_expr ();
622
+ auto expr = user.get_expr (). as <isl::ast_expr_op>() ;
636
623
auto schedule = build.get_schedule ();
637
624
auto scheduleMap = isl::map::from_union_map (schedule);
638
625
639
- auto stmtId = expr.get_op_arg ( 0 ).get_id ();
626
+ auto stmtId = expr.get_arg ( 0 ). as <isl::ast_expr_id>( ).get_id ();
640
627
CHECK_EQ (0u , iteratorMaps.count (stmtId)) << " entry exists: " << stmtId;
641
628
auto iteratorMap = isl::pw_multi_aff (scheduleMap.reverse ());
642
629
auto iterators = scop.halide .iterators .at (stmtId);
0 commit comments