Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit ab7215c

Browse files
author
Sven Verdoolaege
committed
promoteToRegistersBelowThreads: extract depth from marker node
The thread mapping filters are now inserted right above the band members that are mapped to thread identifiers, while the marker node appears underneath. The difference in schedule depth is therefore equal to the number of thread identifiers mapped to a band member. Once the mapping filter has been changed to explicitly keep track of the thread identifiers mapped to band members, this code can be further simplified. The outermost thread mapping filter is returned because findThreadMappingAncestor will be reused in isCoalesced, where the result will also be used to derive the active domain points outside the mapping. In this case, it is important to have a pointer to the outer mapping filter in order to avoid the mapping filters getting mixed in.
1 parent 2263130 commit ab7215c

File tree

2 files changed

+21
-36
lines changed

2 files changed

+21
-36
lines changed

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,21 @@ size_t computeThreadIdxXScheduleDepth(
298298
return *depths.begin();
299299
}
300300

301+
/*
302+
* Return the outermost thread mapping filter among the ancestors of "node",
303+
* assuming that there is at least one.
304+
*/
305+
const detail::ScheduleTree* findThreadMappingAncestor(
306+
const detail::ScheduleTree* root,
307+
const detail::ScheduleTree* node) {
308+
auto ancestors = node->ancestors(root);
309+
ancestors = functional::Filter(isThreadMapping, ancestors);
310+
if (ancestors.size() < 1) {
311+
throw promotion::PromotionLogicError("missing MappingFilter");
312+
}
313+
return ancestors[0];
314+
}
315+
301316
/*
302317
* Check if a reference group is accessed in a coalesced way.
303318
*
@@ -595,32 +610,17 @@ void promoteToRegistersBelowThreads(Scop& scop, size_t nRegisters) {
595610
auto markers = findThreadSpecificMarkers(root);
596611

597612
for (auto marker : markers) {
598-
// Find out how many threads are actually mapped. Active domain points
599-
// will involve all mapping parameters when we take them below the
600-
// mapping. Skip mapping parameters obviously mapped to 0, because they
601-
// do not correspond to band members that should be fixed to obtain
602-
// per-thread-group access relations.
603-
auto points = activeDomainPoints(root, marker);
604613
auto partialSched = prefixSchedule(root, marker);
605614
// Pure affine schedule without (mapping) filters.
606615
auto partialSchedMupa = prefixScheduleMupa(root, marker);
607616

608617
auto depth = marker->scheduleDepth(root);
609-
size_t nMappedThreads = 0;
610-
for (unsigned j = 0; j < points.dim(isl::dim_type::param); ++j) {
611-
auto id = points.get_space().get_dim_id(isl::dim_type::param, j);
612-
for (size_t i = 0; i < mapping::ThreadId::kMaxDim; ++i) {
613-
if (id != mapping::ThreadId::makeId(i)) {
614-
continue;
615-
}
616-
if (isl::getParamValIfFixed(points, j) ==
617-
isl::val::zero(points.get_ctx())) {
618-
continue;
619-
}
620-
++nMappedThreads;
621-
break;
622-
}
623-
}
618+
619+
// Thread mapping filters are inserted immediately above the members
620+
// mapped to threads. The number of intermediate band members
621+
// is therefore equal to the number of mapped thread identifiers.
622+
auto mapping = findThreadMappingAncestor(root, marker);
623+
size_t nMappedThreads = marker->scheduleDepth(mapping);
624624

625625
auto groupMap = TensorReferenceGroup::accessedBySubtree(marker, scop);
626626
for (auto& tensorGroups : groupMap) {

tc/external/detail/islpp.h

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -395,21 +395,6 @@ auto end(L& list) -> ListIter<decltype(list.get(0)), L> {
395395
using detail::begin;
396396
using detail::end;
397397

398-
template <typename T>
399-
isl::val getParamValIfFixed(T t, int pos) {
400-
auto val = isl::val::nan(t.get_ctx());
401-
for (auto set : isl::UnionAsVector<T>(t)) {
402-
auto currentVal = set.plain_get_val_if_fixed(isl::dim_type::param, pos);
403-
if (currentVal.is_nan()) {
404-
return currentVal;
405-
}
406-
if (!val.is_nan() && val != currentVal) {
407-
return isl::val::nan(t.get_ctx());
408-
}
409-
val = currentVal;
410-
}
411-
return val;
412-
}
413398
} // namespace isl
414399

415400
namespace isl {

0 commit comments

Comments
 (0)