@@ -346,35 +346,118 @@ isl::union_set collectMappingsTo(const Scop& scop) {
346
346
return mapping;
347
347
}
348
348
349
+ /*
350
+ * Check that only unrolled loops may appear in access subscripts.
351
+ * Because the scoping point can be above a branching tree, descend into each
352
+ * leaf of the subtree below the scoping point. For each leaf, construct an
353
+ * affine multi-expression containing only those band members between the
354
+ * scoping point and the leaf that are fully unrolled.
355
+ *
356
+ * Within each instance of the scope loops, check that loops that are either
357
+ * unrolled or mapped to threads access a single tensor element in the group
358
+ * (other loop indices will then not appear in the subscripts, making register
359
+ * promotion possible). In other words, check that the relation between the
360
+ * flat product of prefix, thread-mapped, and unrolled loop indices and
361
+ * accessed elements is single-valued.
362
+ *
363
+ * If band members are mapped to blocks(threads), they may still correspond to
364
+ * loops in the code in cases where the number of blocks(threads) is less than
365
+ * the extent of the band member. If there is no "unroll" flag on these
366
+ * members, we require that they not appear in the access subscripts similarly
367
+ * to regular loops. This is slightly more conservative than necessary because
368
+ * the actual generated loop iterators may disappear from the access after
369
+ * mapping to threads in cases where they are used with a modulo that is less
370
+ * than the number of blocks(threads). Precise analysis requires non-trivial
371
+ * schedule manipulations or explicit tiling by grid(block) sizes before
372
+ * mapping to blocks(threads).
373
+ *
374
+ * TODO: note that if a group is formed from partially overlapping references,
375
+ * one must consider per-reference access relation for single-valuedness as
376
+ * different references may have different values, but all of them remain
377
+ * independent of non-unrolled loop iterators.
378
+ */
379
+ bool accessSubscriptsAreUnrolledLoops (
380
+ const TensorReferenceGroup& group,
381
+ const detail::ScheduleTree* root,
382
+ const detail::ScheduleTree* scope,
383
+ isl::multi_union_pw_aff outerSchedule) {
384
+ using namespace detail ;
385
+
386
+ auto nodes = ScheduleTree::collect (scope);
387
+ auto leaves = functional::Filter (
388
+ [](const ScheduleTree* tree) { return tree->numChildren () == 0 ; }, nodes);
389
+
390
+ auto domainNode = root->elemAs <detail::ScheduleTreeElemDomain>();
391
+ TC_CHECK (domainNode);
392
+ auto domain = domainNode->domain_ ;
393
+
394
+ // Descend into every leaf.
395
+ for (auto leaf : leaves) {
396
+ auto ancestors = leaf->ancestors (root);
397
+ ancestors.push_back (leaf);
398
+ auto subdomain = activeDomainPointsBelow (root, leaf);
399
+
400
+ auto unrolledDims = isl::union_pw_aff_list (leaf->ctx_ , 1 );
401
+ for (auto node : ancestors) {
402
+ auto band = node->elemAs <detail::ScheduleTreeElemBand>();
403
+ if (!band) {
404
+ continue ;
405
+ }
406
+
407
+ isl::multi_union_pw_aff schedule = band->mupa_ ;
408
+ schedule = schedule.intersect_domain (subdomain);
409
+ for (size_t i = 0 , e = band->nMember (); i < e; ++i) {
410
+ if (!band->unroll_ [i]) {
411
+ continue ;
412
+ }
413
+ unrolledDims = unrolledDims.add (schedule.get_union_pw_aff (i));
414
+ }
415
+ }
416
+
417
+ auto space = isl::space (leaf->ctx_ , 0 , unrolledDims.n ())
418
+ .align_params (subdomain.get_space ());
419
+ auto unrolledDimsMupa = isl::multi_union_pw_aff (space, unrolledDims);
420
+
421
+ // It is possible that no loops are unrolled, in which case
422
+ // unrolledDimsMupa is zero-dimensional and needs an explicit domain
423
+ // to be convertible to a union_map.
424
+ unrolledDimsMupa =
425
+ unrolledDimsMupa.intersect_domain (group.originalAccesses ().domain ());
426
+
427
+ auto accesses = group.originalAccesses ();
428
+ auto schedule = outerSchedule.flat_range_product (unrolledDimsMupa);
429
+ accesses = accesses.apply_domain (isl::union_map::from (schedule));
430
+
431
+ if (!accesses.is_single_valued ()) {
432
+ return false ;
433
+ }
434
+ }
435
+
436
+ return true ;
437
+ }
438
+
349
439
/*
350
440
* Check if the given "group" can be promoted to registers for the given
351
441
* mapping to thread identifiers and within the given outer schedule.
352
442
*
353
- * In particular, the group's footprint must contain only one element and the
443
+ * In particular, all tensor subscripts that may appear in the promoted access
444
+ * must be either unrolled loops or thread identifiers and the
354
445
* same tensor element should never be accessed by two different threads
355
446
* within the same iteration of the outer schedule.
356
447
* The second test is performed by checking that there is only a single
357
448
* thread associated to a given pair of tensor element and outer schedule
358
449
* iteration.
359
- * Note that the test for a single thread is performed by looking
360
- * at the range of "thread". This range may be larger than the number
361
- * of threads, such that multiple instances may get mapped to the same thread.
362
- * Requiring different such instances is therefore slightly more conservative
363
- * than strictly needed.
364
450
*/
365
- bool isPromotableToRegisterBelowThreads (
451
+ bool isPromotableToRegistersBelow (
366
452
const TensorReferenceGroup& group,
453
+ const detail::ScheduleTree* root,
454
+ const detail::ScheduleTree* scope,
367
455
isl::multi_union_pw_aff outer,
368
456
isl::multi_union_pw_aff thread) {
369
457
auto originalAccesses = group.originalAccesses ();
370
458
371
- // Return early if more than one element needs to be stored in registers.
372
- // TODO: support arrays in registers if they are only accessed with constant
373
- // subscripts, e.g. if the inner loops are fully unrolled.
374
- auto sizes = group.approximationSizes ();
375
- auto nElements =
376
- std::accumulate (sizes.begin (), sizes.end (), 1 , std::multiplies<size_t >());
377
- if (nElements != 1 ) {
459
+ if (!accessSubscriptsAreUnrolledLoops (
460
+ group, root, scope, outer.flat_range_product (thread))) {
378
461
return false ;
379
462
}
380
463
@@ -567,21 +650,20 @@ void promoteGreedilyAtDepth(
567
650
}
568
651
569
652
// Promote at the positions of the thread specific markers.
570
- void promoteToRegistersBelowThreads (Scop& scop , size_t nRegisters) {
653
+ void promoteToRegistersBelowThreads (MappedScop& mscop , size_t nRegisters) {
571
654
using namespace tc ::polyhedral::detail;
572
655
656
+ auto & scop = mscop.scop ();
573
657
auto root = scop.scheduleRoot ();
658
+ auto threadMapping = mscop.threadMappingSchedule (root);
574
659
575
660
{
576
661
auto markers = findThreadSpecificMarkers (root);
577
662
578
663
for (auto marker : markers) {
579
664
auto partialSched = prefixSchedule (root, marker);
580
665
// Pure affine schedule without (mapping) filters.
581
- auto mapping = findThreadMappingAncestor (root, marker);
582
- auto prefixSchedMupa = prefixScheduleMupa (root, mapping);
583
- auto mapSchedMupa = infixScheduleMupa (root, mapping, marker);
584
- auto partialSchedMupa = prefixSchedMupa.flat_range_product (mapSchedMupa);
666
+ auto partialSchedMupa = partialScheduleMupa (root, marker);
585
667
586
668
// Because this function is called below the thread mapping marker,
587
669
// partialSched has been intersected with both the block and the thread
@@ -600,8 +682,8 @@ void promoteToRegistersBelowThreads(Scop& scop, size_t nRegisters) {
600
682
if (sizes.size () == 0 ) {
601
683
continue ;
602
684
}
603
- if (!isPromotableToRegisterBelowThreads (
604
- *group, prefixSchedMupa, mapSchedMupa )) {
685
+ if (!isPromotableToRegistersBelow (
686
+ *group, root, marker, partialSchedMupa, threadMapping )) {
605
687
continue ;
606
688
}
607
689
if (!hasReuseWithin (*group, partialSchedMupa)) {
0 commit comments