Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 41279c8

Browse files
Merge pull request #518 from facebookresearch/pr/filter
no longer consider ScheduleTreeElemMappingFilter to be a subclass of filter
2 parents 2e9b79e + 3273fc8 commit 41279c8

22 files changed

+206
-273
lines changed

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ void validate(const detail::ScheduleTree* root) {
6565
filter(
6666
[](isl::union_set uset) { return !uset || uset.is_empty(); }, any()),
6767
root);
68-
throwIfHasPattern<EmptyMappingFilterException>(
68+
throwIfHasPattern<EmptyMappingException>(
6969
mapping_filter(
7070
[](isl::union_set uset) { return !uset || uset.is_empty(); }, any()),
7171
root);
@@ -130,7 +130,7 @@ detail::ScheduleTree* MappedScop::map(
130130
idList.emplace_back(id);
131131
}
132132

133-
auto mapping = detail::ScheduleTree::makeMappingFilter(idList, affList);
133+
auto mapping = detail::ScheduleTree::makeMapping(idList, affList);
134134
tree = insertNodeAbove(root, tree, std::move(mapping))->child({0});
135135

136136
return tree;

tc/core/polyhedral/cuda/mapped_scop.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,13 @@ class MappedScop {
190190
// to the thread identifiers, where all branches in "tree"
191191
// are assumed to have been mapped to thread identifiers.
192192
// The result lives in a space of the form block[x, ...].
193-
//
194-
// Note: this function ignores statements introduced by extension nodes.
195193
isl::multi_union_pw_aff threadMappingSchedule(
196194
const detail::ScheduleTree* tree) const;
197195

198196
// Extract a mapping from the domain elements active at "tree"
199197
// to the block identifiers, where all branches in "tree"
200198
// are assumed to have been mapped to block identifiers.
201199
// The result lives in a space of the form grid[x, ...].
202-
//
203-
// Note: this function ignores statements introduced by extension nodes.
204200
isl::multi_union_pw_aff blockMappingSchedule(
205201
const detail::ScheduleTree* tree) const;
206202

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 7 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -34,37 +34,6 @@ namespace tc {
3434
namespace polyhedral {
3535
namespace {
3636

37-
/*
38-
* Is "id" a mapping of the type provided as template argument?
39-
*/
40-
template <typename MappingType>
41-
bool isMappingIdType(const mapping::MappingId& id) {
42-
for (size_t i = 0; i < MappingType::kMaxDim; ++i) {
43-
if (id == MappingType::makeId(i)) {
44-
return true;
45-
}
46-
}
47-
return false;
48-
}
49-
50-
/*
51-
* Is "tree" a mapping filter that maps identifiers of the type provided as
52-
* template argument?
53-
*/
54-
template <typename MappingType>
55-
bool isMappingTo(const detail::ScheduleTree* tree) {
56-
using namespace detail;
57-
58-
if (auto filterNode = tree->elemAs<ScheduleTreeElemMappingFilter>()) {
59-
for (auto& kvp : filterNode->mapping) {
60-
if (isMappingIdType<MappingType>(kvp.first)) {
61-
return true;
62-
}
63-
}
64-
}
65-
return false;
66-
}
67-
6837
// Map global<->shared copy bands to threads, starting from the innermost
6938
// loop as it iterates over the last subscript and will result in coalescing.
7039
void mapCopiesToThreads(MappedScop& mscop, bool unroll) {
@@ -258,10 +227,12 @@ isl::map makeNextElementMap(isl::space setSpace, unsigned dim) {
258227
const detail::ScheduleTree* findThreadMappingAncestor(
259228
const detail::ScheduleTree* root,
260229
const detail::ScheduleTree* node) {
230+
using namespace tc::polyhedral::detail;
231+
261232
auto ancestors = node->ancestors(root);
262233
ancestors = functional::Filter(isMappingTo<mapping::ThreadId>, ancestors);
263234
if (ancestors.size() < 1) {
264-
throw promotion::PromotionLogicError("missing MappingFilter");
235+
throw promotion::PromotionLogicError("missing Mapping");
265236
}
266237
return ancestors[0];
267238
}
@@ -325,22 +296,18 @@ bool promotionImprovesCoalescing(
325296

326297
/*
327298
* Returns the union of all mapping filters to "MappingType" in "scop".
328-
*
329-
* Note: similarly to MappedScop::[thread|block]MappingSchedule, this function
330-
* does not take into account elements introduced by extension nodes.
331299
*/
332300
template <typename MappingType>
333301
isl::union_set collectMappingsTo(const Scop& scop) {
334302
auto root = scop.scheduleRoot();
335303
auto domain = scop.domain();
336-
auto mappingFilters = detail::ScheduleTree::collect(
337-
root, detail::ScheduleTreeType::MappingFilter);
304+
auto mappingFilters =
305+
detail::ScheduleTree::collect(root, detail::ScheduleTreeType::Mapping);
338306
mappingFilters = functional::Filter(isMappingTo<MappingType>, mappingFilters);
339307
auto mapping = isl::union_set::empty(domain.get_space());
340308
for (auto mf : mappingFilters) {
341-
auto filterNode = mf->elemAs<detail::ScheduleTreeElemMappingFilter>();
342-
auto filter = filterNode->filter_.intersect(
343-
activeDomainPointsNoMappingNoExtension(root, mf));
309+
auto filterNode = mf->elemAs<detail::ScheduleTreeElemMapping>();
310+
auto filter = filterNode->filter_.intersect(activeDomainPoints(root, mf));
344311
mapping = mapping.unite(filterNode->filter_);
345312
}
346313
return mapping;

tc/core/polyhedral/cuda/tighten_launch_bounds.cc

Lines changed: 57 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,10 @@
2626
namespace tc {
2727
namespace polyhedral {
2828
namespace {
29-
// This returns the (inclusive) range of the mapping parameter that is active
30-
// at node under root given:
31-
// 1. a context that is the intersection of the specialization context and
32-
// the mapping context
33-
// 2. a MappingId
34-
// This range corresponds to the blocks/threads active at that particular
35-
// location in the tree.
29+
// This returns the (inclusive) range of the mapping parameter "mappingId"
30+
// within the context "mappingContext".
31+
// This range corresponds to the blocks/threads active at the particular
32+
// location in the tree where this mapping is active.
3633
//
3734
// This is used to tighten the kernel to only launch on the necessary amount
3835
// of resources.
@@ -43,23 +40,20 @@ namespace {
4340
// Otherwise, the range is asserted bounded on the left and to lie in the
4441
// positive half of the integer axis.
4542
std::pair<size_t, size_t> rangeOfMappingParameter(
46-
const detail::ScheduleTree* root,
47-
const detail::ScheduleTree* node,
48-
isl::set context,
43+
isl::set mappingContext,
4944
mapping::MappingId mappingId) {
50-
auto active =
51-
activeDomainPoints(root, node).intersect_params(context).params();
52-
if (!active.involves_param(mappingId)) {
45+
if (!mappingContext.involves_param(mappingId)) {
5346
return std::make_pair(0, std::numeric_limits<size_t>::max());
5447
}
55-
isl::aff a(isl::aff::param_on_domain_space(active.get_space(), mappingId));
56-
auto max = active.max_val(a);
48+
auto space = mappingContext.get_space();
49+
isl::aff a(isl::aff::param_on_domain_space(space, mappingId));
50+
auto max = mappingContext.max_val(a);
5751
if (max.is_nan() || max.is_infty()) {
5852
return std::make_pair(0, std::numeric_limits<size_t>::max());
5953
}
6054
TC_CHECK(max.is_int()) << max.to_str();
6155
TC_CHECK(max.is_nonneg()) << max.to_str();
62-
auto min = active.min_val(a);
56+
auto min = mappingContext.min_val(a);
6357
TC_CHECK(min.is_int()) << max.to_str();
6458
TC_CHECK(min.is_nonneg()) << max.to_str();
6559

@@ -68,13 +62,52 @@ std::pair<size_t, size_t> rangeOfMappingParameter(
6862
static_cast<size_t>(max.get_num_si()));
6963
}
7064

71-
// Look for nodes with no children.
72-
inline std::vector<const detail::ScheduleTree*> leaves(
73-
const detail::ScheduleTree* tree) {
74-
return functional::Filter(
75-
[](const detail::ScheduleTree* st) { return st->numChildren() == 0; },
76-
detail::ScheduleTree::collect(tree));
65+
/*
66+
* Compute the maximal value attained by the mapping parameter "id".
67+
*/
68+
template <typename MappingIdType>
69+
size_t maxValue(const Scop& scop, const MappingIdType& id) {
70+
using namespace polyhedral::detail;
71+
72+
auto root = scop.scheduleRoot();
73+
auto params = scop.context();
74+
size_t sizetMax = std::numeric_limits<size_t>::max();
75+
size_t max = 0;
76+
size_t min = sizetMax;
77+
auto filters = root->collect(root, ScheduleTreeType::Mapping);
78+
filters = functional::Filter(isMappingTo<MappingIdType>, filters);
79+
for (auto p : filters) {
80+
auto mappingNode = p->elemAs<ScheduleTreeElemMapping>();
81+
auto active = activeDomainPoints(root, p).intersect_params(params);
82+
active = active.intersect(mappingNode->filter_);
83+
auto range = rangeOfMappingParameter(active.params(), id);
84+
min = std::min(min, range.first);
85+
max = std::max(max, range.second);
86+
}
87+
// Ignore min for now but there is a future possibility for shifting
88+
LOG_IF(WARNING, min > 0)
89+
<< "Opportunity for tightening launch bounds with shifting -> min:"
90+
<< min;
91+
TC_CHECK(max < sizetMax) << "missing mapping to " << id << *root;
92+
// Inclusive range needs + 1 to translate to sizes
93+
return max + 1;
94+
}
95+
96+
/*
97+
* Take grid or block launch bounds "size" and replace them
98+
* by the tightened, actual, launch bounds used in practice.
99+
*/
100+
template <typename MappingIdType, typename Size>
101+
Size launchBounds(const Scop& scop, Size size) {
102+
Size tightened;
103+
104+
for (size_t i = 0; i < size.view.size(); ++i) {
105+
tightened.view[i] = maxValue(scop, MappingIdType::makeId(i));
106+
}
107+
108+
return tightened;
77109
}
110+
78111
} // namespace
79112

80113
// Takes grid/block launch bounds that have been passed to mapping and
@@ -84,56 +117,9 @@ std::pair<tc::Grid, tc::Block> tightenLaunchBounds(
84117
const Scop& scop,
85118
const tc::Grid& grid,
86119
const tc::Block& block) {
87-
auto root = scop.scheduleRoot();
88-
auto params = scop.context();
89-
90-
auto max = [root, params](const mapping::MappingId& id) -> size_t {
91-
size_t sizetMax = std::numeric_limits<size_t>::max();
92-
size_t max = 0;
93-
size_t min = sizetMax;
94-
auto nonSyncLeaves = functional::Filter(
95-
[root, params](const detail::ScheduleTree* node) {
96-
auto f = node->elemAsBase<detail::ScheduleTreeElemFilter>();
97-
if (!f) {
98-
return true;
99-
}
100-
if (f->filter_.n_set() != 1) {
101-
std::stringstream ss;
102-
ss << "In tree:\n"
103-
<< *root << "\nnot a single set in filter: " << f->filter_;
104-
throw tightening::TighteningException(ss.str());
105-
}
106-
auto single = isl::set::from_union_set(f->filter_);
107-
auto single_id = single.get_tuple_id();
108-
return !Scop::isSyncId(single_id) && !Scop::isWarpSyncId(single_id);
109-
},
110-
leaves(root));
111-
for (auto p : nonSyncLeaves) {
112-
auto range = rangeOfMappingParameter(root, p, params, id);
113-
min = std::min(min, range.first);
114-
max = std::max(max, range.second);
115-
}
116-
// Ignore min for now but there is a future possibility for shifting
117-
LOG_IF(WARNING, min > 0)
118-
<< "Opportunity for tightening launch bounds with shifting -> min:"
119-
<< min;
120-
// Inclusive range needs + 1 to translate to sizes
121-
if (max < sizetMax) { // avoid overflow
122-
return max + 1;
123-
}
124-
return sizetMax;
125-
};
126-
127-
USING_MAPPING_SHORT_NAMES(BX, BY, BZ, TX, TY, TZ);
128-
// Corner case: take the min with the current size to avoid degenerate
129-
// range in the unbounded case.
130120
return std::make_pair(
131-
tc::Grid({std::min(max(BX), BX.mappingSize(grid)),
132-
std::min(max(BY), BY.mappingSize(grid)),
133-
std::min(max(BZ), BZ.mappingSize(grid))}),
134-
tc::Block({std::min(max(TX), TX.mappingSize(block)),
135-
std::min(max(TY), TY.mappingSize(block)),
136-
std::min(max(TZ), TZ.mappingSize(block))}));
121+
launchBounds<mapping::BlockId>(scop, grid),
122+
launchBounds<mapping::ThreadId>(scop, block));
137123
}
138124
} // namespace polyhedral
139125
} // namespace tc

tc/core/polyhedral/exceptions.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ struct EmptyFilterException : public std::runtime_error {
2525
explicit EmptyFilterException(const std::string& s) : std::runtime_error(s) {}
2626
};
2727

28-
struct EmptyMappingFilterException : public std::runtime_error {
29-
explicit EmptyMappingFilterException(const std::string& s)
28+
struct EmptyMappingException : public std::runtime_error {
29+
explicit EmptyMappingException(const std::string& s)
3030
: std::runtime_error(s) {}
3131
};
3232

tc/core/polyhedral/mapping_types.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@ struct MappingId : public isl::id {
2828
public:
2929
MappingId(const MappingId& id) : isl::id(id), dim(id.dim) {}
3030

31+
// Is "id" a mapping of the type provided as template argument?
32+
template <typename MappingType>
33+
bool is() const {
34+
for (size_t i = 0; i < MappingType::kMaxDim; ++i) {
35+
if (*this == MappingType::makeId(i)) {
36+
return true;
37+
}
38+
}
39+
return false;
40+
}
41+
3142
// For indexing into positional arrays
3243
// TODO: this should go away but this probably requires tinkering with
3344
// mapping_options.h::Grid/Block.

tc/core/polyhedral/memory_promotion.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ std::unique_ptr<TensorReferenceGroup> TensorReferenceGroup::makeSingleton(
115115
return group;
116116
}
117117

118-
isl::set TensorReferenceGroup::approximateFootprint() const {
118+
isl::map TensorReferenceGroup::approximateFootprint() const {
119119
auto scopedDomain = scopedAccesses().domain();
120120
auto space = approximation.box.get_space();
121121
auto accessed = isl::map::universe(space).intersect_domain(scopedDomain);
@@ -134,7 +134,7 @@ isl::set TensorReferenceGroup::approximateFootprint() const {
134134

135135
accessed = accessed & partial;
136136
}
137-
return accessed.range();
137+
return accessed;
138138
}
139139

140140
isl::multi_aff ScopedFootprint::lowerBounds() const {
@@ -517,9 +517,8 @@ ScheduleTree* insertCopiesUnder(
517517
isl::set::universe(promotionSpace.domain().unwrap().domain());
518518
auto arrayId =
519519
promotionSpace.domain().unwrap().get_tuple_id(isl::dim_type::out);
520-
auto approximatedRead = scheduleUniverse.product(
521-
group.approximateFootprint().set_tuple_id(arrayId).intersect(
522-
tensorElements));
520+
auto approximatedRead =
521+
group.approximateFootprint().intersect_range(tensorElements).wrap();
523522
approximatedRead = approximatedRead.product(promotedFootprint);
524523
auto readExtension = extension.intersect_range(approximatedRead)
525524
.set_tuple_id(isl::dim_type::out, readId);

tc/core/polyhedral/memory_promotion.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ class TensorReferenceGroup {
142142
}
143143

144144
// Rectangular overapproximation of the set of tensor elements accessed below
145-
// the scoping point.
146-
isl::set approximateFootprint() const;
145+
// and relative to the scoping point.
146+
isl::map approximateFootprint() const;
147147

148148
isl::multi_aff promotion() const;
149149
isl::set promotedFootprint() const;

0 commit comments

Comments
 (0)