Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit dffba78

Browse files
Merge pull request #331 from facebookresearch/pr/ast_expr
Bump isl for exporting isl_ast_expr as subclasses
2 parents c60f444 + 3f12a71 commit dffba78

File tree

3 files changed

+61
-74
lines changed

3 files changed

+61
-74
lines changed

tc/core/polyhedral/codegen_llvm.cc

Lines changed: 56 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333

3434
#include "Halide.h"
3535

36-
#include "isl/ast.h"
37-
3836
#include "tc/core/constants.h"
3937
#include "tc/core/flags.h"
4038
#include "tc/core/halide2isl.h"
@@ -83,12 +81,12 @@ llvm::Value* getLLVMConstantSignedInt64(int64_t v) {
8381
}
8482

8583
int64_t IslExprToSInt(isl::ast_expr e) {
86-
CHECK(isl_ast_expr_get_type(e.get()) == isl_ast_expr_type::isl_ast_expr_int);
87-
return toSInt(isl::manage(isl_ast_expr_get_val(e.get())));
84+
auto intExpr = e.as<isl::ast_expr_int>();
85+
CHECK(intExpr);
86+
return toSInt(intExpr.get_val());
8887
}
8988

90-
int64_t islIdToInt(isl::ast_expr e, isl::set context) {
91-
CHECK(isl_ast_expr_get_type(e.get()) == isl_ast_expr_type::isl_ast_expr_id);
89+
int64_t islIdToInt(isl::ast_expr_id e, isl::set context) {
9290
auto space = context.get_space();
9391
isl::aff param(isl::aff::param_on_domain_space(space, e.get_id()));
9492
auto p = context.sample_point();
@@ -127,22 +125,21 @@ class IslAstExprInterpeter {
127125
IslAstExprInterpeter(isl::set context) : context_(context){};
128126

129127
int64_t interpret(isl::ast_expr e) {
130-
switch (isl_ast_expr_get_type(e.get())) {
131-
case isl_ast_expr_type::isl_ast_expr_int:
132-
return IslExprToSInt(e);
133-
case isl_ast_expr_type::isl_ast_expr_id:
134-
return islIdToInt(e, context_);
135-
case isl_ast_expr_type::isl_ast_expr_op:
136-
return interpretOp(e);
137-
default:
138-
CHECK(false) << "NYI";
139-
return 0; // avoid warning
128+
if (auto intExpr = e.as<isl::ast_expr_int>()) {
129+
return IslExprToSInt(intExpr);
130+
} else if (auto idExpr = e.as<isl::ast_expr_id>()) {
131+
return islIdToInt(idExpr, context_);
132+
} else if (auto opExpr = e.as<isl::ast_expr_op>()) {
133+
return interpretOp(opExpr);
134+
} else {
135+
CHECK(false) << "NYI";
136+
return 0; // avoid warning
140137
}
141138
};
142139

143140
private:
144-
int64_t interpretOp(isl::ast_expr e) {
145-
switch (e.get_op_n_arg()) {
141+
int64_t interpretOp(isl::ast_expr_op e) {
142+
switch (e.get_n_arg()) {
146143
case 1:
147144
return interpretUnaryOp(e);
148145
case 2:
@@ -153,28 +150,26 @@ class IslAstExprInterpeter {
153150
}
154151
}
155152

156-
int64_t interpretBinaryOp(isl::ast_expr e) {
157-
auto left = interpret(e.get_op_arg(0));
158-
auto right = interpret(e.get_op_arg(1));
159-
switch (e.get_op_type()) {
160-
case isl::ast_op_type::add:
161-
return left + right;
162-
case isl::ast_op_type::sub:
163-
return left - right;
164-
default:
165-
CHECK(false) << "NYI: " << e;
166-
return 0; // avoid warning
153+
int64_t interpretBinaryOp(isl::ast_expr_op e) {
154+
auto left = interpret(e.get_arg(0));
155+
auto right = interpret(e.get_arg(1));
156+
if (e.as<isl::ast_op_add>()) {
157+
return left + right;
158+
} else if (e.as<isl::ast_op_sub>()) {
159+
return left - right;
160+
} else {
161+
CHECK(false) << "NYI: " << e;
162+
return 0; // avoid warning
167163
}
168164
}
169165

170-
int64_t interpretUnaryOp(isl::ast_expr e) {
171-
auto val = interpret(e.get_op_arg(0));
172-
switch (e.get_op_type()) {
173-
case isl::ast_op_type::minus:
174-
return -val;
175-
default:
176-
CHECK(false) << "NYI";
177-
return 0; // avoid warning
166+
int64_t interpretUnaryOp(isl::ast_expr_op e) {
167+
auto val = interpret(e.get_arg(0));
168+
if (e.as<isl::ast_op_minus>()) {
169+
return -val;
170+
} else {
171+
CHECK(false) << "NYI";
172+
return 0; // avoid warning
178173
}
179174
}
180175
};
@@ -301,16 +296,13 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
301296
};
302297

303298
llvm::Value* CodeGen_TC::getValue(isl::ast_expr expr) {
304-
switch (isl_ast_expr_get_type(expr.get())) {
305-
case isl_ast_expr_type::isl_ast_expr_id:
306-
return sym_get(expr.get_id().get_name());
307-
case isl_ast_expr_type::isl_ast_expr_int: {
308-
auto val = isl::manage(isl_ast_expr_get_val(expr.get()));
309-
return getLLVMConstantSignedInt64(toSInt(val));
310-
}
311-
default:
312-
LOG(FATAL) << "NYI";
313-
return nullptr;
299+
if (auto idExpr = expr.as<isl::ast_expr_id>()) {
300+
return sym_get(idExpr.get_id().get_name());
301+
} else if (auto intExpr = expr.as<isl::ast_expr_int>()) {
302+
return getLLVMConstantSignedInt64(toSInt(intExpr.get_val()));
303+
} else {
304+
LOG(FATAL) << "NYI";
305+
return nullptr;
314306
}
315307
}
316308

@@ -483,7 +475,7 @@ class LLVMCodegen {
483475
halide_cg.get_builder().CreateBr(headerBB);
484476

485477
llvm::PHINode* phi = nullptr;
486-
auto iterator = node.get_iterator().get_id();
478+
auto iterator = node.get_iterator().as<isl::ast_expr_id>().get_id();
487479

488480
// Loop Header
489481
{
@@ -494,30 +486,25 @@ class LLVMCodegen {
494486
halide_cg.sym_push(iterator.get_name(), phi);
495487
phi->addIncoming(getLLVMConstantSignedInt64(initVal), incoming);
496488

497-
auto cond_expr = node.get_cond();
498-
CHECK(
499-
cond_expr.get_op_type() == isl::ast_op_type::lt or
500-
cond_expr.get_op_type() == isl::ast_op_type::le)
489+
auto cond_expr = node.get_cond().as<isl::ast_expr_op>();
490+
CHECK(cond_expr.as<isl::ast_op_lt>() or cond_expr.as<isl::ast_op_le>())
501491
<< "I only know how to codegen lt and le";
502-
auto condLHS = cond_expr.get_op_arg(0);
503-
CHECK(
504-
isl_ast_expr_get_type(condLHS.get()) ==
505-
isl_ast_expr_type::isl_ast_expr_id);
492+
auto condLHS = cond_expr.get_arg(0).as<isl::ast_expr_id>();
493+
CHECK(condLHS);
506494
CHECK_EQ(condLHS.get_id(), iterator);
507495

508496
IslAstExprInterpeter i(scop_.globalParameterContext);
509-
auto condRHSVal = i.interpret(cond_expr.get_op_arg(1));
497+
auto condRHSVal = i.interpret(cond_expr.get_arg(1));
510498

511499
auto cond = [&]() {
512500
auto constant = getLLVMConstantSignedInt64(condRHSVal);
513-
switch (cond_expr.get_op_type()) {
514-
case isl::ast_op_type::lt:
515-
return halide_cg.get_builder().CreateICmpSLT(phi, constant);
516-
case isl::ast_op_type::le:
517-
return halide_cg.get_builder().CreateICmpSLE(phi, constant);
518-
default:
519-
CHECK(false) << "NYI";
520-
return static_cast<llvm::Value*>(nullptr); // avoid warning
501+
if (cond_expr.as<isl::ast_op_lt>()) {
502+
return halide_cg.get_builder().CreateICmpSLT(phi, constant);
503+
} else if (cond_expr.as<isl::ast_op_le>()) {
504+
return halide_cg.get_builder().CreateICmpSLE(phi, constant);
505+
} else {
506+
CHECK(false) << "NYI";
507+
return static_cast<llvm::Value*>(nullptr); // avoid warning
521508
}
522509
}();
523510
halide_cg.get_builder().CreateCondBr(cond, loopBodyBB, loopExitBB);
@@ -572,8 +559,8 @@ class LLVMCodegen {
572559
}
573560

574561
llvm::BasicBlock* emitStmt(isl::ast_node_user node) {
575-
isl::ast_expr usrExp = node.get_expr();
576-
auto id = usrExp.get_op_arg(0).get_id();
562+
isl::ast_expr_op usrExp = node.get_expr().as<isl::ast_expr_op>();
563+
auto id = usrExp.get_arg(0).as<isl::ast_expr_id>().get_id();
577564
auto provide = scop_.halide.statements.at(id);
578565
auto op = provide.as<Halide::Internal::Provide>();
579566
CHECK(op) << "Expected a Provide node: " << provide << '\n';
@@ -632,11 +619,11 @@ IslCodegenRes codegenISL(const Scop& scop) {
632619
StmtSubscriptExprMapType& stmtSubscripts) -> isl::ast_node {
633620
auto user = node.as<isl::ast_node_user>();
634621
CHECK(user);
635-
auto expr = user.get_expr();
622+
auto expr = user.get_expr().as<isl::ast_expr_op>();
636623
auto schedule = build.get_schedule();
637624
auto scheduleMap = isl::map::from_union_map(schedule);
638625

639-
auto stmtId = expr.get_op_arg(0).get_id();
626+
auto stmtId = expr.get_arg(0).as<isl::ast_expr_id>().get_id();
640627
CHECK_EQ(0u, iteratorMaps.count(stmtId)) << "entry exists: " << stmtId;
641628
auto iteratorMap = isl::pw_multi_aff(scheduleMap.reverse());
642629
auto iterators = scop.halide.iterators.at(stmtId);

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,8 @@ void emitCopyStmt(const CodegenStatementContext& context) {
392392
}
393393

394394
void AstPrinter::emitStmt(isl::ast_node_user node) {
395-
isl::ast_expr usrExp = node.get_expr();
396-
auto stmtId = usrExp.get_op_arg(0).get_id();
395+
isl::ast_expr_op usrExp = node.get_expr().as<isl::ast_expr_op>();
396+
auto stmtId = usrExp.get_arg(0).as<isl::ast_expr_id>().get_id();
397397
auto nodeId = node.get_annotation();
398398
auto statementContext = CodegenStatementContext(context_, nodeId);
399399
CHECK_EQ(context_.nodeInfoMap.count(nodeId), 1u)
@@ -719,8 +719,8 @@ string emitCudaKernel(
719719
NodeInfoMapType* nodeInfoMap) -> isl::ast_node {
720720
auto user = node.as<isl::ast_node_user>();
721721
CHECK(user);
722-
auto expr = user.get_expr();
723-
auto stmtId = expr.get_op_arg(0).get_id();
722+
auto expr = user.get_expr().as<isl::ast_expr_op>();
723+
auto stmtId = expr.get_arg(0).as<isl::ast_expr_id>().get_id();
724724
auto schedule = build.get_schedule();
725725
auto scheduleMap = isl::map::from_union_map(schedule);
726726

third-party/islpp

Submodule islpp updated from 310e910 to a28f039

0 commit comments

Comments
 (0)