Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit f8d91dd

Browse files
Merge pull request #149 from facebookresearch/register-promotion
promotion to registers (aka private memory)
2 parents bb6dca4 + 3259e76 commit f8d91dd

15 files changed

+407
-107
lines changed

include/tc/core/polyhedral/codegen_cuda.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,18 @@ struct CodegenStatementContext : CodegenContext {
9696
isl::id statementId() const {
9797
return this->iteratorMaps.at(astNodeId).get_tuple_id(isl::dim_type::out);
9898
}
99+
isl::set domain() const {
100+
return isl::map::from(this->iteratorMaps.at(astNodeId)).range();
101+
}
99102
std::vector<Scop::PromotionInfo> activePromotions() const {
100-
auto stmtId = statementId();
101-
const auto& promotions = this->scop().activePromotions();
102-
if (promotions.count(stmtId) == 0) {
103-
return {};
103+
std::vector<Scop::PromotionInfo> result;
104+
auto dom = isl::union_set(this->domain());
105+
for (const auto& kvp : this->scop().activePromotions()) {
106+
if (!kvp.first.intersect(dom).is_empty()) {
107+
result.emplace_back(kvp.second);
108+
}
104109
}
105-
return promotions.at(stmtId);
110+
return result;
106111
}
107112

108113
isl::id astNodeId;

include/tc/core/polyhedral/memory_promotion.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,9 @@ inline std::ostream& operator<<(std::ostream& os, const TensorReference& tr) {
179179
inline std::ostream& operator<<(
180180
std::ostream& os,
181181
const TensorReferenceGroup& tg) {
182-
os << " with footprint BB: " << tg.approximation << " ";
182+
os << "Reference with footprint: " << tg.approximation << "\n";
183183
for (const auto& tr : tg.references) {
184-
os << *tr << " ";
184+
os << *tr << "\n";
185185
}
186186
return os;
187187
}

include/tc/core/polyhedral/memory_promotion_heuristic.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ using ThreadIdxxScheduleDepthState =
2626
std::vector<std::pair<isl::union_set, size_t>>;
2727

2828
class MappedScop;
29+
class Scop;
2930

3031
// In the given mapped scop "mscop",
3132
// promote to shared memory at "depth" until "sharedMemorySize" is used.
@@ -40,5 +41,10 @@ void promoteGreedilyAtDepth(
4041
std::size_t depth,
4142
std::size_t sharedMemorySize,
4243
bool unrollCopies);
44+
45+
void promoteToRegistersBelowThreads(
46+
Scop& scop,
47+
const ThreadIdxxScheduleDepthState& threadIdxxScheduleDepthState,
48+
std::size_t nRegisters);
4349
} // namespace polyhedral
4450
} // namespace tc

include/tc/core/polyhedral/schedule_transforms.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -284,13 +284,6 @@ isl::union_set activeDomainPoints(
284284
const detail::ScheduleTree* root,
285285
const detail::ScheduleTree* node);
286286

287-
// Get the set of statement identifiers whose domains have at least one active
288-
// point at the given node, i.e. the statements that were not filtered away on
289-
// the path from root to node.
290-
std::unordered_set<isl::id, isl::IslIdIslHash> activeStatements(
291-
const detail::ScheduleTree* root,
292-
const detail::ScheduleTree* node);
293-
294287
////////////////////////////////////////////////////////////////////////////////
295288
// Experimental
296289
////////////////////////////////////////////////////////////////////////////////

include/tc/core/polyhedral/scop.h

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,11 @@ struct Scop {
306306
void promoteEverythingAt(std::vector<size_t> pos);
307307

308308
struct PromotedDecl {
309+
enum class Kind { SharedMem, Register };
310+
309311
isl::id tensorId;
310312
std::vector<size_t> sizes;
313+
Kind kind;
311314
};
312315

313316
struct PromotionInfo {
@@ -321,9 +324,8 @@ struct Scop {
321324
return promotedDecls_;
322325
}
323326

324-
const std::
325-
unordered_map<isl::id, std::vector<PromotionInfo>, isl::IslIdIslHash>&
326-
activePromotions() const {
327+
const std::vector<std::pair<isl::union_set, PromotionInfo>>&
328+
activePromotions() const {
327329
return activePromotions_;
328330
}
329331

@@ -356,7 +358,8 @@ struct Scop {
356358
// Assumes such argument exists.
357359
const Halide::OutputImageParam& findArgument(isl::id id) const;
358360

359-
// Promote a tensor reference group to shared memory, inserting the copy
361+
// Promote a tensor reference group to a storage of a given "kind",
362+
// inserting the copy
360363
// statements below the given node. Inserts an Extension node below the give
361364
// node, unless there is already another Extension node which introduces
362365
// copies. The Extension node has a unique Sequence child, whose children
@@ -368,11 +371,11 @@ struct Scop {
368371
// If "forceLastExtentOdd" is set, the last extent in the declaration is
369372
// incremented if it is even. This serves as a simple heuristic to reduce
370373
// shared memory bank conflicts.
371-
void promoteGroupToShared(
374+
void promoteGroup(
375+
PromotedDecl::Kind kind,
372376
isl::id tensorId,
373377
std::unique_ptr<TensorReferenceGroup>&& gr,
374378
detail::ScheduleTree* tree,
375-
const std::unordered_set<isl::id, isl::IslIdIslHash>& activeStmts,
376379
isl::union_map schedule,
377380
bool forceLastExtentOdd = false);
378381

@@ -463,9 +466,10 @@ struct Scop {
463466
std::unordered_map<isl::id, size_t, isl::IslIdIslHash> groupCounts_;
464467
// groupId -> (tensorId, groupSizes)
465468
std::unordered_map<isl::id, PromotedDecl, isl::IslIdIslHash> promotedDecls_;
466-
// stmtId -> (group, partial schedule, groupId)
467-
std::unordered_map<isl::id, std::vector<PromotionInfo>, isl::IslIdIslHash>
468-
activePromotions_;
469+
// (domain, group, partial schedule, groupId)
470+
// Note that domain is a non-unique key, i.e. multiple groups can be listed
471+
// for the same domain, or for partially intersecting domains.
472+
std::vector<std::pair<isl::union_set, PromotionInfo>> activePromotions_;
469473
};
470474

471475
std::ostream& operator<<(std::ostream& os, const Scop&);

include/tc/external/detail/islpp.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,10 @@ inline bool operator==(const isl::id& id1, const isl::id& id2) {
268268
return id1.get() == id2.get();
269269
}
270270

271+
inline bool operator!=(const isl::id& id1, const isl::id& id2) {
272+
return id1.get() != id2.get();
273+
}
274+
271275
///////////////////////////////////////////////////////////////////////////////
272276
// Helper functions
273277
///////////////////////////////////////////////////////////////////////////////
@@ -399,6 +403,21 @@ auto end(L& list) -> ListIter<decltype(list.get(0)), L> {
399403
using detail::begin;
400404
using detail::end;
401405

406+
template <typename T>
407+
isl::val getParamValIfFixed(T t, int pos) {
408+
auto val = isl::val::nan(t.get_ctx());
409+
for (auto set : isl::UnionAsVector<T>(t)) {
410+
auto currentVal = set.plain_get_val_if_fixed(isl::dim_type::param, pos);
411+
if (currentVal.is_nan()) {
412+
return currentVal;
413+
}
414+
if (!val.is_nan() && val != currentVal) {
415+
return isl::val::nan(t.get_ctx());
416+
}
417+
val = currentVal;
418+
}
419+
return val;
420+
}
402421
} // namespace isl
403422

404423
namespace isl {

src/core/polyhedral/codegen_cuda.cc

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -501,17 +501,6 @@ isl::space findDomainSpaceById(const CodegenStatementContext& context) {
501501
return isl::space();
502502
}
503503

504-
isl::map findScheduleByStmtId(isl::union_map schedule, isl::id stmtId) {
505-
for (auto s : isl::UnionAsVector<isl::union_map>(schedule)) {
506-
if (s.get_tuple_id(isl::dim_type::in) == stmtId) {
507-
return s;
508-
}
509-
}
510-
CHECK(false) << "could not find schedule for " << stmtId << " in "
511-
<< schedule;
512-
return isl::map();
513-
}
514-
515504
isl::multi_aff makeMultiAffAccess(
516505
isl::id tensorId,
517506
const std::vector<Halide::Expr>& subscripts,
@@ -631,9 +620,9 @@ void emitMappedTensorAccess(
631620
auto promotion = promotionInfo.group->promotion(); // MA :: [S -> O] -> P
632621
promotion = promotion.set_tuple_id(isl::dim_type::out, promotionInfo.groupId);
633622
auto iteratorMap = context.iteratorMap(); // PMA :: A -> D
634-
auto schedule = findScheduleByStmtId(
635-
promotionInfo.outerSchedule,
636-
context.statementId()); // map :: D -> S
623+
auto schedule =
624+
isl::map::from_union_map(promotionInfo.outerSchedule.intersect_domain(
625+
context.domain())); // map :: D -> S
637626

638627
CHECK(schedule.is_single_valued())
639628
<< "expected single-valued schedule, got " << schedule;
@@ -705,7 +694,10 @@ void emitPromotedArrayViewsHalide(stringstream& ss, const Scop& scop) {
705694
t = i.type();
706695
}
707696
}
708-
ss << "__shared__ " << t << " " << viewName;
697+
if (p.second.kind == Scop::PromotedDecl::Kind::SharedMem) {
698+
ss << "__shared__ ";
699+
}
700+
ss << t << " " << viewName;
709701
for (auto s : p.second.sizes) {
710702
ss << "[" << s << "]";
711703
}

src/core/polyhedral/mapped_scop.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,11 @@ void fixThreadsBelowFilter(
176176

177177
for (size_t i = begin; i < end; ++i) {
178178
if (mapping::ThreadId::makeId(i) == mapping::ThreadId::x()) {
179+
// Mapping happend below filterTree, so we need points active for its
180+
// children. After insertion, filterTree is guaranteed to have at least
181+
// one child.
179182
mscop.threadIdxxScheduleDepthState.emplace_back(std::make_pair(
180-
activeDomainPoints(mscop.schedule(), filterTree),
183+
activeDomainPoints(mscop.schedule(), filterTree->child({0})),
181184
filterTree->scheduleDepth(mscop.schedule())));
182185
}
183186
}
@@ -686,10 +689,16 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
686689
}
687690
}
688691

689-
// 8. Insert mapping context
692+
// 8. Promote to registers below the loops mapped to threads.
693+
if (options.proto.use_private_memory()) {
694+
promoteToRegistersBelowThreads(
695+
mappedScop->scop(), mappedScop->threadIdxxScheduleDepthState, -1ull);
696+
}
697+
698+
// 9. Insert mapping context
690699
mappedScop->insertMappingContext();
691700

692-
// 9. Optionally insert reduction synchronizations
701+
// 10. Optionally insert reduction synchronizations
693702
for (auto bandUpdate : mappedScop->reductionBandUpdates_) {
694703
for (auto updateId : bandUpdate.second.ids) {
695704
scop->insertReductionSync1D(

src/core/polyhedral/memory_promotion.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,10 +321,21 @@ void addSingletonReferenceGroups(
321321
// access relations have a shape :: [D -> ref] -> O
322322
// use currying to isolate the D part before intersecting with the domain
323323
// Compute initial groups with single reference per group.
324-
accesses = accesses.curry().intersect_domain(domain).uncurry();
324+
std::unordered_set<isl::id, isl::IslIdIslHash> unapproximatable;
325325
for (auto a : isl::UnionAsVector<isl::union_map>(accesses)) {
326+
if (isl::union_map(a.curry()).intersect_domain(domain).is_empty()) {
327+
continue;
328+
}
329+
326330
auto tensorId = a.get_tuple_id(isl::dim_type::out);
327-
addSingletonReferenceGroup(tensorGroups, tensorId, schedule, a, type);
331+
if (unapproximatable.count(tensorId) != 0) {
332+
continue;
333+
}
334+
try {
335+
addSingletonReferenceGroup(tensorGroups, tensorId, schedule, a, type);
336+
} catch (const promotion::GroupingError& err) {
337+
unapproximatable.insert(tensorId);
338+
}
328339
}
329340
}
330341
} // namespace

0 commit comments

Comments
 (0)