Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 18e8728

Browse files
committed
ScheduleTree: rename "elemAs" to "as"
The preceding commits removed the notion of schedule tree element (an object contained in the tree nodes), making specific tree node types inherit from a generic one. Rename the function to make it clear to the caller that it casts the object it is being called on.
1 parent 6b55c2f commit 18e8728

16 files changed

+121
-129
lines changed

tc/core/polyhedral/codegen.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ isl::id_list Codegen::makeLoopIterators(
3131
detail::ScheduleTree::collect(root, detail::ScheduleTreeType::Band);
3232
size_t n = 0;
3333
for (auto const& node : bands) {
34-
auto bandElem = node->elemAs<detail::ScheduleTreeElemBand>();
34+
auto bandElem = node->as<detail::ScheduleTreeElemBand>();
3535
auto depth = node->scheduleDepth(root) + bandElem->nMember();
3636
if (depth > n) {
3737
n = depth;

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -785,9 +785,8 @@ string emitCudaKernel(
785785
const std::string& specializedName,
786786
const MappedScop& mscop) {
787787
// Expecting a schedule with domain root and context first child.
788-
TC_CHECK(mscop.schedule()->elemAs<detail::ScheduleTreeElemDomain>());
789-
TC_CHECK(
790-
mscop.schedule()->child({0})->elemAs<detail::ScheduleTreeElemContext>());
788+
TC_CHECK(mscop.schedule()->as<detail::ScheduleTreeElemDomain>());
789+
TC_CHECK(mscop.schedule()->child({0})->as<detail::ScheduleTreeElemContext>());
791790
const auto& scop = mscop.scop();
792791

793792
// Make a map of the specialized scalar parameter values

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ detail::ScheduleTree* MappedScop::map(
140140
detail::ScheduleTree* MappedScop::mapBlocksForward(
141141
detail::ScheduleTree* band,
142142
size_t nToMap) {
143-
auto bandNode = band->elemAs<detail::ScheduleTreeElemBand>();
143+
auto bandNode = band->as<detail::ScheduleTreeElemBand>();
144144
TC_CHECK(bandNode) << "expected a band, got " << *band;
145145

146146
auto list = bandNode->mupa_.get_union_pw_aff_list();
@@ -155,7 +155,7 @@ void MappedScop::mapToBlocksAndScaleBand(
155155
std::vector<size_t> tileSizes) {
156156
using namespace tc::polyhedral::detail;
157157

158-
auto bandNode = band->elemAs<ScheduleTreeElemBand>();
158+
auto bandNode = band->as<ScheduleTreeElemBand>();
159159
TC_CHECK(bandNode->permutable_) << "cannot map non-permutable band to blocks";
160160

161161
auto nBlocksToMap = bandNode->nOuterCoincident();
@@ -235,7 +235,7 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
235235
for (auto c : tree->children()) {
236236
found |= detectReductions(c);
237237
}
238-
auto band = tree->elemAs<detail::ScheduleTreeElemBand>();
238+
auto band = tree->as<detail::ScheduleTreeElemBand>();
239239
// Nested reductions are not currently supported.
240240
if (!band || found) {
241241
return found;
@@ -296,7 +296,7 @@ bool MappedScop::needReductionSeparation(const detail::ScheduleTree* st) {
296296
isl::multi_union_pw_aff MappedScop::reductionMapSchedule(
297297
const detail::ScheduleTree* st) {
298298
TC_CHECK(reductionBandUpdates_.count(st) == 1);
299-
auto reductionBand = st->elemAs<detail::ScheduleTreeElemBand>();
299+
auto reductionBand = st->as<detail::ScheduleTreeElemBand>();
300300
TC_CHECK(reductionBand);
301301

302302
auto nMember = reductionBand->nMember();
@@ -359,7 +359,7 @@ detail::ScheduleTree* MappedScop::separateReduction(detail::ScheduleTree* st) {
359359

360360
detail::ScheduleTree* MappedScop::mapThreadsBackward(
361361
detail::ScheduleTree* band) {
362-
auto bandNode = band->elemAs<detail::ScheduleTreeElemBand>();
362+
auto bandNode = band->as<detail::ScheduleTreeElemBand>();
363363
TC_CHECK(bandNode);
364364
auto nMember = bandNode->nMember();
365365
auto nToMap = std::min(nMember, numThreads.view.size());
@@ -376,7 +376,7 @@ detail::ScheduleTree* MappedScop::mapThreadsBackward(
376376
size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
377377
using namespace tc::polyhedral::detail;
378378

379-
auto bandNode = band->elemAs<ScheduleTreeElemBand>();
379+
auto bandNode = band->as<ScheduleTreeElemBand>();
380380
// Cannot map non-permutable bands.
381381
if (!bandNode->permutable_) {
382382
return 0;
@@ -417,7 +417,7 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
417417
reductionBandUpdates_.erase(band);
418418
}
419419
band = child;
420-
bandNode = band->elemAs<ScheduleTreeElemBand>();
420+
bandNode = band->as<ScheduleTreeElemBand>();
421421
}
422422

423423
if (nMappedThreads < bandNode->nMember()) {
@@ -450,11 +450,11 @@ bool hasOuterSequentialMember(
450450
auto ancestors = st->ancestors(root);
451451
std::reverse(ancestors.begin(), ancestors.end());
452452
for (auto a : ancestors) {
453-
auto band = a->elemAs<detail::ScheduleTreeElemBand>();
453+
auto band = a->as<detail::ScheduleTreeElemBand>();
454454
if (band && band->nMember() > band->nOuterCoincident()) {
455455
return true;
456456
}
457-
if (a->elemAs<detail::ScheduleTreeElemSequence>()) {
457+
if (a->as<detail::ScheduleTreeElemSequence>()) {
458458
return false;
459459
}
460460
}
@@ -542,7 +542,7 @@ Scop::SyncLevel MappedScop::findBestSync(
542542

543543
TC_CHECK_LE(1u, scop_->scheduleRoot()->children().size());
544544
auto contextSt = scop_->scheduleRoot()->children()[0];
545-
auto contextElem = contextSt->elemAs<detail::ScheduleTreeElemContext>();
545+
auto contextElem = contextSt->as<detail::ScheduleTreeElemContext>();
546546
TC_CHECK(nullptr != contextElem);
547547
dependences = dependences.intersect_params(contextElem->context_);
548548

@@ -705,7 +705,7 @@ std::vector<std::pair<int, int>> MappedScop::findBestSyncConfigInSeq(
705705
}
706706

707707
void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) {
708-
TC_CHECK(seq->elemAs<detail::ScheduleTreeElemSequence>());
708+
TC_CHECK(seq->as<detail::ScheduleTreeElemSequence>());
709709

710710
auto children = seq->children();
711711
auto nChildren = children.size();
@@ -779,7 +779,7 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
779779
}
780780
auto n = nChildren > 0 ? *std::max_element(nInner.begin(), nInner.end()) : 0;
781781
if (nChildren > 1) {
782-
auto needSync = st->elemAs<detail::ScheduleTreeElemSequence>() && n > 0;
782+
auto needSync = st->as<detail::ScheduleTreeElemSequence>() && n > 0;
783783
if (n > 0) {
784784
for (size_t i = 0; i < nChildren; ++i) {
785785
fixThreadsBelow(*this, children[i], nInner[i]);
@@ -790,7 +790,7 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
790790
}
791791
}
792792

793-
if (auto band = st->elemAs<detail::ScheduleTreeElemBand>()) {
793+
if (auto band = st->as<detail::ScheduleTreeElemBand>()) {
794794
if (n == 0) {
795795
// If children were not mapped to threads, the current band can be mapped.
796796
// First, map the coincidence and reduction dimension to threads.
@@ -1040,7 +1040,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
10401040
sharedMemorySize -= reductionMemoryRequirement;
10411041
}
10421042

1043-
auto band = outerBand->elemAs<ScheduleTreeElemBand>();
1043+
auto band = outerBand->as<ScheduleTreeElemBand>();
10441044
LOG_IF(WARNING, FLAGS_debug_tc_mapper && band->nMember() == 0)
10451045
<< "Aborting memory promotion because outer band has 0 members (NYI)";
10461046
if (band->nMember() > 0 && sharedMemorySize > 0) {

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ void mapCopiesToThreads(MappedScop& mscop, bool unroll) {
6464
}
6565

6666
auto bandNode = node->child({0});
67-
auto band = bandNode->elemAs<ScheduleTreeElemBand>();
67+
auto band = bandNode->as<ScheduleTreeElemBand>();
6868
if (!band) {
6969
throw promotion::PromotionLogicError("no copy band");
7070
}
@@ -143,7 +143,7 @@ std::vector<T> collectBranchMarkers(T root, T node) {
143143
isl::union_map fullSchedule(const detail::ScheduleTree* root) {
144144
using namespace tc::polyhedral::detail;
145145

146-
if (!root->elemAs<ScheduleTreeElemDomain>()) {
146+
if (!root->as<ScheduleTreeElemDomain>()) {
147147
throw promotion::PromotionLogicError("expected root to be a domain node");
148148
}
149149

@@ -157,18 +157,18 @@ isl::union_map fullSchedule(const detail::ScheduleTree* root) {
157157
// are innermost, the partial schedule can no longer be affected by deeper
158158
// nodes and hence is full.
159159
auto schedule = isl::union_map::empty(
160-
root->elemAs<ScheduleTreeElemDomain>()->domain_.get_space());
160+
root->as<ScheduleTreeElemDomain>()->domain_.get_space());
161161
for (auto node : leaves) {
162-
auto domain = root->elemAs<ScheduleTreeElemDomain>()->domain_;
162+
auto domain = root->as<ScheduleTreeElemDomain>()->domain_;
163163
auto prefixMupa = prefixScheduleMupa(root, node);
164-
if (auto band = node->elemAs<ScheduleTreeElemBand>()) {
164+
if (auto band = node->as<ScheduleTreeElemBand>()) {
165165
prefixMupa = prefixMupa.flat_range_product(band->mupa_);
166166
}
167167

168168
auto pathToRoot = node->ancestors(root);
169169
pathToRoot.push_back(node);
170170
for (auto n : pathToRoot) {
171-
if (auto filterNode = n->elemAs<ScheduleTreeElemFilter>()) {
171+
if (auto filterNode = n->as<ScheduleTreeElemFilter>()) {
172172
domain = domain.intersect(filterNode->filter_);
173173
}
174174
}
@@ -308,7 +308,7 @@ isl::union_set collectMappingsTo(const Scop& scop) {
308308
mappingFilters = functional::Filter(isMappingTo<MappingType>, mappingFilters);
309309
auto mapping = isl::union_set::empty(domain.get_space());
310310
for (auto mf : mappingFilters) {
311-
auto filterNode = mf->elemAs<detail::ScheduleTreeElemMapping>();
311+
auto filterNode = mf->as<detail::ScheduleTreeElemMapping>();
312312
auto filter = filterNode->filter_.intersect(activeDomainPoints(root, mf));
313313
mapping = mapping.unite(filterNode->filter_);
314314
}
@@ -356,7 +356,7 @@ bool accessSubscriptsAreUnrolledLoops(
356356
auto leaves = functional::Filter(
357357
[](const ScheduleTree* tree) { return tree->numChildren() == 0; }, nodes);
358358

359-
auto domainNode = root->elemAs<detail::ScheduleTreeElemDomain>();
359+
auto domainNode = root->as<detail::ScheduleTreeElemDomain>();
360360
TC_CHECK(domainNode);
361361
auto domain = domainNode->domain_;
362362

@@ -368,7 +368,7 @@ bool accessSubscriptsAreUnrolledLoops(
368368

369369
auto unrolledDims = isl::union_pw_aff_list(leaf->ctx_, 1);
370370
for (auto node : ancestors) {
371-
auto band = node->elemAs<detail::ScheduleTreeElemBand>();
371+
auto band = node->as<detail::ScheduleTreeElemBand>();
372372
if (!band) {
373373
continue;
374374
}
@@ -455,7 +455,7 @@ std::vector<detail::ScheduleTree*> bandsContainingScheduleDepth(
455455
ScheduleTree::collectDFSPreorder(root, detail::ScheduleTreeType::Band);
456456
std::function<bool(ScheduleTree * st)> containsDepth = [&](ScheduleTree* st) {
457457
auto depthBefore = st->scheduleDepth(root);
458-
auto band = st->elemAs<ScheduleTreeElemBand>();
458+
auto band = st->as<ScheduleTreeElemBand>();
459459
auto depthAfter = depthBefore + band->nMember();
460460
return depthBefore < depth && depthAfter >= depth;
461461
};
@@ -474,7 +474,7 @@ std::vector<detail::ScheduleTree*> bandsSplitAfterDepth(
474474

475475
std::function<ScheduleTree*(ScheduleTree*)> splitAtDepth =
476476
[&](ScheduleTree* st) {
477-
auto nMember = st->elemAs<ScheduleTreeElemBand>()->nMember();
477+
auto nMember = st->as<ScheduleTreeElemBand>()->nMember();
478478
auto scheduleDepth = st->scheduleDepth(root);
479479
auto depthAfter = scheduleDepth + nMember;
480480
return depthAfter == depth ? st
@@ -641,16 +641,15 @@ void promoteGreedilyAtDepth(
641641
void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
642642
// Cannot promote below a sequence or a set node. Promotion may insert an
643643
// extension node, but sequence/set must be followed by filters.
644-
if (scope->elemAs<detail::ScheduleTreeElemSequence>() ||
645-
scope->elemAs<detail::ScheduleTreeElemSet>()) {
644+
if (scope->as<detail::ScheduleTreeElemSequence>() ||
645+
scope->as<detail::ScheduleTreeElemSet>()) {
646646
throw promotion::IncorrectScope("cannot promote under a sequence/set node");
647647
}
648648
// Cannot promote between a thread-mapped band and a thread-specific marker
649649
// node because the latter is used to identify thread-mapped bands as
650650
// immediate ancestors.
651651
if (scope->numChildren() == 1 &&
652-
scope->child({0})
653-
->elemAs<detail::ScheduleTreeElemThreadSpecificMarker>()) {
652+
scope->child({0})->as<detail::ScheduleTreeElemThreadSpecificMarker>()) {
654653
throw promotion::IncorrectScope(
655654
"cannot promote above a thread-specific marker node");
656655
}
@@ -757,7 +756,7 @@ void promoteToRegistersAtDepth(MappedScop& mscop, size_t depth) {
757756
auto bands = bandsContainingScheduleDepth(root, depth);
758757
bands = functional::Filter(
759758
[root, depth](ScheduleTree* tree) {
760-
auto band = tree->elemAs<ScheduleTreeElemBand>();
759+
auto band = tree->as<ScheduleTreeElemBand>();
761760
return !isThreadMappedBand(tree) ||
762761
tree->scheduleDepth(root) + band->nMember() == depth;
763762
},

tc/core/polyhedral/cuda/tighten_launch_bounds.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ size_t maxValue(const Scop& scop, const MappingIdType& id) {
7777
auto filters = root->collect(root, ScheduleTreeType::Mapping);
7878
filters = functional::Filter(isMappingTo<MappingIdType>, filters);
7979
for (auto p : filters) {
80-
auto mappingNode = p->elemAs<ScheduleTreeElemMapping>();
80+
auto mappingNode = p->as<ScheduleTreeElemMapping>();
8181
auto active = activeDomainPoints(root, p).intersect_params(params);
8282
active = active.intersect(mappingNode->filter_);
8383
auto range = rangeOfMappingParameter(active.params(), id);

tc/core/polyhedral/memory_promotion.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,8 +494,8 @@ ScheduleTree* insertCopiesUnder(
494494
auto writeBandNode = ScheduleTree::makeBand(writeSchedule);
495495

496496
if (unrollAllCopies) {
497-
unrollAllMembers(readBandNode->elemAs<detail::ScheduleTreeElemBand>());
498-
unrollAllMembers(writeBandNode->elemAs<detail::ScheduleTreeElemBand>());
497+
unrollAllMembers(readBandNode->as<detail::ScheduleTreeElemBand>());
498+
unrollAllMembers(writeBandNode->as<detail::ScheduleTreeElemBand>());
499499
}
500500

501501
auto extension =

tc/core/polyhedral/schedule_isl_conversion.cc

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ isl::schedule_node insertBranch(
4848
const std::vector<size_t>& pos) {
4949
auto filters = isl::union_set_list(node.get_ctx(), st->numChildren());
5050
for (size_t i = 0; i < pos.size(); ++i) {
51-
auto filter = st->child({pos[i]})->elemAs<ScheduleTreeElemFilter>();
51+
auto filter = st->child({pos[i]})->as<ScheduleTreeElemFilter>();
5252
TC_CHECK(filter);
5353
filters = filters.add(filter->filter_);
5454
}
55-
if (st->elemAs<ScheduleTreeElemSet>()) {
55+
if (st->as<ScheduleTreeElemSet>()) {
5656
node = node.insert_set(filters);
5757
} else {
5858
node = node.insert_sequence(filters);
@@ -83,9 +83,9 @@ std::vector<size_t> findCorePositions(
8383
const ScheduleTree* st,
8484
isl::union_set domain) {
8585
std::vector<size_t> positions;
86-
TC_CHECK(st->elemAs<ScheduleTreeElemSequence>());
86+
TC_CHECK(st->as<ScheduleTreeElemSequence>());
8787
for (size_t i = 0; i < st->numChildren(); ++i) {
88-
auto filter = st->child({i})->elemAs<ScheduleTreeElemFilter>();
88+
auto filter = st->child({i})->as<ScheduleTreeElemFilter>();
8989
TC_CHECK(filter);
9090
if (!filter->filter_.intersect(domain).is_empty()) {
9191
positions.emplace_back(i);
@@ -103,7 +103,7 @@ std::vector<size_t> findCorePositions(
103103
isl::schedule_node graftFromFilterSubtree(
104104
const ScheduleTree* st,
105105
isl::union_map extension) {
106-
auto filter = st->elemAs<ScheduleTreeElemFilter>();
106+
auto filter = st->as<ScheduleTreeElemFilter>();
107107
TC_CHECK(filter);
108108
auto filterExtension = extension.intersect_range(filter->filter_);
109109
auto extensionNode = isl::schedule_node::from_extension(filterExtension);
@@ -131,7 +131,7 @@ isl::schedule_node insertExtension(
131131
TC_CHECK(!corePos.empty());
132132
node = insertBranch(node, child, corePos);
133133

134-
auto extension = st->elemAs<ScheduleTreeElemExtension>()->extension_;
134+
auto extension = st->as<ScheduleTreeElemExtension>()->extension_;
135135
for (size_t i = 0; i < corePos.size() - 1; ++i) {
136136
auto depth0 = node.get_tree_depth();
137137
node = node.child(i).child(0);
@@ -162,7 +162,7 @@ isl::schedule_node insertExtension(
162162
* some extra functionality in isl.
163163
*/
164164
isl::schedule_node insert(isl::schedule_node node, const ScheduleTree* st) {
165-
if (auto band = st->elemAs<ScheduleTreeElemBand>()) {
165+
if (auto band = st->as<ScheduleTreeElemBand>()) {
166166
node = node.insert_partial_schedule(band->mupa_);
167167
auto bandNode = node.as<isl::schedule_node_band>();
168168
bandNode = bandNode.set_permutable(band->permutable_);
@@ -176,19 +176,18 @@ isl::schedule_node insert(isl::schedule_node node, const ScheduleTree* st) {
176176
}
177177
}
178178
node = bandNode;
179-
} else if (auto context = st->elemAs<ScheduleTreeElemContext>()) {
179+
} else if (auto context = st->as<ScheduleTreeElemContext>()) {
180180
node = node.insert_context(context->context_);
181-
} else if (auto filter = st->elemAs<ScheduleTreeElemFilter>()) {
181+
} else if (auto filter = st->as<ScheduleTreeElemFilter>()) {
182182
node = node.insert_filter(filter->filter_);
183-
} else if (auto filter = st->elemAs<ScheduleTreeElemMapping>()) {
183+
} else if (auto filter = st->as<ScheduleTreeElemMapping>()) {
184184
node = node.insert_filter(filter->filter_);
185185
} else if (
186-
st->elemAs<ScheduleTreeElemSet>() ||
187-
st->elemAs<ScheduleTreeElemSequence>()) {
186+
st->as<ScheduleTreeElemSet>() || st->as<ScheduleTreeElemSequence>()) {
188187
return insertBranch(node, st);
189-
} else if (st->elemAs<ScheduleTreeElemExtension>()) {
188+
} else if (st->as<ScheduleTreeElemExtension>()) {
190189
return insertExtension(node, st);
191-
} else if (st->elemAs<ScheduleTreeElemThreadSpecificMarker>()) {
190+
} else if (st->as<ScheduleTreeElemThreadSpecificMarker>()) {
192191
return insertChild(node, st);
193192
} else {
194193
LOG(FATAL) << "NYI: insert type: " << *st;
@@ -228,7 +227,7 @@ isl::schedule_node extendChild(
228227
* then recursively add nodes corresponding to the descendants of "root".
229228
*/
230229
isl::schedule toIslSchedule(const ScheduleTree* root) {
231-
auto domain = root->elemAs<ScheduleTreeElemDomain>();
230+
auto domain = root->as<ScheduleTreeElemDomain>();
232231
TC_CHECK(domain) << "Root node should be domain node" << *root;
233232
auto node = isl::schedule_node::from_domain(domain->domain_);
234233
node = extendChild(node, root);

0 commit comments

Comments
 (0)