Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 0b3379d

Browse files
committed
basic heuristic for register promotion
Assuming mapping to threads starts from the innermost coincident schedule dimension and from the thread x, promote to registers in each subtree below the band member mapped to thread x. Split bands if necessary to ensure that this member is the last one in the band. For each such band, collect references to tensors accessed below it. Group together the references that have overlapping footprints and at least one of them is a write to ensure the most recent value is read. For each group, consider promotion to registers if the footprint contains only one element (hence promotable to a register) and if each element is accessed by at most one thread (registers are private to threads). Do not promote to registers if these references were already promoted to shared memory as this would require either copying from shared memory to registers, or demoting from shared memory first. Do not insert synchroniztaions around these copies as no two threads are accessing the same value. The compiler could load from memory to a register anyway for most arithmetic operations.
1 parent 778f274 commit 0b3379d

File tree

2 files changed

+174
-2
lines changed

2 files changed

+174
-2
lines changed

include/tc/core/polyhedral/memory_promotion_heuristic.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ using ThreadIdxxScheduleDepthState =
2626
std::vector<std::pair<isl::union_set, size_t>>;
2727

2828
class MappedScop;
29+
class Scop;
2930

3031
// In the given mapped scop "mscop",
3132
// promote to shared memory at "depth" until "sharedMemorySize" is used.
@@ -40,5 +41,10 @@ void promoteGreedilyAtDepth(
4041
std::size_t depth,
4142
std::size_t sharedMemorySize,
4243
bool unrollCopies);
44+
45+
void promoteToRegistersBelowThreads(
46+
Scop& scop,
47+
const ThreadIdxxScheduleDepthState& threadIdxxScheduleDepthState,
48+
std::size_t nRegisters);
4349
} // namespace polyhedral
4450
} // namespace tc

src/core/polyhedral/memory_promotion_heuristic.cc

Lines changed: 168 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,7 @@ isl::union_map fullSchedule(const detail::ScheduleTree* root) {
159159
}
160160
}
161161

162-
prefixMupa = isl::manage(isl_multi_union_pw_aff_intersect_domain(
163-
prefixMupa.release(), domain.copy()));
162+
prefixMupa = prefixMupa.intersect_domain(domain);
164163

