@@ -483,6 +483,100 @@ std::vector<detail::ScheduleTree*> bandsSplitAfterDepth(
483
483
return functional::Map (splitAtDepth, bands);
484
484
}
485
485
486
+ /*
487
+ * Promote to shared memory in "scop" below the node "bandNode". Use at most
488
+ * "remainingMemory" bytes, and update the variable to reflect the amount of
489
+ * available shared memory remaining after promotion. "fullSched" is the union
490
+ * of schedules at leaves of the schedule tree, expected to be computed by
491
+ * "fullSchedule".
492
+ */
493
+ void promoteToSharedBelow (
494
+ Scop& scop,
495
+ detail::ScheduleTree* bandNode,
496
+ isl::union_map fullSched,
497
+ size_t & remainingMemory) {
498
+ auto root = scop.scheduleRoot ();
499
+ auto partialSched = partialSchedule (root, bandNode);
500
+
501
+ auto groupMap = TensorReferenceGroup::accessedWithin (
502
+ partialSched, scop.reads , scop.writes );
503
+ // Pure affine schedule without (mapping) filters.
504
+ auto partialSchedMupa = partialScheduleMupa (root, bandNode);
505
+
506
+ // Prepare groups for sorting, to have specified order necessary for
507
+ // reproducibility and tests.
508
+ using TensorGroupList = std::pair<isl::id, TensorGroupsInfo>;
509
+ std::vector<TensorGroupList> groupLists (
510
+ std::make_move_iterator (groupMap.begin ()),
511
+ std::make_move_iterator (groupMap.end ()));
512
+
513
+ // Computes the total number of references in all groups.
514
+ auto refsCount = [](const TensorGroupsInfo& info) {
515
+ size_t refs = 0 ;
516
+ for (auto const & group : info) {
517
+ refs += group->referenceIds ().size ();
518
+ }
519
+ return refs;
520
+ };
521
+
522
+ // Sort by the total number of references, then by name. Because names are
523
+ // guarenteed to be unique, the order is total.
524
+ std::sort (
525
+ groupLists.begin (),
526
+ groupLists.end (),
527
+ [refsCount](const TensorGroupList& l1, const TensorGroupList& l2) {
528
+ auto r1 = refsCount (l1.second );
529
+ auto r2 = refsCount (l2.second );
530
+ return r1 == r2 ? l1.first .get_name () < l2.first .get_name () : r1 < r2;
531
+ });
532
+ for (auto & tensorGroups : groupLists) {
533
+ auto tensorId = tensorGroups.first ;
534
+ // Sort the reference groups to prioritize groups with more references as
535
+ // they are more likely to benefit from promotion.
536
+ std::sort (
537
+ tensorGroups.second .begin (),
538
+ tensorGroups.second .end (),
539
+ [refsCount](
540
+ const std::unique_ptr<TensorReferenceGroup>& group1,
541
+ const std::unique_ptr<TensorReferenceGroup>& group2) {
542
+ return group1->referenceIds ().size () > group2->referenceIds ().size ();
543
+ });
544
+
545
+ for (auto & group : tensorGroups.second ) {
546
+ auto sizes = group->approximationSizes ();
547
+ if (sizes.size () == 0 ) {
548
+ throw promotion::PromotionLogicError (" cannot promote a scalar" );
549
+ }
550
+ if (sizes.back () % 2 == 0 ) {
551
+ sizes.back () += 1 ;
552
+ }
553
+ auto nApproximationElements = std::accumulate (
554
+ sizes.begin (), sizes.end (), 1 , std::multiplies<size_t >());
555
+ size_t memoryRequirement =
556
+ nApproximationElements * scop.findArgument (tensorId).type ().bytes ();
557
+ if (memoryRequirement > remainingMemory) {
558
+ continue ;
559
+ }
560
+ // Do not promote if the group features no reuse and is accessed in a
561
+ // coalesced way.
562
+ if (!hasReuseWithin (*group, partialSchedMupa) &&
563
+ !promotionImprovesCoalescing (root, bandNode, *group, fullSched)) {
564
+ continue ;
565
+ }
566
+
567
+ scop.promoteGroup (
568
+ Scop::PromotedDecl::Kind::SharedMem,
569
+ tensorId,
570
+ std::move (group),
571
+ bandNode,
572
+ partialSched,
573
+ true );
574
+ remainingMemory -= memoryRequirement;
575
+ }
576
+ }
577
+ scop.insertSyncsAroundCopies (bandNode);
578
+ }
579
+
486
580
/*
487
581
* For every place in the schedule tree where schedule depth (i.e., the number
488
582
* of preceding band members) is "depth", promote tensor reference groups to
@@ -525,86 +619,7 @@ void promoteToSharedGreedy(
525
619
// both.
526
620
size_t remainingMemory = maxMemory;
527
621
for (auto bandNode : bands) {
528
- auto partialSched = partialSchedule (root, bandNode);
529
-
530
- auto groupMap =
531
- TensorReferenceGroup::accessedWithin (partialSched, scop.body );
532
- // Pure affine schedule without (mapping) filters.
533
- auto partialSchedMupa = partialScheduleMupa (root, bandNode);
534
-
535
- // Prepare groups for sorting, to have specified order necessary for
536
- // reproducibility and tests.
537
- using TensorGroupList = std::pair<isl::id, TensorGroupsInfo>;
538
- std::vector<TensorGroupList> groupLists (
539
- std::make_move_iterator (groupMap.begin ()),
540
- std::make_move_iterator (groupMap.end ()));
541
-
542
- // Computes the total number of references in all groups.
543
- auto refsCount = [](const TensorGroupsInfo& info) {
544
- size_t refs = 0 ;
545
- for (auto const & group : info) {
546
- refs += group->referenceIds ().size ();
547
- }
548
- return refs;
549
- };
550
-
551
- // Sort by the total number of references, then by name. Because names are
552
- // guarenteed to be unique, the order is total.
553
- std::sort (
554
- groupLists.begin (),
555
- groupLists.end (),
556
- [refsCount](const TensorGroupList& l1, const TensorGroupList& l2) {
557
- auto r1 = refsCount (l1.second );
558
- auto r2 = refsCount (l2.second );
559
- return r1 == r2 ? l1.first .get_name () < l2.first .get_name () : r1 < r2;
560
- });
561
- for (auto & tensorGroups : groupLists) {
562
- auto tensorId = tensorGroups.first ;
563
- // Sort the reference groups to prioritize groups with more references as
564
- // they are more likely to benefit from promotion.
565
- std::sort (
566
- tensorGroups.second .begin (),
567
- tensorGroups.second .end (),
568
- [refsCount](
569
- const std::unique_ptr<TensorReferenceGroup>& group1,
570
- const std::unique_ptr<TensorReferenceGroup>& group2) {
571
- return group1->referenceIds ().size () >
572
- group2->referenceIds ().size ();
573
- });
574
-
575
- for (auto & group : tensorGroups.second ) {
576
- auto sizes = group->approximationSizes ();
577
- if (sizes.size () == 0 ) {
578
- throw promotion::PromotionLogicError (" cannot promote a scalar" );
579
- }
580
- if (sizes.back () % 2 == 0 ) {
581
- sizes.back () += 1 ;
582
- }
583
- auto nApproximationElements = std::accumulate (
584
- sizes.begin (), sizes.end (), 1 , std::multiplies<size_t >());
585
- size_t memoryRequirement =
586
- nApproximationElements * scop.findArgument (tensorId).type ().bytes ();
587
- if (memoryRequirement > remainingMemory) {
588
- continue ;
589
- }
590
- // Do not promote if the group features no reuse and is accessed in a
591
- // coalesced way.
592
- if (!hasReuseWithin (*group, partialSchedMupa) &&
593
- !promotionImprovesCoalescing (root, bandNode, *group, fullSched)) {
594
- continue ;
595
- }
596
-
597
- scop.promoteGroup (
598
- Scop::PromotedDecl::Kind::SharedMem,
599
- tensorId,
600
- std::move (group),
601
- bandNode,
602
- partialSched,
603
- true );
604
- remainingMemory -= memoryRequirement;
605
- }
606
- }
607
- scop.insertSyncsAroundCopies (bandNode);
622
+ promoteToSharedBelow (scop, bandNode, fullSched, remainingMemory);
608
623
}
609
624
}
610
625
0 commit comments