Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 2cd56da

Browse files
Merge pull request #320 from facebookresearch/pr/copies_under
insertCopiesUnder: use insertExtension{Before,After}
2 parents 2f98f2e + 5725f4f commit 2cd56da

File tree

6 files changed

+217
-112
lines changed

6 files changed

+217
-112
lines changed

tc/core/polyhedral/memory_promotion.cc

Lines changed: 13 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -391,34 +391,6 @@ TensorReferenceGroup::referenceIds() const {
391391
}
392392

393393
namespace {
394-
bool hasCopyExtensionSingleChild(const ScheduleTree* tree) {
395-
if (tree->numChildren() != 1) {
396-
return false;
397-
}
398-
399-
auto extensionNode =
400-
tree->child({0})->elemAs<detail::ScheduleTreeElemExtension>();
401-
if (!extensionNode) {
402-
return false;
403-
}
404-
405-
if ((tree->child({0})->numChildren() != 1) &&
406-
(tree->child({0, 0})->elemAs<detail::ScheduleTreeElemSequence>())) {
407-
return false;
408-
}
409-
410-
for (auto e : isl::UnionAsVector<isl::union_map>(extensionNode->extension_)) {
411-
if (!e.has_tuple_name(isl::dim_type::out)) {
412-
return false;
413-
}
414-
if (e.get_tuple_name(isl::dim_type::out) != kReadIdName &&
415-
e.get_tuple_name(isl::dim_type::out) != kWriteIdName) {
416-
return false;
417-
}
418-
}
419-
return true;
420-
}
421-
422394
// Construct the set containing all tensor elements.
423395
//
424396
// Find the Halide image corresponding to the given tensorId. Transform its
@@ -524,48 +496,26 @@ ScheduleTree* insertCopiesUnder(
524496
bool reads = !group.scopedReads().is_empty();
525497
bool writes = !group.scopedWrites().is_empty();
526498

527-
if (hasCopyExtensionSingleChild(tree)) {
528-
auto extensionNode = tree->child({0});
529-
auto sequenceNode = tree->child({0, 0});
530-
531-
auto& ext =
532-
extensionNode->elemAs<detail::ScheduleTreeElemExtension>()->extension_;
533-
if (reads) {
534-
ext = ext.unite(isl::union_map(readExtension));
535-
sequenceNode->insertChild(0, std::move(readFilterNode));
536-
}
537-
if (writes) {
538-
ext = ext.unite(isl::union_map(writeExtension));
539-
sequenceNode->appendChild(std::move(writeFilterNode));
540-
}
541-
return tree;
499+
if (tree->numChildren() == 0) {
500+
// The point underneath a leaf node cannot be referenced,
501+
// so insert a dummy sequence first. It will be extended
502+
// with the reads and/or writes.
503+
insertSequenceBelow(root, tree);
542504
}
543505

544-
auto mainCompFilter = activeDomainPoints(root, tree).universe();
545-
auto mainCompFilterNode =
546-
ScheduleTree::makeFilter(mainCompFilter, tree->detachChildren());
547-
548-
// XXX: I don't really like the syntax-imposed impossibility to create a
549-
// sequence node with no children.
550-
auto sequenceNode = ScheduleTree::makeSequence(
551-
reads ? std::move(readFilterNode) : std::move(mainCompFilterNode));
552506
if (reads) {
553-
sequenceNode->appendChild(std::move(mainCompFilterNode));
507+
insertExtensionBefore(
508+
root, tree, tree->child({0}), readExtension, std::move(readFilterNode));
554509
}
555510
if (writes) {
556-
sequenceNode->appendChild(std::move(writeFilterNode));
511+
insertExtensionAfter(
512+
root,
513+
tree,
514+
tree->child({0}),
515+
writeExtension,
516+
std::move(writeFilterNode));
557517
}
558518

559-
auto extensionUmap = isl::union_map::empty(promotionSpace.params());
560-
if (reads) {
561-
extensionUmap = extensionUmap.unite(readExtension);
562-
}
563-
if (writes) {
564-
extensionUmap = extensionUmap.unite(writeExtension);
565-
}
566-
auto extensionNode =
567-
ScheduleTree::makeExtension(extensionUmap, std::move(sequenceNode));
568-
tree->appendChild(std::move(extensionNode));
569519
return tree;
570520
}
571521
} // namespace polyhedral

tc/core/polyhedral/schedule_isl_conversion.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "tc/external/isl.h"
2323

2424
#include "tc/core/flags.h"
25+
#include "tc/core/polyhedral/schedule_transforms.h"
2526
#include "tc/external/isl.h"
2627

2728
using namespace std;

tc/core/polyhedral/schedule_transforms.cc

Lines changed: 146 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -99,23 +99,22 @@ isl::union_map partialSchedule(
9999
return partialScheduleImpl(root, node, true);
100100
}
101101

102-
// Get a set of domain elements that are active at the given node.
102+
namespace {
103+
// Get a set of domain elements that are active below
104+
// the given branch of nodes.
103105
//
104106
// Domain elements are introduced by the root domain node. Filter nodes
105107
// disable the points that do not intersect with the filter. Extension nodes
106108
// are considered to introduce additional domain points.
107-
isl::union_set activeDomainPoints(
109+
isl::union_set activeDomainPointsHelper(
108110
const ScheduleTree* root,
109-
const ScheduleTree* node) {
111+
const vector<const ScheduleTree*>& nodes) {
110112
auto domainElem = root->elemAs<ScheduleTreeElemDomain>();
111113
CHECK(domainElem) << "root must be a Domain node" << *root;
112114

113115
auto domain = domainElem->domain_;
114-
if (root == node) {
115-
return domain;
116-
}
117116

118-
for (auto anc : node->ancestors(root)) {
117+
for (auto anc : nodes) {
119118
if (auto filterElem = anc->elemAsBase<ScheduleTreeElemFilter>()) {
120119
domain = domain.intersect(filterElem->filter_);
121120
} else if (auto extensionElem = anc->elemAs<ScheduleTreeElemExtension>()) {
@@ -134,6 +133,21 @@ isl::union_set activeDomainPoints(
134133
}
135134
return domain;
136135
}
136+
} // namespace
137+
138+
isl::union_set activeDomainPoints(
139+
const ScheduleTree* root,
140+
const ScheduleTree* node) {
141+
return activeDomainPointsHelper(root, node->ancestors(root));
142+
}
143+
144+
isl::union_set activeDomainPointsBelow(
145+
const ScheduleTree* root,
146+
const ScheduleTree* node) {
147+
auto ancestors = node->ancestors(root);
148+
ancestors.emplace_back(node);
149+
return activeDomainPointsHelper(root, ancestors);
150+
}
137151

138152
vector<ScheduleTree*> collectScheduleTreesPath(
139153
std::function<ScheduleTree*(ScheduleTree*)> next,
@@ -473,8 +487,16 @@ void updateTopLevelContext(detail::ScheduleTree* root, isl::set context) {
473487
contextElem->context_ = contextElem->context_ & context;
474488
}
475489

476-
ScheduleTree* insertSequenceAbove(ScheduleTree* root, ScheduleTree* tree) {
477-
auto parent = tree->ancestor(root, 1);
490+
namespace {
491+
492+
// In a tree starting at "root", insert a sequence node with
493+
// as only child the node identified by "tree"
494+
// within the subtree at "relativeRoot".
495+
ScheduleTree* insertSequenceAbove(
496+
const ScheduleTree* root,
497+
ScheduleTree* relativeRoot,
498+
ScheduleTree* tree) {
499+
auto parent = tree->ancestor(relativeRoot, 1);
478500
auto childPos = tree->positionInParent(parent);
479501
auto filter = activeDomainPoints(root, tree).universe();
480502
parent->insertChild(
@@ -484,11 +506,27 @@ ScheduleTree* insertSequenceAbove(ScheduleTree* root, ScheduleTree* tree) {
484506
return parent->child({childPos});
485507
}
486508

509+
} // namespace
510+
511+
ScheduleTree* insertSequenceAbove(ScheduleTree* root, ScheduleTree* tree) {
512+
return insertSequenceAbove(root, root, tree);
513+
}
514+
515+
void insertSequenceBelow(
516+
const detail::ScheduleTree* root,
517+
detail::ScheduleTree* tree) {
518+
auto numChildren = tree->numChildren();
519+
CHECK_LE(numChildren, 1u);
520+
auto filter = activeDomainPointsBelow(root, tree).universe();
521+
auto node = ScheduleTree::makeFilter(filter, tree->detachChildren());
522+
tree->appendChild(ScheduleTree::makeSequence(std::move(node)));
523+
}
524+
487525
ScheduleTree* insertExtensionAbove(
488-
ScheduleTree* root,
526+
ScheduleTree* relativeRoot,
489527
ScheduleTree* tree,
490528
isl::union_map extension) {
491-
auto parent = tree->ancestor(root, 1);
529+
auto parent = tree->ancestor(relativeRoot, 1);
492530
auto childPos = tree->positionInParent(parent);
493531
auto child = parent->detachChild(childPos);
494532
parent->insertChild(
@@ -500,85 +538,153 @@ namespace {
500538
/*
501539
* Insert an empty extension node above "st" in a tree with the given root and
502540
* return a pointer to the inserted extension node.
541+
* The modification is performed within the subtree at "relativeRoot".
503542
*/
504543
detail::ScheduleTree* insertEmptyExtensionAbove(
505-
ScheduleTree* root,
544+
const ScheduleTree* root,
545+
ScheduleTree* relativeRoot,
506546
ScheduleTree* st) {
507547
auto domain = root->elemAs<ScheduleTreeElemDomain>();
508548
CHECK(domain);
509549
auto space = domain->domain_.get_space();
510550
auto extension = isl::union_map::empty(space);
511-
return insertExtensionAbove(root, st, extension);
551+
return insertExtensionAbove(relativeRoot, st, extension);
512552
}
513-
} // namespace
514553

515-
void insertExtensionLabelAt(
516-
ScheduleTree* root,
554+
/*
555+
* Construct an extension map for a zero-dimensional statement
556+
* with the given identifier.
557+
*/
558+
isl::map labelExtension(ScheduleTree* root, ScheduleTree* tree, isl::id id) {
559+
auto prefix = prefixScheduleMupa(root, tree);
560+
auto scheduleSpace = prefix.get_space();
561+
auto space = scheduleSpace.params().set_from_params().set_tuple_id(
562+
isl::dim_type::set, id);
563+
auto extensionSpace = scheduleSpace.map_from_domain_and_range(space);
564+
return isl::map::universe(extensionSpace);
565+
}
566+
567+
/*
568+
* Construct a filter node for a zero-dimensional extension statement
569+
* with the given extension map.
570+
*/
571+
ScheduleTreeUPtr labelFilterFromExtension(isl::map extension) {
572+
return detail::ScheduleTree::makeFilter(extension.range());
573+
}
574+
575+
/*
576+
* Given a sequence node in the schedule tree, insert
577+
* an extension with the given extension map and extension filter node
578+
* before the child at position "pos".
579+
* If "pos" is equal to the number of children, then
580+
* the statement is added after the last child.
581+
* The modification is performed within the subtree at "relativeRoot".
582+
*/
583+
void insertExtensionAt(
584+
const ScheduleTree* root,
585+
ScheduleTree* relativeRoot,
517586
ScheduleTree* seqNode,
518587
size_t pos,
519-
isl::id id) {
520-
auto extensionTree = seqNode->ancestor(root, 1);
588+
isl::union_map extension,
589+
ScheduleTreeUPtr&& filterNode) {
590+
auto extensionTree = seqNode->ancestor(relativeRoot, 1);
521591
auto extensionNode =
522592
extensionTree->elemAs<detail::ScheduleTreeElemExtension>();
523593
if (!extensionNode) {
524-
extensionTree = insertEmptyExtensionAbove(root, seqNode);
594+
extensionTree = insertEmptyExtensionAbove(root, relativeRoot, seqNode);
525595
extensionNode = extensionTree->elemAs<detail::ScheduleTreeElemExtension>();
526596
}
527597
CHECK(extensionNode);
528598
CHECK(seqNode->elemAs<detail::ScheduleTreeElemSequence>());
529-
auto prefix = prefixScheduleMupa(root, extensionTree);
530-
auto scheduleSpace = prefix.get_space();
531-
auto space = scheduleSpace.params().set_from_params().set_tuple_id(
532-
isl::dim_type::set, id);
533-
auto extensionSpace = scheduleSpace.map_from_domain_and_range(space);
534-
auto extension = isl::map::universe(extensionSpace);
535599
extensionNode->extension_ = extensionNode->extension_.unite(extension);
536-
auto filterNode = detail::ScheduleTree::makeFilter(extension.range());
537600
seqNode->insertChild(pos, std::move(filterNode));
538601
}
602+
} // namespace
539603

540-
void insertExtensionLabelBefore(
541-
ScheduleTree* root,
604+
void insertExtensionBefore(
605+
const ScheduleTree* root,
606+
ScheduleTree* relativeRoot,
542607
ScheduleTree* tree,
543-
isl::id id) {
608+
isl::union_map extension,
609+
ScheduleTreeUPtr&& filterNode) {
544610
size_t pos;
545-
auto parent = tree->ancestor(root, 1);
611+
auto parent = tree->ancestor(relativeRoot, 1);
546612
ScheduleTree* seqTree;
613+
if (tree->elemAs<detail::ScheduleTreeElemExtension>()) {
614+
tree = tree->child({0});
615+
parent = tree;
616+
}
547617
if (tree->elemAs<detail::ScheduleTreeElemSequence>()) {
548618
seqTree = tree;
549619
pos = 0;
550620
} else if (
551621
parent->elemAs<detail::ScheduleTreeElemFilter>() &&
552622
parent->ancestor(root, 1)->elemAs<detail::ScheduleTreeElemSequence>()) {
553-
seqTree = parent->ancestor(root, 1);
623+
seqTree = parent->ancestor(relativeRoot, 1);
554624
pos = parent->positionInParent(seqTree);
555625
} else {
556-
seqTree = insertSequenceAbove(root, tree);
626+
seqTree = insertSequenceAbove(root, relativeRoot, tree);
557627
pos = 0;
558628
}
559-
insertExtensionLabelAt(root, seqTree, pos, id);
629+
insertExtensionAt(
630+
root, relativeRoot, seqTree, pos, extension, std::move(filterNode));
560631
}
561632

562-
void insertExtensionLabelAfter(
563-
ScheduleTree* root,
633+
void insertExtensionAfter(
634+
const ScheduleTree* root,
635+
ScheduleTree* relativeRoot,
564636
ScheduleTree* tree,
565-
isl::id id) {
637+
isl::union_map extension,
638+
ScheduleTreeUPtr&& filterNode) {
566639
size_t pos;
567-
auto parent = tree->ancestor(root, 1);
640+
auto parent = tree->ancestor(relativeRoot, 1);
568641
ScheduleTree* seqTree;
642+
if (tree->elemAs<detail::ScheduleTreeElemExtension>()) {
643+
tree = tree->child({0});
644+
parent = tree;
645+
}
569646
if (tree->elemAs<detail::ScheduleTreeElemSequence>()) {
570647
seqTree = tree;
571648
pos = tree->numChildren();
572649
} else if (
573650
parent->elemAs<detail::ScheduleTreeElemFilter>() &&
574651
parent->ancestor(root, 1)->elemAs<detail::ScheduleTreeElemSequence>()) {
575-
seqTree = parent->ancestor(root, 1);
652+
seqTree = parent->ancestor(relativeRoot, 1);
576653
pos = parent->positionInParent(seqTree) + 1;
577654
} else {
578-
seqTree = insertSequenceAbove(root, tree);
655+
seqTree = insertSequenceAbove(root, relativeRoot, tree);
579656
pos = 1;
580657
}
581-
insertExtensionLabelAt(root, seqTree, pos, id);
658+
insertExtensionAt(
659+
root, relativeRoot, seqTree, pos, extension, std::move(filterNode));
660+
}
661+
662+
void insertExtensionLabelAt(
663+
ScheduleTree* root,
664+
ScheduleTree* seqNode,
665+
size_t pos,
666+
isl::id id) {
667+
auto extension = labelExtension(root, seqNode, id);
668+
auto filterNode = labelFilterFromExtension(extension);
669+
insertExtensionAt(root, root, seqNode, pos, extension, std::move(filterNode));
670+
}
671+
672+
void insertExtensionLabelBefore(
673+
ScheduleTree* root,
674+
ScheduleTree* tree,
675+
isl::id id) {
676+
auto extension = labelExtension(root, tree, id);
677+
auto filterNode = labelFilterFromExtension(extension);
678+
insertExtensionBefore(root, root, tree, extension, std::move(filterNode));
679+
}
680+
681+
void insertExtensionLabelAfter(
682+
ScheduleTree* root,
683+
ScheduleTree* tree,
684+
isl::id id) {
685+
auto extension = labelExtension(root, tree, id);
686+
auto filterNode = labelFilterFromExtension(extension);
687+
insertExtensionAfter(root, root, tree, extension, std::move(filterNode));
582688
}
583689

584690
namespace {

0 commit comments

Comments
 (0)