165164
schedule = schedule.unite(isl::union_map::from(prefixMupa));
166165
if (!schedule.is_single_valued()) {
@@ -315,6 +314,67 @@ bool isCoalesced(
315314
return true;
316315
}
317316

317+
/*
318+
* Check if the given "group" can be promoted to registers for the given active
319+
* domain points under full "schedule" where "nThreads" consecutive dimensions
320+
* are mapped to threads (the innermost of them being mapped to thread x) and
321+
* the depth of this mapping can be obtained from threadIdxxScheduleDepthState.
322+
*
323+
* In parciular, the group's footprint must contain only one element and the
324+
* same tensor element should never be accessed by two different threads.
325+
*/
326+
bool isPromotableToRegisterBelowThreads(
327+
const ThreadIdxxScheduleDepthState& threadIdxxScheduleDepthState,
328+
const TensorReferenceGroup& group,
329+
isl::union_map schedule,
330+
size_t nThreads,
331+
isl::union_set activePoints) {
332+
auto originalAccesses = group.originalAccesses();
333+
334+
// Return early if more than one element needs to be stored in registers.
335+
// TODO: support arrays in registers if they are only accessed with constant
336+
// subscripts, e.g. if the inner loops are fully unrolled.
337+
auto sizes = group.approximationSizes();
338+
auto nElements =
339+
std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<size_t>());
340+
if (nElements != 1) {
341+
return false;
342+
}
343+
344+
// Since this function is only supposed to be called on groups seen _below_
345+
// thread mapping, all refs in the group must all have the same thread-x
346+
// depth.
347+
auto depth = 1 +
348+
computeThreadIdxxScheduleDepth(
349+
threadIdxxScheduleDepthState,
350+
originalAccesses.domain().intersect(activePoints));
351+
352+
auto scheduledAccesses =
353+
originalAccesses.gist_domain(originalAccesses.domain())
354+
.apply_domain(schedule);
355+
356+
// Scheduled accesses contain maps from schedule dimensions to tensor
357+
// subscripts. Compute the relation that between the schedule dimensions
358+
// mapped to threads and tensor subscripts by first removing dimensions
359+
// following the one mapped to thread x (last one assuming inverse mapping
360+
// order), then by equating all dimensions not mapped to threads to
361+
// parameters. Promotion to registers is only allowed if the resulting
362+
// relation is injective, i.e. the same tensor element is never accessed by
363+
// more than one thread. Note that our current check is overly conservative
364+
// because different values of schedule dimension may get mapped to the same
365+
// thread, in which case the could access the same tensor element.
366+
for (auto sa : isl::UnionAsVector<isl::union_map>(scheduledAccesses)) {
367+
sa = sa.project_out(
368+
isl::dim_type::in, depth, sa.dim(isl::dim_type::in) - depth);
369+
sa = fixOuterInputDimsAsParameters(sa, depth - nThreads);
370+
if (!sa.is_injective()) {
371+
return false;
372+
}
373+
}
374+
375+
return true;
376+
}
377+
318378
/*
319379
* Starting from the root, find bands where depth is reached. Using
320380
* DFSPreorder to make sure order is specified and consistent for tests.
@@ -503,5 +563,111 @@ void promoteGreedilyAtDepth(
503563
mapCopiesToThreads(mscop, unrollCopies);
504564
}
505565

566+
namespace {
567+
isl::val getParamValIfFixed(isl::union_set uset, int pos) {
568+
auto val = isl::val::nan(uset.get_ctx());
569+
for (auto set : isl::UnionAsVector<isl::union_set>(uset)) {
570+
auto currentVal = set.plain_get_val_if_fixed(isl::dim_type::param, pos);
571+
if (currentVal.is_nan()) {
572+
return currentVal;
573+
}
574+
if (!val.is_nan() && val != currentVal) {
575+
return isl::val::nan(uset.get_ctx());
576+
}
577+
val = currentVal;
578+
}
579+
return val;
580+
}
581+
} // namespace
582+
583+
// Assuming the mapping to threads happens in inverse order, i.e. the innermost
584+
// loop is mapped to thread x, promote below that depth.
585+
void promoteToRegistersBelowThreads(
586+
Scop& scop,
587+
const ThreadIdxxScheduleDepthState& threadIdxxScheduleDepthState,
588+
size_t nRegisters) {
589+
using namespace tc::polyhedral::detail;
590+
591+
auto root = scop.scheduleRoot();
592+
593+
auto fullSched = fullSchedule(root);
594+
for (const auto& kvp : threadIdxxScheduleDepthState) {
595+
auto depth = kvp.second + 1;
596+
auto subdomain = kvp.first;
597+
598+
// Collect all bands where a member is located at the given depth.
599+
auto bands = bandsContainingScheduleDepth(root, depth);
600+
// We may have no band members mapped to thread x in case when we
601+
// force-mapped everything to one thread.
602+
if (bands.size() == 0) {
603+
continue;
604+
}
605+
606+
// Keep only those bands for which this depth was recorded.
607+
std::function<bool(ScheduleTree*)> keepActive =
608+
[root, subdomain](const ScheduleTree* tree) {
609+
isl::union_set active = activeDomainPoints(root, tree);
610+
return !active.intersect(subdomain).is_empty();
611+
};
612+
bands = functional::Filter(keepActive, bands);
613+
614+
// Make sure the band ends at thread x depth so we can promote below it.
615+
bands = bandsSplitAfterDepth(bands, root, depth);
616+
617+
for (auto band : bands) {
618+
// Find out how many threads are actually mapped. Active domain points
619+
// will involve all mapping parameters when we take them below the
620+
// mapping. Skip mapping parameters obviously mapped to 0, because they
621+
// do not correspond to band members that should be fixed to obtain
622+
// per-thread-group access relations.
623+
auto points = activeDomainPoints(root, band);
624+
size_t nMappedThreads = 0;
625+
for (int j = 0; j < points.dim(isl::dim_type::param); ++j) {
626+
auto id = points.get_space().get_dim_id(isl::dim_type::param, j);
627+
for (size_t i = 0; i < mapping::ThreadId::kMaxDim; ++i) {
628+
if (id != mapping::ThreadId::makeId(i)) {
629+
continue;
630+
}
631+
if (getParamValIfFixed(points, j) ==
632+
isl::val::zero(points.get_ctx())) {
633+
continue;
634+
}
635+
++nMappedThreads;
636+
break;
637+
}
638+
}
639+
640+
auto groupMap = TensorReferenceGroup::accessedBySubtree(band, scop);
641+
for (const auto& tensorGroups : groupMap) {
642+
auto tensorId = tensorGroups.first;
643+
644+
// TODO: sorting of groups and counting the number of promoted elements
645+
646+
for (const auto& group : tensorGroups.second) {
647+
auto sizes = group->approximationSizes();
648+
// No point in promoting a scalar that will go to a register anyway.
649+
if (sizes.size() == 0) {
650+
continue;
651+
}
652+
if (!isPromotableToRegisterBelowThreads(
653+
threadIdxxScheduleDepthState,
654+
*group,
655+
fullSched,
656+
nMappedThreads,
657+
points)) {
658+
continue;
659+
}
660+
if (!hasReuse(*group, fullSched, depth)) {
661+
continue;
662+
}
663+
// TODO: if something is already in shared, but reuse it within one
664+
// thread only, there is no point in keeping it in shared _if_ it
665+
// gets promoted into a register.
666+
}
667+
}
668+
}
669+
}
670+
}
671+
506672
} // namespace polyhedral
507673
} // namespace tc

0 commit comments

Comments
 (0)