Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 1687f6f

Browse files
Merge pull request #286 from facebookresearch/pr/codegen_cuda_names
emitCudaKernel: avoid relying on isl set variable names
2 parents f4c4fc2 + 86825c5 commit 1687f6f

File tree

2 files changed

+57
-50
lines changed

2 files changed

+57
-50
lines changed

include/tc/core/polyhedral/cuda/codegen.h

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <string>
2020
#include <unordered_map>
2121

22+
#include "tc/core/halide2isl.h"
2223
#include "tc/core/polyhedral/cuda/mapped_scop.h"
2324
#include "tc/core/polyhedral/scop.h"
2425
#include "tc/external/isl.h"
@@ -60,38 +61,58 @@ void emitMappedTensorAccess(
6061

6162
} // namespace detail
6263

63-
using IteratorMapsType =
64-
std::unordered_map<isl::id, isl::pw_multi_aff, isl::IslIdIslHash>;
64+
/*
65+
* Information attached to an AST node during printing of the AST.
66+
* iteratorMap is the inverse schedule, mapping schedule dimensions
67+
* to the indices of the statement corresponding to the AST node.
68+
* build is the AST build at the point where the AST node is generated.
69+
* It is used to generate AST expressions in that context.
70+
*/
71+
struct NodeInfo {
72+
isl::pw_multi_aff iteratorMap;
73+
isl::ast_build build;
74+
};
75+
/*
76+
* Type used for mapping AST node identifier to the corresponding
77+
* AST node information.
78+
*/
79+
using NodeInfoMapType =
80+
std::unordered_map<isl::id, NodeInfo, isl::IslIdIslHash>;
6581

6682
struct CodegenContext {
6783
CodegenContext(
6884
std::stringstream& ss_,
6985
const MappedScop& s,
70-
const IteratorMapsType& i)
71-
: ss(ss_), mappedScop(s), iteratorMaps(i) {}
86+
const NodeInfoMapType& i)
87+
: ss(ss_), mappedScop(s), nodeInfoMap(i) {}
7288
CodegenContext(const CodegenContext& c)
73-
: ss(c.ss), mappedScop(c.mappedScop), iteratorMaps(c.iteratorMaps) {}
89+
: ss(c.ss), mappedScop(c.mappedScop), nodeInfoMap(c.nodeInfoMap) {}
7490

7591
const Scop& scop() const {
7692
return mappedScop.scop();
7793
}
7894

7995
std::stringstream& ss;
8096
const MappedScop& mappedScop;
81-
const IteratorMapsType& iteratorMaps;
97+
const NodeInfoMapType& nodeInfoMap;
8298
};
8399

84100
struct CodegenStatementContext : CodegenContext {
85101
CodegenStatementContext(const CodegenContext& c, isl::id astId)
86102
: CodegenContext(c), astNodeId(astId) {}
87103
isl::pw_multi_aff iteratorMap() const {
88-
return this->iteratorMaps.at(astNodeId);
104+
return this->nodeInfoMap.at(astNodeId).iteratorMap;
105+
}
106+
// Return the build where the AST node of this CodegenStatementContext
107+
// was constructed.
108+
isl::ast_build build() const {
109+
return this->nodeInfoMap.at(astNodeId).build;
89110
}
90111
isl::id statementId() const {
91-
return this->iteratorMaps.at(astNodeId).get_tuple_id(isl::dim_type::out);
112+
return this->iteratorMap().get_tuple_id(isl::dim_type::out);
92113
}
93114
isl::set domain() const {
94-
return isl::map::from(this->iteratorMaps.at(astNodeId)).range();
115+
return isl::map::from(this->iteratorMap()).range();
95116
}
96117
std::vector<Scop::PromotionInfo> activePromotions() const {
97118
std::vector<Scop::PromotionInfo> result;
@@ -103,6 +124,16 @@ struct CodegenStatementContext : CodegenContext {
103124
}
104125
return result;
105126
}
127+
// Make an affine function from a Halide Expr that is defined
128+
// over the instance set of the statement corresponding to
129+
// the AST node of this CodegenStatementContext. Return a
130+
// null isl::aff if the expression is not affine. Fail if any
131+
// of the variables does not correspond to a parameter or
132+
// an instance identifier of the statement.
133+
isl::aff makeIslAffFromExpr(const Halide::Expr& e) const {
134+
auto space = iteratorMap().get_space().params();
135+
return scop().makeIslAffFromStmtExpr(statementId(), space, e);
136+
}
106137

107138
isl::id astNodeId;
108139
};

