@@ -507,8 +507,16 @@ void updateTopLevelContext(detail::ScheduleTree* root, isl::set context) {
507
507
contextElem->context_ = contextElem->context_ & context;
508
508
}
509
509
510
- ScheduleTree* insertSequenceAbove (ScheduleTree* root, ScheduleTree* tree) {
511
- auto parent = tree->ancestor (root, 1 );
510
+ namespace {
511
+
512
+ // In a tree starting at "root", insert a sequence node with
513
+ // as only child the node identified by "tree"
514
+ // within the subtree at "relativeRoot".
515
+ ScheduleTree* insertSequenceAbove (
516
+ const ScheduleTree* root,
517
+ ScheduleTree* relativeRoot,
518
+ ScheduleTree* tree) {
519
+ auto parent = tree->ancestor (relativeRoot, 1 );
512
520
auto childPos = tree->positionInParent (parent);
513
521
auto filter = activeDomainPoints (root, tree).universe ();
514
522
parent->insertChild (
@@ -518,6 +526,12 @@ ScheduleTree* insertSequenceAbove(ScheduleTree* root, ScheduleTree* tree) {
518
526
return parent->child ({childPos});
519
527
}
520
528
529
+ } // namespace
530
+
531
+ ScheduleTree* insertSequenceAbove (ScheduleTree* root, ScheduleTree* tree) {
532
+ return insertSequenceAbove (root, root, tree);
533
+ }
534
+
521
535
void insertSequenceBelow (
522
536
const detail::ScheduleTree* root,
523
537
detail::ScheduleTree* tree) {
@@ -544,49 +558,77 @@ namespace {
544
558
/*
545
559
* Insert an empty extension node above "st" in a tree with the given root and
546
560
* return a pointer to the inserted extension node.
561
+ * The modification is performed within the subtree at "relativeRoot".
547
562
*/
548
563
detail::ScheduleTree* insertEmptyExtensionAbove (
549
- ScheduleTree* root,
564
+ const ScheduleTree* root,
565
+ ScheduleTree* relativeRoot,
550
566
ScheduleTree* st) {
551
567
auto domain = root->elemAs <ScheduleTreeElemDomain>();
552
568
CHECK (domain);
553
569
auto space = domain->domain_ .get_space ();
554
570
auto extension = isl::union_map::empty (space);
555
- return insertExtensionAbove (root , st, extension);
571
+ return insertExtensionAbove (relativeRoot , st, extension);
556
572
}
557
- } // namespace
558
573
559
- void insertExtensionLabelAt (
560
- ScheduleTree* root,
574
+ /*
575
+ * Construct an extension map for a zero-dimensional statement
576
+ * with the given identifier.
577
+ */
578
+ isl::map labelExtension (ScheduleTree* root, ScheduleTree* tree, isl::id id) {
579
+ auto prefix = prefixScheduleMupa (root, tree);
580
+ auto scheduleSpace = prefix.get_space ();
581
+ auto space = scheduleSpace.params ().set_from_params ().set_tuple_id (
582
+ isl::dim_type::set, id);
583
+ auto extensionSpace = scheduleSpace.map_from_domain_and_range (space);
584
+ return isl::map::universe (extensionSpace);
585
+ }
586
+
587
+ /*
588
+ * Construct a filter node for a zero-dimensional extension statement
589
+ * with the given extension map.
590
+ */
591
+ ScheduleTreeUPtr labelFilterFromExtension (isl::map extension) {
592
+ return detail::ScheduleTree::makeFilter (extension.range ());
593
+ }
594
+
595
+ /*
596
+ * Given a sequence node in the schedule tree, insert
597
+ * an extension with the given extension map and extension filter node
598
+ * before the child at position "pos".
599
+ * If "pos" is equal to the number of children, then
600
+ * the statement is added after the last child.
601
+ * The modification is performed within the subtree at "relativeRoot".
602
+ */
603
+ void insertExtensionAt (
604
+ const ScheduleTree* root,
605
+ ScheduleTree* relativeRoot,
561
606
ScheduleTree* seqNode,
562
607
size_t pos,
563
- isl::id id) {
564
- auto extensionTree = seqNode->ancestor (root, 1 );
608
+ isl::union_map extension,
609
+ ScheduleTreeUPtr&& filterNode) {
610
+ auto extensionTree = seqNode->ancestor (relativeRoot, 1 );
565
611
auto extensionNode =
566
612
extensionTree->elemAs <detail::ScheduleTreeElemExtension>();
567
613
if (!extensionNode) {
568
- extensionTree = insertEmptyExtensionAbove (root, seqNode);
614
+ extensionTree = insertEmptyExtensionAbove (root, relativeRoot, seqNode);
569
615
extensionNode = extensionTree->elemAs <detail::ScheduleTreeElemExtension>();
570
616
}
571
617
CHECK (extensionNode);
572
618
CHECK (seqNode->elemAs <detail::ScheduleTreeElemSequence>());
573
- auto prefix = prefixScheduleMupa (root, extensionTree);
574
- auto scheduleSpace = prefix.get_space ();
575
- auto space = scheduleSpace.params ().set_from_params ().set_tuple_id (
576
- isl::dim_type::set, id);
577
- auto extensionSpace = scheduleSpace.map_from_domain_and_range (space);
578
- auto extension = isl::map::universe (extensionSpace);
579
619
extensionNode->extension_ = extensionNode->extension_ .unite (extension);
580
- auto filterNode = detail::ScheduleTree::makeFilter (extension.range ());
581
620
seqNode->insertChild (pos, std::move (filterNode));
582
621
}
622
+ } // namespace
583
623
584
- void insertExtensionLabelBefore (
585
- ScheduleTree* root,
624
+ void insertExtensionBefore (
625
+ const ScheduleTree* root,
626
+ ScheduleTree* relativeRoot,
586
627
ScheduleTree* tree,
587
- isl::id id) {
628
+ isl::union_map extension,
629
+ ScheduleTreeUPtr&& filterNode) {
588
630
size_t pos;
589
- auto parent = tree->ancestor (root , 1 );
631
+ auto parent = tree->ancestor (relativeRoot , 1 );
590
632
ScheduleTree* seqTree;
591
633
if (tree->elemAs <detail::ScheduleTreeElemExtension>()) {
592
634
tree = tree->child ({0 });
@@ -598,21 +640,24 @@ void insertExtensionLabelBefore(
598
640
} else if (
599
641
parent->elemAs <detail::ScheduleTreeElemFilter>() &&
600
642
parent->ancestor (root, 1 )->elemAs <detail::ScheduleTreeElemSequence>()) {
601
- seqTree = parent->ancestor (root , 1 );
643
+ seqTree = parent->ancestor (relativeRoot , 1 );
602
644
pos = parent->positionInParent (seqTree);
603
645
} else {
604
- seqTree = insertSequenceAbove (root, tree);
646
+ seqTree = insertSequenceAbove (root, relativeRoot, tree);
605
647
pos = 0 ;
606
648
}
607
- insertExtensionLabelAt (root, seqTree, pos, id);
649
+ insertExtensionAt (
650
+ root, relativeRoot, seqTree, pos, extension, std::move (filterNode));
608
651
}
609
652
610
- void insertExtensionLabelAfter (
611
- ScheduleTree* root,
653
+ void insertExtensionAfter (
654
+ const ScheduleTree* root,
655
+ ScheduleTree* relativeRoot,
612
656
ScheduleTree* tree,
613
- isl::id id) {
657
+ isl::union_map extension,
658
+ ScheduleTreeUPtr&& filterNode) {
614
659
size_t pos;
615
- auto parent = tree->ancestor (root , 1 );
660
+ auto parent = tree->ancestor (relativeRoot , 1 );
616
661
ScheduleTree* seqTree;
617
662
if (tree->elemAs <detail::ScheduleTreeElemExtension>()) {
618
663
tree = tree->child ({0 });
@@ -624,13 +669,42 @@ void insertExtensionLabelAfter(
624
669
} else if (
625
670
parent->elemAs <detail::ScheduleTreeElemFilter>() &&
626
671
parent->ancestor (root, 1 )->elemAs <detail::ScheduleTreeElemSequence>()) {
627
- seqTree = parent->ancestor (root , 1 );
672
+ seqTree = parent->ancestor (relativeRoot , 1 );
628
673
pos = parent->positionInParent (seqTree) + 1 ;
629
674
} else {
630
- seqTree = insertSequenceAbove (root, tree);
675
+ seqTree = insertSequenceAbove (root, relativeRoot, tree);
631
676
pos = 1 ;
632
677
}
633
- insertExtensionLabelAt (root, seqTree, pos, id);
678
+ insertExtensionAt (
679
+ root, relativeRoot, seqTree, pos, extension, std::move (filterNode));
680
+ }
681
+
682
+ void insertExtensionLabelAt (
683
+ ScheduleTree* root,
684
+ ScheduleTree* seqNode,
685
+ size_t pos,
686
+ isl::id id) {
687
+ auto extension = labelExtension (root, seqNode, id);
688
+ auto filterNode = labelFilterFromExtension (extension);
689
+ insertExtensionAt (root, root, seqNode, pos, extension, std::move (filterNode));
690
+ }
691
+
692
+ void insertExtensionLabelBefore (
693
+ ScheduleTree* root,
694
+ ScheduleTree* tree,
695
+ isl::id id) {
696
+ auto extension = labelExtension (root, tree, id);
697
+ auto filterNode = labelFilterFromExtension (extension);
698
+ insertExtensionBefore (root, root, tree, extension, std::move (filterNode));
699
+ }
700
+
701
+ void insertExtensionLabelAfter (
702
+ ScheduleTree* root,
703
+ ScheduleTree* tree,
704
+ isl::id id) {
705
+ auto extension = labelExtension (root, tree, id);
706
+ auto filterNode = labelFilterFromExtension (extension);
707
+ insertExtensionAfter (root, root, tree, extension, std::move (filterNode));
634
708
}
635
709
636
710
namespace {
0 commit comments