Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 2495c6f

Browse files
author
Sven Verdoolaege
committed
extract out insertExtensionBefore and insertExtensionAfter
This will be used in the next commit. The functions have both a root and a relative root, because the root is only needed to derive the active domain points, while the relative root indicates the part that may be modified.
1 parent f33a7ec commit 2495c6f

File tree

2 files changed

+135
-31
lines changed

2 files changed

+135
-31
lines changed

tc/core/polyhedral/schedule_transforms.cc

Lines changed: 105 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -507,8 +507,16 @@ void updateTopLevelContext(detail::ScheduleTree* root, isl::set context) {
507507
contextElem->context_ = contextElem->context_ & context;
508508
}
509509

510-
ScheduleTree* insertSequenceAbove(ScheduleTree* root, ScheduleTree* tree) {
511-
auto parent = tree->ancestor(root, 1);
510+
namespace {
511+
512+
// In a tree starting at "root", insert a sequence node with
513+
// as only child the node identified by "tree"
514+
// within the subtree at "relativeRoot".
515+
ScheduleTree* insertSequenceAbove(
516+
const ScheduleTree* root,
517+
ScheduleTree* relativeRoot,
518+
ScheduleTree* tree) {
519+
auto parent = tree->ancestor(relativeRoot, 1);
512520
auto childPos = tree->positionInParent(parent);
513521
auto filter = activeDomainPoints(root, tree).universe();
514522
parent->insertChild(
@@ -518,6 +526,12 @@ ScheduleTree* insertSequenceAbove(ScheduleTree* root, ScheduleTree* tree) {
518526
return parent->child({childPos});
519527
}
520528

529+
} // namespace
530+
531+
ScheduleTree* insertSequenceAbove(ScheduleTree* root, ScheduleTree* tree) {
532+
return insertSequenceAbove(root, root, tree);
533+
}
534+
521535
void insertSequenceBelow(
522536
const detail::ScheduleTree* root,
523537
detail::ScheduleTree* tree) {
@@ -544,49 +558,77 @@ namespace {
544558
/*
545559
* Insert an empty extension node above "st" in a tree with the given root and
546560
* return a pointer to the inserted extension node.
561+
* The modification is performed within the subtree at "relativeRoot".
547562
*/
548563
detail::ScheduleTree* insertEmptyExtensionAbove(
549-
ScheduleTree* root,
564+
const ScheduleTree* root,
565+
ScheduleTree* relativeRoot,
550566
ScheduleTree* st) {
551567
auto domain = root->elemAs<ScheduleTreeElemDomain>();
552568
CHECK(domain);
553569
auto space = domain->domain_.get_space();
554570
auto extension = isl::union_map::empty(space);
555-
return insertExtensionAbove(root, st, extension);
571+
return insertExtensionAbove(relativeRoot, st, extension);
556572
}
557-
} // namespace
558573

559-
void insertExtensionLabelAt(
560-
ScheduleTree* root,
574+
/*
575+
* Construct an extension map for a zero-dimensional statement
576+
* with the given identifier.
577+
*/
578+
isl::map labelExtension(ScheduleTree* root, ScheduleTree* tree, isl::id id) {
579+
auto prefix = prefixScheduleMupa(root, tree);
580+
auto scheduleSpace = prefix.get_space();
581+
auto space = scheduleSpace.params().set_from_params().set_tuple_id(
582+
isl::dim_type::set, id);
583+
auto extensionSpace = scheduleSpace.map_from_domain_and_range(space);
584+
return isl::map::universe(extensionSpace);
585+
}
586+
587+
/*
588+
* Construct a filter node for a zero-dimensional extension statement
589+
* with the given extension map.
590+
*/
591+
ScheduleTreeUPtr labelFilterFromExtension(isl::map extension) {
592+
return detail::ScheduleTree::makeFilter(extension.range());
593+
}
594+
595+
/*
596+
* Given a sequence node in the schedule tree, insert
597+
* an extension with the given extension map and extension filter node
598+
* before the child at position "pos".
599+
* If "pos" is equal to the number of children, then
600+
* the statement is added after the last child.
601+
* The modification is performed within the subtree at "relativeRoot".
602+
*/
603+
void insertExtensionAt(
604+
const ScheduleTree* root,
605+
ScheduleTree* relativeRoot,
561606
ScheduleTree* seqNode,
562607
size_t pos,
563-
isl::id id) {
564-
auto extensionTree = seqNode->ancestor(root, 1);
608+
isl::union_map extension,
609+
ScheduleTreeUPtr&& filterNode) {
610+
auto extensionTree = seqNode->ancestor(relativeRoot, 1);
565611
auto extensionNode =
566612
extensionTree->elemAs<detail::ScheduleTreeElemExtension>();
567613
if (!extensionNode) {
568-
extensionTree = insertEmptyExtensionAbove(root, seqNode);
614+
extensionTree = insertEmptyExtensionAbove(root, relativeRoot, seqNode);
569615
extensionNode = extensionTree->elemAs<detail::ScheduleTreeElemExtension>();
570616
}
571617
CHECK(extensionNode);
572618
CHECK(seqNode->elemAs<detail::ScheduleTreeElemSequence>());
573-
auto prefix = prefixScheduleMupa(root, extensionTree);
574-
auto scheduleSpace = prefix.get_space();
575-
auto space = scheduleSpace.params().set_from_params().set_tuple_id(
576-
isl::dim_type::set, id);
577-
auto extensionSpace = scheduleSpace.map_from_domain_and_range(space);
578-
auto extension = isl::map::universe(extensionSpace);
579619
extensionNode->extension_ = extensionNode->extension_.unite(extension);
580-
auto filterNode = detail::ScheduleTree::makeFilter(extension.range());
581620
seqNode->insertChild(pos, std::move(filterNode));
582621
}
622+
} // namespace
583623

584-
void insertExtensionLabelBefore(
585-
ScheduleTree* root,
624+
void insertExtensionBefore(
625+
const ScheduleTree* root,
626+
ScheduleTree* relativeRoot,
586627
ScheduleTree* tree,
587-
isl::id id) {
628+
isl::union_map extension,
629+
ScheduleTreeUPtr&& filterNode) {
588630
size_t pos;
589-
auto parent = tree->ancestor(root, 1);
631+
auto parent = tree->ancestor(relativeRoot, 1);
590632
ScheduleTree* seqTree;
591633
if (tree->elemAs<detail::ScheduleTreeElemExtension>()) {
592634
tree = tree->child({0});
@@ -598,21 +640,24 @@ void insertExtensionLabelBefore(
598640
} else if (
599641
parent->elemAs<detail::ScheduleTreeElemFilter>() &&
600642
parent->ancestor(root, 1)->elemAs<detail::ScheduleTreeElemSequence>()) {
601-
seqTree = parent->ancestor(root, 1);
643+
seqTree = parent->ancestor(relativeRoot, 1);
602644
pos = parent->positionInParent(seqTree);
603645
} else {
604-
seqTree = insertSequenceAbove(root, tree);
646+
seqTree = insertSequenceAbove(root, relativeRoot, tree);
605647
pos = 0;
606648
}
607-
insertExtensionLabelAt(root, seqTree, pos, id);
649+
insertExtensionAt(
650+
root, relativeRoot, seqTree, pos, extension, std::move(filterNode));
608651
}
609652

610-
void insertExtensionLabelAfter(
611-
ScheduleTree* root,
653+
void insertExtensionAfter(
654+
const ScheduleTree* root,
655+
ScheduleTree* relativeRoot,
612656
ScheduleTree* tree,
613-
isl::id id) {
657+
isl::union_map extension,
658+
ScheduleTreeUPtr&& filterNode) {
614659
size_t pos;
615-
auto parent = tree->ancestor(root, 1);
660+
auto parent = tree->ancestor(relativeRoot, 1);
616661
ScheduleTree* seqTree;
617662
if (tree->elemAs<detail::ScheduleTreeElemExtension>()) {
618663
tree = tree->child({0});
@@ -624,13 +669,42 @@ void insertExtensionLabelAfter(
624669
} else if (
625670
parent->elemAs<detail::ScheduleTreeElemFilter>() &&
626671
parent->ancestor(root, 1)->elemAs<detail::ScheduleTreeElemSequence>()) {
627-
seqTree = parent->ancestor(root, 1);
672+
seqTree = parent->ancestor(relativeRoot, 1);
628673
pos = parent->positionInParent(seqTree) + 1;
629674
} else {
630-
seqTree = insertSequenceAbove(root, tree);
675+
seqTree = insertSequenceAbove(root, relativeRoot, tree);
631676
pos = 1;
632677
}
633-
insertExtensionLabelAt(root, seqTree, pos, id);
678+
insertExtensionAt(
679+
root, relativeRoot, seqTree, pos, extension, std::move(filterNode));
680+
}
681+
682+
void insertExtensionLabelAt(
683+
ScheduleTree* root,
684+
ScheduleTree* seqNode,
685+
size_t pos,
686+
isl::id id) {
687+
auto extension = labelExtension(root, seqNode, id);
688+
auto filterNode = labelFilterFromExtension(extension);
689+
insertExtensionAt(root, root, seqNode, pos, extension, std::move(filterNode));
690+
}
691+
692+
void insertExtensionLabelBefore(
693+
ScheduleTree* root,
694+
ScheduleTree* tree,
695+
isl::id id) {
696+
auto extension = labelExtension(root, tree, id);
697+
auto filterNode = labelFilterFromExtension(extension);
698+
insertExtensionBefore(root, root, tree, extension, std::move(filterNode));
699+
}
700+
701+
void insertExtensionLabelAfter(
702+
ScheduleTree* root,
703+
ScheduleTree* tree,
704+
isl::id id) {
705+
auto extension = labelExtension(root, tree, id);
706+
auto filterNode = labelFilterFromExtension(extension);
707+
insertExtensionAfter(root, root, tree, extension, std::move(filterNode));
634708
}
635709

636710
namespace {

tc/core/polyhedral/schedule_transforms.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,36 @@ inline void insertMappingFilterBelow(
212212
const std::unordered_set<MappingIdType, typename MappingIdType::Hash>&
213213
mappingIds);
214214

215+
// Insert an extension with the given extension map and extension filter node
216+
// before node "tree".
217+
// If "tree" is a sequence node, an extension node with a sequence child,
218+
// or a grandchild of a sequence node,
219+
// then the new statement is inserted in the right position
220+
// of that sequence node.
221+
// Otherwise, a new sequence node is inserted.
222+
// The modification is performed within the subtree at "relativeRoot".
223+
void insertExtensionBefore(
224+
const detail::ScheduleTree* root,
225+
detail::ScheduleTree* relativeRoot,
226+
detail::ScheduleTree* tree,
227+
isl::union_map extension,
228+
ScheduleTreeUPtr&& filterNode);
229+
230+
// Insert an extension with the given extension map and extension filter node
231+
// after node "tree".
232+
// If "tree" is a sequence node, an extension node with a sequence child,
233+
// or a grandchild of a sequence node,
234+
// then the new statement is inserted in the right position
235+
// of that sequence node.
236+
// Otherwise, a new sequence node is inserted.
237+
// The modification is performed within the subtree at "relativeRoot".
238+
void insertExtensionAfter(
239+
const detail::ScheduleTree* root,
240+
detail::ScheduleTree* relativeRoot,
241+
detail::ScheduleTree* tree,
242+
isl::union_map extension,
243+
ScheduleTreeUPtr&& filterNode);
244+
215245
// Given a sequence node in the schedule tree, insert
216246
// a zero-dimensional extension statement with the given identifier
217247
// before the child at position "pos".

0 commit comments

Comments
 (0)