Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 2a546c8

Browse files
author
Sven Verdoolaege
committed
[RFC] use templated isl types
Templated isl types require the user to specify the domain and range universes of isl objects, allowing the compiler to check whether it makes sense to combine pairs of objects. This RFC only converts isPromotableToRegistersBelow and some related functions to illustrate the effect. The isPromotableToRegistersBelow was already applying operations correctly, so the code itself did not require any changes. However, one variable was reused to store different types of intermediate result and this one had to be split up into several variables because they now have different types.
1 parent 221fa20 commit 2a546c8

File tree

9 files changed

+140
-35
lines changed

9 files changed

+140
-35
lines changed

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -502,24 +502,26 @@ isl::multi_aff constructThreadToWarp(
502502
}
503503
} // namespace
504504

505-
isl::multi_union_pw_aff MappedScop::threadMappingSchedule(
505+
isl::MultiUnionPwAff<Domain, Thread> MappedScop::threadMappingSchedule(
506506
const detail::ScheduleTree* tree) const {
507507
std::vector<mapping::MappingId> ids;
508508
for (size_t i = 0; i < numThreads.view.size(); ++i) {
509509
ids.emplace_back(mapping::ThreadId::makeId(i));
510510
}
511511
auto tupleId = isl::id(tree->ctx_, kBlock);
512-
return extractDomainToIds(scop_->scheduleRoot(), tree, ids, tupleId);
512+
auto schedule = extractDomainToIds(scop_->scheduleRoot(), tree, ids, tupleId);
513+
return isl::MultiUnionPwAff<Domain, Thread>(schedule);
513514
}
514515

515-
isl::multi_union_pw_aff MappedScop::blockMappingSchedule(
516+
isl::MultiUnionPwAff<Domain, Block> MappedScop::blockMappingSchedule(
516517
const detail::ScheduleTree* tree) const {
517518
std::vector<mapping::MappingId> ids;
518519
for (size_t i = 0; i < numBlocks.view.size(); ++i) {
519520
ids.emplace_back(mapping::BlockId::makeId(i));
520521
}
521522
auto tupleId = isl::id(tree->ctx_, kGrid);
522-
return extractDomainToIds(scop_->scheduleRoot(), tree, ids, tupleId);
523+
auto schedule = extractDomainToIds(scop_->scheduleRoot(), tree, ids, tupleId);
524+
return isl::MultiUnionPwAff<Domain, Block>(schedule);
523525
}
524526

525527
Scop::SyncLevel MappedScop::findBestSync(

tc/core/polyhedral/cuda/mapped_scop.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "tc/core/cuda/cuda_mapping_options.h"
2525
#include "tc/core/polyhedral/cuda/mapping_types.h"
2626
#include "tc/core/polyhedral/cuda/memory_promotion_heuristic.h"
27+
#include "tc/core/polyhedral/domain_types.h"
2728
#include "tc/core/polyhedral/scop.h"
2829
#include "tc/core/tensor.h"
2930
#include "tc/external/isl.h"
@@ -197,14 +198,14 @@ class MappedScop {
197198
// to the thread identifiers, where all branches in "tree"
198199
// are assumed to have been mapped to thread identifiers.
199200
// The result lives in a space of the form block[x, ...].
200-
isl::multi_union_pw_aff threadMappingSchedule(
201+
isl::MultiUnionPwAff<Domain, Thread> threadMappingSchedule(
201202
const detail::ScheduleTree* tree) const;
202203

203204
// Extract a mapping from the domain elements active at "tree"
204205
// to the block identifiers, where all branches in "tree"
205206
// are assumed to have been mapped to block identifiers.
206207
// The result lives in a space of the form grid[x, ...].
207-
isl::multi_union_pw_aff blockMappingSchedule(
208+
isl::MultiUnionPwAff<Domain, Block> blockMappingSchedule(
208209
const detail::ScheduleTree* tree) const;
209210

210211
private:

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ std::vector<T> collectBranchMarkers(T root, T node) {
131131
return findThreadSpecificMarkers(node);
132132
}
133133

134+
struct FullSchedule;
135+
134136
/*
135137
* Transform schedule bands into a union_map.
136138
* Takes all partial schedules at leaves as MUPAs (without accounting for
@@ -139,7 +141,8 @@ std::vector<T> collectBranchMarkers(T root, T node) {
139141
* current leaves and transforms them into union maps.
140142
* Mapping filters are ignored.
141143
*/
142-
isl::union_map fullSchedule(const detail::ScheduleTree* root) {
144+
isl::UnionMap<Domain, FullSchedule> fullSchedule(
145+
const detail::ScheduleTree* root) {
143146
using namespace tc::polyhedral::detail;
144147

145148
if (!root->elemAs<ScheduleTreeElemDomain>()) {
@@ -182,7 +185,7 @@ isl::union_map fullSchedule(const detail::ScheduleTree* root) {
182185
throw promotion::PromotionLogicError(ss.str());
183186
}
184187
}
185-
return schedule;
188+
return isl::UnionMap<Domain, FullSchedule>(schedule);
186189
}
187190

188191
/*
@@ -263,7 +266,7 @@ bool promotionImprovesCoalescing(
263266
const detail::ScheduleTree* root,
264267
const detail::ScheduleTree* node,
265268
const TensorReferenceGroup& group,
266-
isl::union_map schedule) {
269+
isl::UnionMap<Domain, FullSchedule> schedule) {
267270
auto originalAccesses = group.originalAccesses();
268271

269272
auto markers = collectBranchMarkers(root, node);
@@ -313,6 +316,8 @@ isl::union_set collectMappingsTo(const Scop& scop) {
313316
return mapping;
314317
}
315318

319+
struct Unrolled;
320+
316321
/*
317322
* Check that only unrolled loops may appear in access subscripts.
318323
* Because the scoping point can be above a branching tree, descend into each
@@ -343,11 +348,12 @@ isl::union_set collectMappingsTo(const Scop& scop) {
343348
* different references may have different values, but all of them remain
344349
* independent of non-unrolled loop iterators.
345350
*/
351+
template <typename Outer>
346352
bool accessSubscriptsAreUnrolledLoops(
347353
const TensorReferenceGroup& group,
348354
const detail::ScheduleTree* root,
349355
const detail::ScheduleTree* scope,
350-
isl::multi_union_pw_aff outerSchedule) {
356+
isl::MultiUnionPwAff<Domain, Outer> outerSchedule) {
351357
using namespace detail;
352358

353359
auto nodes = ScheduleTree::collect(scope);
@@ -365,7 +371,7 @@ bool accessSubscriptsAreUnrolledLoops(
365371
auto subdomain = activeDomainPointsBelow(root, leaf);
366372

367373
auto unrolledDims = isl::union_pw_aff_list(leaf->ctx_, 1);
368-
for (auto node : ancestors) {
374+
for (const detail::ScheduleTree* node : ancestors) {
369375
auto band = node->elemAs<detail::ScheduleTreeElemBand>();
370376
if (!band) {
371377
continue;
@@ -383,7 +389,8 @@ bool accessSubscriptsAreUnrolledLoops(
383389

384390
auto space = isl::space(leaf->ctx_, 0, unrolledDims.n())
385391
.align_params(subdomain.get_space());
386-
auto unrolledDimsMupa = isl::multi_union_pw_aff(space, unrolledDims);
392+
auto unrolledDimsMupa =
393+
isl::MultiUnionPwAff<Domain, Unrolled>(space, unrolledDims);
387394

388395
// It is possible that no loops are unrolled, in which case
389396
// unrolledDimsMupa is zero-dimensional and needs an explicit domain
@@ -392,10 +399,11 @@ bool accessSubscriptsAreUnrolledLoops(
392399
unrolledDimsMupa.intersect_domain(group.originalAccesses().domain());
393400

394401
auto accesses = group.originalAccesses();
395-
auto schedule = outerSchedule.flat_range_product(unrolledDimsMupa);
396-
accesses = accesses.apply_domain(isl::union_map::from(schedule));
402+
auto schedule = outerSchedule.range_product(unrolledDimsMupa);
403+
auto scheduleMap = schedule.asUnionMap();
404+
auto scheduledAccesses = accesses.apply_domain(scheduleMap);
397405

398-
if (!accesses.is_single_valued()) {
406+
if (!scheduledAccesses.is_single_valued()) {
399407
return false;
400408
}
401409
}
@@ -415,23 +423,25 @@ bool accessSubscriptsAreUnrolledLoops(
415423
* thread associated to a given pair of tensor element and outer schedule
416424
* iteration.
417425
*/
426+
template <typename Outer>
418427
bool isPromotableToRegistersBelow(
419428
const TensorReferenceGroup& group,
420429
const detail::ScheduleTree* root,
421430
const detail::ScheduleTree* scope,
422-
isl::multi_union_pw_aff outer,
423-
isl::multi_union_pw_aff thread) {
431+
isl::MultiUnionPwAff<Domain, Outer> outer,
432+
isl::MultiUnionPwAff<Domain, Thread> thread) {
424433
if (!accessSubscriptsAreUnrolledLoops(
425-
group, root, scope, outer.flat_range_product(thread))) {
434+
group, root, scope, outer.range_product(thread))) {
426435
return false;
427436
}
428437

429438
auto originalAccesses = group.originalAccesses();
430-
auto map = isl::union_map::from(outer);
431-
map = map.range_product(originalAccesses);
432-
map = map.apply_domain(isl::union_map::from(thread));
439+
auto outerMap = isl::UnionMap<Domain, Outer>::from(outer);
440+
auto pair = outerMap.range_product(originalAccesses);
441+
auto threadMap = isl::UnionMap<Domain, Thread>::from(thread);
442+
auto threadToPair = pair.apply_domain(threadMap);
433443

434-
return map.is_injective();
444+
return threadToPair.is_injective();
435445
}
436446

437447
/*
@@ -654,15 +664,15 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
654664
auto blockSchedule = mscop.blockMappingSchedule(mscop.schedule());
655665

656666
// Pure affine schedule without (mapping) filters.
657-
auto partialSchedMupa = partialScheduleMupa(root, scope);
667+
auto partialSchedMupa = partialScheduleMupa<Scope>(root, scope);
658668
// Schedule with block mapping filter.
659669
auto partialSched =
660670
isl::union_map::from(partialSchedMupa).intersect_domain(blockMapping);
661671
// The following promotion validity and profitability checks need to be
662672
// performed with respect to the block mapping, so append the block schedule.
663673
// If the partial schedule contains it already, it will just end up with
664674
// identical dimensions without affecting the result of the checks.
665-
partialSchedMupa = partialSchedMupa.flat_range_product(blockSchedule);
675+
auto partialSchedBlockMupa = partialSchedMupa.range_product(blockSchedule);
666676

667677
for (auto& tensorGroups : groupMap) {
668678
auto tensorId = tensorGroups.first;
@@ -676,11 +686,11 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
676686
continue;
677687
}
678688
if (!isPromotableToRegistersBelow(
679-
*group, root, scope, partialSchedMupa, threadSchedule)) {
689+
*group, root, scope, partialSchedBlockMupa, threadSchedule)) {
680690
continue;
681691
}
682692
// Check reuse within threads.
683-
auto schedule = partialSchedMupa.flat_range_product(threadSchedule);
693+
auto schedule = partialSchedBlockMupa.flat_range_product(threadSchedule);
684694
if (!hasReuseWithin(*group, schedule)) {
685695
continue;
686696
}

tc/core/polyhedral/domain_types.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
namespace tc {
2+
namespace polyhedral {
3+
4+
struct Domain;
5+
struct Scope;
6+
struct Tensor;
7+
struct Thread;
8+
9+
} // namespace polyhedral
10+
} // namespace tc

tc/core/polyhedral/memory_promotion.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include <iostream>
1919

20+
#include "tc/core/polyhedral/domain_types.h"
2021
#include "tc/core/polyhedral/schedule_tree.h"
2122
#include "tc/core/polyhedral/scop.h"
2223
#include "tc/external/isl.h"
@@ -137,8 +138,9 @@ class TensorReferenceGroup {
137138
// range spaces.
138139
isl::union_map originalWrites() const;
139140
isl::union_map originalReads() const;
140-
isl::union_map originalAccesses() const {
141-
return originalWrites().unite(originalReads());
141+
isl::UnionMap<Domain, Tensor> originalAccesses() const {
142+
auto accesses = originalWrites().unite(originalReads());
143+
return isl::UnionMap<Domain, Tensor>(accesses);
142144
}
143145

144146
// Rectangular overapproximation of the set of tensor elements accessed below

tc/core/polyhedral/schedule_transforms.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,13 @@ isl::union_map prefixSchedule(
9898
return partialScheduleImpl(root, node, false);
9999
}
100100

101+
namespace detail {
101102
isl::union_map partialSchedule(
102103
const ScheduleTree* root,
103104
const ScheduleTree* node) {
104105
return partialScheduleImpl(root, node, true);
105106
}
107+
} // namespace detail
106108

107109
namespace {
108110
/*
@@ -131,7 +133,7 @@ isl::union_set applyMapping(isl::union_set domain, const ScheduleTree* node) {
131133
// Domain elements are introduced by the root domain node. Some nodes
132134
// refine this set of elements based on "filter". Extension nodes
133135
// are considered to introduce additional domain points.
134-
isl::union_set collectDomain(
136+
isl::UnionSet<Domain> collectDomain(
135137
const ScheduleTree* root,
136138
const vector<const ScheduleTree*>& nodes,
137139
isl::union_set (*filter)(isl::union_set domain, const ScheduleTree* node)) {
@@ -150,12 +152,12 @@ isl::union_set collectDomain(
150152
domain = domain.unite(parentSchedule.range().apply(extension));
151153
}
152154
}
153-
return domain;
155+
return isl::UnionSet<Domain>(domain);
154156
}
155157

156158
// Get the set of domain elements that are active below
157159
// the given branch of nodes.
158-
isl::union_set activeDomainPointsHelper(
160+
isl::UnionSet<Domain> activeDomainPointsHelper(
159161
const ScheduleTree* root,
160162
const vector<const ScheduleTree*>& nodes) {
161163
return collectDomain(root, nodes, &applyFilter);
@@ -169,7 +171,7 @@ isl::union_set prefixMappingFilter(
169171
return collectDomain(root, node->ancestors(root), &applyMapping);
170172
}
171173

172-
isl::union_set activeDomainPoints(
174+
isl::UnionSet<Domain> activeDomainPoints(
173175
const ScheduleTree* root,
174176
const ScheduleTree* node) {
175177
return activeDomainPointsHelper(root, node->ancestors(root));
@@ -503,13 +505,15 @@ isl::multi_union_pw_aff prefixScheduleMupa(
503505
return infixScheduleMupa(root, root, tree);
504506
}
505507

508+
namespace detail {
506509
isl::multi_union_pw_aff partialScheduleMupa(
507510
const detail::ScheduleTree* root,
508511
const detail::ScheduleTree* tree) {
509512
auto prefix = prefixScheduleMupa(root, tree);
510513
auto band = tree->elemAs<ScheduleTreeElemBand>();
511514
return band ? prefix.flat_range_product(band->mupa_) : prefix;
512515
}
516+
} // namespace detail
513517

514518
void updateTopLevelContext(detail::ScheduleTree* root, isl::set context) {
515519
if (!matchOne(tc::polyhedral::domain(tc::polyhedral::context(any())), root)) {

tc/core/polyhedral/schedule_transforms.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <unordered_set>
2222
#include <vector>
2323

24+
#include "tc/core/polyhedral/domain_types.h"
2425
#include "tc/core/polyhedral/functional.h"
2526
#include "tc/core/polyhedral/mapping_types.h"
2627
#include "tc/core/polyhedral/options.h"
@@ -272,9 +273,18 @@ isl::union_map extendSchedule(
272273

273274
// Get the partial schedule defined by ancestors of the given node and the node
274275
// itself.
276+
namespace detail {
275277
isl::union_map partialSchedule(
276278
const detail::ScheduleTree* root,
277279
const detail::ScheduleTree* node);
280+
}
281+
template <typename Schedule>
282+
isl::UnionMap<Domain, Schedule> partialSchedule(
283+
const detail::ScheduleTree* root,
284+
const detail::ScheduleTree* tree) {
285+
auto partial = detail::partialSchedule(root, tree);
286+
return isl::UnionMap<Domain, Schedule>(partial);
287+
}
278288

279289
// Return the schedule defined by the ancestors of the given node.
280290
isl::union_map prefixSchedule(
@@ -306,15 +316,24 @@ isl::multi_union_pw_aff prefixScheduleMupa(
306316
// including that of the node itself.
307317
// Note that this function does not take into account
308318
// any intermediate filter nodes.
319+
namespace detail {
309320
isl::multi_union_pw_aff partialScheduleMupa(
310321
const detail::ScheduleTree* root,
311322
const detail::ScheduleTree* tree);
323+
}
324+
template <typename Schedule>
325+
isl::MultiUnionPwAff<Domain, Schedule> partialScheduleMupa(
326+
const detail::ScheduleTree* root,
327+
const detail::ScheduleTree* tree) {
328+
auto partial = detail::partialScheduleMupa(root, tree);
329+
return isl::MultiUnionPwAff<Domain, Schedule>(partial);
330+
}
312331

313332
// Get the set of domain points active at the given node. A domain
314333
// point is active if it was not filtered away on the path from the
315334
// root to the node. The root must be a domain element, otherwise no
316335
// elements would be considered active.
317-
isl::union_set activeDomainPoints(
336+
isl::UnionSet<Domain> activeDomainPoints(
318337
const detail::ScheduleTree* root,
319338
const detail::ScheduleTree* node);
320339

0 commit comments

Comments
 (0)