Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 0d10566

Browse files
Merge pull request #246 from facebookresearch/pr/clean_llvm
some cleanups to codegen_llvm.cc
2 parents 9a9fcf5 + c8bdfa4 commit 0d10566

File tree

1 file changed

+14
-29
lines changed

1 file changed

+14
-29
lines changed

src/core/polyhedral/codegen_llvm.cc

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@
3636
#include "isl/ast.h"
3737

3838
#include "tc/core/constants.h"
39-
//#include "tc/core/polyhedral/isl_mu_wrappers.h"
4039
#include "tc/core/flags.h"
4140
#include "tc/core/polyhedral/codegen.h"
4241
#include "tc/core/polyhedral/schedule_isl_conversion.h"
4342
#include "tc/core/polyhedral/scop.h"
4443
#include "tc/core/scope_guard.h"
44+
#include "tc/external/isl.h"
4545

4646
#ifndef LLVM_VERSION_MAJOR
4747
#error LLVM_VERSION_MAJOR not set
@@ -76,10 +76,9 @@ namespace {
7676
thread_local llvm::LLVMContext llvmCtx;
7777

7878
int64_t toSInt(isl::val v) {
79-
auto n = v.get_num_si();
80-
auto d = v.get_den_si();
81-
CHECK_EQ(n % d, 0);
82-
return n / d;
79+
CHECK(v.is_int());
80+
static_assert(sizeof(long) <= 8, "long is assumed to fit into 64bits");
81+
return v.get_num_si();
8382
}
8483

8584
llvm::Value* getLLVMConstantSignedInt64(int64_t v) {
@@ -88,25 +87,16 @@ llvm::Value* getLLVMConstantSignedInt64(int64_t v) {
8887

8988
int64_t IslExprToSInt(isl::ast_expr e) {
9089
CHECK(isl_ast_expr_get_type(e.get()) == isl_ast_expr_type::isl_ast_expr_int);
91-
assert(sizeof(long) <= 8); // long is assumed to fit to 64bits
9290
return toSInt(isl::manage(isl_ast_expr_get_val(e.get())));
9391
}
9492

9593
int64_t islIdToInt(isl::ast_expr e, isl::set context) {
9694
CHECK(isl_ast_expr_get_type(e.get()) == isl_ast_expr_type::isl_ast_expr_id);
97-
CHECK_NE(-1, context.find_dim_by_id(isl::dim_type::param, e.get_id()));
98-
while (context.dim(isl::dim_type::param) > 1) {
99-
for (unsigned int d = 0; d < context.dim(isl::dim_type::param); ++d) {
100-
if (d == context.find_dim_by_id(isl::dim_type::param, e.get_id())) {
101-
continue;
102-
}
103-
context = context.remove_dims(isl::dim_type::param, d, 1);
104-
}
105-
}
95+
auto space = context.get_space();
96+
isl::aff param(isl::aff::param_on_domain_space(space, e.get_id()));
10697
auto p = context.sample_point();
107-
108-
auto val = toSInt(p.get_coordinate_val(isl::dim_type::param, 0));
109-
return val;
98+
CHECK(context.is_equal(p));
99+
return toSInt(param.eval(p));
110100
}
111101

112102
int64_t getTensorSize(isl::set context, const Halide::Expr& e) {
@@ -319,8 +309,7 @@ llvm::Value* CodeGen_TC::getValue(isl::ast_expr expr) {
319309
return sym_get(expr.get_id().get_name());
320310
case isl_ast_expr_type::isl_ast_expr_int: {
321311
auto val = isl::manage(isl_ast_expr_get_val(expr.get()));
322-
CHECK(val.is_int());
323-
return getLLVMConstantSignedInt64(val.get_num_si());
312+
return getLLVMConstantSignedInt64(toSInt(val));
324313
}
325314
default:
326315
LOG(FATAL) << "NYI";
@@ -497,16 +486,15 @@ class LLVMCodegen {
497486
halide_cg.get_builder().CreateBr(headerBB);
498487

499488
llvm::PHINode* phi = nullptr;
489+
auto iterator = node.get_iterator().get_id();
500490

501491
// Loop Header
502492
{
503493
auto initVal = IslExprToSInt(node.get_init());
504494
halide_cg.get_builder().SetInsertPoint(headerBB);
505495
phi = halide_cg.get_builder().CreatePHI(
506-
llvm::Type::getInt64Ty(llvmCtx),
507-
2,
508-
node.get_iterator().get_id().get_name());
509-
halide_cg.sym_push(node.get_iterator().get_id().get_name(), phi);
496+
llvm::Type::getInt64Ty(llvmCtx), 2, iterator.get_name());
497+
halide_cg.sym_push(iterator.get_name(), phi);
510498
phi->addIncoming(getLLVMConstantSignedInt64(initVal), incoming);
511499

512500
auto cond_expr = node.get_cond();
@@ -518,7 +506,7 @@ class LLVMCodegen {
518506
CHECK(
519507
isl_ast_expr_get_type(condLHS.get()) ==
520508
isl_ast_expr_type::isl_ast_expr_id);
521-
CHECK_EQ(condLHS.get_id(), node.get_iterator().get_id());
509+
CHECK_EQ(condLHS.get_id(), iterator);
522510

523511
IslAstExprInterpeter i(scop_.globalParameterContext);
524512
auto condRHSVal = i.interpret(cond_expr.get_op_arg(1));
@@ -575,7 +563,7 @@ class LLVMCodegen {
575563
}
576564

577565
halide_cg.get_builder().SetInsertPoint(loopExitBB);
578-
halide_cg.sym_pop(node.get_iterator().get_id().get_name());
566+
halide_cg.sym_pop(iterator.get_name());
579567
#ifdef TAPIR_VERSION_MAJOR
580568
if (parallel) {
581569
auto* syncBB = llvm::BasicBlock::Create(llvmCtx, "synced", function);
@@ -652,9 +640,6 @@ IslCodegenRes codegenISL(const Scop& scop) {
652640
auto scheduleMap = isl::map::from_union_map(schedule);
653641

654642
auto stmtId = expr.get_op_arg(0).get_id();
655-
// auto nodeId = isl::id(
656-
// node.get_ctx(),
657-
// std::string(kAstNodeIdPrefix) + std::to_string(nAstNodes()++));
658643
CHECK_EQ(0, iteratorMaps.count(stmtId)) << "entry exists: " << stmtId;
659644
auto iteratorMap = isl::pw_multi_aff(scheduleMap.reverse());
660645
auto iterators = scop.halide.iterators.at(stmtId);

0 commit comments

Comments
 (0)