@@ -432,30 +432,30 @@ class LLVMCodegen {
432
432
}
433
433
434
434
llvm::BasicBlock* emitAst (isl::ast_node node) {
435
- switch ( node.get_type ()) {
436
- case isl::ast_node_type::_for:
437
- return emitFor ( node);
438
- case isl::ast_node_type::user:
439
- return emitStmt ( node);
440
- case isl::ast_node_type::block:
441
- return emitBlock (node);
442
- case isl::ast_node_type::_if:
435
+ if ( auto forNode = node.as <isl::ast_node_for> ()) {
436
+ return emitFor (forNode);
437
+ } else if ( auto userNode = node. as <isl::ast_node_user>()) {
438
+ return emitStmt (userNode);
439
+ } else if ( auto blockNode = node. as <isl::ast_node_block>()) {
440
+ return emitBlock (blockNode);
441
+ } else {
442
+ if (node. as < isl::ast_node_if>()) {
443
443
LOG (FATAL) << " NYI if node: " << node << std::endl;
444
- default :
444
+ } else {
445
445
LOG (FATAL) << " NYI " << node << std::endl;
446
- return static_cast <llvm::BasicBlock*>(nullptr ); // avoid warning
446
+ }
447
+ return static_cast <llvm::BasicBlock*>(nullptr ); // avoid warning
447
448
}
448
449
}
449
450
450
451
private:
451
- llvm::BasicBlock* emitBlock (isl::ast_node node) {
452
+ llvm::BasicBlock* emitBlock (isl::ast_node_block node) {
452
453
auto * function = halide_cg.get_builder ().GetInsertBlock ()->getParent ();
453
454
auto * currBB = llvm::BasicBlock::Create (llvmCtx, " block_exit" , function);
454
455
halide_cg.get_builder ().CreateBr (currBB);
455
456
halide_cg.get_builder ().SetInsertPoint (currBB);
456
457
457
- CHECK (node.get_type () == isl::ast_node_type::block);
458
- for (auto child : node.block_get_children ()) {
458
+ for (auto child : node.get_children ()) {
459
459
currBB = emitAst (child);
460
460
halide_cg.get_builder ().SetInsertPoint (currBB);
461
461
}
@@ -479,9 +479,7 @@ class LLVMCodegen {
479
479
return arrTy->getPointerTo ();
480
480
}
481
481
482
- llvm::BasicBlock* emitFor (isl::ast_node node) {
483
- CHECK (node.get_type () == isl::ast_node_type::_for);
484
-
482
+ llvm::BasicBlock* emitFor (isl::ast_node_for node) {
485
483
IteratorLLVMValueMapType iterPHIs;
486
484
487
485
auto * incoming = halide_cg.get_builder ().GetInsertBlock ();
@@ -514,16 +512,16 @@ class LLVMCodegen {
514
512
515
513
// Loop Header
516
514
{
517
- auto initVal = IslExprToSInt (node.for_get_init ());
515
+ auto initVal = IslExprToSInt (node.get_init ());
518
516
halide_cg.get_builder ().SetInsertPoint (headerBB);
519
517
phi = halide_cg.get_builder ().CreatePHI (
520
518
llvm::Type::getInt64Ty (llvmCtx),
521
519
2 ,
522
- node.for_get_iterator ().get_id ().get_name ());
523
- halide_cg.sym_push (node.for_get_iterator ().get_id ().get_name (), phi);
520
+ node.get_iterator ().get_id ().get_name ());
521
+ halide_cg.sym_push (node.get_iterator ().get_id ().get_name (), phi);
524
522
phi->addIncoming (getLLVMConstantSignedInt64 (initVal), incoming);
525
523
526
- auto cond_expr = node.for_get_cond ();
524
+ auto cond_expr = node.get_cond ();
527
525
CHECK (
528
526
cond_expr.get_op_type () == isl::ast_op_type::lt or
529
527
cond_expr.get_op_type () == isl::ast_op_type::le)
@@ -532,7 +530,7 @@ class LLVMCodegen {
532
530
CHECK (
533
531
isl_ast_expr_get_type (condLHS.get ()) ==
534
532
isl_ast_expr_type::isl_ast_expr_id);
535
- CHECK_EQ (condLHS.get_id (), node.for_get_iterator ().get_id ());
533
+ CHECK_EQ (condLHS.get_id (), node.get_iterator ().get_id ());
536
534
537
535
IslAstExprInterpeter i (scop_.globalParameterContext );
538
536
auto condRHSVal = i.interpret (cond_expr.get_op_arg (1 ));
@@ -565,7 +563,7 @@ class LLVMCodegen {
565
563
halide_cg.get_builder ().SetInsertPoint (detachedBB);
566
564
}
567
565
#endif
568
- auto * currentBB = emitAst (node.for_get_body ());
566
+ auto * currentBB = emitAst (node.get_body ());
569
567
halide_cg.get_builder ().SetInsertPoint (currentBB);
570
568
571
569
if (parallel) {
@@ -580,7 +578,7 @@ class LLVMCodegen {
580
578
// Create Latch
581
579
{
582
580
halide_cg.get_builder ().SetInsertPoint (loopLatchBB);
583
- auto incVal = IslExprToSInt (node.for_get_inc ());
581
+ auto incVal = IslExprToSInt (node.get_inc ());
584
582
phi->addIncoming (
585
583
halide_cg.get_builder ().CreateAdd (
586
584
phi, getLLVMConstantSignedInt64 (incVal)),
@@ -589,7 +587,7 @@ class LLVMCodegen {
589
587
}
590
588
591
589
halide_cg.get_builder ().SetInsertPoint (loopExitBB);
592
- halide_cg.sym_pop (node.for_get_iterator ().get_id ().get_name ());
590
+ halide_cg.sym_pop (node.get_iterator ().get_id ().get_name ());
593
591
#ifdef TAPIR_VERSION_MAJOR
594
592
if (parallel) {
595
593
auto * syncBB = llvm::BasicBlock::Create (llvmCtx, " synced" , function);
@@ -600,9 +598,8 @@ class LLVMCodegen {
600
598
return halide_cg.get_builder ().GetInsertBlock ();
601
599
}
602
600
603
- llvm::BasicBlock* emitStmt (isl::ast_node node) {
604
- CHECK (node.get_type () == isl::ast_node_type::user);
605
- isl::ast_expr usrExp = node.user_get_expr ();
601
+ llvm::BasicBlock* emitStmt (isl::ast_node_user node) {
602
+ isl::ast_expr usrExp = node.get_expr ();
606
603
auto id = usrExp.get_op_arg (0 ).get_id ();
607
604
auto provide = scop_.halide .statements .at (id);
608
605
auto op = provide.as <Halide::Internal::Provide>();
@@ -660,7 +657,9 @@ IslCodegenRes codegenISL(const Scop& scop) {
660
657
IteratorMapsType& iteratorMaps,
661
658
const Scop& scop,
662
659
StmtSubscriptExprMapType& stmtSubscripts) -> isl::ast_node {
663
- auto expr = node.user_get_expr ();
660
+ auto user = node.as <isl::ast_node_user>();
661
+ CHECK (user);
662
+ auto expr = user.get_expr ();
664
663
auto schedule = build.get_schedule ();
665
664
auto scheduleMap = isl::map::from_union_map (schedule);
666
665
0 commit comments