@@ -157,6 +157,35 @@ std::vector<T> findThreadSpecificMarkers(T root) {
157
157
root, ScheduleTreeType::ThreadSpecificMarker);
158
158
}
159
159
160
+ /*
161
+ * Return the thread specific markers in the tree rooted at "root"
162
+ * that are relevant for "node".
163
+ *
164
+ * Every branch in the tree has exactly one thread marker.
165
+ * If "node" appears underneath a thread marker, then return
166
+ * that single thread marker.
167
+ * Otherwise, return the (possibly multiple) thread markers
168
+ * in the subtree rooted at "node".
169
+ */
170
+ template <typename T>
171
+ std::vector<T> collectBranchMarkers (T root, T node) {
172
+ using namespace detail ;
173
+ static_assert (
174
+ std::is_convertible<T, const ScheduleTree*>::value,
175
+ " expecting ScheduleTree" );
176
+
177
+ auto filterMarker = [](T tree) {
178
+ return tree->type_ == ScheduleTreeType::ThreadSpecificMarker;
179
+ };
180
+
181
+ auto ancestors = node->ancestors (root);
182
+ ancestors = functional::Filter (filterMarker, ancestors);
183
+ if (ancestors.size () > 0 ) {
184
+ return ancestors;
185
+ }
186
+ return findThreadSpecificMarkers (node);
187
+ }
188
+
160
189
/*
161
190
* Transform schedule bands into a union_map.
162
191
* Takes all partial schedules at leaves as MUPAs (without accounting for
@@ -277,27 +306,6 @@ isl::map makeNextElementMap(isl::space setSpace, unsigned dim) {
277
306
return isl::map (identityMA);
278
307
}
279
308
280
- // Obtain the depth of the schedule dimension that was mapped to threadIdx.x
281
- // for the domain elements identified by "s". Assumes the depth is the same
282
- // for all these elements.
283
- size_t computeThreadIdxXScheduleDepth (
284
- const ThreadIdxXScheduleDepthState& threadIdxXScheduleDepthState,
285
- isl::union_set s) {
286
- std::unordered_set<size_t > depths;
287
- for (auto p : threadIdxXScheduleDepthState) {
288
- if (!p.first .intersect (s).is_empty ()) {
289
- depths.insert (p.second );
290
- }
291
- }
292
- if (depths.size () != 1 ) {
293
- std::stringstream ss;
294
- ss << " threadIdx.x depth " << (depths.size () == 0 ? " unknown" : " diverged" )
295
- << " for " << s;
296
- throw promotion::PromotionLogicError (ss.str ());
297
- }
298
- return *depths.begin ();
299
- }
300
-
301
309
/*
302
310
* Return the outermost thread mapping filter among the ancestors of "node",
303
311
* assuming that there is at least one.
@@ -318,42 +326,49 @@ const detail::ScheduleTree* findThreadMappingAncestor(
318
326
*
319
327
* If the reference group is not already accessed in a coalesced way,
320
328
* then the group should be promoted.
329
+ * If a branch is mapped to a single thread, then the accesses
330
+ * in that branch are not considered to contribute to the usefulness
331
+ * of promoting.
332
+ *
321
333
* The check for coalesced accesses is performed as follows.
322
334
* Check if incrementing the schedule dimension mapped to
323
335
* Thread::x results in the last tensor index being incremented as well.
324
336
* Since accesses in the group may belong to different statements, which may
325
- * have different loops mapped to Thread::x, perform the check for each basic
326
- * map in the union of access maps taking into account which dimension is
327
- * mapped for a particular statement (domain of the basic map). The group is
337
+ * have different loops mapped to Thread::x, perform the check for each thread
338
+ * mapping on the statements active at "node" (either a single ancestor,
339
+ * or one or more descendants).
340
+ * The iteration over the spaces is used to handle the case where
341
+ * one of the subbranches does not access the tensor and
342
+ * the scheduled accesses are empty. The group is
328
343
* accessed in a coalesced way if all references in this group are accessed in
329
344
* a coalesced way.
330
345
*/
331
346
bool promotionImprovesCoalescing (
332
- const ThreadIdxXScheduleDepthState& threadIdxXScheduleDepthState,
347
+ const detail::ScheduleTree* root,
348
+ const detail::ScheduleTree* node,
333
349
const TensorReferenceGroup& group,
334
- isl::union_map schedule,
335
- isl::union_set activePoints) {
350
+ isl::union_map schedule) {
336
351
auto originalAccesses = group.originalAccesses ();
337
352
338
- for (auto accessMap : isl::UnionAsVector<isl::union_map>(originalAccesses)) {
339
- for (auto access : accessMap.get_basic_map_list ()) {
353
+ auto markers = collectBranchMarkers (root, node);
354
+ for (auto marker : markers) {
355
+ auto mapping = findThreadMappingAncestor (root, marker);
356
+ size_t nMappedThreads = marker->scheduleDepth (mapping);
357
+ if (nMappedThreads == 0 ) {
358
+ continue ;
359
+ }
360
+ auto depth = marker->scheduleDepth (root);
361
+ auto activePoints = activeDomainPoints (root, mapping);
362
+ auto localAccesses = originalAccesses.intersect_domain (activePoints);
363
+ auto scheduledAccesses = localAccesses.apply_domain (schedule);
364
+ for (auto access : isl::UnionAsVector<isl::union_map>(scheduledAccesses)) {
365
+ auto scheduleSpace = access.get_space ().domain ();
340
366
auto tensorSpace = access.get_space ().range ();
341
367
auto elementToNext = makeNextElementMap (
342
368
tensorSpace, tensorSpace.dim (isl::dim_type::set) - 1 );
343
- auto domainUMap = isl::union_set (isl::set (access.domain ()));
344
- int threadIdxXDepth = computeThreadIdxXScheduleDepth (
345
- threadIdxXScheduleDepthState, domainUMap.intersect (activePoints));
346
- auto partialScheduleUMap =
347
- schedule.intersect_domain (domainUMap.universe ());
348
- if (partialScheduleUMap.n_map () != 1 ) {
349
- throw promotion::PromotionLogicError (" expected single schedule space" );
350
- }
351
- auto partialSchedule = isl::map::from_union_map (partialScheduleUMap);
352
- auto scheduleToNextX = makeNextElementMap (
353
- partialSchedule.get_space ().range (), threadIdxXDepth);
354
- auto scheduledAccess = isl::map (access).apply_domain (partialSchedule);
355
- auto accessedByAdjacentX = scheduleToNextX.apply_domain (scheduledAccess)
356
- .apply_range (scheduledAccess);
369
+ auto scheduleToNextX = makeNextElementMap (scheduleSpace, depth - 1 );
370
+ auto accessedByAdjacentX =
371
+ scheduleToNextX.apply_domain (access).apply_range (access);
357
372
358
373
if (not accessedByAdjacentX.is_subset (elementToNext)) {
359
374
return true ;
@@ -467,7 +482,6 @@ std::vector<detail::ScheduleTree*> bandsSplitAfterDepth(
467
482
*/
468
483
void promoteToSharedGreedy (
469
484
Scop& scop,
470
- const ThreadIdxXScheduleDepthState& threadIdxXScheduleDepthState,
471
485
const Block& block,
472
486
size_t depth,
473
487
size_t maxMemory) {
@@ -561,11 +575,7 @@ void promoteToSharedGreedy(
561
575
// Do not promote if the group features no reuse and is accessed in a
562
576
// coalesced way.
563
577
if (!hasReuseWithin (*group, partialSchedMupa) &&
564
- !promotionImprovesCoalescing (
565
- threadIdxXScheduleDepthState,
566
- *group,
567
- fullSched,
568
- activePoints)) {
578
+ !promotionImprovesCoalescing (root, bandNode, *group, fullSched)) {
569
579
continue ;
570
580
}
571
581
@@ -586,17 +596,12 @@ void promoteToSharedGreedy(
586
596
587
597
void promoteGreedilyAtDepth (
588
598
MappedScop& mscop,
589
- const ThreadIdxXScheduleDepthState& threadIdxXScheduleDepthState,
590
599
size_t depth,
591
600
size_t sharedMemorySize,
592
601
bool unrollCopies) {
593
602
// 1. Promote using heuristic.
594
603
promoteToSharedGreedy (
595
- mscop.scop (),
596
- threadIdxXScheduleDepthState,
597
- mscop.numThreads ,
598
- depth,
599
- sharedMemorySize);
604
+ mscop.scop (), mscop.numThreads , depth, sharedMemorySize);
600
605
601
606
// 2. Map copies to shared, state by copy
602
607
mapCopiesToThreads (mscop, unrollCopies);
0 commit comments