Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 86825c5

Browse files
author
Sven Verdoolaege
committed
emitCudaKernel: avoid relying on isl set variable names in range of schedule
In particular, keep track of the AST build where each AST node was created and use that to generate AST expressions. Note that an AST build is a relatively large data structure, so this may end up consuming some memory if many statements or AST nodes are involved. An alternative would be to generate the AST expressions while the AST is being constructed, as is done in PPCG and in the LLVM code generator, but this requires two passes over the Halide data structures, once to generate the AST expressions and once to print them and would be a more involved change overall. A third alternative would be to generate AST expressions for the individual statement indices during the AST generation (using iteratorMap) and then to generate AST expressions in terms of those statement indices during the printing of the AST, plugging in the AST expressions for the statement indices.
1 parent 4b252e5 commit 86825c5

File tree

2 files changed

+40
-38
lines changed

2 files changed

+40
-38
lines changed

include/tc/core/polyhedral/cuda/codegen.h

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,32 +61,52 @@ void emitMappedTensorAccess(
6161

6262
} // namespace detail
6363

64-
using IteratorMapsType =
65-
std::unordered_map<isl::id, isl::pw_multi_aff, isl::IslIdIslHash>;
64+
/*
65+
* Information attached to an AST node during printing of the AST.
66+
* iteratorMap is the inverse schedule, mapping schedule dimensions
67+
* to the indices of the statement corresponding to the AST node.
68+
* build is the AST build at the point where the AST node is generated.
69+
* It is used to generate AST expressions in that context.
70+
*/
71+
struct NodeInfo {
72+
isl::pw_multi_aff iteratorMap;
73+
isl::ast_build build;
74+
};
75+
/*
76+
* Type used for mapping AST node identifier to the corresponding
77+
* AST node information.
78+
*/
79+
using NodeInfoMapType =
80+
std::unordered_map<isl::id, NodeInfo, isl::IslIdIslHash>;
6681

6782
struct CodegenContext {
6883
CodegenContext(
6984
std::stringstream& ss_,
7085
const MappedScop& s,
71-
const IteratorMapsType& i)
72-
: ss(ss_), mappedScop(s), iteratorMaps(i) {}
86+
const NodeInfoMapType& i)
87+
: ss(ss_), mappedScop(s), nodeInfoMap(i) {}
7388
CodegenContext(const CodegenContext& c)
74-
: ss(c.ss), mappedScop(c.mappedScop), iteratorMaps(c.iteratorMaps) {}
89+
: ss(c.ss), mappedScop(c.mappedScop), nodeInfoMap(c.nodeInfoMap) {}
7590

7691
const Scop& scop() const {
7792
return mappedScop.scop();
7893
}
7994

8095
std::stringstream& ss;
8196
const MappedScop& mappedScop;
82-
const IteratorMapsType& iteratorMaps;
97+
const NodeInfoMapType& nodeInfoMap;
8398
};
8499

