Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 3ee1d4b

Browse files
author
Sven Verdoolaege
committed
Bump isl for exporting isl_ast_node as subclasses
1 parent d4d6f2e commit 3ee1d4b

File tree

3 files changed

+62
-68
lines changed

3 files changed

+62
-68
lines changed

src/core/polyhedral/codegen_llvm.cc

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -432,30 +432,30 @@ class LLVMCodegen {
432432
}
433433

434434
llvm::BasicBlock* emitAst(isl::ast_node node) {
435-
switch (node.get_type()) {
436-
case isl::ast_node_type::_for:
437-
return emitFor(node);
438-
case isl::ast_node_type::user:
439-
return emitStmt(node);
440-
case isl::ast_node_type::block:
441-
return emitBlock(node);
442-
case isl::ast_node_type::_if:
435+
if (auto forNode = node.as<isl::ast_node_for>()) {
436+
return emitFor(forNode);
437+
} else if (auto userNode = node.as<isl::ast_node_user>()) {
438+
return emitStmt(userNode);
439+
} else if (auto blockNode = node.as<isl::ast_node_block>()) {
440+
return emitBlock(blockNode);
441+
} else {
442+
if (node.as<isl::ast_node_if>()) {
443443
LOG(FATAL) << "NYI if node: " << node << std::endl;
444-
default:
444+
} else {
445445
LOG(FATAL) << "NYI " << node << std::endl;
446-
return static_cast<llvm::BasicBlock*>(nullptr); // avoid warning
446+
}
447+
return static_cast<llvm::BasicBlock*>(nullptr); // avoid warning
447448
}
448449
}
449450

450451
private:
451-
llvm::BasicBlock* emitBlock(isl::ast_node node) {
452+
llvm::BasicBlock* emitBlock(isl::ast_node_block node) {
452453
auto* function = halide_cg.get_builder().GetInsertBlock()->getParent();
453454
auto* currBB = llvm::BasicBlock::Create(llvmCtx, "block_exit", function);
454455
halide_cg.get_builder().CreateBr(currBB);
455456
halide_cg.get_builder().SetInsertPoint(currBB);
456457

457-
CHECK(node.get_type() == isl::ast_node_type::block);
458-
for (auto child : node.block_get_children()) {
458+
for (auto child : node.get_children()) {
459459
currBB = emitAst(child);
460460
halide_cg.get_builder().SetInsertPoint(currBB);
461461
}
@@ -479,9 +479,7 @@ class LLVMCodegen {
479479
return arrTy->getPointerTo();
480480
}
481481

482-
llvm::BasicBlock* emitFor(isl::ast_node node) {
483-
CHECK(node.get_type() == isl::ast_node_type::_for);
484-
482+
llvm::BasicBlock* emitFor(isl::ast_node_for node) {
485483
IteratorLLVMValueMapType iterPHIs;
486484

487485
auto* incoming = halide_cg.get_builder().GetInsertBlock();
@@ -514,16 +512,16 @@ class LLVMCodegen {
514512

515513
// Loop Header
516514
{
517-
auto initVal = IslExprToSInt(node.for_get_init());
515+
auto initVal = IslExprToSInt(node.get_init());
518516
halide_cg.get_builder().SetInsertPoint(headerBB);
519517
phi = halide_cg.get_builder().CreatePHI(
520518
llvm::Type::getInt64Ty(llvmCtx),
521519
2,
522-
node.for_get_iterator().get_id().get_name());
523-
halide_cg.sym_push(node.for_get_iterator().get_id().get_name(), phi);
520+
node.get_iterator().get_id().get_name());
521+
halide_cg.sym_push(node.get_iterator().get_id().get_name(), phi);
524522
phi->addIncoming(getLLVMConstantSignedInt64(initVal), incoming);
525523

526-
auto cond_expr = node.for_get_cond();
524+
auto cond_expr = node.get_cond();
527525
CHECK(
528526
cond_expr.get_op_type() == isl::ast_op_type::lt or
529527
cond_expr.get_op_type() == isl::ast_op_type::le)
@@ -532,7 +530,7 @@ class LLVMCodegen {
532530
CHECK(
533531
isl_ast_expr_get_type(condLHS.get()) ==
534532
isl_ast_expr_type::isl_ast_expr_id);
535-
CHECK_EQ(condLHS.get_id(), node.for_get_iterator().get_id());
533+
CHECK_EQ(condLHS.get_id(), node.get_iterator().get_id());
536534

537535
IslAstExprInterpeter i(scop_.globalParameterContext);
538536
auto condRHSVal = i.interpret(cond_expr.get_op_arg(1));
@@ -565,7 +563,7 @@ class LLVMCodegen {
565563
halide_cg.get_builder().SetInsertPoint(detachedBB);
566564
}
567565
#endif
568-
auto* currentBB = emitAst(node.for_get_body());
566+
auto* currentBB = emitAst(node.get_body());
569567
halide_cg.get_builder().SetInsertPoint(currentBB);
570568

571569
if (parallel) {
@@ -580,7 +578,7 @@ class LLVMCodegen {
580578
// Create Latch
581579
{
582580
halide_cg.get_builder().SetInsertPoint(loopLatchBB);
583-
auto incVal = IslExprToSInt(node.for_get_inc());
581+
auto incVal = IslExprToSInt(node.get_inc());
584582
phi->addIncoming(
585583
halide_cg.get_builder().CreateAdd(
586584
phi, getLLVMConstantSignedInt64(incVal)),
@@ -589,7 +587,7 @@ class LLVMCodegen {
589587
}
590588

591589
halide_cg.get_builder().SetInsertPoint(loopExitBB);
592-
halide_cg.sym_pop(node.for_get_iterator().get_id().get_name());
590+
halide_cg.sym_pop(node.get_iterator().get_id().get_name());
593591
#ifdef TAPIR_VERSION_MAJOR
594592
if (parallel) {
595593
auto* syncBB = llvm::BasicBlock::Create(llvmCtx, "synced", function);
@@ -600,9 +598,8 @@ class LLVMCodegen {
600598
return halide_cg.get_builder().GetInsertBlock();
601599
}
602600

603-
llvm::BasicBlock* emitStmt(isl::ast_node node) {
604-
CHECK(node.get_type() == isl::ast_node_type::user);
605-
isl::ast_expr usrExp = node.user_get_expr();
601+
llvm::BasicBlock* emitStmt(isl::ast_node_user node) {
602+
isl::ast_expr usrExp = node.get_expr();
606603
auto id = usrExp.get_op_arg(0).get_id();
607604
auto provide = scop_.halide.statements.at(id);
608605
auto op = provide.as<Halide::Internal::Provide>();
@@ -660,7 +657,9 @@ IslCodegenRes codegenISL(const Scop& scop) {
660657
IteratorMapsType& iteratorMaps,
661658
const Scop& scop,
662659
StmtSubscriptExprMapType& stmtSubscripts) -> isl::ast_node {
663-
auto expr = node.user_get_expr();
660+
auto user = node.as<isl::ast_node_user>();
661+
CHECK(user);
662+
auto expr = user.get_expr();
664663
auto schedule = build.get_schedule();
665664
auto scheduleMap = isl::map::from_union_map(schedule);
666665

src/core/polyhedral/cuda/codegen.cc

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ struct AstPrinter {
8080
}
8181

8282
private:
83-
void emitFor(isl::ast_node node);
84-
void emitIf(isl::ast_node node);
85-
void emitStmt(isl::ast_node node);
83+
void emitFor(isl::ast_node_for node);
84+
void emitIf(isl::ast_node_if node);
85+
void emitStmt(isl::ast_node_user node);
8686
void emitAst(isl::ast_node node);
8787

8888
private:
@@ -216,26 +216,26 @@ void emitTensorViews(
216216
}
217217
}
218218

219-
void AstPrinter::emitFor(isl::ast_node node) {
219+
void AstPrinter::emitFor(isl::ast_node_for node) {
220220
WS ws;
221221
context_.ss << ws.tab();
222-
string iter = node.for_get_iterator().to_C_str();
223-
context_.ss << "for (int " << iter << " = " << node.for_get_init().to_C_str()
224-
<< "; " << node.for_get_cond().to_C_str() << "; " << iter
225-
<< " += " << node.for_get_inc().to_C_str() << ") {" << endl;
226-
emitAst(node.for_get_body());
222+
string iter = node.get_iterator().to_C_str();
223+
context_.ss << "for (int " << iter << " = " << node.get_init().to_C_str()
224+
<< "; " << node.get_cond().to_C_str() << "; " << iter
225+
<< " += " << node.get_inc().to_C_str() << ") {" << endl;
226+
emitAst(node.get_body());
227227
context_.ss << ws.tab() << "}" << endl;
228228
}
229229

230-
void AstPrinter::emitIf(isl::ast_node node) {
230+
void AstPrinter::emitIf(isl::ast_node_if node) {
231231
WS ws;
232232
context_.ss << ws.tab();
233-
context_.ss << "if (" << node.if_get_cond().to_C_str() << ") {" << endl;
234-
emitAst(node.if_get_then());
233+
context_.ss << "if (" << node.get_cond().to_C_str() << ") {" << endl;
234+
emitAst(node.get_then());
235235
context_.ss << ws.tab() << "}";
236-
if (node.if_has_else()) {
236+
if (node.has_else()) {
237237
context_.ss << " else {" << endl;
238-
emitAst(node.if_get_else());
238+
emitAst(node.get_else());
239239
context_.ss << ws.tab() << "}";
240240
}
241241
context_.ss << endl;
@@ -388,8 +388,8 @@ void emitCopyStmt(const CodegenStatementContext& context) {
388388
context.ss << ";" << std::endl;
389389
}
390390

391-
void AstPrinter::emitStmt(isl::ast_node node) {
392-
isl::ast_expr usrExp = node.user_get_expr();
391+
void AstPrinter::emitStmt(isl::ast_node_user node) {
392+
isl::ast_expr usrExp = node.get_expr();
393393
auto stmtId = usrExp.get_op_arg(0).get_id();
394394
auto nodeId = node.get_annotation();
395395
auto statementContext = CodegenStatementContext(context_, nodeId);
@@ -428,28 +428,21 @@ void AstPrinter::emitStmt(isl::ast_node node) {
428428
}
429429

430430
void AstPrinter::emitAst(isl::ast_node node) {
431-
switch (node.get_type()) {
432-
case isl::ast_node_type::_for:
433-
emitFor(node);
434-
break;
435-
case isl::ast_node_type::_if:
436-
emitIf(node);
437-
break;
438-
case isl::ast_node_type::block:
439-
for (auto child : node.block_get_children()) {
440-
emitAst(child);
441-
}
442-
break;
443-
case isl::ast_node_type::mark:
444-
CHECK(false) << "mark";
445-
// emitAst(node.mark_get_node());
446-
break;
447-
case isl::ast_node_type::user:
448-
emitStmt(node);
449-
break;
450-
default:
451-
LOG(FATAL) << "NYI " << node << endl;
452-
return;
431+
if (auto forNode = node.as<isl::ast_node_for>()) {
432+
emitFor(forNode);
433+
} else if (auto ifNode = node.as<isl::ast_node_if>()) {
434+
emitIf(ifNode);
435+
} else if (auto blockNode = node.as<isl::ast_node_block>()) {
436+
for (auto child : blockNode.get_children()) {
437+
emitAst(child);
438+
}
439+
} else if (node.as<isl::ast_node_mark>()) {
440+
CHECK(false) << "mark";
441+
// emitAst(node.mark_get_node());
442+
} else if (auto userNode = node.as<isl::ast_node_user>()) {
443+
emitStmt(userNode);
444+
} else {
445+
LOG(FATAL) << "NYI " << node << endl;
453446
}
454447
}
455448

@@ -746,7 +739,9 @@ string emitCudaKernel(
746739
[](isl::ast_node node,
747740
isl::ast_build build,
748741
IteratorMapsType* iteratorMaps) -> isl::ast_node {
749-
auto expr = node.user_get_expr();
742+
auto user = node.as<isl::ast_node_user>();
743+
CHECK(user);
744+
auto expr = user.get_expr();
750745
auto stmtId = expr.get_op_arg(0).get_id();
751746
// We rename loop-related dimensions manually.
752747
auto schedule = build.get_schedule();

third-party/islpp

Submodule islpp updated from 3e13c53 to 1da73b6

0 commit comments

Comments
 (0)