@@ -340,7 +340,7 @@ class LLVMCodegen {
340
340
341
341
// This creates a signature of the form:
342
342
// input_data_types, output_data_types, parameters
343
- void createSignature (
343
+ llvm::BasicBlock* createSignature (
344
344
const std::vector<Halide::ImageParam>& inputs,
345
345
const std::vector<Halide::OutputImageParam>& outputs,
346
346
const std::vector<Halide::Internal::Parameter>& params,
@@ -383,40 +383,37 @@ class LLVMCodegen {
383
383
it->addAttr (llvm::Attribute::ReadOnly);
384
384
}
385
385
386
- auto entryBB_ = llvm::BasicBlock::Create (llvmCtx, " entry" , function);
387
- halide_cg.get_builder ().SetInsertPoint (entryBB_);
386
+ return llvm::BasicBlock::Create (llvmCtx, " entry" , function);
388
387
}
389
388
390
- void CodeGen (isl::ast_node node) {
391
- emitAst (node);
392
- halide_cg.get_builder ().CreateRetVoid ();
393
-
394
- if (llvm::verifyModule (*halide_cg.get_module ())) {
395
- LOG (ERROR) << str ();
396
- llvm::verifyModule (*halide_cg.get_module (), &llvm::outs ());
397
- throw std::runtime_error (" LLVM generated module is invalid." );
398
- }
399
- }
400
-
401
- llvm::BasicBlock* emitAst (isl::ast_node node) {
389
+ // This is the main entry point to emit pieces of LLVM IR
390
+ // LLVM IR insertion is stateful, configured by SetInsertPoint
391
+ // We make this an explicit parameter to avoid implicit conventions
392
+ // All TC IR builder methods take an explicit insertionPoint.
393
+ // The invariant in all emit* (except for emitAst) is that:
394
+ // TC_CHECK_EQ(halide_cg.get_builder().GetInsertBlock(), insertionPoint);
395
+ llvm::BasicBlock* emitAst (
396
+ isl::ast_node node,
397
+ llvm::BasicBlock* insertionPoint) {
398
+ halide_cg.get_builder ().SetInsertPoint (insertionPoint);
402
399
if (auto forNode = node.as <isl::ast_node_for>()) {
403
- return emitFor (forNode);
400
+ return emitFor (forNode, insertionPoint );
404
401
} else if (auto userNode = node.as <isl::ast_node_user>()) {
405
- return emitStmt (userNode);
402
+ return emitStmt (userNode, insertionPoint );
406
403
} else if (auto blockNode = node.as <isl::ast_node_block>()) {
407
- llvm::BasicBlock* curBB;
404
+ llvm::BasicBlock* curBB = insertionPoint ;
408
405
for (auto child : blockNode.get_children ()) {
409
- curBB = emitAst (child);
406
+ curBB = emitAst (child, curBB );
410
407
}
411
408
return curBB;
412
409
} else {
413
410
if (auto cond = node.as <isl::ast_node_if>()) {
414
- return emitIf (cond);
411
+ return emitIf (cond, insertionPoint );
415
412
} else {
416
413
LOG (FATAL) << " NYI " << node << std::endl;
417
414
}
418
- return static_cast <llvm::BasicBlock*>(nullptr ); // avoid warning
419
415
}
416
+ return nullptr ;
420
417
}
421
418
422
419
private:
@@ -432,18 +429,19 @@ class LLVMCodegen {
432
429
return arrTy->getPointerTo ();
433
430
}
434
431
435
- llvm::BasicBlock* emitIf (isl::ast_node_if node) {
436
- auto * incoming = halide_cg.get_builder ().GetInsertBlock ();
437
- auto * function = incoming->getParent ();
432
+ llvm::BasicBlock* emitIf (
433
+ isl::ast_node_if node,
434
+ llvm::BasicBlock* insertionPoint) {
435
+ TC_CHECK_EQ (halide_cg.get_builder ().GetInsertBlock (), insertionPoint);
436
+ auto * function = insertionPoint->getParent ();
438
437
439
438
llvm::Value* condVal = halide_cg.codegen (node.get_cond ());
440
439
auto * thenBB = llvm::BasicBlock::Create (llvmCtx, " then" , function);
441
440
// Recursively emit "then" in a new thenBB
442
- halide_cg.get_builder ().SetInsertPoint (thenBB);
443
- auto innerBB = emitAst (node.get_then ());
441
+ auto innerBB = emitAst (node.get_then (), thenBB);
444
442
445
443
// outer -> thenBB
446
- halide_cg.get_builder ().SetInsertPoint (incoming );
444
+ halide_cg.get_builder ().SetInsertPoint (insertionPoint );
447
445
// outer ---------> if_exit
448
446
// TODO: When we support "else", go to elseBB instead of exit
449
447
auto * exit = llvm::BasicBlock::Create (llvmCtx, " if_exit" , function);
@@ -456,17 +454,17 @@ class LLVMCodegen {
456
454
// Else is often empty in the absence of full tile extraction
457
455
if (node.has_else ()) {
458
456
LOG (FATAL) << " NYI: else conditional branch" ;
459
- return halide_cg. get_builder (). GetInsertBlock () ;
457
+ return exit ;
460
458
}
461
459
462
- // Set the insertion point to if_exit
463
- halide_cg.get_builder ().SetInsertPoint (exit);
464
- return halide_cg.get_builder ().GetInsertBlock ();
460
+ return exit;
465
461
}
466
462
467
- llvm::BasicBlock* emitFor (isl::ast_node_for node) {
468
- auto * incoming = halide_cg.get_builder ().GetInsertBlock ();
469
- auto * function = incoming->getParent ();
463
+ llvm::BasicBlock* emitFor (
464
+ isl::ast_node_for node,
465
+ llvm::BasicBlock* insertionPoint) {
466
+ TC_CHECK_EQ (halide_cg.get_builder ().GetInsertBlock (), insertionPoint);
467
+ auto * function = insertionPoint->getParent ();
470
468
auto * headerBB = llvm::BasicBlock::Create (llvmCtx, " loop_header" , function);
471
469
auto * loopBodyBB = llvm::BasicBlock::Create (llvmCtx, " loop_body" , function);
472
470
auto * loopLatchBB =
@@ -485,16 +483,15 @@ class LLVMCodegen {
485
483
phi = halide_cg.get_builder ().CreatePHI (
486
484
initVal->getType (), 2 , iterator.get_name ());
487
485
halide_cg.sym_push (iterator.get_name (), phi);
488
- phi->addIncoming (initVal, incoming );
486
+ phi->addIncoming (initVal, insertionPoint );
489
487
490
488
auto cond = halide_cg.codegen (node.get_cond ());
491
489
halide_cg.get_builder ().CreateCondBr (cond, loopBodyBB, loopExitBB);
492
490
}
493
491
494
492
// Create Body
495
493
{
496
- halide_cg.get_builder ().SetInsertPoint (loopBodyBB);
497
- auto * currentBB = emitAst (node.get_body ());
494
+ auto * currentBB = emitAst (node.get_body (), loopBodyBB);
498
495
halide_cg.get_builder ().SetInsertPoint (currentBB);
499
496
halide_cg.get_builder ().CreateBr (loopLatchBB);
500
497
}
@@ -508,12 +505,14 @@ class LLVMCodegen {
508
505
halide_cg.get_builder ().CreateBr (headerBB);
509
506
}
510
507
511
- halide_cg.get_builder ().SetInsertPoint (loopExitBB);
512
508
halide_cg.sym_pop (iterator.get_name ());
513
- return halide_cg. get_builder (). GetInsertBlock () ;
509
+ return loopExitBB ;
514
510
}
515
511
516
- llvm::BasicBlock* emitStmt (isl::ast_node_user node) {
512
+ llvm::BasicBlock* emitStmt (
513
+ isl::ast_node_user node,
514
+ llvm::BasicBlock* insertionPoint) {
515
+ TC_CHECK_EQ (halide_cg.get_builder ().GetInsertBlock (), insertionPoint);
517
516
isl::ast_expr_op usrExp = node.get_expr ().as <isl::ast_expr_op>();
518
517
auto id = usrExp.get_arg (0 ).as <isl::ast_expr_id>().get_id ();
519
518
auto provide = scop_.halide .statements .at (id);
@@ -535,6 +534,9 @@ class LLVMCodegen {
535
534
536
535
llvm::Value* rhs = halide_cg.codegen (op->values [0 ]);
537
536
halide_cg.get_builder ().CreateStore (rhs, destAddr);
537
+ // We must return halide_cg.get_builder().GetInsertBlock() because
538
+ // Halide does not adhere to our conventions and when it emits multiple
539
+ // blocks things may go haywire.
538
540
return halide_cg.get_builder ().GetInsertBlock ();
539
541
}
540
542
@@ -625,12 +627,18 @@ std::unique_ptr<llvm::Module> emitLLVMKernel(
625
627
cg.halide_cg .get_module ()->setDataLayout (dataLayout);
626
628
cg.halide_cg .get_module ()->setTargetTriple (
627
629
llvm::EngineBuilder ().selectTarget ()->getTargetTriple ().str ());
628
- cg.createSignature (
630
+ auto entry = cg.createSignature (
629
631
scop.halide .inputs ,
630
632
scop.halide .outputs ,
631
633
scop.halide .params ,
632
634
specializedName);
633
- cg.CodeGen (islCg.astNode );
635
+ auto exit = cg.emitAst (islCg.astNode , entry);
636
+ cg.halide_cg .get_builder ().SetInsertPoint (exit);
637
+ cg.halide_cg .get_builder ().CreateRetVoid ();
638
+
639
+ TC_CHECK (!llvm::verifyModule (*cg.halide_cg .get_module ()))
640
+ << " LLVM generated module is invalid." << cg.str ().c_str ();
641
+
634
642
cg.halide_cg .optimize_module ();
635
643
return cg.halide_cg .move_module ();
636
644
}
0 commit comments