Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 590c0ae

Browse files
authored
Merge pull request #543 from facebookresearch/schedule-xyz-subclasses
ScheduleTree subclasses
2 parents 7f4ed01 + 96d058f commit 590c0ae

23 files changed

+452
-543
lines changed

tc/core/polyhedral/codegen.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <sstream>
2020

2121
#include "tc/core/polyhedral/schedule_tree.h"
22+
#include "tc/core/polyhedral/schedule_tree_elem.h"
2223

2324
namespace tc {
2425
namespace polyhedral {
@@ -30,7 +31,7 @@ isl::id_list Codegen::makeLoopIterators(
3031
detail::ScheduleTree::collect(root, detail::ScheduleTreeType::Band);
3132
size_t n = 0;
3233
for (auto const& node : bands) {
33-
auto bandElem = node->elemAs<detail::ScheduleTreeElemBand>();
34+
auto bandElem = node->as<detail::ScheduleTreeBand>();
3435
auto depth = node->scheduleDepth(root) + bandElem->nMember();
3536
if (depth > n) {
3637
n = depth;

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -785,9 +785,8 @@ string emitCudaKernel(
785785
const std::string& specializedName,
786786
const MappedScop& mscop) {
787787
// Expecting a schedule with domain root and context first child.
788-
TC_CHECK(mscop.schedule()->elemAs<detail::ScheduleTreeElemDomain>());
789-
TC_CHECK(
790-
mscop.schedule()->child({0})->elemAs<detail::ScheduleTreeElemContext>());
788+
TC_CHECK(mscop.schedule()->as<detail::ScheduleTreeDomain>());
789+
TC_CHECK(mscop.schedule()->child({0})->as<detail::ScheduleTreeContext>());
791790
const auto& scop = mscop.scop();
792791

793792
// Make a map of the specialized scalar parameter values

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ void validate(const detail::ScheduleTree* root) {
7272
root);
7373
}
7474

75-
bool anyNonCoincidentMember(const detail::ScheduleTreeElemBand* band) {
75+
bool anyNonCoincidentMember(const detail::ScheduleTreeBand* band) {
7676
return band->nOuterCoincident() < band->nMember();
7777
}
7878

@@ -140,7 +140,7 @@ detail::ScheduleTree* MappedScop::map(
140140
detail::ScheduleTree* MappedScop::mapBlocksForward(
141141
detail::ScheduleTree* band,
142142
size_t nToMap) {
143-
auto bandNode = band->elemAs<detail::ScheduleTreeElemBand>();
143+
auto bandNode = band->as<detail::ScheduleTreeBand>();
144144
TC_CHECK(bandNode) << "expected a band, got " << *band;
145145

146146
auto list = bandNode->mupa_.get_union_pw_aff_list();
@@ -155,7 +155,7 @@ void MappedScop::mapToBlocksAndScaleBand(
155155
std::vector<size_t> tileSizes) {
156156
using namespace tc::polyhedral::detail;
157157

158-
auto bandNode = band->elemAs<ScheduleTreeElemBand>();
158+
auto bandNode = band->as<ScheduleTreeBand>();
159159
TC_CHECK(bandNode->permutable_) << "cannot map non-permutable band to blocks";
160160

161161
auto nBlocksToMap = bandNode->nOuterCoincident();
@@ -235,7 +235,7 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
235235
for (auto c : tree->children()) {
236236
found |= detectReductions(c);
237237
}
238-
auto band = tree->elemAs<detail::ScheduleTreeElemBand>();
238+
auto band = tree->as<detail::ScheduleTreeBand>();
239239
// Nested reductions are not currently supported.
240240
if (!band || found) {
241241
return found;
@@ -296,7 +296,7 @@ bool MappedScop::needReductionSeparation(const detail::ScheduleTree* st) {
296296
isl::multi_union_pw_aff MappedScop::reductionMapSchedule(
297297
const detail::ScheduleTree* st) {
298298
TC_CHECK(reductionBandUpdates_.count(st) == 1);
299-
auto reductionBand = st->elemAs<detail::ScheduleTreeElemBand>();
299+
auto reductionBand = st->as<detail::ScheduleTreeBand>();
300300
TC_CHECK(reductionBand);
301301

302302
auto nMember = reductionBand->nMember();
@@ -359,7 +359,7 @@ detail::ScheduleTree* MappedScop::separateReduction(detail::ScheduleTree* st) {
359359

360360
detail::ScheduleTree* MappedScop::mapThreadsBackward(
361361
detail::ScheduleTree* band) {
362-
auto bandNode = band->elemAs<detail::ScheduleTreeElemBand>();
362+
auto bandNode = band->as<detail::ScheduleTreeBand>();
363363
TC_CHECK(bandNode);
364364
auto nMember = bandNode->nMember();
365365
auto nToMap = std::min(nMember, numThreads.view.size());
@@ -376,7 +376,7 @@ detail::ScheduleTree* MappedScop::mapThreadsBackward(
376376
size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
377377
using namespace tc::polyhedral::detail;
378378

379-
auto bandNode = band->elemAs<ScheduleTreeElemBand>();
379+
auto bandNode = band->as<ScheduleTreeBand>();
380380
// Cannot map non-permutable bands.
381381
if (!bandNode->permutable_) {
382382
return 0;
@@ -417,7 +417,7 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
417417
reductionBandUpdates_.erase(band);
418418
}
419419
band = child;
420-
bandNode = band->elemAs<ScheduleTreeElemBand>();
420+
bandNode = band->as<ScheduleTreeBand>();
421421
}
422422

423423
if (nMappedThreads < bandNode->nMember()) {
@@ -450,11 +450,11 @@ bool hasOuterSequentialMember(
450450
auto ancestors = st->ancestors(root);
451451
std::reverse(ancestors.begin(), ancestors.end());
452452
for (auto a : ancestors) {
453-
auto band = a->elemAs<detail::ScheduleTreeElemBand>();
453+
auto band = a->as<detail::ScheduleTreeBand>();
454454
if (band && band->nMember() > band->nOuterCoincident()) {
455455
return true;
456456
}
457-
if (a->elemAs<detail::ScheduleTreeElemSequence>()) {
457+
if (a->as<detail::ScheduleTreeSequence>()) {
458458
return false;
459459
}
460460
}
@@ -542,7 +542,7 @@ Scop::SyncLevel MappedScop::findBestSync(
542542

543543
TC_CHECK_LE(1u, scop_->scheduleRoot()->children().size());
544544
auto contextSt = scop_->scheduleRoot()->children()[0];
545-
auto contextElem = contextSt->elemAs<detail::ScheduleTreeElemContext>();
545+
auto contextElem = contextSt->as<detail::ScheduleTreeContext>();
546546
TC_CHECK(nullptr != contextElem);
547547
dependences = dependences.intersect_params(contextElem->context_);
548548

@@ -705,7 +705,7 @@ std::vector<std::pair<int, int>> MappedScop::findBestSyncConfigInSeq(
705705
}
706706

707707
void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) {
708-
TC_CHECK(seq->elemAs<detail::ScheduleTreeElemSequence>());
708+
TC_CHECK(seq->as<detail::ScheduleTreeSequence>());
709709

710710
auto children = seq->children();
711711
auto nChildren = children.size();
@@ -779,7 +779,7 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
779779
}
780780
auto n = nChildren > 0 ? *std::max_element(nInner.begin(), nInner.end()) : 0;
781781
if (nChildren > 1) {
782-
auto needSync = st->elemAs<detail::ScheduleTreeElemSequence>() && n > 0;
782+
auto needSync = st->as<detail::ScheduleTreeSequence>() && n > 0;
783783
if (n > 0) {
784784
for (size_t i = 0; i < nChildren; ++i) {
785785
fixThreadsBelow(*this, children[i], nInner[i]);
@@ -790,7 +790,7 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
790790
}
791791
}
792792

793-
if (auto band = st->elemAs<detail::ScheduleTreeElemBand>()) {
793+
if (auto band = st->as<detail::ScheduleTreeBand>()) {
794794
if (n == 0) {
795795
// If children were not mapped to threads, the current band can be mapped.
796796
// First, map the coincidence and reduction dimension to threads.
@@ -1040,7 +1040,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
10401040
sharedMemorySize -= reductionMemoryRequirement;
10411041
}
10421042

1043-
auto band = outerBand->elemAs<ScheduleTreeElemBand>();
1043+
auto band = outerBand->as<ScheduleTreeBand>();
10441044
LOG_IF(WARNING, FLAGS_debug_tc_mapper && band->nMember() == 0)
10451045
<< "Aborting memory promotion because outer band has 0 members (NYI)";
10461046
if (band->nMember() > 0 && sharedMemorySize > 0) {

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ void mapCopiesToThreads(MappedScop& mscop, bool unroll) {
6464
}
6565

6666
auto bandNode = node->child({0});
67-
auto band = bandNode->elemAs<ScheduleTreeElemBand>();
67+
auto band = bandNode->as<ScheduleTreeBand>();
6868
if (!band) {
6969
throw promotion::PromotionLogicError("no copy band");
7070
}
@@ -143,7 +143,7 @@ std::vector<T> collectBranchMarkers(T root, T node) {
143143
isl::union_map fullSchedule(const detail::ScheduleTree* root) {
144144
using namespace tc::polyhedral::detail;
145145

146-
if (!root->elemAs<ScheduleTreeElemDomain>()) {
146+
if (!root->as<ScheduleTreeDomain>()) {
147147
throw promotion::PromotionLogicError("expected root to be a domain node");
148148
}
149149

@@ -157,18 +157,18 @@ isl::union_map fullSchedule(const detail::ScheduleTree* root) {
157157
// are innermost, the partial schedule can no longer be affected by deeper
158158
// nodes and hence is full.
159159
auto schedule = isl::union_map::empty(
160-
root->elemAs<ScheduleTreeElemDomain>()->domain_.get_space());
160+
root->as<ScheduleTreeDomain>()->domain_.get_space());
161161
for (auto node : leaves) {
162-
auto domain = root->elemAs<ScheduleTreeElemDomain>()->domain_;
162+
auto domain = root->as<ScheduleTreeDomain>()->domain_;
163163
auto prefixMupa = prefixScheduleMupa(root, node);
164-
if (auto band = node->elemAs<ScheduleTreeElemBand>()) {
164+
if (auto band = node->as<ScheduleTreeBand>()) {
165165
prefixMupa = prefixMupa.flat_range_product(band->mupa_);
166166
}
167167

168168
auto pathToRoot = node->ancestors(root);
169169
pathToRoot.push_back(node);
170170
for (auto n : pathToRoot) {
171-
if (auto filterNode = n->elemAs<ScheduleTreeElemFilter>()) {
171+
if (auto filterNode = n->as<ScheduleTreeFilter>()) {
172172
domain = domain.intersect(filterNode->filter_);
173173
}
174174
}
@@ -308,7 +308,7 @@ isl::union_set collectMappingsTo(const Scop& scop) {
308308
mappingFilters = functional::Filter(isMappingTo<MappingType>, mappingFilters);
309309
auto mapping = isl::union_set::empty(domain.get_space());
310310
for (auto mf : mappingFilters) {
311-
auto filterNode = mf->elemAs<detail::ScheduleTreeElemMapping>();
311+
auto filterNode = mf->as<detail::ScheduleTreeMapping>();
312312
auto filter = filterNode->filter_.intersect(activeDomainPoints(root, mf));
313313
mapping = mapping.unite(filterNode->filter_);
314314
}
@@ -356,7 +356,7 @@ bool accessSubscriptsAreUnrolledLoops(
356356
auto leaves = functional::Filter(
357357
[](const ScheduleTree* tree) { return tree->numChildren() == 0; }, nodes);
358358

359-
auto domainNode = root->elemAs<detail::ScheduleTreeElemDomain>();
359+
auto domainNode = root->as<detail::ScheduleTreeDomain>();
360360
TC_CHECK(domainNode);
361361
auto domain = domainNode->domain_;
362362

@@ -368,7 +368,7 @@ bool accessSubscriptsAreUnrolledLoops(
368368

369369
auto unrolledDims = isl::union_pw_aff_list(leaf->ctx_, 1);
370370
for (auto node : ancestors) {
371-
auto band = node->elemAs<detail::ScheduleTreeElemBand>();
371+
auto band = node->as<detail::ScheduleTreeBand>();
372372
if (!band) {
373373
continue;
374374
}
@@ -455,7 +455,7 @@ std::vector<detail::ScheduleTree*> bandsContainingScheduleDepth(
455455
ScheduleTree::collectDFSPreorder(root, detail::ScheduleTreeType::Band);
456456
std::function<bool(ScheduleTree * st)> containsDepth = [&](ScheduleTree* st) {
457457
auto depthBefore = st->scheduleDepth(root);
458-
auto band = st->elemAs<ScheduleTreeElemBand>();
458+
auto band = st->as<ScheduleTreeBand>();
459459
auto depthAfter = depthBefore + band->nMember();
460460
return depthBefore < depth && depthAfter >= depth;
461461
};
@@ -474,7 +474,7 @@ std::vector<detail::ScheduleTree*> bandsSplitAfterDepth(
474474

475475
std::function<ScheduleTree*(ScheduleTree*)> splitAtDepth =
476476
[&](ScheduleTree* st) {
477-
auto nMember = st->elemAs<ScheduleTreeElemBand>()->nMember();
477+
auto nMember = st->as<ScheduleTreeBand>()->nMember();
478478
auto scheduleDepth = st->scheduleDepth(root);
479479
auto depthAfter = scheduleDepth + nMember;
480480
return depthAfter == depth ? st
@@ -641,16 +641,15 @@ void promoteGreedilyAtDepth(
641641
void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
642642
// Cannot promote below a sequence or a set node. Promotion may insert an
643643
// extension node, but sequence/set must be followed by filters.
644-
if (scope->elemAs<detail::ScheduleTreeElemSequence>() ||
645-
scope->elemAs<detail::ScheduleTreeElemSet>()) {
644+
if (scope->as<detail::ScheduleTreeSequence>() ||
645+
scope->as<detail::ScheduleTreeSet>()) {
646646
throw promotion::IncorrectScope("cannot promote under a sequence/set node");
647647
}
648648
// Cannot promote between a thread-mapped band and a thread-specific marker
649649
// node because the latter is used to identify thread-mapped bands as
650650
// immediate ancestors.
651651
if (scope->numChildren() == 1 &&
652-
scope->child({0})
653-
->elemAs<detail::ScheduleTreeElemThreadSpecificMarker>()) {
652+
scope->child({0})->as<detail::ScheduleTreeThreadSpecificMarker>()) {
654653
throw promotion::IncorrectScope(
655654
"cannot promote above a thread-specific marker node");
656655
}
@@ -757,7 +756,7 @@ void promoteToRegistersAtDepth(MappedScop& mscop, size_t depth) {
757756
auto bands = bandsContainingScheduleDepth(root, depth);
758757
bands = functional::Filter(
759758
[root, depth](ScheduleTree* tree) {
760-
auto band = tree->elemAs<ScheduleTreeElemBand>();
759+
auto band = tree->as<ScheduleTreeBand>();
761760
return !isThreadMappedBand(tree) ||
762761
tree->scheduleDepth(root) + band->nMember() == depth;
763762
},

tc/core/polyhedral/cuda/tighten_launch_bounds.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ size_t maxValue(const Scop& scop, const MappingIdType& id) {
7777
auto filters = root->collect(root, ScheduleTreeType::Mapping);
7878
filters = functional::Filter(isMappingTo<MappingIdType>, filters);
7979
for (auto p : filters) {
80-
auto mappingNode = p->elemAs<ScheduleTreeElemMapping>();
80+
auto mappingNode = p->as<ScheduleTreeMapping>();
8181
auto active = activeDomainPoints(root, p).intersect_params(params);
8282
active = active.intersect(mappingNode->filter_);
8383
auto range = rangeOfMappingParameter(active.params(), id);

tc/core/polyhedral/memory_promotion.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ isl::multi_aff dropDummyTensorDimensions(
453453
return isl::multi_aff(space, list);
454454
}
455455

456-
inline void unrollAllMembers(detail::ScheduleTreeElemBand* band) {
456+
inline void unrollAllMembers(detail::ScheduleTreeBand* band) {
457457
band->unroll_ = std::vector<bool>(band->nMember(), true);
458458
}
459459

@@ -494,8 +494,8 @@ ScheduleTree* insertCopiesUnder(
494494
auto writeBandNode = ScheduleTree::makeBand(writeSchedule);
495495

496496
if (unrollAllCopies) {
497-
unrollAllMembers(readBandNode->elemAs<detail::ScheduleTreeElemBand>());
498-
unrollAllMembers(writeBandNode->elemAs<detail::ScheduleTreeElemBand>());
497+
unrollAllMembers(readBandNode->as<detail::ScheduleTreeBand>());
498+
unrollAllMembers(writeBandNode->as<detail::ScheduleTreeBand>());
499499
}
500500

501501
auto extension =

tc/core/polyhedral/reduction_matcher.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ namespace tc {
2626
namespace polyhedral {
2727

2828
using detail::ScheduleTree;
29-
using detail::ScheduleTreeElemBand;
30-
using detail::ScheduleTreeElemFilter;
29+
using detail::ScheduleTreeBand;
30+
using detail::ScheduleTreeFilter;
3131

3232
namespace {
3333

0 commit comments

Comments
 (0)