@@ -159,8 +159,7 @@ isl::union_map fullSchedule(const detail::ScheduleTree* root) {
159
159
}
160
160
}
161
161
162
- prefixMupa = isl::manage (isl_multi_union_pw_aff_intersect_domain (
163
- prefixMupa.release (), domain.copy ()));
162
+ prefixMupa = prefixMupa.intersect_domain (domain);
164
163
165
164
schedule = schedule.unite (isl::union_map::from (prefixMupa));
166
165
if (!schedule.is_single_valued ()) {
@@ -315,6 +314,67 @@ bool isCoalesced(
315
314
return true ;
316
315
}
317
316
317
+ /*
318
+ * Check if the given "group" can be promoted to registers for the given active
319
+ * domain points under full "schedule" where "nThreads" consecutive dimensions
320
+ * are mapped to threads (the innermost of them being mapped to thread x) and
321
+ * the depth of this mapping can be obtained from threadIdxxScheduleDepthState.
322
+ *
323
+ * In parciular, the group's footprint must contain only one element and the
324
+ * same tensor element should never be accessed by two different threads.
325
+ */
326
+ bool isPromotableToRegisterBelowThreads (
327
+ const ThreadIdxxScheduleDepthState& threadIdxxScheduleDepthState,
328
+ const TensorReferenceGroup& group,
329
+ isl::union_map schedule,
330
+ size_t nThreads,
331
+ isl::union_set activePoints) {
332
+ auto originalAccesses = group.originalAccesses ();
333
+
334
+ // Return early if more than one element needs to be stored in registers.
335
+ // TODO: support arrays in registers if they are only accessed with constant
336
+ // subscripts, e.g. if the inner loops are fully unrolled.
337
+ auto sizes = group.approximationSizes ();
338
+ auto nElements =
339
+ std::accumulate (sizes.begin (), sizes.end (), 1 , std::multiplies<size_t >());
340
+ if (nElements != 1 ) {
341
+ return false ;
342
+ }
343
+
344
+ // Since this function is only supposed to be called on groups seen _below_
345
+ // thread mapping, all refs in the group must all have the same thread-x
346
+ // depth.
347
+ auto depth = 1 +
348
+ computeThreadIdxxScheduleDepth (
349
+ threadIdxxScheduleDepthState,
350
+ originalAccesses.domain ().intersect (activePoints));
351
+
352
+ auto scheduledAccesses =
353
+ originalAccesses.gist_domain (originalAccesses.domain ())
354
+ .apply_domain (schedule);
355
+
356
+ // Scheduled accesses contain maps from schedule dimensions to tensor
357
+ // subscripts. Compute the relation that between the schedule dimensions
358
+ // mapped to threads and tensor subscripts by first removing dimensions
359
+ // following the one mapped to thread x (last one assuming inverse mapping
360
+ // order), then by equating all dimensions not mapped to threads to
361
+ // parameters. Promotion to registers is only allowed if the resulting
362
+ // relation is injective, i.e. the same tensor element is never accessed by
363
+ // more than one thread. Note that our current check is overly conservative
364
+ // because different values of schedule dimension may get mapped to the same
365
+ // thread, in which case the could access the same tensor element.
366
+ for (auto sa : isl::UnionAsVector<isl::union_map>(scheduledAccesses)) {
367
+ sa = sa.project_out (
368
+ isl::dim_type::in, depth, sa.dim (isl::dim_type::in) - depth);
369
+ sa = fixOuterInputDimsAsParameters (sa, depth - nThreads);
370
+ if (!sa.is_injective ()) {
371
+ return false ;
372
+ }
373
+ }
374
+
375
+ return true ;
376
+ }
377
+
318
378
/*
319
379
* Starting from the root, find bands where depth is reached. Using
320
380
* DFSPreorder to make sure order is specified and consistent for tests.
@@ -503,5 +563,111 @@ void promoteGreedilyAtDepth(
503
563
mapCopiesToThreads (mscop, unrollCopies);
504
564
}
505
565
566
+ namespace {
567
+ isl::val getParamValIfFixed (isl::union_set uset, int pos) {
568
+ auto val = isl::val::nan (uset.get_ctx ());
569
+ for (auto set : isl::UnionAsVector<isl::union_set>(uset)) {
570
+ auto currentVal = set.plain_get_val_if_fixed (isl::dim_type::param, pos);
571
+ if (currentVal.is_nan ()) {
572
+ return currentVal;
573
+ }
574
+ if (!val.is_nan () && val != currentVal) {
575
+ return isl::val::nan (uset.get_ctx ());
576
+ }
577
+ val = currentVal;
578
+ }
579
+ return val;
580
+ }
581
+ } // namespace
582
+
583
+ // Assuming the mapping to threads happens in inverse order, i.e. the innermost
584
+ // loop is mapped to thread x, promote below that depth.
585
+ void promoteToRegistersBelowThreads (
586
+ Scop& scop,
587
+ const ThreadIdxxScheduleDepthState& threadIdxxScheduleDepthState,
588
+ size_t nRegisters) {
589
+ using namespace tc ::polyhedral::detail;
590
+
591
+ auto root = scop.scheduleRoot ();
592
+
593
+ auto fullSched = fullSchedule (root);
594
+ for (const auto & kvp : threadIdxxScheduleDepthState) {
595
+ auto depth = kvp.second + 1 ;
596
+ auto subdomain = kvp.first ;
597
+
598
+ // Collect all bands where a member is located at the given depth.
599
+ auto bands = bandsContainingScheduleDepth (root, depth);
600
+ // We may have no band members mapped to thread x in case when we
601
+ // force-mapped everything to one thread.
602
+ if (bands.size () == 0 ) {
603
+ continue ;
604
+ }
605
+
606
+ // Keep only those bands for which this depth was recorded.
607
+ std::function<bool (ScheduleTree*)> keepActive =
608
+ [root, subdomain](const ScheduleTree* tree) {
609
+ isl::union_set active = activeDomainPoints (root, tree);
610
+ return !active.intersect (subdomain).is_empty ();
611
+ };
612
+ bands = functional::Filter (keepActive, bands);
613
+
614
+ // Make sure the band ends at thread x depth so we can promote below it.
615
+ bands = bandsSplitAfterDepth (bands, root, depth);
616
+
617
+ for (auto band : bands) {
618
+ // Find out how many threads are actually mapped. Active domain points
619
+ // will involve all mapping parameters when we take them below the
620
+ // mapping. Skip mapping parameters obviously mapped to 0, because they
621
+ // do not correspond to band members that should be fixed to obtain
622
+ // per-thread-group access relations.
623
+ auto points = activeDomainPoints (root, band);
624
+ size_t nMappedThreads = 0 ;
625
+ for (int j = 0 ; j < points.dim (isl::dim_type::param); ++j) {
626
+ auto id = points.get_space ().get_dim_id (isl::dim_type::param, j);
627
+ for (size_t i = 0 ; i < mapping::ThreadId::kMaxDim ; ++i) {
628
+ if (id != mapping::ThreadId::makeId (i)) {
629
+ continue ;
630
+ }
631
+ if (getParamValIfFixed (points, j) ==
632
+ isl::val::zero (points.get_ctx ())) {
633
+ continue ;
634
+ }
635
+ ++nMappedThreads;
636
+ break ;
637
+ }
638
+ }
639
+
640
+ auto groupMap = TensorReferenceGroup::accessedBySubtree (band, scop);
641
+ for (const auto & tensorGroups : groupMap) {
642
+ auto tensorId = tensorGroups.first ;
643
+
644
+ // TODO: sorting of groups and counting the number of promoted elements
645
+
646
+ for (const auto & group : tensorGroups.second ) {
647
+ auto sizes = group->approximationSizes ();
648
+ // No point in promoting a scalar that will go to a register anyway.
649
+ if (sizes.size () == 0 ) {
650
+ continue ;
651
+ }
652
+ if (!isPromotableToRegisterBelowThreads (
653
+ threadIdxxScheduleDepthState,
654
+ *group,
655
+ fullSched,
656
+ nMappedThreads,
657
+ points)) {
658
+ continue ;
659
+ }
660
+ if (!hasReuse (*group, fullSched, depth)) {
661
+ continue ;
662
+ }
663
+ // TODO: if something is already in shared, but reuse it within one
664
+ // thread only, there is no point in keeping it in shared _if_ it
665
+ // gets promoted into a register.
666
+ }
667
+ }
668
+ }
669
+ }
670
+ }
671
+
506
672
} // namespace polyhedral
507
673
} // namespace tc
0 commit comments