@@ -108,22 +108,69 @@ const CudaDim& mappingSize<mapping::ThreadId>(const MappedScop* mscop) {
108
108
}
109
109
} // namespace
110
110
111
+ // Map the elements in "list" to successive blocks or thread identifiers,
112
+ // with the first element mapped to identifier X. The extents are obtained
113
+ // from the initial elements of numBlocks or numThreads. The identifiers
114
+ // must not be present in the space of the partial schedules in "list" and
115
+ // extents must be non-zero. The mapping corresponds to inserting a filter
116
+ // node with condition 'list % extent = ids'.
117
+ // The mapping is inserted above "tree".
118
+ //
119
+ // Return a pointer to the updated node (below the inserted filter)
120
+ // for call chaining purposes.
111
121
template <typename MappingTypeId>
112
- void MappedScop::mapRemaining (detail::ScheduleTree* tree, size_t nMapped) {
113
- size_t nToMap = mappingSize<MappingTypeId>(this ).view .size ();
114
- if (nMapped >= nToMap) {
115
- return ;
122
+ detail::ScheduleTree* MappedScop::map (
123
+ detail::ScheduleTree* tree,
124
+ isl::union_pw_aff_list list) {
125
+ size_t nToMap = list.n ();
126
+ const auto & extent = mappingSize<MappingTypeId>(this ).view ;
127
+ CHECK_LE (nToMap, extent.size ()) << " dimension overflow" ;
128
+
129
+ auto root = scop_->scheduleRoot ();
130
+ auto domain = activeDomainPoints (root, tree).universe ();
131
+ auto filter = domain;
132
+
133
+ std::unordered_set<MappingTypeId, typename MappingTypeId::Hash> idSet;
134
+ for (size_t i = 0 ; i < nToMap; ++i) {
135
+ auto id = MappingTypeId::makeId (i);
136
+ auto upa = list.get (i);
137
+ // Introduce the "mapping" parameter after checking it is not already
138
+ // present in the schedule space.
139
+ CHECK (not upa.involves_param (id));
140
+ CHECK_NE (extent[i], 0u ) << " NYI: mapping to 0" ;
141
+
142
+ // Create mapping filter by equating the newly introduced
143
+ // parameter ids[i] to the "i"-th affine function modulo its extent.
144
+ upa = upa.mod_val (isl::val (tree->ctx_ , extent[i]));
145
+ upa = upa.sub (isl::union_pw_aff::param_on_domain (domain, id));
146
+ filter = filter.intersect (upa.zero_union_set ());
147
+
148
+ idSet.emplace (id);
116
149
}
117
150
118
- std::unordered_set<MappingTypeId, typename MappingTypeId::Hash> ids;
119
- for (size_t i = nMapped; i < nToMap; ++i) {
120
- ids.insert (MappingTypeId::makeId (i));
151
+ std::unordered_set<MappingTypeId, typename MappingTypeId::Hash> unmapped;
152
+ for (size_t i = nToMap; i < extent.size (); ++i) {
153
+ auto id = MappingTypeId::makeId (i);
154
+ unmapped.emplace (id);
155
+ idSet.emplace (id);
121
156
}
122
- auto root = scop_->scheduleRoot ();
123
- auto domain = activeDomainPoints (root, tree);
124
- auto filter = makeFixRemainingZeroFilter (domain, ids);
125
- auto mapping = detail::ScheduleTree::makeMappingFilter (filter, ids);
126
- insertNodeAbove (root, tree, std::move (mapping));
157
+ filter = filter.intersect (makeFixRemainingZeroFilter (domain, unmapped));
158
+
159
+ auto mapping = detail::ScheduleTree::makeMappingFilter (filter, idSet);
160
+ tree = insertNodeAbove (root, tree, std::move (mapping))->child ({0 });
161
+
162
+ return tree;
163
+ }
164
+
165
+ detail::ScheduleTree* MappedScop::mapBlocksForward (
166
+ detail::ScheduleTree* band,
167
+ size_t nToMap) {
168
+ auto bandNode = band->elemAs <detail::ScheduleTreeElemBand>();
169
+ CHECK (bandNode) << " expected a band, got " << *band;
170
+
171
+ auto list = bandNode->mupa_ .get_union_pw_aff_list ();
172
+ list = list.drop (nToMap, list.n () - nToMap);
173
+ return map<mapping::BlockId>(band, list);
127
174
}
128
175
129
176
// Uses as many blockSizes elements as outer coincident dimensions in the
@@ -142,10 +189,7 @@ void MappedScop::mapToBlocksAndScaleBand(
142
189
// and no more than block dimensions to be mapped
143
190
nBlocksToMap = std::min (nBlocksToMap, numBlocks.view .size ());
144
191
145
- for (size_t i = 0 ; i < nBlocksToMap; ++i) {
146
- band = map (band, i, mapping::BlockId::makeId (i));
147
- }
148
- mapRemaining<mapping::BlockId>(band, nBlocksToMap);
192
+ mapBlocksForward (band, nBlocksToMap);
149
193
bandScale (band, tileSizes);
150
194
}
151
195
@@ -166,10 +210,7 @@ void fixThreadsBelow(
166
210
167
211
auto band = detail::ScheduleTree::makeEmptyBand (mscop.scop ().scheduleRoot ());
168
212
auto bandTree = insertNodeBelow (tree, std::move (band));
169
- auto ctx = tree->ctx_ ;
170
- insertNodeBelow (
171
- bandTree, detail::ScheduleTree::makeThreadSpecificMarker (ctx));
172
- mscop.mapRemaining <mapping::ThreadId>(bandTree, begin);
213
+ mscop.mapThreadsBackward (bandTree);
173
214
}
174
215
175
216
bool MappedScop::detectReductions (detail::ScheduleTree* tree) {
@@ -305,6 +346,22 @@ detail::ScheduleTree* MappedScop::separateReduction(detail::ScheduleTree* st) {
305
346
return st->ancestor (root, 2 );
306
347
}
307
348
349
+ detail::ScheduleTree* MappedScop::mapThreadsBackward (
350
+ detail::ScheduleTree* band) {
351
+ auto bandNode = band->elemAs <detail::ScheduleTreeElemBand>();
352
+ CHECK (bandNode);
353
+ auto nMember = bandNode->nMember ();
354
+ auto nToMap = std::min (nMember, numThreads.view .size ());
355
+ CHECK_LE (nToMap, 3 ) << " mapping to too many threads" ;
356
+
357
+ auto ctx = band->ctx_ ;
358
+ insertNodeBelow (band, detail::ScheduleTree::makeThreadSpecificMarker (ctx));
359
+
360
+ auto list = bandNode->mupa_ .get_union_pw_aff_list ().reverse ();
361
+ list = list.drop (nToMap, list.n () - nToMap);
362
+ return map<mapping::ThreadId>(band, list);
363
+ }
364
+
308
365
size_t MappedScop::mapToThreads (detail::ScheduleTree* band) {
309
366
using namespace tc ::polyhedral::detail;
310
367
@@ -355,20 +412,9 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
355
412
bandSplit (scop_->scheduleRoot (), band, nMappedThreads);
356
413
}
357
414
358
- auto ctx = band->ctx_ ;
359
- insertNodeBelow (band, detail::ScheduleTree::makeThreadSpecificMarker (ctx));
360
-
361
415
CHECK_GT (nMappedThreads, 0 ) << " not mapping to threads" ;
362
- CHECK_LE (nMappedThreads, 3 ) << " mapping to too many threads" ;
363
416
364
- // Map the coincident dimensions to threads starting from the innermost and
365
- // from thread x.
366
- for (size_t i = 0 ; i < nMappedThreads; ++i) {
367
- auto id = mapping::ThreadId::makeId (i);
368
- auto dim = nMappedThreads - 1 - i;
369
- band = map (band, dim, id);
370
- }
371
- mapRemaining<mapping::ThreadId>(band, nMappedThreads);
417
+ mapThreadsBackward (band);
372
418
373
419
if (isReduction) {
374
420
splitOutReductionAndInsertSyncs (band, nMappedThreads - 1 );
0 commit comments