Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 22cf55d

Browse files
author
Sven Verdoolaege
committed
[RFC] use templated isl types
Templated isl types require the user to specify the domain and range universes of isl objects, allowing the compiler to check whether it makes sense to combine pairs of objects. This RFC only converts isPromotableToRegistersBelow and some related functions to illustrate the effect. The isPromotableToRegistersBelow was already applying operations correctly, so the code itself did not require any changes. However, one variable was reused to store different types of intermediate result and this one had to be split up into several variables because they now have different types.
1 parent 9bc96e8 commit 22cf55d

20 files changed

+430
-155
lines changed

tc/core/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ add_library(
3030
polyhedral/schedule_tree_elem.cc
3131
polyhedral/schedule_print.cc
3232
polyhedral/scop.cc
33-
polyhedral/separation.cc
3433
polyhedral/unroll.cc
3534
)
3635
target_include_directories(tc_core PUBLIC ${LLVM_INCLUDE_DIRS})

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ void fixThreadsBelow(
199199
bool separatedOut(
200200
Scop& scop,
201201
detail::ScheduleTree* tree,
202-
isl::union_set updates) {
202+
isl::UnionSet<Domain> updates) {
203203
auto domain = activeDomainPoints(scop.scheduleRoot(), tree);
204204
auto other = domain.subtract(updates);
205205
if (other.is_empty()) {
@@ -253,7 +253,7 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
253253
// a single reduction for now.
254254
// Support for multiple reductions would require a check
255255
// that these reductions do not interfere with each other.
256-
auto domain = band->mupa_.domain();
256+
auto domain = isl::UnionSet<Domain>(band->mupa_.domain());
257257
auto updates = reductionUpdates(domain, scop());
258258
if (updates.n_set() != 1) {
259259
return false;
@@ -287,8 +287,8 @@ bool MappedScop::needReductionSeparation(const detail::ScheduleTree* st) {
287287
return !reductionBandUpdates_.at(st).separated;
288288
}
289289

290-
isl::multi_union_pw_aff MappedScop::reductionMapSchedule(
291-
const detail::ScheduleTree* st) {
290+
isl::MultiUnionPwAff<Domain, ReductionSchedule>
291+
MappedScop::reductionMapSchedule(const detail::ScheduleTree* st) {
292292
TC_CHECK(reductionBandUpdates_.count(st) == 1);
293293
auto reductionBand = st->elemAs<detail::ScheduleTreeElemBand>();
294294
TC_CHECK(reductionBand);
@@ -305,7 +305,7 @@ isl::multi_union_pw_aff MappedScop::reductionMapSchedule(
305305
reductionSchedule = reductionSchedule.drop_dims(
306306
isl::dim_type::set, 0, reductionDim - nMappedThreads + 1);
307307

308-
return reductionSchedule;
308+
return isl::MultiUnionPwAff<Domain, ReductionSchedule>(reductionSchedule);
309309
}
310310

311311
detail::ScheduleTree* MappedScop::separateReduction(detail::ScheduleTree* st) {
@@ -316,7 +316,7 @@ detail::ScheduleTree* MappedScop::separateReduction(detail::ScheduleTree* st) {
316316

317317
auto root = scop_->scheduleRoot();
318318
auto domain = activeDomainPoints(root, st);
319-
auto prefixSchedule = prefixScheduleMupa(root, st);
319+
auto prefixSchedule = prefixScheduleMupa<Prefix>(root, st);
320320
auto reductionSchedule = reductionMapSchedule(st);
321321
auto space = reductionSchedule.get_space();
322322
auto size = isl::multi_val::zero(space);
@@ -479,7 +479,7 @@ constexpr auto kWarp = "warp";
479479
* (of size "warpSize") to a warp identifier,
480480
* based on the thread sizes s_x, s_y up to s_z in "block".
481481
*/
482-
isl::multi_aff constructThreadToWarp(
482+
isl::MultiAff<Thread, Warp> constructThreadToWarp(
483483
isl::ctx ctx,
484484
const unsigned warpSize,
485485
const Block& block) {
@@ -498,35 +498,36 @@ isl::multi_aff constructThreadToWarp(
498498

499499
aff = aff.scale_down(isl::val(ctx, warpSize)).floor();
500500
auto mapSpace = blockSpace.product(warpSpace).unwrap();
501-
return isl::multi_aff(mapSpace, isl::aff_list(aff));
501+
return isl::MultiAff<Thread, Warp>(
502+
isl::multi_aff(mapSpace, isl::aff_list(aff)));
502503
}
503504
} // namespace
504505

505-
isl::multi_union_pw_aff MappedScop::threadMappingSchedule(
506+
isl::MultiUnionPwAff<Domain, Thread> MappedScop::threadMappingSchedule(
506507
const detail::ScheduleTree* tree) const {
507508
std::vector<mapping::MappingId> ids;
508509
for (size_t i = 0; i < numThreads.view.size(); ++i) {
509510
ids.emplace_back(mapping::ThreadId::makeId(i));
510511
}
511512
auto tupleId = isl::id(tree->ctx_, kBlock);
512-
return extractDomainToIds(scop_->scheduleRoot(), tree, ids, tupleId);
513+
return extractDomainToIds<Thread>(scop_->scheduleRoot(), tree, ids, tupleId);
513514
}
514515

515-
isl::multi_union_pw_aff MappedScop::blockMappingSchedule(
516+
isl::MultiUnionPwAff<Domain, Block> MappedScop::blockMappingSchedule(
516517
const detail::ScheduleTree* tree) const {
517518
std::vector<mapping::MappingId> ids;
518519
for (size_t i = 0; i < numBlocks.view.size(); ++i) {
519520
ids.emplace_back(mapping::BlockId::makeId(i));
520521
}
521522
auto tupleId = isl::id(tree->ctx_, kGrid);
522-
return extractDomainToIds(scop_->scheduleRoot(), tree, ids, tupleId);
523+
return extractDomainToIds<Block>(scop_->scheduleRoot(), tree, ids, tupleId);
523524
}
524525

525526
Scop::SyncLevel MappedScop::findBestSync(
526527
detail::ScheduleTree* st1,
527528
detail::ScheduleTree* st2,
528-
isl::multi_union_pw_aff domainToThread,
529-
isl::multi_union_pw_aff domainToWarp) {
529+
isl::MultiUnionPwAff<Domain, Thread> domainToThread,
530+
isl::MultiUnionPwAff<Domain, Warp> domainToWarp) {
530531
// Active points in the two schedule trees
531532
auto stRoot = scop_->scheduleRoot();
532533
auto activePoints1 = activeDomainPointsBelow(stRoot, st1);

tc/core/polyhedral/cuda/mapped_scop.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "tc/core/cuda/cuda_mapping_options.h"
2525
#include "tc/core/polyhedral/cuda/mapping_types.h"
2626
#include "tc/core/polyhedral/cuda/memory_promotion_heuristic.h"
27+
#include "tc/core/polyhedral/domain_types.h"
2728
#include "tc/core/polyhedral/scop.h"
2829
#include "tc/core/tensor.h"
2930
#include "tc/external/isl.h"
@@ -160,7 +161,8 @@ class MappedScop {
160161
// Return the schedule that will be used by mapInnermostBandsToThreads
161162
// for mapping to thread identifiers, with the last function
162163
// corresponding to thread identifier x.
163-
isl::multi_union_pw_aff reductionMapSchedule(const detail::ScheduleTree* st);
164+
isl::MultiUnionPwAff<Domain, ReductionSchedule> reductionMapSchedule(
165+
const detail::ScheduleTree* st);
164166
// Separate out reductions that can be mapped to an entire block.
165167
// The remaining parts, if any, are no longer considered for replacement
166168
// by a library call.
@@ -175,8 +177,8 @@ class MappedScop {
175177
Scop::SyncLevel findBestSync(
176178
detail::ScheduleTree* st1,
177179
detail::ScheduleTree* st2,
178-
isl::multi_union_pw_aff domainToThread,
179-
isl::multi_union_pw_aff domainToWarp);
180+
isl::MultiUnionPwAff<Domain, Thread> domainToThread,
181+
isl::MultiUnionPwAff<Domain, Warp> domainToWarp);
180182

181183
public:
182184
// Find best configuration of synchronizations in a sequence, minimizing
@@ -197,14 +199,14 @@ class MappedScop {
197199
// to the thread identifiers, where all branches in "tree"
198200
// are assumed to have been mapped to thread identifiers.
199201
// The result lives in a space of the form block[x, ...].
200-
isl::multi_union_pw_aff threadMappingSchedule(
202+
isl::MultiUnionPwAff<Domain, Thread> threadMappingSchedule(
201203
const detail::ScheduleTree* tree) const;
202204

203205
// Extract a mapping from the domain elements active at "tree"
204206
// to the block identifiers, where all branches in "tree"
205207
// are assumed to have been mapped to block identifiers.
206208
// The result lives in a space of the form grid[x, ...].
207-
isl::multi_union_pw_aff blockMappingSchedule(
209+
isl::MultiUnionPwAff<Domain, Block> blockMappingSchedule(
208210
const detail::ScheduleTree* tree) const;
209211

210212
private:

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ std::vector<T> collectBranchMarkers(T root, T node) {
131131
return findThreadSpecificMarkers(node);
132132
}
133133

134+
struct FullSchedule;
135+
134136
/*
135137
* Transform schedule bands into a union_map.
136138
* Takes all partial schedules at leaves as MUPAs (without accounting for
@@ -139,7 +141,8 @@ std::vector<T> collectBranchMarkers(T root, T node) {
139141
* current leaves and transforms them into union maps.
140142
* Mapping filters are ignored.
141143
*/
142-
isl::union_map fullSchedule(const detail::ScheduleTree* root) {
144+
isl::UnionMap<Domain, FullSchedule> fullSchedule(
145+
const detail::ScheduleTree* root) {
143146
using namespace tc::polyhedral::detail;
144147

145148
if (!root->elemAs<ScheduleTreeElemDomain>()) {
@@ -182,7 +185,7 @@ isl::union_map fullSchedule(const detail::ScheduleTree* root) {
182185
throw promotion::PromotionLogicError(ss.str());
183186
}
184187
}
185-
return schedule;
188+
return isl::UnionMap<Domain, FullSchedule>(schedule);
186189
}
187190

188191
/*
@@ -263,7 +266,7 @@ bool promotionImprovesCoalescing(
263266
const detail::ScheduleTree* root,
264267
const detail::ScheduleTree* node,
265268
const TensorReferenceGroup& group,
266-
isl::union_map schedule) {
269+
isl::UnionMap<Domain, FullSchedule> schedule) {
267270
auto originalAccesses = group.originalAccesses();
268271

269272
auto markers = collectBranchMarkers(root, node);
@@ -313,6 +316,8 @@ isl::union_set collectMappingsTo(const Scop& scop) {
313316
return mapping;
314317
}
315318

319+
struct Unrolled;
320+
316321
/*
317322
* Check that only unrolled loops may appear in access subscripts.
318323
* Because the scoping point can be above a branching tree, descend into each
@@ -343,11 +348,12 @@ isl::union_set collectMappingsTo(const Scop& scop) {
343348
* different references may have different values, but all of them remain
344349
* independent of non-unrolled loop iterators.
345350
*/
351+
template <typename Outer>
346352
bool accessSubscriptsAreUnrolledLoops(
347353
const TensorReferenceGroup& group,
348354
const detail::ScheduleTree* root,
349355
const detail::ScheduleTree* scope,
350-
isl::multi_union_pw_aff outerSchedule) {
356+
isl::MultiUnionPwAff<Domain, Outer> outerSchedule) {
351357
using namespace detail;
352358

353359
auto nodes = ScheduleTree::collect(scope);
@@ -366,7 +372,7 @@ bool accessSubscriptsAreUnrolledLoops(
366372

367373
auto unrolledDims = isl::union_pw_aff_list(leaf->ctx_, 1);
368374
for (auto node : ancestors) {
369-
auto band = node->elemAs<detail::ScheduleTreeElemBand>();
375+
auto band = node->template elemAs<detail::ScheduleTreeElemBand>();
370376
if (!band) {
371377
continue;
372378
}
@@ -383,7 +389,8 @@ bool accessSubscriptsAreUnrolledLoops(
383389

384390
auto space = isl::space(leaf->ctx_, 0, unrolledDims.n())
385391
.align_params(subdomain.get_space());
386-
auto unrolledDimsMupa = isl::multi_union_pw_aff(space, unrolledDims);
392+
auto unrolledDimsMupa =
393+
isl::MultiUnionPwAff<Domain, Unrolled>(space, unrolledDims);
387394

388395
// It is possible that no loops are unrolled, in which case
389396
// unrolledDimsMupa is zero-dimensional and needs an explicit domain
@@ -392,10 +399,11 @@ bool accessSubscriptsAreUnrolledLoops(
392399
unrolledDimsMupa.intersect_domain(group.originalAccesses().domain());
393400

394401
auto accesses = group.originalAccesses();
395-
auto schedule = outerSchedule.flat_range_product(unrolledDimsMupa);
396-
accesses = accesses.apply_domain(isl::union_map::from(schedule));
402+
auto schedule = outerSchedule.range_product(unrolledDimsMupa);
403+
auto scheduleMap = schedule.asUnionMap();
404+
auto scheduledAccesses = accesses.apply_domain(scheduleMap);
397405

398-
if (!accesses.is_single_valued()) {
406+
if (!scheduledAccesses.is_single_valued()) {
399407
return false;
400408
}
401409
}
@@ -415,23 +423,25 @@ bool accessSubscriptsAreUnrolledLoops(
415423
* thread associated to a given pair of tensor element and outer schedule
416424
* iteration.
417425
*/
426+
template <typename Outer>
418427
bool isPromotableToRegistersBelow(
419428
const TensorReferenceGroup& group,
420429
const detail::ScheduleTree* root,
421430
const detail::ScheduleTree* scope,
422-
isl::multi_union_pw_aff outer,
423-
isl::multi_union_pw_aff thread) {
431+
isl::MultiUnionPwAff<Domain, Outer> outer,
432+
isl::MultiUnionPwAff<Domain, Thread> thread) {
424433
if (!accessSubscriptsAreUnrolledLoops(
425-
group, root, scope, outer.flat_range_product(thread))) {
434+
group, root, scope, outer.range_product(thread))) {
426435
return false;
427436
}
428437

429438
auto originalAccesses = group.originalAccesses();
430-
auto map = isl::union_map::from(outer);
431-
map = map.range_product(originalAccesses);
432-
map = map.apply_domain(isl::union_map::from(thread));
439+
auto outerMap = isl::UnionMap<Domain, Outer>::from(outer);
440+
auto pair = outerMap.range_product(originalAccesses);
441+
auto threadMap = isl::UnionMap<Domain, Thread>::from(thread);
442+
auto threadToPair = pair.apply_domain(threadMap);
433443

434-
return map.is_injective();
444+
return threadToPair.is_injective();
435445
}
436446

437447
/*
@@ -654,15 +664,15 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
654664
auto blockSchedule = mscop.blockMappingSchedule(mscop.schedule());
655665

656666
// Pure affine schedule without (mapping) filters.
657-
auto partialSchedMupa = partialScheduleMupa(root, scope);
667+
auto partialSchedMupa = partialScheduleMupa<Scope>(root, scope);
658668
// Schedule with block mapping filter.
659669
auto partialSched =
660670
isl::union_map::from(partialSchedMupa).intersect_domain(blockMapping);
661671
// The following promotion validity and profitability checks need to be
662672
// performed with respect to the block mapping, so append the block schedule.
663673
// If the partial schedule contains it already, it will just end up with
664674
// identical dimensions without affecting the result of the checks.
665-
partialSchedMupa = partialSchedMupa.flat_range_product(blockSchedule);
675+
auto partialSchedBlockMupa = partialSchedMupa.range_product(blockSchedule);
666676

667677
for (auto& tensorGroups : groupMap) {
668678
auto tensorId = tensorGroups.first;
@@ -676,11 +686,11 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
676686
continue;
677687
}
678688
if (!isPromotableToRegistersBelow(
679-
*group, root, scope, partialSchedMupa, threadSchedule)) {
689+
*group, root, scope, partialSchedBlockMupa, threadSchedule)) {
680690
continue;
681691
}
682692
// Check reuse within threads.
683-
auto schedule = partialSchedMupa.flat_range_product(threadSchedule);
693+
auto schedule = partialSchedBlockMupa.flat_range_product(threadSchedule);
684694
if (!hasReuseWithin(*group, schedule)) {
685695
continue;
686696
}

tc/core/polyhedral/domain_types.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
namespace tc {
2+
namespace polyhedral {
3+
4+
struct Domain;
5+
struct Prefix;
6+
struct ReductionSchedule;
7+
struct Scope;
8+
struct Tensor;
9+
struct Thread;
10+
struct Warp;
11+
12+
} // namespace polyhedral
13+
} // namespace tc

tc/core/polyhedral/memory_promotion.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -547,14 +547,18 @@ ScheduleTree* insertCopiesUnder(
547547

548548
if (reads) {
549549
insertExtensionBefore(
550-
root, tree, tree->child({0}), readExtension, std::move(readFilterNode));
550+
root,
551+
tree,
552+
tree->child({0}),
553+
isl::UnionMap<Prefix, Domain>(readExtension),
554+
std::move(readFilterNode));
551555
}
552556
if (writes) {
553557
insertExtensionAfter(
554558
root,
555559
tree,
556560
tree->child({0}),
557-
writeExtension,
561+
isl::UnionMap<Prefix, Domain>(writeExtension),
558562
std::move(writeFilterNode));
559563
}
560564

tc/core/polyhedral/memory_promotion.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include <iostream>
1919

20+
#include "tc/core/polyhedral/domain_types.h"
2021
#include "tc/core/polyhedral/schedule_tree.h"
2122
#include "tc/core/polyhedral/scop.h"
2223
#include "tc/external/isl.h"
@@ -137,8 +138,9 @@ class TensorReferenceGroup {
137138
// range spaces.
138139
isl::union_map originalWrites() const;
139140
isl::union_map originalReads() const;
140-
isl::union_map originalAccesses() const {
141-
return originalWrites().unite(originalReads());
141+
isl::UnionMap<Domain, Tensor> originalAccesses() const {
142+
auto accesses = originalWrites().unite(originalReads());
143+
return isl::UnionMap<Domain, Tensor>(accesses);
142144
}
143145

144146
// Rectangular overapproximation of the set of tensor elements accessed below

tc/core/polyhedral/reduction_matcher.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,15 @@ bool isAlmostIdentityReduction(isl::pw_aff pa, const Scop& scop) {
117117

118118
} // namespace
119119

120-
isl::union_set reductionUpdates(isl::union_set domain, const Scop& scop) {
121-
auto update = isl::union_set::empty(domain.get_space());
120+
isl::UnionSet<Domain> reductionUpdates(
121+
isl::UnionSet<Domain> domain,
122+
const Scop& scop) {
123+
auto update = isl::UnionSet<Domain>::empty(domain.get_space());
122124
domain.foreach_set([&update, &scop](isl::set set) {
123125
auto setId = set.get_tuple_id();
124126
std::vector<size_t> reductionDims;
125127
if (isReductionUpdateId(setId, scop, reductionDims)) {
126-
update = update.unite(set);
128+
update = update.unite(isl::UnionSet<Domain>(set));
127129
}
128130
});
129131
return update;

0 commit comments

Comments
 (0)