@@ -451,11 +451,12 @@ isl::Set<Tensor> tensorElementsSet(const Scop& scop, isl::id tensorId) {
451
451
* Note that this function drops the name of the target space of "schedule",
452
452
* but this space is irrelevant for the caller.
453
453
*/
454
- isl::multi_aff dropDummyTensorDimensions (
455
- isl::multi_aff schedule,
454
+ template <typename Domain>
455
+ isl::MultiAff<Domain, Domain> dropDummyTensorDimensions (
456
+ isl::MultiAff<Domain, Domain> schedule,
456
457
const Scop::PromotedDecl& decl) {
457
458
auto list = schedule.get_aff_list ();
458
- auto space = schedule.get_space ().domain ();
459
+ auto domainSpace = schedule.get_space ().domain ();
459
460
460
461
auto n = list.size ();
461
462
for (int i = n - 1 ; i >= 0 ; --i) {
@@ -464,8 +465,8 @@ isl::multi_aff dropDummyTensorDimensions(
464
465
}
465
466
}
466
467
467
- space = space. add_unnamed_tuple_ui (list.size ());
468
- return isl::multi_aff (space, list);
468
+ auto space = domainSpace. template add_unnamed_tuple_ui <Domain> (list.size ());
469
+ return isl::MultiAff<Domain, Domain> (space, list);
469
470
}
470
471
471
472
inline void unrollAllMembers (detail::ScheduleTreeBand* band) {
@@ -489,20 +490,25 @@ ScheduleTree* insertCopiesUnder(
489
490
// Take the set of all tensor elements.
490
491
auto tensorElements = tensorElementsSet (scop, tensorId);
491
492
492
- auto promotion = isl::map (group.promotion ()).set_range_tuple_id (groupId);
493
+ auto promotion =
494
+ group.promotion ().asMap ().set_range_tuple_id <Promoted>(groupId);
493
495
auto promotionSpace = promotion.get_space ();
494
496
495
- auto identityCopySchedule =
496
- isl::multi_aff::identity ( promotionSpace.range ().map_from_set ());
497
+ auto identityCopySchedule = isl::MultiAff<Promoted, Promoted>:: identity (
498
+ promotionSpace.range ().map_from_set ());
497
499
// Only iterate over significant tensor dimensions.
498
500
auto decl = scop.promotedDecl (groupId);
499
501
identityCopySchedule = dropDummyTensorDimensions (identityCopySchedule, decl);
500
- auto readSpace = promotionSpace.wrap ().set_set_tuple_id (readId);
501
- auto writeSpace = promotionSpace.wrap ().set_set_tuple_id (writeId);
502
+ auto readSpace = promotionSpace.wrap ().set_set_tuple_id <Statement> (readId);
503
+ auto writeSpace = promotionSpace.wrap ().set_set_tuple_id <Statement> (writeId);
502
504
auto readSchedule = isl::multi_union_pw_aff (identityCopySchedule.pullback (
503
- isl::multi_aff::wrapped_range_map (readSpace)));
505
+ isl::MultiAff<
506
+ isl::NamedPair<Statement, isl::Pair<Prefix, Tensor>, Promoted>,
507
+ Promoted>::wrapped_range_map (readSpace)));
504
508
auto writeSchedule = isl::multi_union_pw_aff (identityCopySchedule.pullback (
505
- isl::multi_aff::wrapped_range_map (writeSpace)));
509
+ isl::MultiAff<
510
+ isl::NamedPair<Statement, isl::Pair<Prefix, Tensor>, Promoted>,
511
+ Promoted>::wrapped_range_map (writeSpace)));
506
512
507
513
auto readBandNode = ScheduleTree::makeBand (
508
514
isl::MultiUnionPwAff<Statement, Band>(readSchedule));
@@ -524,19 +530,19 @@ ScheduleTree* insertCopiesUnder(
524
530
auto promotedFootprint =
525
531
group.promotedFootprint ().set_tuple_id <Promoted>(groupId);
526
532
auto scheduleUniverse =
527
- isl::set ::universe (promotionSpace.domain ().unwrap ().domain ());
533
+ isl::Set<Prefix> ::universe (promotionSpace.domain ().unwrap ().domain ());
528
534
auto arrayId = promotionSpace.domain ().unwrap ().get_map_range_tuple_id ();
529
535
auto approximatedRead =
530
536
group.approximateScopedAccesses ().intersect_range (tensorElements).wrap ();
531
537
auto product = approximatedRead.product (promotedFootprint);
532
538
auto readExtension =
533
- extension.intersect_range (product).set_range_tuple_id (readId);
539
+ extension.intersect_range (product).set_range_tuple_id <Statement> (readId);
534
540
auto writtenElements = group.scopedWrites ()
535
541
.intersect_range (tensorElements)
536
542
.wrap ()
537
543
.product (promotedFootprint);
538
- auto writeExtension =
539
- extension. intersect_range (writtenElements). set_range_tuple_id (writeId);
544
+ auto writeExtension = extension. intersect_range (writtenElements)
545
+ . set_range_tuple_id <Statement> (writeId);
540
546
541
547
auto readFilterNode = ScheduleTree::makeFilter (
542
548
isl::set::universe (readExtension.get_space ().range ()),
@@ -557,18 +563,14 @@ ScheduleTree* insertCopiesUnder(
557
563
558
564
if (reads) {
559
565
insertExtensionBefore (
560
- root,
561
- tree,
562
- tree->child ({0 }),
563
- isl::UnionMap<Prefix, Statement>(readExtension),
564
- std::move (readFilterNode));
566
+ root, tree, tree->child ({0 }), readExtension, std::move (readFilterNode));
565
567
}
566
568
if (writes) {
567
569
insertExtensionAfter (
568
570
root,
569
571
tree,
570
572
tree->child ({0 }),
571
- isl::UnionMap<Prefix, Statement>( writeExtension) ,
573
+ writeExtension,
572
574
std::move (writeFilterNode));
573
575
}
574
576
0 commit comments