Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 255e251

Browse files
Make insertion points explicit in LLVM IR
The usage of GetInsertBlock / SetInsertBlock creates a lot of implicit state that is hard to track. For simple structured control flow like we have it is very easy to do with significantly less surprising behavior. This commit kills implicit behavior with fire and forces the passing of insertion points. Insertion points are enforced at the beginning of each private API emit* method.
1 parent 6eb026f commit 255e251

File tree

1 file changed

+50
-42
lines changed

1 file changed

+50
-42
lines changed

tc/core/polyhedral/codegen_llvm.cc

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ class LLVMCodegen {
340340

341341
// This creates a signature of the form:
342342
// input_data_types, output_data_types, parameters
343-
void createSignature(
343+
llvm::BasicBlock* createSignature(
344344
const std::vector<Halide::ImageParam>& inputs,
345345
const std::vector<Halide::OutputImageParam>& outputs,
346346
const std::vector<Halide::Internal::Parameter>& params,
@@ -383,40 +383,37 @@ class LLVMCodegen {
383383
it->addAttr(llvm::Attribute::ReadOnly);
384384
}
385385

386-
auto entryBB_ = llvm::BasicBlock::Create(llvmCtx, "entry", function);
387-
halide_cg.get_builder().SetInsertPoint(entryBB_);
386+
return llvm::BasicBlock::Create(llvmCtx, "entry", function);
388387
}
389388

390-
void CodeGen(isl::ast_node node) {
391-
emitAst(node);
392-
halide_cg.get_builder().CreateRetVoid();
393-
394-
if (llvm::verifyModule(*halide_cg.get_module())) {
395-
LOG(ERROR) << str();
396-
llvm::verifyModule(*halide_cg.get_module(), &llvm::outs());
397-
throw std::runtime_error("LLVM generated module is invalid.");
398-
}
399-
}
400-
401-
llvm::BasicBlock* emitAst(isl::ast_node node) {
389+
// This is the main entry point to emit pieces of LLVM IR
390+
// LLVM IR insertion is stateful, configured by SetInsertPoint
391+
// We make this an explicit parameter to avoid implicit conventions
392+
// All TC IR builder methods take an explicit insertionPoint.
393+
// The invariant in all emit* (except for emitAst) is that:
394+
// TC_CHECK_EQ(halide_cg.get_builder().GetInsertBlock(), insertionPoint);
395+
llvm::BasicBlock* emitAst(
396+
isl::ast_node node,
397+
llvm::BasicBlock* insertionPoint) {
398+
halide_cg.get_builder().SetInsertPoint(insertionPoint);
402399
if (auto forNode = node.as<isl::ast_node_for>()) {
403-
return emitFor(forNode);
400+
return emitFor(forNode, insertionPoint);
404401
} else if (auto userNode = node.as<isl::ast_node_user>()) {
405-
return emitStmt(userNode);
402+
return emitStmt(userNode, insertionPoint);
406403
} else if (auto blockNode = node.as<isl::ast_node_block>()) {
407-
llvm::BasicBlock* curBB;
404+
llvm::BasicBlock* curBB = insertionPoint;
408405
for (auto child : blockNode.get_children()) {
409-
curBB = emitAst(child);
406+
curBB = emitAst(child, curBB);
410407
}
411408
return curBB;
412409
} else {
413410
if (auto cond = node.as<isl::ast_node_if>()) {
414-
return emitIf(cond);
411+
return emitIf(cond, insertionPoint);
415412
} else {
416413
LOG(FATAL) << "NYI " << node << std::endl;
417414
}
418-
return static_cast<llvm::BasicBlock*>(nullptr); // avoid warning
419415
}
416+
return nullptr;
420417
}
421418

422419
private:
@@ -432,18 +429,19 @@ class LLVMCodegen {
432429
return arrTy->getPointerTo();
433430
}
434431

435-
llvm::BasicBlock* emitIf(isl::ast_node_if node) {
436-
auto* incoming = halide_cg.get_builder().GetInsertBlock();
437-
auto* function = incoming->getParent();
432+
llvm::BasicBlock* emitIf(
433+
isl::ast_node_if node,
434+
llvm::BasicBlock* insertionPoint) {
435+
TC_CHECK_EQ(halide_cg.get_builder().GetInsertBlock(), insertionPoint);
436+
auto* function = insertionPoint->getParent();
438437

439438
llvm::Value* condVal = halide_cg.codegen(node.get_cond());
440439
auto* thenBB = llvm::BasicBlock::Create(llvmCtx, "then", function);
441440
// Recursively emit "then" in a new thenBB
442-
halide_cg.get_builder().SetInsertPoint(thenBB);
443-
auto innerBB = emitAst(node.get_then());
441+
auto innerBB = emitAst(node.get_then(), thenBB);
444442

445443
// outer -> thenBB
446-
halide_cg.get_builder().SetInsertPoint(incoming);
444+
halide_cg.get_builder().SetInsertPoint(insertionPoint);
447445
// outer ---------> if_exit
448446
// TODO: When we support "else", go to elseBB instead of exit
449447
auto* exit = llvm::BasicBlock::Create(llvmCtx, "if_exit", function);
@@ -456,17 +454,17 @@ class LLVMCodegen {
456454
// Else is often empty in the absence of full tile extraction
457455
if (node.has_else()) {
458456
LOG(FATAL) << "NYI: else conditional branch";
459-
return halide_cg.get_builder().GetInsertBlock();
457+
return exit;
460458
}
461459

462-
// Set the insertion point to if_exit
463-
halide_cg.get_builder().SetInsertPoint(exit);
464-
return halide_cg.get_builder().GetInsertBlock();
460+
return exit;
465461
}
466462

467-
llvm::BasicBlock* emitFor(isl::ast_node_for node) {
468-
auto* incoming = halide_cg.get_builder().GetInsertBlock();
469-
auto* function = incoming->getParent();
463+
llvm::BasicBlock* emitFor(
464+
isl::ast_node_for node,
465+
llvm::BasicBlock* insertionPoint) {
466+
TC_CHECK_EQ(halide_cg.get_builder().GetInsertBlock(), insertionPoint);
467+
auto* function = insertionPoint->getParent();
470468
auto* headerBB = llvm::BasicBlock::Create(llvmCtx, "loop_header", function);
471469
auto* loopBodyBB = llvm::BasicBlock::Create(llvmCtx, "loop_body", function);
472470
auto* loopLatchBB =
@@ -485,16 +483,15 @@ class LLVMCodegen {
485483
phi = halide_cg.get_builder().CreatePHI(
486484
initVal->getType(), 2, iterator.get_name());
487485
halide_cg.sym_push(iterator.get_name(), phi);
488-
phi->addIncoming(initVal, incoming);
486+
phi->addIncoming(initVal, insertionPoint);
489487

490488
auto cond = halide_cg.codegen(node.get_cond());
491489
halide_cg.get_builder().CreateCondBr(cond, loopBodyBB, loopExitBB);
492490
}
493491

494492
// Create Body
495493
{
496-
halide_cg.get_builder().SetInsertPoint(loopBodyBB);
497-
auto* currentBB = emitAst(node.get_body());
494+
auto* currentBB = emitAst(node.get_body(), loopBodyBB);
498495
halide_cg.get_builder().SetInsertPoint(currentBB);
499496
halide_cg.get_builder().CreateBr(loopLatchBB);
500497
}
@@ -508,12 +505,14 @@ class LLVMCodegen {
508505
halide_cg.get_builder().CreateBr(headerBB);
509506
}
510507

511-
halide_cg.get_builder().SetInsertPoint(loopExitBB);
512508
halide_cg.sym_pop(iterator.get_name());
513-
return halide_cg.get_builder().GetInsertBlock();
509+
return loopExitBB;
514510
}
515511

516-
llvm::BasicBlock* emitStmt(isl::ast_node_user node) {
512+
llvm::BasicBlock* emitStmt(
513+
isl::ast_node_user node,
514+
llvm::BasicBlock* insertionPoint) {
515+
TC_CHECK_EQ(halide_cg.get_builder().GetInsertBlock(), insertionPoint);
517516
isl::ast_expr_op usrExp = node.get_expr().as<isl::ast_expr_op>();
518517
auto id = usrExp.get_arg(0).as<isl::ast_expr_id>().get_id();
519518
auto provide = scop_.halide.statements.at(id);
@@ -535,6 +534,9 @@ class LLVMCodegen {
535534

536535
llvm::Value* rhs = halide_cg.codegen(op->values[0]);
537536
halide_cg.get_builder().CreateStore(rhs, destAddr);
537+
// We must return halide_cg.get_builder().GetInsertBlock() because
538+
// Halide does not adhere to our conventions and when it emits multiple
539+
// blocks things may go haywire.
538540
return halide_cg.get_builder().GetInsertBlock();
539541
}
540542

@@ -625,12 +627,18 @@ std::unique_ptr<llvm::Module> emitLLVMKernel(
625627
cg.halide_cg.get_module()->setDataLayout(dataLayout);
626628
cg.halide_cg.get_module()->setTargetTriple(
627629
llvm::EngineBuilder().selectTarget()->getTargetTriple().str());
628-
cg.createSignature(
630+
auto entry = cg.createSignature(
629631
scop.halide.inputs,
630632
scop.halide.outputs,
631633
scop.halide.params,
632634
specializedName);
633-
cg.CodeGen(islCg.astNode);
635+
auto exit = cg.emitAst(islCg.astNode, entry);
636+
cg.halide_cg.get_builder().SetInsertPoint(exit);
637+
cg.halide_cg.get_builder().CreateRetVoid();
638+
639+
TC_CHECK(!llvm::verifyModule(*cg.halide_cg.get_module()))
640+
<< "LLVM generated module is invalid." << cg.str().c_str();
641+
634642
cg.halide_cg.optimize_module();
635643
return cg.halide_cg.move_module();
636644
}

0 commit comments

Comments
 (0)