Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 5bc3e2c

Browse files
author
Sven Verdoolaege
committed
emitLLVMKernel: avoid relying on isl set variable names
codegenISL relied on set variable names in two ways. First, it assumed that the variable names of the domain of the schedule were preserved through the entire AST generation process. Explicitly set the names of the space sent to halide2isl::makeIslAffFromExpr instead of assuming they are still available in the schedule. halide2isl::makeIslAffFromExpr itself still relies on the names in the space itself being preserved, but the space does not get modified in between, so the risk of the names disappearing is reduced. Second, codegenISL would manually set the names of the range of the schedule, expecting them to be preserved across subsequent operations, but this is again not guaranteed. Convert the corresponding expressions to isl::ast_expr objects instead. As a nice bonus, the code generation for subscripts and variables now share the same mechanism, while before one used isl::ast_expr objects, while the other used an isl::pw_multi_aff. As a result, a lot of code can be removed.
1 parent 98bf0f5 commit 5bc3e2c

File tree

3 files changed

+39
-61
lines changed

3 files changed

+39
-61
lines changed

include/tc/core/polyhedral/scop.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,17 @@ struct Scop {
358358
// Assumes such argument exists.
359359
const Halide::OutputImageParam& findArgument(isl::id id) const;
360360

361+
// Make an affine function from a Halide Expr that is defined
362+
// over the instance set of the statement with identifier "stmtId" and
363+
// with parameters specified by "paramSpace". Return a
364+
// null isl::aff if the expression is not affine. Fail if any
365+
// of the variables does not correspond to a parameter or
366+
// an instance identifier of the statement.
367+
isl::aff makeIslAffFromStmtExpr(
368+
isl::id stmtId,
369+
isl::space paramSpace,
370+
const Halide::Expr& e) const;
371+
361372
// Promote a tensor reference group to a storage of a given "kind",
362373
// inserting the copy
363374
// statements below the given node. Inserts an Extension node below the give

src/core/polyhedral/codegen_llvm.cc

Lines changed: 10 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ isl::aff makeIslAffFromExpr(isl::space space, const Halide::Expr& e);
7272

7373
namespace polyhedral {
7474

75-
using IteratorMapType = isl::pw_multi_aff;
75+
using IteratorMapType = std::unordered_map<std::string, isl::ast_expr>;
7676
using IteratorMapsType =
7777
std::unordered_map<isl::id, IteratorMapType, isl::IslIdIslHash>;
7878

@@ -97,14 +97,6 @@ llvm::Value* getLLVMConstantSignedInt64(int64_t v) {
9797
return llvm::ConstantInt::get(llvm::Type::getInt64Ty(llvmCtx), v, true);
9898
}
9999

100-
isl::aff extractAff(isl::pw_multi_aff pma) {
101-
isl::PMA pma_(pma);
102-
CHECK_EQ(pma_.size(), 1);
103-
isl::MA ma(pma_[0].second);
104-
CHECK_EQ(ma.size(), 1);
105-
return ma[0];
106-
}
107-
108100
int64_t IslExprToSInt(isl::ast_expr e) {
109101
CHECK(isl_ast_expr_get_type(e.get()) == isl_ast_expr_type::isl_ast_expr_int);
110102
assert(sizeof(long) <= 8); // long is assumed to fit to 64bits
@@ -278,44 +270,7 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
278270
}
279271
}
280272
void visit(const Halide::Internal::Variable* op) override {
281-
auto aff = halide2isl::makeIslAffFromExpr(
282-
iteratorMap_->get_space().range(), Halide::Expr(op));
283-
284-
auto subscriptPma = isl::pw_aff(aff).pullback(*iteratorMap_);
285-
auto subscriptAff = extractAff(subscriptPma);
286-
287-
// sanity checks
288-
CHECK_EQ(subscriptAff.dim(isl::dim_type::div), 0);
289-
CHECK_EQ(subscriptAff.dim(isl::dim_type::out), 1);
290-
for (int d = 0; d < subscriptAff.dim(isl::dim_type::param); ++d) {
291-
auto v = subscriptAff.get_coefficient_val(isl::dim_type::param, d);
292-
CHECK(v.is_zero());
293-
}
294-
295-
llvm::Optional<int> posOne;
296-
int sum = 0;
297-
for (int d = 0; d < subscriptAff.dim(isl::dim_type::in); ++d) {
298-
auto v = subscriptAff.get_coefficient_val(isl::dim_type::in, d);
299-
CHECK(v.is_zero() or v.is_one());
300-
if (v.is_zero()) {
301-
continue;
302-
}
303-
++sum;
304-
posOne = d;
305-
}
306-
CHECK_LE(sum, 1);
307-
308-
if (sum == 0) {
309-
value =
310-
getLLVMConstantSignedInt64(toSInt(subscriptAff.get_constant_val()));
311-
return;
312-
}
313-
CHECK(posOne);
314-
315-
std::string name(
316-
isl_aff_get_dim_name(subscriptAff.get(), isl_dim_in, *posOne));
317-
318-
value = sym_get(name);
273+
value = getValue(iteratorMap_->at(op->name));
319274
}
320275

321276
public:
@@ -709,34 +664,28 @@ IslCodegenRes codegenISL(const Scop& scop) {
709664
const Scop& scop,
710665
StmtSubscriptExprMapType& stmtSubscripts) -> isl::ast_node {
711666
auto expr = node.user_get_expr();
712-
// We rename loop-related dimensions manually.
713667
auto schedule = build.get_schedule();
714-
auto scheduleSpace = build.get_schedule_space();
715668
auto scheduleMap = isl::map::from_union_map(schedule);
716669

717670
auto stmtId = expr.get_op_arg(0).get_id();
718671
// auto nodeId = isl::id(
719672
// node.get_ctx(),
720673
// std::string(kAstNodeIdPrefix) + std::to_string(nAstNodes()++));
721674
CHECK_EQ(0, iteratorMaps.count(stmtId)) << "entry exists: " << stmtId;
722-
CHECK_EQ(
723-
scheduleMap.dim(isl::dim_type::out),
724-
scheduleSpace.dim(isl::dim_type::set));
725-
for (int i = 0; i < scheduleSpace.dim(isl::dim_type::set); ++i) {
726-
scheduleMap = scheduleMap.set_dim_id(
727-
isl::dim_type::out,
728-
i,
729-
scheduleSpace.get_dim_id(isl::dim_type::set, i));
730-
}
731675
auto iteratorMap = isl::pw_multi_aff(scheduleMap.reverse());
732-
iteratorMaps.emplace(stmtId, iteratorMap);
676+
auto iterators = scop.halide.iterators.at(stmtId);
677+
auto& stmtIteratorMap = iteratorMaps[stmtId];
678+
for (int i = 0; i < iterators.size(); ++i) {
679+
auto expr = build.expr_from(iteratorMap.get_pw_aff(i));
680+
stmtIteratorMap.emplace(iterators[i], expr);
681+
}
733682
auto& subscripts = stmtSubscripts[stmtId];
734683
auto provide =
735684
scop.halide.statements.at(stmtId).as<Halide::Internal::Provide>();
736685
for (auto e : provide->args) {
737686
const auto& map = iteratorMap;
738-
auto space = map.get_space().range();
739-
auto aff = halide2isl::makeIslAffFromExpr(space, e);
687+
auto space = map.get_space().params();
688+
auto aff = scop.makeIslAffFromStmtExpr(stmtId, space, e);
740689
auto pulled = isl::pw_aff(aff).pullback(map);
741690
CHECK_EQ(pulled.n_piece(), 1);
742691
subscripts.push_back(build.expr_from(pulled));

src/core/polyhedral/scop.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,5 +558,23 @@ const Halide::OutputImageParam& Scop::findArgument(isl::id id) const {
558558
return *halide.inputs.begin();
559559
}
560560

561+
isl::aff Scop::makeIslAffFromStmtExpr(
562+
isl::id stmtId,
563+
isl::space paramSpace,
564+
const Halide::Expr& e) const {
565+
auto ctx = stmtId.get_ctx();
566+
auto iterators = halide.iterators.at(stmtId);
567+
auto space = paramSpace.set_from_params();
568+
space = space.add_dims(isl::dim_type::set, iterators.size());
569+
// Set the names of the set dimensions of "space" for use
570+
// by halide2isl::makeIslAffFromExpr.
571+
for (int i = 0; i < iterators.size(); ++i) {
572+
isl::id id(ctx, iterators[i]);
573+
space = space.set_dim_id(isl::dim_type::set, i, id);
574+
}
575+
space = space.set_tuple_id(isl::dim_type::set, stmtId);
576+
return halide2isl::makeIslAffFromExpr(space, e);
577+
}
578+
561579
} // namespace polyhedral
562580
} // namespace tc

0 commit comments

Comments
 (0)