85100
struct CodegenStatementContext : CodegenContext {
86101
CodegenStatementContext(const CodegenContext& c, isl::id astId)
87102
: CodegenContext(c), astNodeId(astId) {}
88103
isl::pw_multi_aff iteratorMap() const {
89-
return this->iteratorMaps.at(astNodeId);
104+
return this->nodeInfoMap.at(astNodeId).iteratorMap;
105+
}
106+
// Return the build where the AST node of this CodegenStatementContext
107+
// was constructed.
108+
isl::ast_build build() const {
109+
return this->nodeInfoMap.at(astNodeId).build;
90110
}
91111
isl::id statementId() const {
92112
return this->iteratorMap().get_tuple_id(isl::dim_type::out);

src/core/polyhedral/cuda/codegen.cc

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -367,11 +367,7 @@ void emitReductionInit(
367367
namespace {
368368
template <typename AFF>
369369
void emitAccess(AFF access, const CodegenStatementContext& context) {
370-
// Use a temporary isl::ast_build to print the expression.
371-
// Ideally, this should use the build at the point
372-
// where the user statement was created.
373-
auto astBuild = isl::ast_build::from_context(access.domain());
374-
context.ss << astBuild.access_from(access).to_C_str();
370+
context.ss << context.build().access_from(access).to_C_str();
375371
}
376372
} // namespace
377373

@@ -400,8 +396,8 @@ void AstPrinter::emitStmt(isl::ast_node_user node) {
400396
auto stmtId = usrExp.get_op_arg(0).get_id();
401397
auto nodeId = node.get_annotation();
402398
auto statementContext = CodegenStatementContext(context_, nodeId);
403-
CHECK_EQ(context_.iteratorMaps.count(nodeId), 1)
404-
<< "no iterator remapping for op " << nodeId;
399+
CHECK_EQ(context_.nodeInfoMap.count(nodeId), 1)
400+
<< "no info for node " << nodeId;
405401

406402
WS ws;
407403
context_.ss << ws.tab();
@@ -514,11 +510,7 @@ void emitHalideExpr(
514510
void visit(const Halide::Internal::Variable* op) {
515511
auto pwAff = tc::polyhedral::detail::makeAffFromMappedExpr(
516512
Halide::Expr(op), context);
517-
// Use a temporary isl::ast_build to print the expression.
518-
// Ideally, this should use the build at the point
519-
// where the user statement was created.
520-
auto astBuild = isl::ast_build::from_context(pwAff.domain());
521-
auto expr = astBuild.expr_from(pwAff);
513+
auto expr = context.build().expr_from(pwAff);
522514
auto s = expr.to_C_str();
523515
if (!is_identifier_or_nonnegative_integer(expr)) {
524516
s = "(" + s + ")";
@@ -718,42 +710,32 @@ string emitCudaKernel(
718710
emitTensorViews(ss, scop.halide.inputs, paramValues);
719711
emitTmpDecl(ss, scop);
720712
emitPromotedArrayViewsHalide(ss, scop);
721-
IteratorMapsType iteratorMaps;
722-
auto collect = [&iteratorMaps](
713+
NodeInfoMapType nodeInfoMap;
714+
auto collect = [&nodeInfoMap](
723715
isl::ast_node n, isl::ast_build b) -> isl::ast_node {
724716
auto collectIteratorMaps =
725717
[](isl::ast_node node,
726718
isl::ast_build build,
727-
IteratorMapsType* iteratorMaps) -> isl::ast_node {
719+
NodeInfoMapType* nodeInfoMap) -> isl::ast_node {
728720
auto user = node.as<isl::ast_node_user>();
729721
CHECK(user);
730722
auto expr = user.get_expr();
731723
auto stmtId = expr.get_op_arg(0).get_id();
732-
// We rename loop-related dimensions manually.
733724
auto schedule = build.get_schedule();
734-
auto scheduleSpace = build.get_schedule_space();
735725
auto scheduleMap = isl::map::from_union_map(schedule);
736726

737727
auto nodeId = isl::id(
738728
node.get_ctx(),
739729
std::string(kAstNodeIdPrefix) + std::to_string(nAstNodes()++));
740-
CHECK_EQ(0, iteratorMaps->count(nodeId)) << "entry exists: " << nodeId;
741-
CHECK_EQ(
742-
scheduleMap.dim(isl::dim_type::out),
743-
scheduleSpace.dim(isl::dim_type::set));
744-
for (int i = 0; i < scheduleSpace.dim(isl::dim_type::set); ++i) {
745-
scheduleMap = scheduleMap.set_dim_id(
746-
isl::dim_type::out,
747-
i,
748-
scheduleSpace.get_dim_id(isl::dim_type::set, i));
749-
}
730+
CHECK_EQ(0, nodeInfoMap->count(nodeId)) << "entry exists: " << nodeId;
750731

751-
auto iteratorMap = isl::pw_multi_aff(scheduleMap.reverse());
752-
iteratorMaps->emplace(nodeId, iteratorMap);
732+
auto& nodeInfo = (*nodeInfoMap)[nodeId];
733+
nodeInfo.iteratorMap = isl::pw_multi_aff(scheduleMap.reverse());
734+
nodeInfo.build = build;
753735
return node.set_annotation(nodeId);
754736
};
755737

756-
return collectIteratorMaps(n, b, &iteratorMaps);
738+
return collectIteratorMaps(n, b, &nodeInfoMap);
757739
};
758740

759741
auto bands = detail::ScheduleTree::collect(
@@ -775,7 +757,7 @@ string emitCudaKernel(
775757
astBuild = astBuild.set_at_each_domain(collect);
776758
astBuild = astBuild.set_iterators(Codegen::makeLoopIterators(ctx, maxDepth));
777759
auto astNode = astBuild.node_from(schedule);
778-
AstPrinter(CodegenContext(ss, mscop, iteratorMaps)).emit(astNode);
760+
AstPrinter(CodegenContext(ss, mscop, nodeInfoMap)).emit(astNode);
779761
ss << "}" << endl;
780762

781763
return ss.str();

0 commit comments

Comments
 (0)