@@ -99,23 +99,22 @@ isl::union_map partialSchedule(
99
99
return partialScheduleImpl (root, node, true );
100
100
}
101
101
102
- // Get a set of domain elements that are active at the given node.
102
+ namespace {
103
+ // Get a set of domain elements that are active below
104
+ // the given branch of nodes.
103
105
//
104
106
// Domain elements are introduced by the root domain node. Filter nodes
105
107
// disable the points that do not intersect with the filter. Extension nodes
106
108
// are considered to introduce additional domain points.
107
- isl::union_set activeDomainPoints (
109
+ isl::union_set activeDomainPointsHelper (
108
110
const ScheduleTree* root,
109
- const ScheduleTree* node ) {
111
+ const vector< const ScheduleTree*>& nodes ) {
110
112
auto domainElem = root->elemAs <ScheduleTreeElemDomain>();
111
113
CHECK (domainElem) << " root must be a Domain node" << *root;
112
114
113
115
auto domain = domainElem->domain_ ;
114
- if (root == node) {
115
- return domain;
116
- }
117
116
118
- for (auto anc : node-> ancestors (root) ) {
117
+ for (auto anc : nodes ) {
119
118
if (auto filterElem = anc->elemAsBase <ScheduleTreeElemFilter>()) {
120
119
domain = domain.intersect (filterElem->filter_ );
121
120
} else if (auto extensionElem = anc->elemAs <ScheduleTreeElemExtension>()) {
@@ -134,6 +133,21 @@ isl::union_set activeDomainPoints(
134
133
}
135
134
return domain;
136
135
}
136
+ } // namespace
137
+
138
+ isl::union_set activeDomainPoints (
139
+ const ScheduleTree* root,
140
+ const ScheduleTree* node) {
141
+ return activeDomainPointsHelper (root, node->ancestors (root));
142
+ }
143
+
144
+ isl::union_set activeDomainPointsBelow (
145
+ const ScheduleTree* root,
146
+ const ScheduleTree* node) {
147
+ auto ancestors = node->ancestors (root);
148
+ ancestors.emplace_back (node);
149
+ return activeDomainPointsHelper (root, ancestors);
150
+ }
137
151
138
152
vector<ScheduleTree*> collectScheduleTreesPath (
139
153
std::function<ScheduleTree*(ScheduleTree*)> next,
@@ -473,8 +487,16 @@ void updateTopLevelContext(detail::ScheduleTree* root, isl::set context) {
473
487
contextElem->context_ = contextElem->context_ & context;
474
488
}
475
489
476
- ScheduleTree* insertSequenceAbove (ScheduleTree* root, ScheduleTree* tree) {
477
- auto parent = tree->ancestor (root, 1 );
490
+ namespace {
491
+
492
+ // In a tree starting at "root", insert a sequence node with
493
+ // as only child the node identified by "tree"
494
+ // within the subtree at "relativeRoot".
495
+ ScheduleTree* insertSequenceAbove (
496
+ const ScheduleTree* root,
497
+ ScheduleTree* relativeRoot,
498
+ ScheduleTree* tree) {
499
+ auto parent = tree->ancestor (relativeRoot, 1 );
478
500
auto childPos = tree->positionInParent (parent);
479
501
auto filter = activeDomainPoints (root, tree).universe ();
480
502
parent->insertChild (
@@ -484,11 +506,27 @@ ScheduleTree* insertSequenceAbove(ScheduleTree* root, ScheduleTree* tree) {
484
506
return parent->child ({childPos});
485
507
}
486
508
509
+ } // namespace
510
+
511
+ ScheduleTree* insertSequenceAbove (ScheduleTree* root, ScheduleTree* tree) {
512
+ return insertSequenceAbove (root, root, tree);
513
+ }
514
+
515
+ void insertSequenceBelow (
516
+ const detail::ScheduleTree* root,
517
+ detail::ScheduleTree* tree) {
518
+ auto numChildren = tree->numChildren ();
519
+ CHECK_LE (numChildren, 1u );
520
+ auto filter = activeDomainPointsBelow (root, tree).universe ();
521
+ auto node = ScheduleTree::makeFilter (filter, tree->detachChildren ());
522
+ tree->appendChild (ScheduleTree::makeSequence (std::move (node)));
523
+ }
524
+
487
525
ScheduleTree* insertExtensionAbove (
488
- ScheduleTree* root ,
526
+ ScheduleTree* relativeRoot ,
489
527
ScheduleTree* tree,
490
528
isl::union_map extension) {
491
- auto parent = tree->ancestor (root , 1 );
529
+ auto parent = tree->ancestor (relativeRoot , 1 );
492
530
auto childPos = tree->positionInParent (parent);
493
531
auto child = parent->detachChild (childPos);
494
532
parent->insertChild (
@@ -500,85 +538,153 @@ namespace {
500
538
/*
501
539
* Insert an empty extension node above "st" in a tree with the given root and
502
540
* return a pointer to the inserted extension node.
541
+ * The modification is performed within the subtree at "relativeRoot".
503
542
*/
504
543
detail::ScheduleTree* insertEmptyExtensionAbove (
505
- ScheduleTree* root,
544
+ const ScheduleTree* root,
545
+ ScheduleTree* relativeRoot,
506
546
ScheduleTree* st) {
507
547
auto domain = root->elemAs <ScheduleTreeElemDomain>();
508
548
CHECK (domain);
509
549
auto space = domain->domain_ .get_space ();
510
550
auto extension = isl::union_map::empty (space);
511
- return insertExtensionAbove (root , st, extension);
551
+ return insertExtensionAbove (relativeRoot , st, extension);
512
552
}
513
- } // namespace
514
553
515
- void insertExtensionLabelAt (
516
- ScheduleTree* root,
554
+ /*
555
+ * Construct an extension map for a zero-dimensional statement
556
+ * with the given identifier.
557
+ */
558
+ isl::map labelExtension (ScheduleTree* root, ScheduleTree* tree, isl::id id) {
559
+ auto prefix = prefixScheduleMupa (root, tree);
560
+ auto scheduleSpace = prefix.get_space ();
561
+ auto space = scheduleSpace.params ().set_from_params ().set_tuple_id (
562
+ isl::dim_type::set, id);
563
+ auto extensionSpace = scheduleSpace.map_from_domain_and_range (space);
564
+ return isl::map::universe (extensionSpace);
565
+ }
566
+
567
+ /*
568
+ * Construct a filter node for a zero-dimensional extension statement
569
+ * with the given extension map.
570
+ */
571
+ ScheduleTreeUPtr labelFilterFromExtension (isl::map extension) {
572
+ return detail::ScheduleTree::makeFilter (extension.range ());
573
+ }
574
+
575
+ /*
576
+ * Given a sequence node in the schedule tree, insert
577
+ * an extension with the given extension map and extension filter node
578
+ * before the child at position "pos".
579
+ * If "pos" is equal to the number of children, then
580
+ * the statement is added after the last child.
581
+ * The modification is performed within the subtree at "relativeRoot".
582
+ */
583
+ void insertExtensionAt (
584
+ const ScheduleTree* root,
585
+ ScheduleTree* relativeRoot,
517
586
ScheduleTree* seqNode,
518
587
size_t pos,
519
- isl::id id) {
520
- auto extensionTree = seqNode->ancestor (root, 1 );
588
+ isl::union_map extension,
589
+ ScheduleTreeUPtr&& filterNode) {
590
+ auto extensionTree = seqNode->ancestor (relativeRoot, 1 );
521
591
auto extensionNode =
522
592
extensionTree->elemAs <detail::ScheduleTreeElemExtension>();
523
593
if (!extensionNode) {
524
- extensionTree = insertEmptyExtensionAbove (root, seqNode);
594
+ extensionTree = insertEmptyExtensionAbove (root, relativeRoot, seqNode);
525
595
extensionNode = extensionTree->elemAs <detail::ScheduleTreeElemExtension>();
526
596
}
527
597
CHECK (extensionNode);
528
598
CHECK (seqNode->elemAs <detail::ScheduleTreeElemSequence>());
529
- auto prefix = prefixScheduleMupa (root, extensionTree);
530
- auto scheduleSpace = prefix.get_space ();
531
- auto space = scheduleSpace.params ().set_from_params ().set_tuple_id (
532
- isl::dim_type::set, id);
533
- auto extensionSpace = scheduleSpace.map_from_domain_and_range (space);
534
- auto extension = isl::map::universe (extensionSpace);
535
599
extensionNode->extension_ = extensionNode->extension_ .unite (extension);
536
- auto filterNode = detail::ScheduleTree::makeFilter (extension.range ());
537
600
seqNode->insertChild (pos, std::move (filterNode));
538
601
}
602
+ } // namespace
539
603
540
- void insertExtensionLabelBefore (
541
- ScheduleTree* root,
604
+ void insertExtensionBefore (
605
+ const ScheduleTree* root,
606
+ ScheduleTree* relativeRoot,
542
607
ScheduleTree* tree,
543
- isl::id id) {
608
+ isl::union_map extension,
609
+ ScheduleTreeUPtr&& filterNode) {
544
610
size_t pos;
545
- auto parent = tree->ancestor (root , 1 );
611
+ auto parent = tree->ancestor (relativeRoot , 1 );
546
612
ScheduleTree* seqTree;
613
+ if (tree->elemAs <detail::ScheduleTreeElemExtension>()) {
614
+ tree = tree->child ({0 });
615
+ parent = tree;
616
+ }
547
617
if (tree->elemAs <detail::ScheduleTreeElemSequence>()) {
548
618
seqTree = tree;
549
619
pos = 0 ;
550
620
} else if (
551
621
parent->elemAs <detail::ScheduleTreeElemFilter>() &&
552
622
parent->ancestor (root, 1 )->elemAs <detail::ScheduleTreeElemSequence>()) {
553
- seqTree = parent->ancestor (root , 1 );
623
+ seqTree = parent->ancestor (relativeRoot , 1 );
554
624
pos = parent->positionInParent (seqTree);
555
625
} else {
556
- seqTree = insertSequenceAbove (root, tree);
626
+ seqTree = insertSequenceAbove (root, relativeRoot, tree);
557
627
pos = 0 ;
558
628
}
559
- insertExtensionLabelAt (root, seqTree, pos, id);
629
+ insertExtensionAt (
630
+ root, relativeRoot, seqTree, pos, extension, std::move (filterNode));
560
631
}
561
632
562
- void insertExtensionLabelAfter (
563
- ScheduleTree* root,
633
+ void insertExtensionAfter (
634
+ const ScheduleTree* root,
635
+ ScheduleTree* relativeRoot,
564
636
ScheduleTree* tree,
565
- isl::id id) {
637
+ isl::union_map extension,
638
+ ScheduleTreeUPtr&& filterNode) {
566
639
size_t pos;
567
- auto parent = tree->ancestor (root , 1 );
640
+ auto parent = tree->ancestor (relativeRoot , 1 );
568
641
ScheduleTree* seqTree;
642
+ if (tree->elemAs <detail::ScheduleTreeElemExtension>()) {
643
+ tree = tree->child ({0 });
644
+ parent = tree;
645
+ }
569
646
if (tree->elemAs <detail::ScheduleTreeElemSequence>()) {
570
647
seqTree = tree;
571
648
pos = tree->numChildren ();
572
649
} else if (
573
650
parent->elemAs <detail::ScheduleTreeElemFilter>() &&
574
651
parent->ancestor (root, 1 )->elemAs <detail::ScheduleTreeElemSequence>()) {
575
- seqTree = parent->ancestor (root , 1 );
652
+ seqTree = parent->ancestor (relativeRoot , 1 );
576
653
pos = parent->positionInParent (seqTree) + 1 ;
577
654
} else {
578
- seqTree = insertSequenceAbove (root, tree);
655
+ seqTree = insertSequenceAbove (root, relativeRoot, tree);
579
656
pos = 1 ;
580
657
}
581
- insertExtensionLabelAt (root, seqTree, pos, id);
658
+ insertExtensionAt (
659
+ root, relativeRoot, seqTree, pos, extension, std::move (filterNode));
660
+ }
661
+
662
+ void insertExtensionLabelAt (
663
+ ScheduleTree* root,
664
+ ScheduleTree* seqNode,
665
+ size_t pos,
666
+ isl::id id) {
667
+ auto extension = labelExtension (root, seqNode, id);
668
+ auto filterNode = labelFilterFromExtension (extension);
669
+ insertExtensionAt (root, root, seqNode, pos, extension, std::move (filterNode));
670
+ }
671
+
672
+ void insertExtensionLabelBefore (
673
+ ScheduleTree* root,
674
+ ScheduleTree* tree,
675
+ isl::id id) {
676
+ auto extension = labelExtension (root, tree, id);
677
+ auto filterNode = labelFilterFromExtension (extension);
678
+ insertExtensionBefore (root, root, tree, extension, std::move (filterNode));
679
+ }
680
+
681
+ void insertExtensionLabelAfter (
682
+ ScheduleTree* root,
683
+ ScheduleTree* tree,
684
+ isl::id id) {
685
+ auto extension = labelExtension (root, tree, id);
686
+ auto filterNode = labelFilterFromExtension (extension);
687
+ insertExtensionAfter (root, root, tree, extension, std::move (filterNode));
582
688
}
583
689
584
690
namespace {
0 commit comments