Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit b596539

Browse files
Scope relevant polyhedral/cuda classes under cuda namespace
Similarly to the previous commit which scoped MappedScop under the cpu namespace, we do the same for relevant cuda classes.
1 parent 7ecfcc2 commit b596539

12 files changed

+77
-60
lines changed

tc/core/cuda/cuda_tc_executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ CudaCompilationResult CudaBackend::compileWithTcMapper(
8686

8787
// Now we can build stuff
8888
auto mappedScop =
89-
polyhedral::MappedScop::makeWithOuterBlockInnerThreadStrategy(
89+
polyhedral::cuda::MappedScop::makeWithOuterBlockInnerThreadStrategy(
9090
std::move(scop), options);
9191
LOG_IF(INFO, FLAGS_debug_tc_mapper) << "Mapped schedule:" << std::endl
9292
<< *(mappedScop->schedule());

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,17 @@
3131
#include "tc/core/polyhedral/memory_promotion.h"
3232
#include "tc/core/polyhedral/schedule_isl_conversion.h"
3333
#include "tc/core/polyhedral/schedule_transforms.h"
34+
#include "tc/core/polyhedral/scop.h"
3435

3536
using namespace std;
3637

38+
using tc::polyhedral::detail::ScheduleTreeContext;
39+
using tc::polyhedral::detail::ScheduleTreeDomain;
40+
using tc::polyhedral::detail::toIslSchedule;
41+
3742
namespace tc {
3843
namespace polyhedral {
44+
namespace cuda {
3945

4046
namespace {
4147

@@ -612,8 +618,7 @@ void emitHalideExpr(
612618
class EmitHalide : public Halide::Internal::IRPrinter {
613619
using Halide::Internal::IRPrinter::visit;
614620
void visit(const Halide::Internal::Variable* op) {
615-
auto pwAff = tc::polyhedral::detail::makeAffFromMappedExpr(
616-
Halide::Expr(op), context);
621+
auto pwAff = detail::makeAffFromMappedExpr(Halide::Expr(op), context);
617622
auto expr = context.build().expr_from(pwAff);
618623
auto s = expr.to_C_str();
619624
if (!is_identifier_or_nonnegative_integer(expr)) {
@@ -627,8 +632,7 @@ void emitHalideExpr(
627632
} else if (
628633
op->call_type == Halide::Internal::Call::CallType::Halide ||
629634
op->call_type == Halide::Internal::Call::CallType::Image) {
630-
tc::polyhedral::detail::emitMappedTensorAccess(
631-
op->name, op, op->args, context);
635+
detail::emitMappedTensorAccess(op->name, op, op->args, context);
632636
} else if (op->is_intrinsic(tc2halide::kReductionUpdate)) {
633637
op->args[0].accept(this);
634638
} else {
@@ -831,8 +835,8 @@ string emitCudaKernel(
831835
const std::string& specializedName,
832836
const MappedScop& mscop) {
833837
// Expecting a schedule with domain root and context first child.
834-
TC_CHECK(mscop.schedule()->as<detail::ScheduleTreeDomain>());
835-
TC_CHECK(mscop.schedule()->child({0})->as<detail::ScheduleTreeContext>());
838+
TC_CHECK(mscop.schedule()->as<ScheduleTreeDomain>());
839+
TC_CHECK(mscop.schedule()->child({0})->as<ScheduleTreeContext>());
836840
const auto& scop = mscop.scop();
837841

838842
// Make a map of the specialized scalar parameter values
@@ -876,7 +880,7 @@ string emitCudaKernel(
876880
return collectIteratorMaps(n, b, &nodeInfoMap);
877881
};
878882

879-
auto schedule = detail::toIslSchedule(mscop.schedule());
883+
auto schedule = toIslSchedule(mscop.schedule());
880884
auto astBuild = isl::ast_build(schedule.get_ctx());
881885
astBuild = astBuild.set_at_each_domain(collect);
882886
auto root = mscop.schedule();
@@ -890,5 +894,6 @@ string emitCudaKernel(
890894
return ss.str();
891895
}
892896

897+
} // namespace cuda
893898
} // namespace polyhedral
894899
} // namespace tc

tc/core/polyhedral/cuda/codegen.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
namespace tc {
2828
namespace polyhedral {
29+
namespace cuda {
2930

3031
struct CodegenContext;
3132
struct CodegenStatementContext;
@@ -146,5 +147,6 @@ std::string emitCudaKernel(
146147
const std::string& specializedName,
147148
const MappedScop& scop);
148149

150+
} // namespace cuda
149151
} // namespace polyhedral
150152
} // namespace tc

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 43 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,14 @@
4343

4444
#include <glog/logging.h>
4545

46+
using tc::polyhedral::detail::ScheduleTree;
47+
using tc::polyhedral::detail::ScheduleTreeBand;
48+
using tc::polyhedral::detail::ScheduleTreeContext;
49+
using tc::polyhedral::detail::ScheduleTreeSequence;
50+
4651
namespace tc {
4752
namespace polyhedral {
53+
namespace cuda {
4854

4955
namespace {
5056

@@ -77,7 +83,7 @@ static void checkMinimum(isl::union_set domain, isl::union_pw_aff_list list) {
7783
template <typename ExceptionType>
7884
inline void throwIfHasPattern(
7985
ScheduleTreeMatcher matcher,
80-
const detail::ScheduleTree* root) {
86+
const ScheduleTree* root) {
8187
auto candidates = match(matcher, root);
8288
if (candidates.size() > 0) {
8389
std::stringstream ss;
@@ -87,7 +93,7 @@ inline void throwIfHasPattern(
8793
}
8894
}
8995

90-
void validate(const detail::ScheduleTree* root) {
96+
void validate(const ScheduleTree* root) {
9197
throwIfHasPattern<EmptyFilterException>(
9298
filter(
9399
[](isl::union_set uset) { return !uset || uset.is_empty(); }, any()),
@@ -98,7 +104,7 @@ void validate(const detail::ScheduleTree* root) {
98104
root);
99105
}
100106

101-
bool anyNonCoincidentMember(const detail::ScheduleTreeBand* band) {
107+
bool anyNonCoincidentMember(const ScheduleTreeBand* band) {
102108
return band->nOuterCoincident() < band->nMember();
103109
}
104110

@@ -129,9 +135,7 @@ const CudaDim& mappingSize<mapping::ThreadId>(const MappedScop* mscop) {
129135
// Return a pointer to the updated node (below the inserted filter)
130136
// for call chaining purposes.
131137
template <typename MappingTypeId>
132-
detail::ScheduleTree* MappedScop::map(
133-
detail::ScheduleTree* tree,
134-
isl::union_pw_aff_list list) {
138+
ScheduleTree* MappedScop::map(ScheduleTree* tree, isl::union_pw_aff_list list) {
135139
size_t nToMap = list.size();
136140
const auto& extent = mappingSize<MappingTypeId>(this).view;
137141
TC_CHECK_LE(nToMap, extent.size()) << "dimension overflow";
@@ -160,16 +164,14 @@ detail::ScheduleTree* MappedScop::map(
160164

161165
checkMinimum(domain, affList);
162166

163-
auto mapping = detail::ScheduleTree::makeMapping(idList, affList);
167+
auto mapping = ScheduleTree::makeMapping(idList, affList);
164168
tree = insertNodeAbove(root, tree, std::move(mapping))->child({0});
165169

166170
return tree;
167171
}
168172

169-
detail::ScheduleTree* MappedScop::mapBlocksForward(
170-
detail::ScheduleTree* band,
171-
size_t nToMap) {
172-
auto bandNode = band->as<detail::ScheduleTreeBand>();
173+
ScheduleTree* MappedScop::mapBlocksForward(ScheduleTree* band, size_t nToMap) {
174+
auto bandNode = band->as<ScheduleTreeBand>();
173175
TC_CHECK(bandNode) << "expected a band, got " << *band;
174176

175177
auto list = bandNode->mupa_.get_union_pw_aff_list();
@@ -180,7 +182,7 @@ detail::ScheduleTree* MappedScop::mapBlocksForward(
180182
// Uses as many blockSizes elements as outer coincident dimensions in the
181183
// outermost band
182184
void MappedScop::mapToBlocksAndScaleBand(
183-
detail::ScheduleTree* band,
185+
ScheduleTree* band,
184186
std::vector<size_t> tileSizes) {
185187
using namespace tc::polyhedral::detail;
186188

@@ -205,16 +207,13 @@ namespace {
205207
* the remaining thread identifiers starting at "begin" to zero.
206208
* Add a marker underneath that marks the subtree that is thread specific.
207209
*/
208-
void fixThreadsBelow(
209-
MappedScop& mscop,
210-
detail::ScheduleTree* tree,
211-
size_t begin) {
210+
void fixThreadsBelow(MappedScop& mscop, ScheduleTree* tree, size_t begin) {
212211
size_t end = mscop.numThreads.view.size();
213212
if (begin == end) {
214213
return;
215214
}
216215

217-
auto band = detail::ScheduleTree::makeEmptyBand(mscop.scop().scheduleRoot());
216+
auto band = ScheduleTree::makeEmptyBand(mscop.scop().scheduleRoot());
218217
auto bandTree = insertNodeBelow(tree, std::move(band));
219218
mscop.mapThreadsBackward(bandTree);
220219
}
@@ -226,10 +225,7 @@ void fixThreadsBelow(
226225
* Anything that depends on an update statement is ordered after
227226
* the update statements. Anything else is ordered before.
228227
*/
229-
bool separatedOut(
230-
Scop& scop,
231-
detail::ScheduleTree* tree,
232-
isl::union_set updates) {
228+
bool separatedOut(Scop& scop, ScheduleTree* tree, isl::union_set updates) {
233229
auto domain = activeDomainPoints(scop.scheduleRoot(), tree);
234230
auto other = domain.subtract(updates);
235231
if (other.is_empty()) {
@@ -254,7 +250,7 @@ bool separatedOut(
254250

255251
} // namespace
256252

257-
bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
253+
bool MappedScop::detectReductions(ScheduleTree* tree) {
258254
// Do not bother with reductions if block is of size 1 in the x direction.
259255
if (numThreads.view.size() == 0 || numThreads.view[0] == 1) {
260256
return false;
@@ -264,7 +260,7 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
264260
for (auto c : tree->children()) {
265261
found |= detectReductions(c);
266262
}
267-
auto band = tree->as<detail::ScheduleTreeBand>();
263+
auto band = tree->as<ScheduleTreeBand>();
268264
// Nested reductions are not currently supported.
269265
if (!band || found) {
270266
return found;
@@ -314,7 +310,7 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
314310
return true;
315311
}
316312

317-
bool MappedScop::needReductionSeparation(const detail::ScheduleTree* st) {
313+
bool MappedScop::needReductionSeparation(const ScheduleTree* st) {
318314
if (reductionBandUpdates_.count(st) != 1) {
319315
return false;
320316
}
@@ -323,9 +319,9 @@ bool MappedScop::needReductionSeparation(const detail::ScheduleTree* st) {
323319
}
324320

325321
isl::multi_union_pw_aff MappedScop::reductionMapSchedule(
326-
const detail::ScheduleTree* st) {
322+
const ScheduleTree* st) {
327323
TC_CHECK(reductionBandUpdates_.count(st) == 1);
328-
auto reductionBand = st->as<detail::ScheduleTreeBand>();
324+
auto reductionBand = st->as<ScheduleTreeBand>();
329325
TC_CHECK(reductionBand);
330326

331327
auto nMember = reductionBand->nMember();
@@ -337,7 +333,7 @@ isl::multi_union_pw_aff MappedScop::reductionMapSchedule(
337333
return reductionBand->memberRange(first, nMappedThreads);
338334
}
339335

340-
detail::ScheduleTree* MappedScop::separateReduction(detail::ScheduleTree* st) {
336+
ScheduleTree* MappedScop::separateReduction(ScheduleTree* st) {
341337
auto reduction = st;
342338
// This function either separates full blocks (if needed) or
343339
// disables the reduction handling.
@@ -386,23 +382,22 @@ detail::ScheduleTree* MappedScop::separateReduction(detail::ScheduleTree* st) {
386382
return st->ancestor(root, 2);
387383
}
388384

389-
detail::ScheduleTree* MappedScop::mapThreadsBackward(
390-
detail::ScheduleTree* band) {
391-
auto bandNode = band->as<detail::ScheduleTreeBand>();
385+
ScheduleTree* MappedScop::mapThreadsBackward(ScheduleTree* band) {
386+
auto bandNode = band->as<ScheduleTreeBand>();
392387
TC_CHECK(bandNode);
393388
auto nMember = bandNode->nMember();
394389
auto nToMap = std::min(nMember, numThreads.view.size());
395390
TC_CHECK_LE(nToMap, 3u) << "mapping to too many threads";
396391

397392
auto ctx = band->ctx_;
398-
insertNodeBelow(band, detail::ScheduleTree::makeThreadSpecificMarker(ctx));
393+
insertNodeBelow(band, ScheduleTree::makeThreadSpecificMarker(ctx));
399394

400395
auto list = bandNode->mupa_.get_union_pw_aff_list().reverse();
401396
list = list.drop(nToMap, list.size() - nToMap);
402397
return map<mapping::ThreadId>(band, list);
403398
}
404399

405-
size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
400+
size_t MappedScop::mapToThreads(ScheduleTree* band) {
406401
using namespace tc::polyhedral::detail;
407402

408403
auto bandNode = band->as<ScheduleTreeBand>();
@@ -473,17 +468,15 @@ namespace {
473468
* That is, assuming "st" is a sequence node, does the last child
474469
* need to be protected from the next iteration of the first child?
475470
*/
476-
bool hasOuterSequentialMember(
477-
const detail::ScheduleTree* root,
478-
detail::ScheduleTree* st) {
471+
bool hasOuterSequentialMember(const ScheduleTree* root, ScheduleTree* st) {
479472
auto ancestors = st->ancestors(root);
480473
std::reverse(ancestors.begin(), ancestors.end());
481474
for (auto a : ancestors) {
482-
auto band = a->as<detail::ScheduleTreeBand>();
475+
auto band = a->as<ScheduleTreeBand>();
483476
if (band && band->nMember() > band->nOuterCoincident()) {
484477
return true;
485478
}
486-
if (a->as<detail::ScheduleTreeSequence>()) {
479+
if (a->as<ScheduleTreeSequence>()) {
487480
return false;
488481
}
489482
}
@@ -532,7 +525,7 @@ isl::multi_aff constructThreadToWarp(
532525
} // namespace
533526

534527
isl::multi_union_pw_aff MappedScop::threadMappingSchedule(
535-
const detail::ScheduleTree* tree) const {
528+
const ScheduleTree* tree) const {
536529
std::vector<mapping::MappingId> ids;
537530
for (size_t i = 0; i < numThreads.view.size(); ++i) {
538531
ids.emplace_back(mapping::ThreadId::makeId(i));
@@ -542,7 +535,7 @@ isl::multi_union_pw_aff MappedScop::threadMappingSchedule(
542535
}
543536

544537
isl::multi_union_pw_aff MappedScop::blockMappingSchedule(
545-
const detail::ScheduleTree* tree) const {
538+
const ScheduleTree* tree) const {
546539
std::vector<mapping::MappingId> ids;
547540
for (size_t i = 0; i < numBlocks.view.size(); ++i) {
548541
ids.emplace_back(mapping::BlockId::makeId(i));
@@ -552,8 +545,8 @@ isl::multi_union_pw_aff MappedScop::blockMappingSchedule(
552545
}
553546

554547
Scop::SyncLevel MappedScop::findBestSync(
555-
detail::ScheduleTree* st1,
556-
detail::ScheduleTree* st2,
548+
ScheduleTree* st1,
549+
ScheduleTree* st2,
557550
isl::multi_union_pw_aff domainToThread,
558551
isl::multi_union_pw_aff domainToWarp) {
559552
// Active points in the two schedule trees
@@ -571,7 +564,7 @@ Scop::SyncLevel MappedScop::findBestSync(
571564

572565
TC_CHECK_LE(1u, scop_->scheduleRoot()->children().size());
573566
auto contextSt = scop_->scheduleRoot()->children()[0];
574-
auto contextElem = contextSt->as<detail::ScheduleTreeContext>();
567+
auto contextElem = contextSt->as<ScheduleTreeContext>();
575568
TC_CHECK(nullptr != contextElem);
576569
dependences = dependences.intersect_params(contextElem->context_);
577570

@@ -733,8 +726,8 @@ std::vector<std::pair<int, int>> MappedScop::findBestSyncConfigInSeq(
733726
return solutionWithBestBegining;
734727
}
735728

736-
void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) {
737-
TC_CHECK(seq->as<detail::ScheduleTreeSequence>());
729+
void MappedScop::insertBestSyncInSeq(ScheduleTree* seq) {
730+
TC_CHECK(seq->as<ScheduleTreeSequence>());
738731

739732
auto children = seq->children();
740733
auto nChildren = children.size();
@@ -796,7 +789,7 @@ void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) {
796789
// the next iteration of the first child if there may be such
797790
// a next iteration that is not already covered by synchronization
798791
// on an outer node.
799-
size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
792+
size_t MappedScop::mapInnermostBandsToThreads(ScheduleTree* st) {
800793
if (needReductionSeparation(st)) {
801794
st = separateReduction(st);
802795
}
@@ -808,7 +801,7 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
808801
}
809802
auto n = nChildren > 0 ? *std::max_element(nInner.begin(), nInner.end()) : 0;
810803
if (nChildren > 1) {
811-
auto needSync = st->as<detail::ScheduleTreeSequence>() && n > 0;
804+
auto needSync = st->as<ScheduleTreeSequence>() && n > 0;
812805
if (n > 0) {
813806
for (size_t i = 0; i < nChildren; ++i) {
814807
fixThreadsBelow(*this, children[i], nInner[i]);
@@ -819,7 +812,7 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
819812
}
820813
}
821814

822-
if (auto band = st->as<detail::ScheduleTreeBand>()) {
815+
if (auto band = st->as<ScheduleTreeBand>()) {
823816
if (n == 0) {
824817
// If children were not mapped to threads, the current band can be mapped.
825818
// First, map the coincidence and reduction dimension to threads.
@@ -957,8 +950,8 @@ std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
957950
// Split out a single reduction tile (in the directions other than
958951
// the reduction) and insert reduction synchronizations outside this tile.
959952
// Return a pointer to the split off tile.
960-
detail::ScheduleTree* MappedScop::splitOutReductionTileAndInsertSyncs(
961-
detail::ScheduleTree* band) {
953+
ScheduleTree* MappedScop::splitOutReductionTileAndInsertSyncs(
954+
ScheduleTree* band) {
962955
using namespace polyhedral::detail;
963956
size_t n = numThreads.view.size();
964957

@@ -1103,5 +1096,6 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
11031096
return mappedScop;
11041097
}
11051098

1099+
} // namespace cuda
11061100
} // namespace polyhedral
11071101
} // namespace tc

0 commit comments

Comments
 (0)