43
43
44
44
#include < glog/logging.h>
45
45
46
+ using tc::polyhedral::detail::ScheduleTree;
47
+ using tc::polyhedral::detail::ScheduleTreeBand;
48
+ using tc::polyhedral::detail::ScheduleTreeContext;
49
+ using tc::polyhedral::detail::ScheduleTreeSequence;
50
+
46
51
namespace tc {
47
52
namespace polyhedral {
53
+ namespace cuda {
48
54
49
55
namespace {
50
56
@@ -77,7 +83,7 @@ static void checkMinimum(isl::union_set domain, isl::union_pw_aff_list list) {
77
83
template <typename ExceptionType>
78
84
inline void throwIfHasPattern (
79
85
ScheduleTreeMatcher matcher,
80
- const detail:: ScheduleTree* root) {
86
+ const ScheduleTree* root) {
81
87
auto candidates = match (matcher, root);
82
88
if (candidates.size () > 0 ) {
83
89
std::stringstream ss;
@@ -87,7 +93,7 @@ inline void throwIfHasPattern(
87
93
}
88
94
}
89
95
90
- void validate (const detail:: ScheduleTree* root) {
96
+ void validate (const ScheduleTree* root) {
91
97
throwIfHasPattern<EmptyFilterException>(
92
98
filter (
93
99
[](isl::union_set uset) { return !uset || uset.is_empty (); }, any ()),
@@ -98,7 +104,7 @@ void validate(const detail::ScheduleTree* root) {
98
104
root);
99
105
}
100
106
101
- bool anyNonCoincidentMember (const detail:: ScheduleTreeBand* band) {
107
+ bool anyNonCoincidentMember (const ScheduleTreeBand* band) {
102
108
return band->nOuterCoincident () < band->nMember ();
103
109
}
104
110
@@ -129,9 +135,7 @@ const CudaDim& mappingSize<mapping::ThreadId>(const MappedScop* mscop) {
129
135
// Return a pointer to the updated node (below the inserted filter)
130
136
// for call chaining purposes.
131
137
template <typename MappingTypeId>
132
- detail::ScheduleTree* MappedScop::map (
133
- detail::ScheduleTree* tree,
134
- isl::union_pw_aff_list list) {
138
+ ScheduleTree* MappedScop::map (ScheduleTree* tree, isl::union_pw_aff_list list) {
135
139
size_t nToMap = list.size ();
136
140
const auto & extent = mappingSize<MappingTypeId>(this ).view ;
137
141
TC_CHECK_LE (nToMap, extent.size ()) << " dimension overflow" ;
@@ -160,16 +164,14 @@ detail::ScheduleTree* MappedScop::map(
160
164
161
165
checkMinimum (domain, affList);
162
166
163
- auto mapping = detail:: ScheduleTree::makeMapping (idList, affList);
167
+ auto mapping = ScheduleTree::makeMapping (idList, affList);
164
168
tree = insertNodeAbove (root, tree, std::move (mapping))->child ({0 });
165
169
166
170
return tree;
167
171
}
168
172
169
- detail::ScheduleTree* MappedScop::mapBlocksForward (
170
- detail::ScheduleTree* band,
171
- size_t nToMap) {
172
- auto bandNode = band->as <detail::ScheduleTreeBand>();
173
+ ScheduleTree* MappedScop::mapBlocksForward (ScheduleTree* band, size_t nToMap) {
174
+ auto bandNode = band->as <ScheduleTreeBand>();
173
175
TC_CHECK (bandNode) << " expected a band, got " << *band;
174
176
175
177
auto list = bandNode->mupa_ .get_union_pw_aff_list ();
@@ -180,7 +182,7 @@ detail::ScheduleTree* MappedScop::mapBlocksForward(
180
182
// Uses as many blockSizes elements as outer coincident dimensions in the
181
183
// outermost band
182
184
void MappedScop::mapToBlocksAndScaleBand (
183
- detail:: ScheduleTree* band,
185
+ ScheduleTree* band,
184
186
std::vector<size_t > tileSizes) {
185
187
using namespace tc ::polyhedral::detail;
186
188
@@ -205,16 +207,13 @@ namespace {
205
207
* the remaining thread identifiers starting at "begin" to zero.
206
208
* Add a marker underneath that marks the subtree that is thread specific.
207
209
*/
208
- void fixThreadsBelow (
209
- MappedScop& mscop,
210
- detail::ScheduleTree* tree,
211
- size_t begin) {
210
+ void fixThreadsBelow (MappedScop& mscop, ScheduleTree* tree, size_t begin) {
212
211
size_t end = mscop.numThreads .view .size ();
213
212
if (begin == end) {
214
213
return ;
215
214
}
216
215
217
- auto band = detail:: ScheduleTree::makeEmptyBand (mscop.scop ().scheduleRoot ());
216
+ auto band = ScheduleTree::makeEmptyBand (mscop.scop ().scheduleRoot ());
218
217
auto bandTree = insertNodeBelow (tree, std::move (band));
219
218
mscop.mapThreadsBackward (bandTree);
220
219
}
@@ -226,10 +225,7 @@ void fixThreadsBelow(
226
225
* Anything that depends on an update statement is ordered after
227
226
* the update statements. Anything else is ordered before.
228
227
*/
229
- bool separatedOut (
230
- Scop& scop,
231
- detail::ScheduleTree* tree,
232
- isl::union_set updates) {
228
+ bool separatedOut (Scop& scop, ScheduleTree* tree, isl::union_set updates) {
233
229
auto domain = activeDomainPoints (scop.scheduleRoot (), tree);
234
230
auto other = domain.subtract (updates);
235
231
if (other.is_empty ()) {
@@ -254,7 +250,7 @@ bool separatedOut(
254
250
255
251
} // namespace
256
252
257
- bool MappedScop::detectReductions (detail:: ScheduleTree* tree) {
253
+ bool MappedScop::detectReductions (ScheduleTree* tree) {
258
254
// Do not bother with reductions if block is of size 1 in the x direction.
259
255
if (numThreads.view .size () == 0 || numThreads.view [0 ] == 1 ) {
260
256
return false ;
@@ -264,7 +260,7 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
264
260
for (auto c : tree->children ()) {
265
261
found |= detectReductions (c);
266
262
}
267
- auto band = tree->as <detail:: ScheduleTreeBand>();
263
+ auto band = tree->as <ScheduleTreeBand>();
268
264
// Nested reductions are not currently supported.
269
265
if (!band || found) {
270
266
return found;
@@ -314,7 +310,7 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
314
310
return true ;
315
311
}
316
312
317
- bool MappedScop::needReductionSeparation (const detail:: ScheduleTree* st) {
313
+ bool MappedScop::needReductionSeparation (const ScheduleTree* st) {
318
314
if (reductionBandUpdates_.count (st) != 1 ) {
319
315
return false ;
320
316
}
@@ -323,9 +319,9 @@ bool MappedScop::needReductionSeparation(const detail::ScheduleTree* st) {
323
319
}
324
320
325
321
isl::multi_union_pw_aff MappedScop::reductionMapSchedule (
326
- const detail:: ScheduleTree* st) {
322
+ const ScheduleTree* st) {
327
323
TC_CHECK (reductionBandUpdates_.count (st) == 1 );
328
- auto reductionBand = st->as <detail:: ScheduleTreeBand>();
324
+ auto reductionBand = st->as <ScheduleTreeBand>();
329
325
TC_CHECK (reductionBand);
330
326
331
327
auto nMember = reductionBand->nMember ();
@@ -337,7 +333,7 @@ isl::multi_union_pw_aff MappedScop::reductionMapSchedule(
337
333
return reductionBand->memberRange (first, nMappedThreads);
338
334
}
339
335
340
- detail:: ScheduleTree* MappedScop::separateReduction (detail:: ScheduleTree* st) {
336
+ ScheduleTree* MappedScop::separateReduction (ScheduleTree* st) {
341
337
auto reduction = st;
342
338
// This function either separates full blocks (if needed) or
343
339
// disables the reduction handling.
@@ -386,23 +382,22 @@ detail::ScheduleTree* MappedScop::separateReduction(detail::ScheduleTree* st) {
386
382
return st->ancestor (root, 2 );
387
383
}
388
384
389
- detail::ScheduleTree* MappedScop::mapThreadsBackward (
390
- detail::ScheduleTree* band) {
391
- auto bandNode = band->as <detail::ScheduleTreeBand>();
385
+ ScheduleTree* MappedScop::mapThreadsBackward (ScheduleTree* band) {
386
+ auto bandNode = band->as <ScheduleTreeBand>();
392
387
TC_CHECK (bandNode);
393
388
auto nMember = bandNode->nMember ();
394
389
auto nToMap = std::min (nMember, numThreads.view .size ());
395
390
TC_CHECK_LE (nToMap, 3u ) << " mapping to too many threads" ;
396
391
397
392
auto ctx = band->ctx_ ;
398
- insertNodeBelow (band, detail:: ScheduleTree::makeThreadSpecificMarker (ctx));
393
+ insertNodeBelow (band, ScheduleTree::makeThreadSpecificMarker (ctx));
399
394
400
395
auto list = bandNode->mupa_ .get_union_pw_aff_list ().reverse ();
401
396
list = list.drop (nToMap, list.size () - nToMap);
402
397
return map<mapping::ThreadId>(band, list);
403
398
}
404
399
405
- size_t MappedScop::mapToThreads (detail:: ScheduleTree* band) {
400
+ size_t MappedScop::mapToThreads (ScheduleTree* band) {
406
401
using namespace tc ::polyhedral::detail;
407
402
408
403
auto bandNode = band->as <ScheduleTreeBand>();
@@ -473,17 +468,15 @@ namespace {
473
468
* That is, assuming "st" is a sequence node, does the last child
474
469
* need to be protected from the next iteration of the first child?
475
470
*/
476
- bool hasOuterSequentialMember (
477
- const detail::ScheduleTree* root,
478
- detail::ScheduleTree* st) {
471
+ bool hasOuterSequentialMember (const ScheduleTree* root, ScheduleTree* st) {
479
472
auto ancestors = st->ancestors (root);
480
473
std::reverse (ancestors.begin (), ancestors.end ());
481
474
for (auto a : ancestors) {
482
- auto band = a->as <detail:: ScheduleTreeBand>();
475
+ auto band = a->as <ScheduleTreeBand>();
483
476
if (band && band->nMember () > band->nOuterCoincident ()) {
484
477
return true ;
485
478
}
486
- if (a->as <detail:: ScheduleTreeSequence>()) {
479
+ if (a->as <ScheduleTreeSequence>()) {
487
480
return false ;
488
481
}
489
482
}
@@ -532,7 +525,7 @@ isl::multi_aff constructThreadToWarp(
532
525
} // namespace
533
526
534
527
isl::multi_union_pw_aff MappedScop::threadMappingSchedule (
535
- const detail:: ScheduleTree* tree) const {
528
+ const ScheduleTree* tree) const {
536
529
std::vector<mapping::MappingId> ids;
537
530
for (size_t i = 0 ; i < numThreads.view .size (); ++i) {
538
531
ids.emplace_back (mapping::ThreadId::makeId (i));
@@ -542,7 +535,7 @@ isl::multi_union_pw_aff MappedScop::threadMappingSchedule(
542
535
}
543
536
544
537
isl::multi_union_pw_aff MappedScop::blockMappingSchedule (
545
- const detail:: ScheduleTree* tree) const {
538
+ const ScheduleTree* tree) const {
546
539
std::vector<mapping::MappingId> ids;
547
540
for (size_t i = 0 ; i < numBlocks.view .size (); ++i) {
548
541
ids.emplace_back (mapping::BlockId::makeId (i));
@@ -552,8 +545,8 @@ isl::multi_union_pw_aff MappedScop::blockMappingSchedule(
552
545
}
553
546
554
547
Scop::SyncLevel MappedScop::findBestSync (
555
- detail:: ScheduleTree* st1,
556
- detail:: ScheduleTree* st2,
548
+ ScheduleTree* st1,
549
+ ScheduleTree* st2,
557
550
isl::multi_union_pw_aff domainToThread,
558
551
isl::multi_union_pw_aff domainToWarp) {
559
552
// Active points in the two schedule trees
@@ -571,7 +564,7 @@ Scop::SyncLevel MappedScop::findBestSync(
571
564
572
565
TC_CHECK_LE (1u , scop_->scheduleRoot ()->children ().size ());
573
566
auto contextSt = scop_->scheduleRoot ()->children ()[0 ];
574
- auto contextElem = contextSt->as <detail:: ScheduleTreeContext>();
567
+ auto contextElem = contextSt->as <ScheduleTreeContext>();
575
568
TC_CHECK (nullptr != contextElem);
576
569
dependences = dependences.intersect_params (contextElem->context_ );
577
570
@@ -733,8 +726,8 @@ std::vector<std::pair<int, int>> MappedScop::findBestSyncConfigInSeq(
733
726
return solutionWithBestBegining;
734
727
}
735
728
736
- void MappedScop::insertBestSyncInSeq (detail:: ScheduleTree* seq) {
737
- TC_CHECK (seq->as <detail:: ScheduleTreeSequence>());
729
+ void MappedScop::insertBestSyncInSeq (ScheduleTree* seq) {
730
+ TC_CHECK (seq->as <ScheduleTreeSequence>());
738
731
739
732
auto children = seq->children ();
740
733
auto nChildren = children.size ();
@@ -796,7 +789,7 @@ void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) {
796
789
// the next iteration of the first child if there may be such
797
790
// a next iteration that is not already covered by synchronization
798
791
// on an outer node.
799
- size_t MappedScop::mapInnermostBandsToThreads (detail:: ScheduleTree* st) {
792
+ size_t MappedScop::mapInnermostBandsToThreads (ScheduleTree* st) {
800
793
if (needReductionSeparation (st)) {
801
794
st = separateReduction (st);
802
795
}
@@ -808,7 +801,7 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
808
801
}
809
802
auto n = nChildren > 0 ? *std::max_element (nInner.begin (), nInner.end ()) : 0 ;
810
803
if (nChildren > 1 ) {
811
- auto needSync = st->as <detail:: ScheduleTreeSequence>() && n > 0 ;
804
+ auto needSync = st->as <ScheduleTreeSequence>() && n > 0 ;
812
805
if (n > 0 ) {
813
806
for (size_t i = 0 ; i < nChildren; ++i) {
814
807
fixThreadsBelow (*this , children[i], nInner[i]);
@@ -819,7 +812,7 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
819
812
}
820
813
}
821
814
822
- if (auto band = st->as <detail:: ScheduleTreeBand>()) {
815
+ if (auto band = st->as <ScheduleTreeBand>()) {
823
816
if (n == 0 ) {
824
817
// If children were not mapped to threads, the current band can be mapped.
825
818
// First, map the coincidence and reduction dimension to threads.
@@ -957,8 +950,8 @@ std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
957
950
// Split out a single reduction tile (in the directions other than
958
951
// the reduction) and insert reduction synchronizations outside this tile.
959
952
// Return a pointer to the split off tile.
960
- detail:: ScheduleTree* MappedScop::splitOutReductionTileAndInsertSyncs (
961
- detail:: ScheduleTree* band) {
953
+ ScheduleTree* MappedScop::splitOutReductionTileAndInsertSyncs (
954
+ ScheduleTree* band) {
962
955
using namespace polyhedral ::detail;
963
956
size_t n = numThreads.view .size ();
964
957
@@ -1103,5 +1096,6 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
1103
1096
return mappedScop;
1104
1097
}
1105
1098
1099
+ } // namespace cuda
1106
1100
} // namespace polyhedral
1107
1101
} // namespace tc
0 commit comments