Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 3b8f190

Browse files
authored
Merge pull request #368 from facebookresearch/pr/registers
isPromotableToRegisterBelowThreads: directly use relevant parts of schedule
2 parents 78399a0 + b90e351 commit 3b8f190

File tree

3 files changed

+45
-72
lines changed

3 files changed

+45
-72
lines changed

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 22 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -200,37 +200,6 @@ isl::union_map fullSchedule(const detail::ScheduleTree* root) {
200200
return schedule;
201201
}
202202

203-
/*
204-
* Insert map constraints that equate first "nDims" input dimensions to newly
205-
* introduced parameters.
206-
*/
207-
isl::map fixOuterInputDimsAsParameters(isl::map map, unsigned nDims) {
208-
if (nDims < 0 || nDims > map.dim(isl::dim_type::in)) {
209-
std::stringstream ss;
210-
ss << nDims << " is out of [0, " << map.dim(isl::dim_type::in)
211-
<< ") range";
212-
throw promotion::OutOfRangeException(ss.str());
213-
}
214-
215-
auto fixedMap = map;
216-
auto localSpace = isl::local_space(map.get_space().domain());
217-
auto nParams = map.dim(isl::dim_type::param);
218-
localSpace = localSpace.add_dims(isl::dim_type::param, nDims);
219-
for (unsigned i = 0; i < nDims; ++i) {
220-
localSpace = localSpace.set_dim_name(
221-
isl::dim_type::param,
222-
nParams + i,
223-
"__tcFixerParam" + std::to_string(i));
224-
}
225-
for (unsigned i = 0; i < nDims; ++i) {
226-
auto left = isl::aff(localSpace, isl::dim_type::param, nParams + i);
227-
auto right = isl::aff(localSpace, isl::dim_type::set, i);
228-
auto dom = isl::aff_set(left) == right;
229-
fixedMap = fixedMap.intersect_domain(dom);
230-
}
231-
return fixedMap;
232-
}
233-
234203
/*
235204
* Check if a reference group features reuse within the "outer" schedule.
236205
* In particular, check that for some given point in the outer schedule and
@@ -339,19 +308,25 @@ bool promotionImprovesCoalescing(
339308
}
340309

341310
/*
342-
* Check if the given "group" can be promoted to registers for the given active
343-
* domain points under full "schedule" where "nThreads" consecutive dimensions
344-
* at "depth"
345-
* are mapped to threads (the innermost of them being mapped to thread x).
311+
* Check if the given "group" can be promoted to registers for the given
312+
* mapping to thread identifiers and within the given outer schedule.
346313
*
347314
* In particular, the group's footprint must contain only one element and the
348-
* same tensor element should never be accessed by two different threads.
315+
* same tensor element should never be accessed by two different threads
316+
* within the same iteration of the outer schedule.
317+
* The second test is performed by checking that there is only a single
318+
* thread associated to a given pair of tensor element and outer schedule
319+
* iteration.
320+
* Note that the test for a single thread is performed by looking
321+
* at the range of "thread". This range may be larger than the number
322+
* of threads, such that multiple instances may get mapped to the same thread.
323+
* Requiring different such instances is therefore slightly more conservative
324+
* than strictly needed.
349325
*/
350326
bool isPromotableToRegisterBelowThreads(
351327
const TensorReferenceGroup& group,
352-
isl::union_map schedule,
353-
size_t depth,
354-
size_t nThreads) {
328+
isl::multi_union_pw_aff outer,
329+
isl::multi_union_pw_aff thread) {
355330
auto originalAccesses = group.originalAccesses();
356331

357332
// Return early if more than one element needs to be stored in registers.
@@ -364,28 +339,11 @@ bool isPromotableToRegisterBelowThreads(
364339
return false;
365340
}
366341

367-
auto scheduledAccesses = originalAccesses.apply_domain(schedule);
368-
369-
// Scheduled accesses contain maps from schedule dimensions to tensor
370-
// subscripts. Compute the relation between the schedule dimensions
371-
// mapped to threads and tensor subscripts by first removing dimensions
372-
// following the one mapped to thread x (last one assuming inverse mapping
373-
// order), then by equating all dimensions not mapped to threads to
374-
// parameters. Promotion to registers is only allowed if the resulting
375-
// relation is injective, i.e. the same tensor element is never accessed by
376-
// more than one thread. Note that our current check is overly conservative
377-
// because different values of schedule dimension may get mapped to the same
378-
// thread, in which case they could access the same tensor element.
379-
for (auto sa : isl::UnionAsVector<isl::union_map>(scheduledAccesses)) {
380-
sa = sa.project_out(
381-
isl::dim_type::in, depth, sa.dim(isl::dim_type::in) - depth);
382-
sa = fixOuterInputDimsAsParameters(sa, depth - nThreads);
383-
if (!sa.is_injective()) {
384-
return false;
385-
}
386-
}
342+
auto map = isl::union_map::from(outer);
343+
map = map.range_product(group.originalAccesses());
344+
map = map.apply_domain(isl::union_map::from(thread));
387345

388-
return true;
346+
return map.is_injective();
389347
}
390348

