Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 1371777

Browse files
authored
Merge pull request #302 from facebookresearch/pr/reuse_within
memory promotion: check for any reuse within outer schedule
2 parents 43aa4ee + 5ff0f63 commit 1371777

File tree

3 files changed

+33
-19
lines changed

3 files changed

+33
-19
lines changed

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -210,25 +210,19 @@ isl::map fixOuterInputDimsAsParameters(isl::map map, int nDims) {
210210
}
211211

212212
/*
213-
* Check if a reference group features reuse at "depth" after applying
214-
* "schedule". In particular, consider first depth schedule dimensions as fixed
215-
* by equating them to parameters and check if the resulting relation is not
216-
* injective.
213+
* Check if a reference group features reuse within the "outer" schedule.
214+
* In particular, check that for some given point in the outer schedule and
215+
* some given group element, there is more than one statement instance
216+
* accessing the element within the point in the outer schedule.
217+
* In other words, check that the mapping from statement instances
218+
* to pairs of outer schedule points and group elements is not injective.
217219
*/
218-
bool hasReuse(
220+
bool hasReuseWithin(
219221
const TensorReferenceGroup& group,
220-
isl::union_map schedule,
221-
size_t depth) {
222-
auto scheduledAccessesUMap = group.originalAccesses().apply_domain(schedule);
223-
auto scheduledAccessMaps =
224-
isl::UnionAsVector<isl::union_map>(scheduledAccessesUMap);
225-
return std::any_of(
226-
scheduledAccessMaps.begin(),
227-
scheduledAccessMaps.end(),
228-
[schedule, depth](isl::map access) {
229-
access = fixOuterInputDimsAsParameters(access, static_cast<int>(depth));
230-
return !access.is_injective();
231-
});
222+
isl::multi_union_pw_aff outer) {
223+
auto map = isl::union_map::from(outer);
224+
map = map.range_product(group.originalAccesses());
225+
return !map.is_injective();
232226
}
233227

234228
/*
@@ -463,6 +457,8 @@ void promoteToSharedGreedy(
463457
for (auto bandNode : bands) {
464458
auto groupMap = TensorReferenceGroup::accessedBySubtree(bandNode, scop);
465459
auto partialSched = partialSchedule(root, bandNode);
460+
// Pure affine schedule without (mapping) filters.
461+
auto partialSchedMupa = partialScheduleMupa(root, bandNode);
466462
auto activePoints = activeDomainPoints(root, bandNode);
467463

468464
// Prepare groups for sorting, to have specified order necessary for
@@ -522,7 +518,7 @@ void promoteToSharedGreedy(
522518
}
523519
// Do not promote if the group features no reuse and is accessed in a
524520
// coalesced way.
525-
if (!hasReuse(*group, fullSched, depth) &&
521+
if (!hasReuseWithin(*group, partialSchedMupa) &&
526522
isCoalesced(
527523
threadIdxXScheduleDepthState,
528524
*group,
@@ -606,6 +602,8 @@ void promoteToRegistersBelowThreads(
606602
// per-thread-group access relations.
607603
auto points = activeDomainPoints(root, band);
608604
auto partialSched = partialSchedule(root, band);
605+
// Pure affine schedule without (mapping) filters.
606+
auto partialSchedMupa = partialScheduleMupa(root, band);
609607

610608
size_t nMappedThreads = 0;
611609
for (int j = 0; j < points.dim(isl::dim_type::param); ++j) {
@@ -643,7 +641,7 @@ void promoteToRegistersBelowThreads(
643641
points)) {
644642
continue;
645643
}
646-
if (!hasReuse(*group, fullSched, depth)) {
644+
if (!hasReuseWithin(*group, partialSchedMupa)) {
647645
continue;
648646
}
649647
// TODO: if something is already in shared, but reuse it within one

tc/core/polyhedral/schedule_transforms.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,14 @@ isl::multi_union_pw_aff prefixScheduleMupa(
454454
return prefix;
455455
}
456456

457+
isl::multi_union_pw_aff partialScheduleMupa(
458+
const detail::ScheduleTree* root,
459+
const detail::ScheduleTree* tree) {
460+
auto band = tree->elemAs<ScheduleTreeElemBand>();
461+
CHECK(band);
462+
return prefixScheduleMupa(root, tree).flat_range_product(band->mupa_);
463+
}
464+
457465
ScheduleTree* insertBandAbove(
458466
ScheduleTree* root,
459467
ScheduleTree* tree,

tc/core/polyhedral/schedule_transforms.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,14 @@ isl::multi_union_pw_aff prefixScheduleMupa(
276276
const detail::ScheduleTree* root,
277277
const detail::ScheduleTree* tree);
278278

279+
// Return the concatenation of all outer band node partial schedules,
280+
// including that of the node itself.
281+
// Note that this function does not take into account
282+
// any intermediate filter nodes.
283+
isl::multi_union_pw_aff partialScheduleMupa(
284+
const detail::ScheduleTree* root,
285+
const detail::ScheduleTree* tree);
286+
279287
// Get the set of domain points active at the given node. A domain
280288
// point is active if it was not filtered away on the path from the
281289
// root to the node. The root must be a domain element, otherwise no

0 commit comments

Comments
 (0)