@@ -622,6 +622,9 @@ void promoteToRegistersBelowThreads(
622
622
// do not correspond to band members that should be fixed to obtain
623
623
// per-thread-group access relations.
624
624
auto points = activeDomainPoints (root, band);
625
+ auto partialSched = partialSchedule (root, band);
626
+ auto activeStmts = activeStatements (root, band);
627
+
625
628
size_t nMappedThreads = 0 ;
626
629
for (int j = 0 ; j < points.dim (isl::dim_type::param); ++j) {
627
630
auto id = points.get_space ().get_dim_id (isl::dim_type::param, j);
@@ -639,12 +642,12 @@ void promoteToRegistersBelowThreads(
639
642
}
640
643
641
644
auto groupMap = TensorReferenceGroup::accessedBySubtree (band, scop);
642
- for (const auto & tensorGroups : groupMap) {
645
+ for (auto & tensorGroups : groupMap) {
643
646
auto tensorId = tensorGroups.first ;
644
647
645
648
// TODO: sorting of groups and counting the number of promoted elements
646
649
647
- for (const auto & group : tensorGroups.second ) {
650
+ for (auto & group : tensorGroups.second ) {
648
651
auto sizes = group->approximationSizes ();
649
652
// No point in promoting a scalar that will go to a register anyway.
650
653
if (sizes.size () == 0 ) {
@@ -664,6 +667,13 @@ void promoteToRegistersBelowThreads(
664
667
// TODO: if something is already in shared, but reuse it within one
665
668
// thread only, there is no point in keeping it in shared _if_ it
666
669
// gets promoted into a register.
670
+ scop.promoteGroup (
671
+ Scop::PromotedDecl::Kind::Register,
672
+ tensorId,
673
+ std::move (group),
674
+ band,
675
+ activeStmts,
676
+ partialSched);
667
677
}
668
678
}
669
679
}
0 commit comments