src/core/polyhedral/cuda/codegen.cc

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include <utility>
2222

2323
#include "tc/core/flags.h"
24-
#include "tc/core/halide2isl.h"
2524
#include "tc/core/islpp_wrap.h"
2625
#include "tc/core/libraries.h"
2726
#include "tc/core/polyhedral/codegen.h"
@@ -368,11 +367,7 @@ void emitReductionInit(
368367
namespace {
369368
template <typename AFF>
370369
void emitAccess(AFF access, const CodegenStatementContext& context) {
371-
// Use a temporary isl::ast_build to print the expression.
372-
// Ideally, this should use the build at the point
373-
// where the user statement was created.
374-
auto astBuild = isl::ast_build::from_context(access.domain());
375-
context.ss << astBuild.access_from(access).to_C_str();
370+
context.ss << context.build().access_from(access).to_C_str();
376371
}
377372
} // namespace
378373

@@ -401,6 +396,8 @@ void AstPrinter::emitStmt(isl::ast_node_user node) {
401396
auto stmtId = usrExp.get_op_arg(0).get_id();
402397
auto nodeId = node.get_annotation();
403398
auto statementContext = CodegenStatementContext(context_, nodeId);
399+
CHECK_EQ(context_.nodeInfoMap.count(nodeId), 1)
400+
<< "no info for node " << nodeId;
404401

405402
WS ws;
406403
context_.ss << ws.tab();
@@ -414,8 +411,6 @@ void AstPrinter::emitStmt(isl::ast_node_user node) {
414411
emitReductionInit(stmtId, updateId, context_);
415412
inReduction_ = true;
416413
} else if (inReduction_ && context_.scop().isReductionUpdate(stmtId)) {
417-
CHECK_EQ(context_.iteratorMaps.count(nodeId), 1)
418-
<< "no iterator remapping for op " << nodeId;
419414
emitReductionUpdate(stmtId, statementContext);
420415
reductionUpdateNodeId_ = nodeId;
421416
} else if (context_.scop().isSyncId(stmtId)) {
@@ -424,14 +419,11 @@ void AstPrinter::emitStmt(isl::ast_node_user node) {
424419
stmtId.get_name() == kReadIdName || stmtId.get_name() == kWriteIdName) {
425420
emitCopyStmt(statementContext);
426421
} else { // regular statement
427-
CHECK_EQ(context_.iteratorMaps.count(nodeId), 1)
428-
<< "no iterator remapping for op " << nodeId;
429-
auto mappedStmtId =
430-
context_.iteratorMaps.at(nodeId).get_tuple_id(isl::dim_type::out);
422+
auto mappedStmtId = statementContext.statementId();
431423
CHECK_EQ(stmtId, mappedStmtId)
432424
<< "statement ids in expr (" << stmtId << ") and in iteratorMaps ("
433425
<< mappedStmtId << ") do not match";
434-
emitUserStmt(stmtId, CodegenStatementContext(context_, nodeId));
426+
emitUserStmt(stmtId, statementContext);
435427
}
436428
}
437429

@@ -461,11 +453,10 @@ namespace detail {
461453
isl::pw_aff makeAffFromMappedExpr(
462454
const Halide::Expr& expr,
463455
const CodegenStatementContext& context) {
464-
auto space = context.iteratorMap().get_space().range();
465456
// We only expect this to be called on encountering a free
466457
// variable. Compound expressions should be emitted as Halide.
467458
CHECK(expr.as<Halide::Internal::Variable>());
468-
auto aff = halide2isl::makeIslAffFromExpr(space, expr);
459+
auto aff = context.makeIslAffFromExpr(expr);
469460
auto pwaff = isl::pw_aff(aff).pullback(context.iteratorMap());
470461
return pwaff;
471462
}
@@ -495,8 +486,7 @@ isl::multi_aff makeMultiAffAccess(
495486

496487
auto ma = isl::multi_aff::zero(space);
497488
for (size_t i = 0; i < subscripts.size(); ++i) {
498-
ma = ma.set_aff(
499-
i, halide2isl::makeIslAffFromExpr(domainSpace, subscripts[i]));
489+
ma = ma.set_aff(i, context.makeIslAffFromExpr(subscripts[i]));
500490
}
501491
return ma;
502492
}
@@ -520,11 +510,7 @@ void emitHalideExpr(
520510
void visit(const Halide::Internal::Variable* op) {
521511
auto pwAff = tc::polyhedral::detail::makeAffFromMappedExpr(
522512
Halide::Expr(op), context);
523-
// Use a temporary isl::ast_build to print the expression.
524-
// Ideally, this should use the build at the point
525-
// where the user statement was created.
526-
auto astBuild = isl::ast_build::from_context(pwAff.domain());
527-
auto expr = astBuild.expr_from(pwAff);
513+
auto expr = context.build().expr_from(pwAff);
528514
auto s = expr.to_C_str();
529515
if (!is_identifier_or_nonnegative_integer(expr)) {
530516
s = "(" + s + ")";
@@ -724,42 +710,32 @@ string emitCudaKernel(
724710
emitTensorViews(ss, scop.halide.inputs, paramValues);
725711
emitTmpDecl(ss, scop);
726712
emitPromotedArrayViewsHalide(ss, scop);
727-
IteratorMapsType iteratorMaps;
728-
auto collect = [&iteratorMaps](
713+
NodeInfoMapType nodeInfoMap;
714+
auto collect = [&nodeInfoMap](
729715
isl::ast_node n, isl::ast_build b) -> isl::ast_node {
730716
auto collectIteratorMaps =
731717
[](isl::ast_node node,
732718
isl::ast_build build,
733-
IteratorMapsType* iteratorMaps) -> isl::ast_node {
719+
NodeInfoMapType* nodeInfoMap) -> isl::ast_node {
734720
auto user = node.as<isl::ast_node_user>();
735721
CHECK(user);
736722
auto expr = user.get_expr();
737723
auto stmtId = expr.get_op_arg(0).get_id();
738-
// We rename loop-related dimensions manually.
739724
auto schedule = build.get_schedule();
740-
auto scheduleSpace = build.get_schedule_space();
741725
auto scheduleMap = isl::map::from_union_map(schedule);
742726

743727
auto nodeId = isl::id(
744728
node.get_ctx(),
745729
std::string(kAstNodeIdPrefix) + std::to_string(nAstNodes()++));
746-
CHECK_EQ(0, iteratorMaps->count(nodeId)) << "entry exists: " << nodeId;
747-
CHECK_EQ(
748-
scheduleMap.dim(isl::dim_type::out),
749-
scheduleSpace.dim(isl::dim_type::set));
750-
for (int i = 0; i < scheduleSpace.dim(isl::dim_type::set); ++i) {
751-
scheduleMap = scheduleMap.set_dim_id(
752-
isl::dim_type::out,
753-
i,
754-
scheduleSpace.get_dim_id(isl::dim_type::set, i));
755-
}
730+
CHECK_EQ(0, nodeInfoMap->count(nodeId)) << "entry exists: " << nodeId;
756731

757-
auto iteratorMap = isl::pw_multi_aff(scheduleMap.reverse());
758-
iteratorMaps->emplace(nodeId, iteratorMap);
732+
auto& nodeInfo = (*nodeInfoMap)[nodeId];
733+
nodeInfo.iteratorMap = isl::pw_multi_aff(scheduleMap.reverse());
734+
nodeInfo.build = build;
759735
return node.set_annotation(nodeId);
760736
};
761737

762-
return collectIteratorMaps(n, b, &iteratorMaps);
738+
return collectIteratorMaps(n, b, &nodeInfoMap);
763739
};
764740

765741
auto bands = detail::ScheduleTree::collect(
@@ -781,7 +757,7 @@ string emitCudaKernel(
781757
astBuild = astBuild.set_at_each_domain(collect);
782758
astBuild = astBuild.set_iterators(Codegen::makeLoopIterators(ctx, maxDepth));
783759
auto astNode = astBuild.node_from(schedule);
784-
AstPrinter(CodegenContext(ss, mscop, iteratorMaps)).emit(astNode);
760+
AstPrinter(CodegenContext(ss, mscop, nodeInfoMap)).emit(astNode);
785761
ss << "}" << endl;
786762

787763
return ss.str();

0 commit comments

Comments
 (0)