Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 58cd495

Browse files
committed
rename ScheduleTreeElem* to ScheduleTree*
The preceding commits removed the notion of a schedule tree element. There is now no point why the specific node types should be called ScheduleTreeElem while the generic class is called ScheduleTree. Drop the "Elem" part from their names. Alternatively, we could use ScheduleTreeNode everywhere, but it would have required us to change around 400 occurrences of ScheduleTree in the codebase.
1 parent 18e8728 commit 58cd495

21 files changed

+253
-273
lines changed

tc/core/polyhedral/codegen.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ isl::id_list Codegen::makeLoopIterators(
3131
detail::ScheduleTree::collect(root, detail::ScheduleTreeType::Band);
3232
size_t n = 0;
3333
for (auto const& node : bands) {
34-
auto bandElem = node->as<detail::ScheduleTreeElemBand>();
34+
auto bandElem = node->as<detail::ScheduleTreeBand>();
3535
auto depth = node->scheduleDepth(root) + bandElem->nMember();
3636
if (depth > n) {
3737
n = depth;

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -785,8 +785,8 @@ string emitCudaKernel(
785785
const std::string& specializedName,
786786
const MappedScop& mscop) {
787787
// Expecting a schedule with domain root and context first child.
788-
TC_CHECK(mscop.schedule()->as<detail::ScheduleTreeElemDomain>());
789-
TC_CHECK(mscop.schedule()->child({0})->as<detail::ScheduleTreeElemContext>());
788+
TC_CHECK(mscop.schedule()->as<detail::ScheduleTreeDomain>());
789+
TC_CHECK(mscop.schedule()->child({0})->as<detail::ScheduleTreeContext>());
790790
const auto& scop = mscop.scop();
791791

792792
// Make a map of the specialized scalar parameter values

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ void validate(const detail::ScheduleTree* root) {
7272
root);
7373
}
7474

75-
bool anyNonCoincidentMember(const detail::ScheduleTreeElemBand* band) {
75+
bool anyNonCoincidentMember(const detail::ScheduleTreeBand* band) {
7676
return band->nOuterCoincident() < band->nMember();
7777
}
7878

@@ -140,7 +140,7 @@ detail::ScheduleTree* MappedScop::map(
140140
detail::ScheduleTree* MappedScop::mapBlocksForward(
141141
detail::ScheduleTree* band,
142142
size_t nToMap) {
143-
auto bandNode = band->as<detail::ScheduleTreeElemBand>();
143+
auto bandNode = band->as<detail::ScheduleTreeBand>();
144144
TC_CHECK(bandNode) << "expected a band, got " << *band;
145145

146146
auto list = bandNode->mupa_.get_union_pw_aff_list();
@@ -155,7 +155,7 @@ void MappedScop::mapToBlocksAndScaleBand(
155155
std::vector<size_t> tileSizes) {
156156
using namespace tc::polyhedral::detail;
157157

158-
auto bandNode = band->as<ScheduleTreeElemBand>();
158+
auto bandNode = band->as<ScheduleTreeBand>();
159159
TC_CHECK(bandNode->permutable_) << "cannot map non-permutable band to blocks";
160160

161161
auto nBlocksToMap = bandNode->nOuterCoincident();
@@ -235,7 +235,7 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
235235
for (auto c : tree->children()) {
236236
found |= detectReductions(c);
237237
}
238-
auto band = tree->as<detail::ScheduleTreeElemBand>();
238+
auto band = tree->as<detail::ScheduleTreeBand>();
239239
// Nested reductions are not currently supported.
240240
if (!band || found) {
241241
return found;
@@ -296,7 +296,7 @@ bool MappedScop::needReductionSeparation(const detail::ScheduleTree* st) {
296296
isl::multi_union_pw_aff MappedScop::reductionMapSchedule(
297297
const detail::ScheduleTree* st) {
298298
TC_CHECK(reductionBandUpdates_.count(st) == 1);
299-
auto reductionBand = st->as<detail::ScheduleTreeElemBand>();
299+
auto reductionBand = st->as<detail::ScheduleTreeBand>();
300300
TC_CHECK(reductionBand);
301301

302302
auto nMember = reductionBand->nMember();
@@ -359,7 +359,7 @@ detail::ScheduleTree* MappedScop::separateReduction(detail::ScheduleTree* st) {
359359

360360
detail::ScheduleTree* MappedScop::mapThreadsBackward(
361361
detail::ScheduleTree* band) {
362-
auto bandNode = band->as<detail::ScheduleTreeElemBand>();
362+
auto bandNode = band->as<detail::ScheduleTreeBand>();
363363
TC_CHECK(bandNode);
364364
auto nMember = bandNode->nMember();
365365
auto nToMap = std::min(nMember, numThreads.view.size());
@@ -376,7 +376,7 @@ detail::ScheduleTree* MappedScop::mapThreadsBackward(
376376
size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
377377
using namespace tc::polyhedral::detail;
378378

379-
auto bandNode = band->as<ScheduleTreeElemBand>();
379+
auto bandNode = band->as<ScheduleTreeBand>();
380380
// Cannot map non-permutable bands.
381381
if (!bandNode->permutable_) {
382382
return 0;
@@ -417,7 +417,7 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
417417
reductionBandUpdates_.erase(band);
418418
}
419419
band = child;
420-
bandNode = band->as<ScheduleTreeElemBand>();
420+
bandNode = band->as<ScheduleTreeBand>();
421421
}
422422

423423
if (nMappedThreads < bandNode->nMember()) {
@@ -450,11 +450,11 @@ bool hasOuterSequentialMember(
450450
auto ancestors = st->ancestors(root);
451451
std::reverse(ancestors.begin(), ancestors.end());
452452
for (auto a : ancestors) {
453-
auto band = a->as<detail::ScheduleTreeElemBand>();
453+
auto band = a->as<detail::ScheduleTreeBand>();
454454
if (band && band->nMember() > band->nOuterCoincident()) {
455455
return true;
456456
}
457-
if (a->as<detail::ScheduleTreeElemSequence>()) {
457+
if (a->as<detail::ScheduleTreeSequence>()) {
458458
return false;
459459
}
460460
}
@@ -542,7 +542,7 @@ Scop::SyncLevel MappedScop::findBestSync(
542542

543543
TC_CHECK_LE(1u, scop_->scheduleRoot()->children().size());
544544
auto contextSt = scop_->scheduleRoot()->children()[0];
545-
auto contextElem = contextSt->as<detail::ScheduleTreeElemContext>();
545+
auto contextElem = contextSt->as<detail::ScheduleTreeContext>();
546546
TC_CHECK(nullptr != contextElem);
547547
dependences = dependences.intersect_params(contextElem->context_);
548548

@@ -705,7 +705,7 @@ std::vector<std::pair<int, int>> MappedScop::findBestSyncConfigInSeq(
705705
}
706706

707707
void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) {
708-
TC_CHECK(seq->as<detail::ScheduleTreeElemSequence>());
708+
TC_CHECK(seq->as<detail::ScheduleTreeSequence>());
709709

710710
auto children = seq->children();
711711
auto nChildren = children.size();
@@ -779,7 +779,7 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
779779
}
780780
auto n = nChildren > 0 ? *std::max_element(nInner.begin(), nInner.end()) : 0;
781781
if (nChildren > 1) {
782-
auto needSync = st->as<detail::ScheduleTreeElemSequence>() && n > 0;
782+
auto needSync = st->as<detail::ScheduleTreeSequence>() && n > 0;
783783
if (n > 0) {
784784
for (size_t i = 0; i < nChildren; ++i) {
785785
fixThreadsBelow(*this, children[i], nInner[i]);
@@ -790,7 +790,7 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
790790
}
791791
}
792792

793-
if (auto band = st->as<detail::ScheduleTreeElemBand>()) {
793+
if (auto band = st->as<detail::ScheduleTreeBand>()) {
794794
if (n == 0) {
795795
// If children were not mapped to threads, the current band can be mapped.
796796
// First, map the coincidence and reduction dimension to threads.
@@ -1040,7 +1040,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
10401040
sharedMemorySize -= reductionMemoryRequirement;
10411041
}
10421042

1043-
auto band = outerBand->as<ScheduleTreeElemBand>();
1043+
auto band = outerBand->as<ScheduleTreeBand>();
10441044
LOG_IF(WARNING, FLAGS_debug_tc_mapper && band->nMember() == 0)
10451045
<< "Aborting memory promotion because outer band has 0 members (NYI)";
10461046
if (band->nMember() > 0 && sharedMemorySize > 0) {

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ void mapCopiesToThreads(MappedScop& mscop, bool unroll) {
6464
}
6565

6666
auto bandNode = node->child({0});
67-
auto band = bandNode->as<ScheduleTreeElemBand>();
67+
auto band = bandNode->as<ScheduleTreeBand>();
6868
if (!band) {
6969
throw promotion::PromotionLogicError("no copy band");
7070
}
@@ -143,7 +143,7 @@ std::vector<T> collectBranchMarkers(T root, T node) {
143143
isl::union_map fullSchedule(const detail::ScheduleTree* root) {
144144
using namespace tc::polyhedral::detail;
145145

146-
if (!root->as<ScheduleTreeElemDomain>()) {
146+
if (!root->as<ScheduleTreeDomain>()) {
147147
throw promotion::PromotionLogicError("expected root to be a domain node");
148148
}
149149

@@ -157,18 +157,18 @@ isl::union_map fullSchedule(const detail::ScheduleTree* root) {
157157
// are innermost, the partial schedule can no longer be affected by deeper
158158
// nodes and hence is full.
159159
auto schedule = isl::union_map::empty(
160-
root->as<ScheduleTreeElemDomain>()->domain_.get_space());
160+
root->as<ScheduleTreeDomain>()->domain_.get_space());
161161
for (auto node : leaves) {
162-
auto domain = root->as<ScheduleTreeElemDomain>()->domain_;
162+
auto domain = root->as<ScheduleTreeDomain>()->domain_;
163163
auto prefixMupa = prefixScheduleMupa(root, node);
164-
if (auto band = node->as<ScheduleTreeElemBand>()) {
164+
if (auto band = node->as<ScheduleTreeBand>()) {
165165
prefixMupa = prefixMupa.flat_range_product(band->mupa_);
166166
}
167167

168168
auto pathToRoot = node->ancestors(root);
169169
pathToRoot.push_back(node);
170170
for (auto n : pathToRoot) {
171-
if (auto filterNode = n->as<ScheduleTreeElemFilter>()) {
171+
if (auto filterNode = n->as<ScheduleTreeFilter>()) {
172172
domain = domain.intersect(filterNode->filter_);
173173
}
174174
}
@@ -308,7 +308,7 @@ isl::union_set collectMappingsTo(const Scop& scop) {
308308
mappingFilters = functional::Filter(isMappingTo<MappingType>, mappingFilters);
309309
auto mapping = isl::union_set::empty(domain.get_space());
310310
for (auto mf : mappingFilters) {
311-
auto filterNode = mf->as<detail::ScheduleTreeElemMapping>();
311+
auto filterNode = mf->as<detail::ScheduleTreeMapping>();
312312
auto filter = filterNode->filter_.intersect(activeDomainPoints(root, mf));
313313
mapping = mapping.unite(filterNode->filter_);
314314
}
@@ -356,7 +356,7 @@ bool accessSubscriptsAreUnrolledLoops(
356356
auto leaves = functional::Filter(
357357
[](const ScheduleTree* tree) { return tree->numChildren() == 0; }, nodes);
358358

359-
auto domainNode = root->as<detail::ScheduleTreeElemDomain>();
359+
auto domainNode = root->as<detail::ScheduleTreeDomain>();
360360
TC_CHECK(domainNode);
361361
auto domain = domainNode->domain_;
362362

@@ -368,7 +368,7 @@ bool accessSubscriptsAreUnrolledLoops(
368368

369369
auto unrolledDims = isl::union_pw_aff_list(leaf->ctx_, 1);
370370
for (auto node : ancestors) {
371-
auto band = node->as<detail::ScheduleTreeElemBand>();
371+
auto band = node->as<detail::ScheduleTreeBand>();
372372
if (!band) {
373373
continue;
374374
}
@@ -455,7 +455,7 @@ std::vector<detail::ScheduleTree*> bandsContainingScheduleDepth(
455455
ScheduleTree::collectDFSPreorder(root, detail::ScheduleTreeType::Band);
456456
std::function<bool(ScheduleTree * st)> containsDepth = [&](ScheduleTree* st) {
457457
auto depthBefore = st->scheduleDepth(root);
458-
auto band = st->as<ScheduleTreeElemBand>();
458+
auto band = st->as<ScheduleTreeBand>();
459459
auto depthAfter = depthBefore + band->nMember();
460460
return depthBefore < depth && depthAfter >= depth;
461461
};
@@ -474,7 +474,7 @@ std::vector<detail::ScheduleTree*> bandsSplitAfterDepth(
474474

475475
std::function<ScheduleTree*(ScheduleTree*)> splitAtDepth =
476476
[&](ScheduleTree* st) {
477-
auto nMember = st->as<ScheduleTreeElemBand>()->nMember();
477+
auto nMember = st->as<ScheduleTreeBand>()->nMember();
478478
auto scheduleDepth = st->scheduleDepth(root);
479479
auto depthAfter = scheduleDepth + nMember;
480480
return depthAfter == depth ? st
@@ -641,15 +641,15 @@ void promoteGreedilyAtDepth(
641641
void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
642642
// Cannot promote below a sequence or a set node. Promotion may insert an
643643
// extension node, but sequence/set must be followed by filters.
644-
if (scope->as<detail::ScheduleTreeElemSequence>() ||
645-
scope->as<detail::ScheduleTreeElemSet>()) {
644+
if (scope->as<detail::ScheduleTreeSequence>() ||
645+
scope->as<detail::ScheduleTreeSet>()) {
646646
throw promotion::IncorrectScope("cannot promote under a sequence/set node");
647647
}
648648
// Cannot promote between a thread-mapped band and a thread-specific marker
649649
// node because the latter is used to identify thread-mapped bands as
650650
// immediate ancestors.
651651
if (scope->numChildren() == 1 &&
652-
scope->child({0})->as<detail::ScheduleTreeElemThreadSpecificMarker>()) {
652+
scope->child({0})->as<detail::ScheduleTreeThreadSpecificMarker>()) {
653653
throw promotion::IncorrectScope(
654654
"cannot promote above a thread-specific marker node");
655655
}
@@ -756,7 +756,7 @@ void promoteToRegistersAtDepth(MappedScop& mscop, size_t depth) {
756756
auto bands = bandsContainingScheduleDepth(root, depth);
757757
bands = functional::Filter(
758758
[root, depth](ScheduleTree* tree) {
759-
auto band = tree->as<ScheduleTreeElemBand>();
759+
auto band = tree->as<ScheduleTreeBand>();
760760
return !isThreadMappedBand(tree) ||
761761
tree->scheduleDepth(root) + band->nMember() == depth;
762762
},

tc/core/polyhedral/cuda/tighten_launch_bounds.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ size_t maxValue(const Scop& scop, const MappingIdType& id) {
7777
auto filters = root->collect(root, ScheduleTreeType::Mapping);
7878
filters = functional::Filter(isMappingTo<MappingIdType>, filters);
7979
for (auto p : filters) {
80-
auto mappingNode = p->as<ScheduleTreeElemMapping>();
80+
auto mappingNode = p->as<ScheduleTreeMapping>();
8181
auto active = activeDomainPoints(root, p).intersect_params(params);
8282
active = active.intersect(mappingNode->filter_);
8383
auto range = rangeOfMappingParameter(active.params(), id);

tc/core/polyhedral/memory_promotion.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ isl::multi_aff dropDummyTensorDimensions(
453453
return isl::multi_aff(space, list);
454454
}
455455

456-
inline void unrollAllMembers(detail::ScheduleTreeElemBand* band) {
456+
inline void unrollAllMembers(detail::ScheduleTreeBand* band) {
457457
band->unroll_ = std::vector<bool>(band->nMember(), true);
458458
}
459459

@@ -494,8 +494,8 @@ ScheduleTree* insertCopiesUnder(
494494
auto writeBandNode = ScheduleTree::makeBand(writeSchedule);
495495

496496
if (unrollAllCopies) {
497-
unrollAllMembers(readBandNode->as<detail::ScheduleTreeElemBand>());
498-
unrollAllMembers(writeBandNode->as<detail::ScheduleTreeElemBand>());
497+
unrollAllMembers(readBandNode->as<detail::ScheduleTreeBand>());
498+
unrollAllMembers(writeBandNode->as<detail::ScheduleTreeBand>());
499499
}
500500

501501
auto extension =

tc/core/polyhedral/reduction_matcher.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ namespace tc {
2626
namespace polyhedral {
2727

2828
using detail::ScheduleTree;
29-
using detail::ScheduleTreeElemBand;
30-
using detail::ScheduleTreeElemFilter;
29+
using detail::ScheduleTreeBand;
30+
using detail::ScheduleTreeFilter;
3131

3232
namespace {
3333

0 commit comments

Comments
 (0)