Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit dfe48ca

Browse files
committed
use domain sets instead of statement ids in memory promotion
Original implementation of memory promotion (and related parts of code generation) was assuming schedule tree branches only for entire statements. It was therefore using statement ids as a lightweight key for finding relevant active promotions and schedule parts. With the introduction of full/partial tile separation, this assumption no longer holds. Make memory promotion use domain sets instead of statement ids, and find relevant promotions and schedule parts by intersecting these sets with sets of statement instances active in a particular location in the schedule tree.
1 parent bd2d0cd commit dfe48ca

File tree

5 files changed

+26
-40
lines changed

5 files changed

+26
-40
lines changed

include/tc/core/polyhedral/codegen_cuda.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,18 @@ struct CodegenStatementContext : CodegenContext {
9696
isl::id statementId() const {
9797
return this->iteratorMaps.at(astNodeId).get_tuple_id(isl::dim_type::out);
9898
}
99+
isl::set domain() const {
100+
return isl::map::from(this->iteratorMaps.at(astNodeId)).range();
101+
}
99102
std::vector<Scop::PromotionInfo> activePromotions() const {
100-
auto stmtId = statementId();
101-
const auto& promotions = this->scop().activePromotions();
102-
if (promotions.count(stmtId) == 0) {
103-
return {};
103+
std::vector<Scop::PromotionInfo> result;
104+
auto dom = isl::union_set(this->domain());
105+
for (const auto& kvp : this->scop().activePromotions()) {
106+
if (!kvp.first.intersect(dom).is_empty()) {
107+
result.emplace_back(kvp.second);
108+
}
104109
}
105-
return promotions.at(stmtId);
110+
return result;
106111
}
107112

108113
isl::id astNodeId;

include/tc/core/polyhedral/scop.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -324,9 +324,8 @@ struct Scop {
324324
return promotedDecls_;
325325
}
326326

327-
const std::
328-
unordered_map<isl::id, std::vector<PromotionInfo>, isl::IslIdIslHash>&
329-
activePromotions() const {
327+
const std::vector<std::pair<isl::union_set, PromotionInfo>>&
328+
activePromotions() const {
330329
return activePromotions_;
331330
}
332331

@@ -377,7 +376,6 @@ struct Scop {
377376
isl::id tensorId,
378377
std::unique_ptr<TensorReferenceGroup>&& gr,
379378
detail::ScheduleTree* tree,
380-
const std::unordered_set<isl::id, isl::IslIdIslHash>& activeStmts,
381379
isl::union_map schedule,
382380
bool forceLastExtentOdd = false);
383381

@@ -468,9 +466,10 @@ struct Scop {
468466
std::unordered_map<isl::id, size_t, isl::IslIdIslHash> groupCounts_;
469467
// groupId -> (tensorId, groupSizes)
470468
std::unordered_map<isl::id, PromotedDecl, isl::IslIdIslHash> promotedDecls_;
471-
// stmtId -> (group, partial schedule, groupId)
472-
std::unordered_map<isl::id, std::vector<PromotionInfo>, isl::IslIdIslHash>
473-
activePromotions_;
469+
// (domain, group, partial schedule, groupId)
470+
// Note that domain is a non-unique key, i.e. multiple groups can be listed
471+
// for the same domain, or for partially intersecting domains.
472+
std::vector<std::pair<isl::union_set, PromotionInfo>> activePromotions_;
474473
};
475474

476475
std::ostream& operator<<(std::ostream& os, const Scop&);

src/core/polyhedral/codegen_cuda.cc

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -503,17 +503,6 @@ isl::space findDomainSpaceById(const CodegenStatementContext& context) {
503503
return isl::space();
504504
}
505505

506-
isl::map findScheduleByStmtId(isl::union_map schedule, isl::id stmtId) {
507-
for (auto s : isl::UnionAsVector<isl::union_map>(schedule)) {
508-
if (s.get_tuple_id(isl::dim_type::in) == stmtId) {
509-
return s;
510-
}
511-
}
512-
CHECK(false) << "could not find schedule for " << stmtId << " in "
513-
<< schedule;
514-
return isl::map();
515-
}
516-
517506
isl::multi_aff makeMultiAffAccess(
518507
isl::id tensorId,
519508
const std::vector<Halide::Expr>& subscripts,
@@ -633,9 +622,9 @@ void emitMappedTensorAccess(
633622
auto promotion = promotionInfo.group->promotion(); // MA :: [S -> O] -> P
634623
promotion = promotion.set_tuple_id(isl::dim_type::out, promotionInfo.groupId);
635624
auto iteratorMap = context.iteratorMap(); // PMA :: A -> D
636-
auto schedule = findScheduleByStmtId(
637-
promotionInfo.outerSchedule,
638-
context.statementId()); // map :: D -> S
625+
auto schedule =
626+
isl::map::from_union_map(promotionInfo.outerSchedule.intersect_domain(
627+
context.domain())); // map :: D -> S
639628

640629
CHECK(schedule.is_single_valued())
641630
<< "expected single-valued schedule, got " << schedule;

src/core/polyhedral/memory_promotion_heuristic.cc

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -300,9 +300,7 @@ bool isCoalesced(
300300
auto partialSchedule = isl::map::from_union_map(partialScheduleUMap);
301301
auto scheduleToNextX = makeNextElementMap(
302302
partialSchedule.get_space().range(), threadIdxxDepth);
303-
auto scheduledAccess = isl::map(access)
304-
.gist_domain(access.domain())
305-
.apply_domain(partialSchedule);
303+
auto scheduledAccess = isl::map(access).apply_domain(partialSchedule);
306304
auto accessedByAdjacentX = scheduleToNextX.apply_domain(scheduledAccess)
307305
.apply_range(scheduledAccess);
308306

@@ -349,9 +347,7 @@ bool isPromotableToRegisterBelowThreads(
349347
threadIdxxScheduleDepthState,
350348
originalAccesses.domain().intersect(activePoints));
351349

352-
auto scheduledAccesses =
353-
originalAccesses.gist_domain(originalAccesses.domain())
354-
.apply_domain(schedule);
350+
auto scheduledAccesses = originalAccesses.apply_domain(schedule);
355351

356352
// Scheduled accesses contain maps from schedule dimensions to tensor
357353
// subscripts. Compute the relation that between the schedule dimensions
@@ -460,7 +456,6 @@ void promoteToSharedGreedy(
460456
size_t remainingMemory = maxMemory;
461457
for (auto bandNode : bands) {
462458
auto groupMap = TensorReferenceGroup::accessedBySubtree(bandNode, scop);
463-
auto activeStmts = activeStatements(root, bandNode);
464459
auto partialSched = partialSchedule(root, bandNode);
465460
auto activePoints = activeDomainPoints(root, bandNode);
466461

@@ -535,7 +530,6 @@ void promoteToSharedGreedy(
535530
tensorId,
536531
std::move(group),
537532
bandNode,
538-
activeStmts,
539533
partialSched,
540534
true);
541535
remainingMemory -= memoryRequirement;

src/core/polyhedral/scop.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,10 @@ void Scop::promoteGroup(
184184
isl::id tensorId,
185185
std::unique_ptr<TensorReferenceGroup>&& gr,
186186
ScheduleTree* tree,
187-
const std::unordered_set<isl::id, isl::IslIdIslHash>& activeStmts,
188187
isl::union_map schedule,
189188
bool forceLastExtentOdd) {
189+
auto activePoints = activeDomainPoints(scheduleRoot(), tree);
190+
190191
for (const auto& id : activeStmts) {
191192
for (const auto& prom : activePromotions_[id]) {
192193
if (promotedDecls_.count(prom.groupId) != 0 &&
@@ -211,10 +212,10 @@ void Scop::promoteGroup(
211212
}
212213
promotedDecls_[groupId] = PromotedDecl{tensorId, sizes, kind};
213214

215+
// FIXME: we can now store a unique pointer...
214216
auto group = std::shared_ptr<TensorReferenceGroup>(std::move(gr));
215-
for (const auto& id : activeStmts) {
216-
activePromotions_[id].push_back(PromotionInfo{group, schedule, groupId});
217-
}
217+
activePromotions_.emplace_back(
218+
std::make_pair(activePoints, PromotionInfo{group, schedule, groupId}));
218219
}
219220

220221
void Scop::insertSyncsAroundCopies(ScheduleTree* tree) {
@@ -259,7 +260,6 @@ void Scop::promoteEverythingAt(std::vector<size_t> pos) {
259260
auto tree = scheduleRoot()->child(pos);
260261

261262
checkFiltersDisjointStatements(scheduleRoot());
262-
auto activeStmts = activeStatements(root, tree);
263263
auto schedule = partialSchedule(root, tree);
264264

265265
auto groupMap = TensorReferenceGroup::accessedBySubtree(tree, *this);
@@ -270,7 +270,6 @@ void Scop::promoteEverythingAt(std::vector<size_t> pos) {
270270
p.first,
271271
std::move(gr),
272272
tree,
273-
activeStmts,
274273
schedule);
275274
}
276275
}

0 commit comments

Comments
 (0)