Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 835035f

Browse files
author
Sven Verdoolaege
committed
[RFC] use templated isl types
Templated isl types require the user to specify the domain and range universes of isl objects, allowing the compiler to check whether it makes sense to combine pairs of objects. This RFC only converts isPromotableToRegistersBelow and some related functions to illustrate the effect. The isPromotableToRegistersBelow was already applying operations correctly, so the code itself did not require any changes. However, one variable was reused to store different types of intermediate result and this one had to be split up into several variables because they now have different types.
1 parent 8d5fe7e commit 835035f

23 files changed

+237
-156
lines changed

tc/core/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ add_library(
3232
polyhedral/schedule_print.cc
3333
polyhedral/schedule_utils.cc
3434
polyhedral/scop.cc
35-
polyhedral/separation.cc
3635
polyhedral/unroll.cc
3736
polyhedral/utils.cc
3837
)

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,10 @@ void fixThreadsBelow(MappedScop& mscop, ScheduleTree* tree, size_t begin) {
226226
* Anything that depends on an update statement is ordered after
227227
* the update statements. Anything else is ordered before.
228228
*/
229-
bool separatedOut(Scop& scop, ScheduleTree* tree, isl::union_set updates) {
229+
bool separatedOut(
230+
Scop& scop,
231+
ScheduleTree* tree,
232+
isl::UnionSet<Statement> updates) {
230233
auto domain = activeDomainPoints(scop.scheduleRoot(), tree);
231234
auto other = domain.subtract(updates);
232235
if (other.is_empty()) {
@@ -285,7 +288,7 @@ bool MappedScop::detectReductions(ScheduleTree* tree) {
285288
// a single reduction for now.
286289
// Support for multiple reductions would require a check
287290
// that these reductions do not interfere with each other.
288-
auto domain = isl::UnionSet<Statement>(band->mupa_.domain());
291+
auto domain = band->mupa_.domain();
289292
auto updates = reductionUpdates(domain, scop());
290293
if (updates.n_set() != 1) {
291294
return false;
@@ -502,54 +505,54 @@ constexpr auto kWarp = "warp";
502505
* (of size "warpSize") to a warp identifier,
503506
* based on the thread sizes s_x, s_y up to s_z in "block".
504507
*/
505-
isl::multi_aff constructThreadToWarp(
508+
isl::MultiAff<Thread, Warp> constructThreadToWarp(
506509
isl::ctx ctx,
507510
const unsigned warpSize,
508511
const Block& block) {
509-
auto space = isl::space(ctx, 0);
512+
auto space = isl::Space<>(isl::space(ctx, 0));
510513
auto id = isl::id(ctx, kBlock);
511-
auto blockSpace = space.add_named_tuple_id_ui(id, block.view.size());
512-
auto warpSpace = space.add_named_tuple_id_ui(isl::id(ctx, kWarp), 1);
513-
auto aff = isl::aff::zero_on_domain(blockSpace);
514+
auto blockSpace = space.add_named_tuple_id_ui<Thread>(id, block.view.size());
515+
auto warpSpace = space.add_named_tuple_id_ui<Warp>(isl::id(ctx, kWarp), 1);
516+
auto aff = isl::AffOn<Thread>::zero_on_domain(blockSpace);
514517

515518
auto nThread = block.view.size();
516-
auto identity = isl::multi_aff::identity(blockSpace.map_from_set());
519+
auto identity = isl::MultiAff<Thread, Thread>::identity(blockSpace.map_from_set());
517520
for (int i = nThread - 1; i >= 0; --i) {
518521
aff = aff.scale(isl::val(ctx, block.view[i]));
519522
aff = aff.add(identity.get_aff(i));
520523
}
521524

522525
aff = aff.scale_down(isl::val(ctx, warpSize)).floor();
523-
auto mapSpace = blockSpace.product(warpSpace).unwrap();
524-
return isl::multi_aff(mapSpace, isl::aff_list(aff));
526+
auto mapSpace = blockSpace.product(warpSpace).unwrap<Thread, Warp>();
527+
return isl::MultiAff<Thread, Warp>(mapSpace, aff.asAffList());
525528
}
526529
} // namespace
527530

528-
isl::multi_union_pw_aff MappedScop::threadMappingSchedule(
531+
isl::MultiUnionPwAff<Statement, Thread> MappedScop::threadMappingSchedule(
529532
const ScheduleTree* tree) const {
530533
std::vector<mapping::MappingId> ids;
531534
for (size_t i = 0; i < numThreads.view.size(); ++i) {
532535
ids.emplace_back(mapping::ThreadId::makeId(i));
533536
}
534537
auto tupleId = isl::id(tree->ctx_, kBlock);
535-
return extractDomainToIds(scop_->scheduleRoot(), tree, ids, tupleId);
538+
return extractDomainToIds<Thread>(scop_->scheduleRoot(), tree, ids, tupleId);
536539
}
537540

538-
isl::multi_union_pw_aff MappedScop::blockMappingSchedule(
541+
isl::MultiUnionPwAff<Statement, Block> MappedScop::blockMappingSchedule(
539542
const ScheduleTree* tree) const {
540543
std::vector<mapping::MappingId> ids;
541544
for (size_t i = 0; i < numBlocks.view.size(); ++i) {
542545
ids.emplace_back(mapping::BlockId::makeId(i));
543546
}
544547
auto tupleId = isl::id(tree->ctx_, kGrid);
545-
return extractDomainToIds(scop_->scheduleRoot(), tree, ids, tupleId);
548+
return extractDomainToIds<Block>(scop_->scheduleRoot(), tree, ids, tupleId);
546549
}
547550

548551
Scop::SyncLevel MappedScop::findBestSync(
549552
ScheduleTree* st1,
550553
ScheduleTree* st2,
551-
isl::multi_union_pw_aff domainToThread,
552-
isl::multi_union_pw_aff domainToWarp) {
554+
isl::MultiUnionPwAff<Statement, Thread> domainToThread,
555+
isl::MultiUnionPwAff<Statement, Warp> domainToWarp) {
553556
// Active points in the two schedule trees
554557
auto stRoot = scop_->scheduleRoot();
555558
auto activePoints1 = activeDomainPointsBelow(stRoot, st1);

tc/core/polyhedral/cuda/mapped_scop.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,8 @@ class MappedScop {
188188
Scop::SyncLevel findBestSync(
189189
detail::ScheduleTree* st1,
190190
detail::ScheduleTree* st2,
191-
isl::multi_union_pw_aff domainToThread,
192-
isl::multi_union_pw_aff domainToWarp);
191+
isl::MultiUnionPwAff<Statement, Thread> domainToThread,
192+
isl::MultiUnionPwAff<Statement, Warp> domainToWarp);
193193

194194
public:
195195
// Find best configuration of synchronizations in a sequence, minimizing
@@ -210,14 +210,14 @@ class MappedScop {
210210
// to the thread identifiers, where all branches in "tree"
211211
// are assumed to have been mapped to thread identifiers.
212212
// The result lives in a space of the form block[x, ...].
213-
isl::multi_union_pw_aff threadMappingSchedule(
213+
isl::MultiUnionPwAff<Statement, Thread> threadMappingSchedule(
214214
const detail::ScheduleTree* tree) const;
215215

216216
// Extract a mapping from the domain elements active at "tree"
217217
// to the block identifiers, where all branches in "tree"
218218
// are assumed to have been mapped to block identifiers.
219219
// The result lives in a space of the form grid[x, ...].
220-
isl::multi_union_pw_aff blockMappingSchedule(
220+
isl::MultiUnionPwAff<Statement, Block> blockMappingSchedule(
221221
const detail::ScheduleTree* tree) const;
222222

223223
private:

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ bool promotionImprovesCoalescing(
225225
auto depth = marker->scheduleDepth(root);
226226
auto activePoints = activeDomainPoints(root, mapping);
227227
auto localAccesses = originalAccesses.intersect_domain(activePoints);
228-
auto schedule = prefixSchedule(root, marker);
228+
auto schedule = prefixSchedule<Prefix>(root, marker);
229229
auto scheduledAccesses = localAccesses.apply_domain(schedule);
230230
for (auto access : isl::UnionAsVector<isl::union_map>(scheduledAccesses)) {
231231
auto scheduleSpace = access.get_space().domain();
@@ -262,6 +262,8 @@ isl::union_set collectMappingsTo(const Scop& scop) {
262262
return mapping;
263263
}
264264

265+
struct Unrolled;
266+
265267
/*
266268
* Check that only unrolled loops may appear in access subscripts.
267269
* Because the scoping point can be above a branching tree, descend into each
@@ -292,11 +294,12 @@ isl::union_set collectMappingsTo(const Scop& scop) {
292294
* different references may have different values, but all of them remain
293295
* independent of non-unrolled loop iterators.
294296
*/
297+
template <typename Outer>
295298
bool accessSubscriptsAreUnrolledLoops(
296299
const TensorReferenceGroup& group,
297300
const detail::ScheduleTree* root,
298301
const detail::ScheduleTree* scope,
299-
isl::multi_union_pw_aff outerSchedule) {
302+
isl::MultiUnionPwAff<Statement, Outer> outerSchedule) {
300303
using namespace detail;
301304

302305
auto nodes = ScheduleTree::collect(scope);
@@ -315,7 +318,7 @@ bool accessSubscriptsAreUnrolledLoops(
315318

316319
auto unrolledDims = isl::union_pw_aff_list(leaf->ctx_, 1);
317320
for (auto node : ancestors) {
318-
auto band = node->as<detail::ScheduleTreeBand>();
321+
auto band = node->template as<detail::ScheduleTreeBand>();
319322
if (!band) {
320323
continue;
321324
}
@@ -331,8 +334,9 @@ bool accessSubscriptsAreUnrolledLoops(
331334
}
332335

333336
auto space =
334-
subdomain.get_space().add_unnamed_tuple_ui(unrolledDims.size());
335-
auto unrolledDimsMupa = isl::multi_union_pw_aff(space, unrolledDims);
337+
subdomain.get_space().template add_unnamed_tuple_ui<Unrolled>(unrolledDims.size());
338+
auto unrolledDimsMupa = isl::MultiUnionPwAff<Statement, Unrolled>(
339+
space, isl::UnionPwAffListOn<Statement>(unrolledDims));
336340

337341
// It is possible that no loops are unrolled, in which case
338342
// unrolledDimsMupa is zero-dimensional and needs an explicit domain
@@ -341,10 +345,11 @@ bool accessSubscriptsAreUnrolledLoops(
341345
unrolledDimsMupa.intersect_domain(group.originalAccesses().domain());
342346

343347
auto accesses = group.originalAccesses();
344-
auto schedule = outerSchedule.flat_range_product(unrolledDimsMupa);
345-
accesses = accesses.apply_domain(isl::union_map::from(schedule));
348+
auto schedule = outerSchedule.range_product(unrolledDimsMupa);
349+
auto scheduleMap = schedule.toUnionMap();
350+
auto scheduledAccesses = accesses.apply_domain(scheduleMap);
346351

347-
if (!accesses.is_single_valued()) {
352+
if (!scheduledAccesses.is_single_valued()) {
348353
return false;
349354
}
350355
}
@@ -364,23 +369,25 @@ bool accessSubscriptsAreUnrolledLoops(
364369
* thread associated to a given pair of tensor element and outer schedule
365370
* iteration.
366371
*/
372+
template <typename Outer>
367373
bool isPromotableToRegistersBelow(
368374
const TensorReferenceGroup& group,
369375
const detail::ScheduleTree* root,
370376
const detail::ScheduleTree* scope,
371-
isl::multi_union_pw_aff outer,
372-
isl::multi_union_pw_aff thread) {
377+
isl::MultiUnionPwAff<Statement, Outer> outer,
378+
isl::MultiUnionPwAff<Statement, Thread> thread) {
373379
if (!accessSubscriptsAreUnrolledLoops(
374-
group, root, scope, outer.flat_range_product(thread))) {
380+
group, root, scope, outer.range_product(thread))) {
375381
return false;
376382
}
377383

378384
auto originalAccesses = group.originalAccesses();
379-
auto map = isl::union_map::from(outer);
380-
map = map.range_product(originalAccesses);
381-
map = map.apply_domain(isl::union_map::from(thread));
385+
auto outerMap = isl::UnionMap<Statement, Outer>::from(outer);
386+
auto pair = outerMap.range_product(originalAccesses);
387+
auto threadMap = isl::UnionMap<Statement, Thread>::from(thread);
388+
auto threadToPair = pair.apply_domain(threadMap);
382389

383-
return map.is_injective();
390+
return threadToPair.is_injective();
384391
}
385392

386393
/*
@@ -653,15 +660,15 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
653660
auto blockSchedule = mscop.blockMappingSchedule(mscop.schedule());
654661

655662
// Pure affine schedule without (mapping) filters.
656-
auto partialSchedMupa = partialScheduleMupa(root, scope);
663+
auto partialSchedMupa = partialScheduleMupa<Scope>(root, scope);
657664
// Schedule with block mapping filter.
658665
auto partialSched =
659666
isl::union_map::from(partialSchedMupa).intersect_domain(blockMapping);
660667
// The following promotion validity and profitability checks need to be
661668
// performed with respect to the block mapping, so append the block schedule.
662669
// If the partial schedule contains it already, it will just end up with
663670
// identical dimensions without affecting the result of the checks.
664-
partialSchedMupa = partialSchedMupa.flat_range_product(blockSchedule);
671+
auto partialSchedBlockMupa = partialSchedMupa.range_product(blockSchedule);
665672

666673
for (auto& tensorGroups : groupMap) {
667674
auto tensorId = tensorGroups.first;
@@ -675,11 +682,11 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
675682
continue;
676683
}
677684
if (!isPromotableToRegistersBelow(
678-
*group, root, scope, partialSchedMupa, threadSchedule)) {
685+
*group, root, scope, partialSchedBlockMupa, threadSchedule)) {
679686
continue;
680687
}
681688
// Check reuse within threads.
682-
auto schedule = partialSchedMupa.flat_range_product(threadSchedule);
689+
auto schedule = partialSchedBlockMupa.flat_range_product(threadSchedule);
683690
if (!hasReuseWithin(*group, schedule)) {
684691
continue;
685692
}

tc/core/polyhedral/memory_promotion.cc

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -408,25 +408,25 @@ namespace {
408408
// each dimension of the tensor is contrained by the min_aff on the left and
409409
// by the min_aff + extent_aff on the right. Intersect this set with the
410410
// context of the scop.
411-
isl::set tensorElementsSet(const Scop& scop, isl::id tensorId) {
411+
isl::Set<Tensor> tensorElementsSet(const Scop& scop, isl::id tensorId) {
412412
auto halideParameter = scop.findArgument(tensorId).parameter();
413-
auto space = isl::Space<>(scop.domain().get_space());
413+
auto space = scop.domain().get_space();
414414
auto nDim = halideParameter.dimensions();
415415
auto tensorTuple = constructTensorTuple(space, tensorId, nDim);
416416
auto tensorSpace = tensorTuple.get_space();
417417

418-
auto tensorElements = isl::set::universe(tensorSpace);
419-
auto identity = isl::multi_aff::identity(tensorSpace.map_from_set());
418+
auto tensorElements = isl::Set<Tensor>::universe(tensorSpace);
419+
auto identity = isl::MultiAff<Tensor,Tensor>::identity(tensorSpace.map_from_set());
420420
for (int i = 0; i < nDim; ++i) {
421-
isl::aff minAff = halide2isl::makeIslAffFromExpr(
421+
auto minAff = halide2isl::makeIslAffFromExpr(
422422
space, halideParameter.min_constraint(i));
423-
isl::aff extentAff = halide2isl::makeIslAffFromExpr(
423+
auto extentAff = halide2isl::makeIslAffFromExpr(
424424
space, halideParameter.extent_constraint(i));
425-
minAff = minAff.unbind_params_insert_domain(tensorTuple);
426-
extentAff = extentAff.unbind_params_insert_domain(tensorTuple);
425+
auto minAff2 = minAff.unbind_params_insert_domain(tensorTuple);
426+
auto extentAff2 = extentAff.unbind_params_insert_domain(tensorTuple);
427427
auto aff = identity.get_aff(i);
428-
tensorElements = tensorElements & (minAff <= isl::aff_set(aff)) &
429-
(isl::aff_set(aff) < (minAff + extentAff));
428+
tensorElements = tensorElements & (minAff2.le_set(aff)) &
429+
(aff.lt_set(minAff2 + extentAff2));
430430
}
431431

432432
tensorElements = tensorElements.intersect_params(scop.context());
@@ -493,8 +493,8 @@ ScheduleTree* insertCopiesUnder(
493493
auto writeSchedule = isl::multi_union_pw_aff(identityCopySchedule.pullback(
494494
isl::multi_aff::wrapped_range_map(writeSpace)));
495495

496-
auto readBandNode = ScheduleTree::makeBand(readSchedule);
497-
auto writeBandNode = ScheduleTree::makeBand(writeSchedule);
496+
auto readBandNode = ScheduleTree::makeBand(isl::MultiUnionPwAff<Statement, Band>(readSchedule));
497+
auto writeBandNode = ScheduleTree::makeBand(isl::MultiUnionPwAff<Statement, Band>(writeSchedule));
498498

499499
if (unrollAllCopies) {
500500
unrollAllMembers(readBandNode->as<detail::ScheduleTreeBand>());
@@ -542,14 +542,18 @@ ScheduleTree* insertCopiesUnder(
542542

543543
if (reads) {
544544
insertExtensionBefore(
545-
root, tree, tree->child({0}), readExtension, std::move(readFilterNode));
545+
root,
546+
tree,
547+
tree->child({0}),
548+
isl::UnionMap<Prefix, Statement>(readExtension),
549+
std::move(readFilterNode));
546550
}
547551
if (writes) {
548552
insertExtensionAfter(
549553
root,
550554
tree,
551555
tree->child({0}),
552-
writeExtension,
556+
isl::UnionMap<Prefix, Statement>(writeExtension),
553557
std::move(writeFilterNode));
554558
}
555559

tc/core/polyhedral/memory_promotion.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include <iostream>
1919

20+
#include "tc/core/polyhedral/domain_types.h"
2021
#include "tc/core/polyhedral/schedule_tree.h"
2122
#include "tc/core/polyhedral/scop.h"
2223
#include "tc/external/isl.h"
@@ -136,8 +137,9 @@ class TensorReferenceGroup {
136137
// range spaces.
137138
isl::union_map originalWrites() const;
138139
isl::union_map originalReads() const;
139-
isl::union_map originalAccesses() const {
140-
return originalWrites().unite(originalReads());
140+
isl::UnionMap<Statement, Tensor> originalAccesses() const {
141+
auto accesses = originalWrites().unite(originalReads());
142+
return isl::UnionMap<Statement, Tensor>(accesses);
141143
}
142144

143145
// Rectangular overapproximation of the set of tensor elements accessed below

tc/core/polyhedral/schedule_isl_conversion.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#include "tc/core/check.h"
2525
#include "tc/core/flags.h"
26+
#include "tc/core/polyhedral/domain_types.h"
2627
#include "tc/core/polyhedral/schedule_transforms.h"
2728
#include "tc/external/isl.h"
2829

@@ -81,7 +82,7 @@ isl::schedule_node insertBranch(
8182
*/
8283
std::vector<size_t> findCorePositions(
8384
const ScheduleTree* st,
84-
isl::union_set domain) {
85+
isl::UnionSet<Statement> domain) {
8586
std::vector<size_t> positions;
8687
TC_CHECK(st->as<ScheduleTreeSequence>());
8788
for (size_t i = 0; i < st->numChildren(); ++i) {
@@ -125,7 +126,7 @@ isl::schedule_node insertExtension(
125126
isl::schedule_node node,
126127
const ScheduleTree* st) {
127128
auto depth0 = node.get_tree_depth();
128-
auto domain = node.get_universe_domain();
129+
auto domain = isl::UnionSet<Statement>(node.get_universe_domain());
129130
auto child = st->child({0});
130131
auto corePos = findCorePositions(child, domain);
131132
TC_CHECK(!corePos.empty());
@@ -242,16 +243,17 @@ std::unique_ptr<ScheduleTreeBand> fromIslScheduleNodeBand(
242243
for (size_t i = 0; i < n; ++i) {
243244
coincident[i] = b.member_get_coincident(i);
244245
}
246+
auto mupa = isl::MultiUnionPwAff<Statement, Band>(b.get_partial_schedule());
245247
return ScheduleTreeBand::make(
246-
b.get_partial_schedule(), b.get_permutable(), coincident, unroll);
248+
mupa, b.get_permutable(), coincident, unroll);
247249
}
248250

249251
std::unique_ptr<ScheduleTree> elemFromIslScheduleNode(isl::schedule_node node) {
250252
auto ctx = node.get_ctx();
251253
if (auto band = node.as<isl::schedule_node_band>()) {
252254
return fromIslScheduleNodeBand(band);
253255
} else if (auto context = node.as<isl::schedule_node_context>()) {
254-
auto c = context.get_context();
256+
auto c = isl::Set<Prefix>(context.get_context());
255257
return ScheduleTreeContext::make(c);
256258
} else if (auto domain = node.as<isl::schedule_node_domain>()) {
257259
auto c = domain.get_domain();

0 commit comments

Comments
 (0)