391349
/*
@@ -573,22 +531,16 @@ void promoteToRegistersBelowThreads(Scop& scop, size_t nRegisters) {
573531

574532
auto root = scop.scheduleRoot();
575533

576-
auto fullSched = fullSchedule(root);
577534
{
578535
auto markers = findThreadSpecificMarkers(root);
579536

580537
for (auto marker : markers) {
581538
auto partialSched = prefixSchedule(root, marker);
582539
// Pure affine schedule without (mapping) filters.
583-
auto partialSchedMupa = prefixScheduleMupa(root, marker);
584-
585-
auto depth = marker->scheduleDepth(root);
586-
587-
// Thread mapping filters are inserted immediately above the members
588-
// mapped to threads. The number of intermediate band members
589-
// is therefore equal to the number of mapped thread identifiers.
590540
auto mapping = findThreadMappingAncestor(root, marker);
591-
size_t nMappedThreads = marker->scheduleDepth(mapping);
541+
auto prefixSchedMupa = prefixScheduleMupa(root, mapping);
542+
auto mapSchedMupa = infixScheduleMupa(root, mapping, marker);
543+
auto partialSchedMupa = prefixSchedMupa.flat_range_product(mapSchedMupa);
592544

593545
auto groupMap = TensorReferenceGroup::accessedBySubtree(marker, scop);
594546
for (auto& tensorGroups : groupMap) {
@@ -603,7 +555,7 @@ void promoteToRegistersBelowThreads(Scop& scop, size_t nRegisters) {
603555
continue;
604556
}
605557
if (!isPromotableToRegisterBelowThreads(
606-
*group, fullSched, depth, nMappedThreads)) {
558+
*group, prefixSchedMupa, mapSchedMupa)) {
607559
continue;
608560
}
609561
if (!hasReuseWithin(*group, partialSchedMupa)) {

tc/core/polyhedral/schedule_transforms.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,16 +450,19 @@ ostream& operator<<(ostream& os, const vector<Args...>& v) {
450450
}
451451
} // namespace
452452

453-
isl::multi_union_pw_aff prefixScheduleMupa(
453+
isl::multi_union_pw_aff infixScheduleMupa(
454454
const ScheduleTree* root,
455+
const ScheduleTree* relativeRoot,
455456
const ScheduleTree* tree) {
456457
auto domainElem = root->elemAs<ScheduleTreeElemDomain>();
457458
CHECK(domainElem);
458459
auto domain = domainElem->domain_.universe();
459460
auto zero = isl::multi_val::zero(domain.get_space().set_from_params());
460461
auto prefix = isl::multi_union_pw_aff(domain, zero);
462+
// Work around bug in isl.
463+
prefix = prefix.intersect_domain(domain);
461464
prefix = foldl(
462-
filterType<ScheduleTreeElemBand>(tree->ancestors(root)),
465+
filterType<ScheduleTreeElemBand>(tree->ancestors(relativeRoot)),
463466
[](const ScheduleTree* st, isl::multi_union_pw_aff prefix) {
464467
auto mupa = st->elemAs<ScheduleTreeElemBand>()->mupa_;
465468
return prefix.flat_range_product(mupa);
@@ -468,6 +471,12 @@ isl::multi_union_pw_aff prefixScheduleMupa(
468471
return prefix;
469472
}
470473

474+
isl::multi_union_pw_aff prefixScheduleMupa(
475+
const ScheduleTree* root,
476+
const ScheduleTree* tree) {
477+
return infixScheduleMupa(root, root, tree);
478+
}
479+
471480
isl::multi_union_pw_aff partialScheduleMupa(
472481
const detail::ScheduleTree* root,
473482
const detail::ScheduleTree* tree) {

tc/core/polyhedral/schedule_transforms.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,18 @@ isl::union_map prefixSchedule(
264264
const detail::ScheduleTree* root,
265265
const detail::ScheduleTree* node);
266266

267+
// Return the concatenation of all band node partial schedules
268+
// from "relativeRoot" (inclusive) to "tree" (exclusive)
269+
// within a tree rooted at "root".
270+
// If there are no intermediate band nodes, then return a zero-dimensional
271+
// function on the universe domain of the schedule tree.
272+
// Note that this function does not take into account
273+
// any intermediate filter nodes.
274+
isl::multi_union_pw_aff infixScheduleMupa(
275+
const detail::ScheduleTree* root,
276+
const detail::ScheduleTree* relativeRoot,
277+
const detail::ScheduleTree* tree);
278+
267279
// Return the concatenation of all outer band node partial schedules.
268280
// If there are no outer band nodes, then return a zero-dimensional
269281
// function on the universe domain of the schedule tree.

0 commit comments

Comments
 (0)