Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit ef644ba

Browse files
Merge pull request #265 from facebookresearch/pr/param_order
remove code that depends on order of parameters in isl objects
2 parents c60eff3 + ca98a28 commit ef644ba

File tree

4 files changed

+22
-41
lines changed

4 files changed

+22
-41
lines changed

include/tc/core/polyhedral/scop.h

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -118,22 +118,12 @@ struct Scop {
118118
writes = writes.intersect_params(globalParameterContext);
119119
}
120120

121-
// Returns a set that specializes (all) the scop's parameter space to the
122-
// integer values passed to the function.
123-
// WARNING: this version relies on parameter ordering, be sure you know what
124-
// you are doing.
125-
template <typename T>
126-
isl::set makeContext(const std::vector<T>& sizes = std::vector<T>()) const {
127-
auto s = domain().get_space().params();
128-
return makeSpecializationSet(s, sizes);
129-
}
130-
131-
// Returns a set that specializes the (positional) scop's subset of
121+
// Returns a set that specializes the named scop's subset of
132122
// parameter space to the integer values passed to the function.
133123
template <typename T>
134124
isl::set makeContext(
135-
const std::unordered_map<int, T>& sizes =
136-
std::unordered_map<int, T>()) const {
125+
const std::unordered_map<std::string, T>& sizes =
126+
std::unordered_map<std::string, T>()) const {
137127
auto s = domain().get_space().params();
138128
return makeSpecializationSet(s, sizes);
139129
}
@@ -142,8 +132,7 @@ struct Scop {
142132
// parameter space to the integer values passed to the function.
143133
template <typename T>
144134
isl::set makeContext(
145-
const std::unordered_map<std::string, T>& sizes =
146-
std::unordered_map<std::string, T>()) const {
135+
std::initializer_list<std::pair<const std::string, T>> sizes) {
147136
auto s = domain().get_space().params();
148137
return makeSpecializationSet(s, sizes);
149138
}

include/tc/external/detail/islpp.h

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -317,17 +317,19 @@ inline isl::set makeParameterContext(
317317

318318
// Given a space and values for parameters, this function creates the set
319319
// that ties the space parameter to the values.
320-
// This assumes space.dim(isl::dim_type::param) == paramValues.size()
321320
//
322321
template <typename T>
323322
inline isl::set makeSpecializationSet(
324323
isl::space space,
325-
const std::unordered_map<int, T>& paramValues) {
326-
CHECK_GE(space.dim(isl::dim_type::param), paramValues.size());
327-
auto lspace = isl::local_space(space);
324+
const std::unordered_map<std::string, T>& paramValues) {
325+
auto ctx = space.get_ctx();
326+
for (auto kvp : paramValues) {
327+
space = space.add_param(isl::id(ctx, kvp.first));
328+
}
328329
auto set = isl::set::universe(space);
329330
for (auto kvp : paramValues) {
330-
auto affParam = isl::aff(lspace, isl::dim_type::param, kvp.first);
331+
auto id = isl::id(ctx, kvp.first);
332+
isl::aff affParam(isl::aff::param_on_domain_space(space, id));
331333
set = set & (isl::aff_set(affParam) == kvp.second);
332334
}
333335
return set;
@@ -336,30 +338,20 @@ inline isl::set makeSpecializationSet(
336338
template <typename T>
337339
inline isl::set makeSpecializationSet(
338340
isl::space space,
339-
const std::unordered_map<std::string, T>& paramValues) {
340-
CHECK_GE(space.dim(isl::dim_type::param), paramValues.size());
341-
std::unordered_map<int, T> aux;
342-
for (auto kvp : paramValues) {
343-
auto pos = space.find_dim_by_name(isl::dim_type::param, kvp.first);
344-
CHECK_LE(0, pos) << "No " << kvp.first << " in: " << space;
345-
CHECK_EQ(0, aux.count(pos));
346-
aux[pos] = kvp.second;
347-
}
348-
return makeSpecializationSet(space, aux);
341+
std::initializer_list<std::pair<const std::string, T>> paramValues) {
342+
std::unordered_map<std::string, T> map(paramValues);
343+
return makeSpecializationSet(space, map);
349344
}
350345

351-
// WARNING: this version relies on parameter ordering, be sure you know what
352-
// you are doing.
353346
template <typename T>
354347
inline isl::set makeSpecializationSet(
355348
isl::space space,
356-
const std::vector<T>& paramValues) {
357-
CHECK_EQ(space.dim(isl::dim_type::param), paramValues.size());
358-
std::unordered_map<int, T> paramValuesMap;
359-
for (int i = 0; i < paramValues.size(); ++i) {
360-
paramValuesMap[i] = paramValues[i];
349+
std::initializer_list<std::pair<isl::id, T>> paramValues) {
350+
std::unordered_map<std::string, T> map;
351+
for (auto kvp : paramValues) {
352+
map.emplace(kvp.first.get_name(), kvp.second);
361353
}
362-
return makeSpecializationSet(space, paramValuesMap);
354+
return makeSpecializationSet(space, map);
363355
}
364356

365357
namespace detail {

test/test_mapper.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ TEST_F(PolyhedralMapperTest, MergedContexts) {
333333
auto scop = PrepareAndJoinBands(makeMatmulTc());
334334

335335
// Unit test claims to use scop->globalParameterContext properly
336-
auto context = scop->makeContext(std::vector<int>{64, 64, 64});
336+
auto context = scop->makeContext<int>({{"M", 64}, {"N", 64}, {"K", 64}});
337337
auto& globalParameterContext =
338338
const_cast<isl::set&>(scop->globalParameterContext);
339339
globalParameterContext = globalParameterContext.intersect(context);
@@ -349,7 +349,7 @@ TEST_F(PolyhedralMapperTest, FilterMerge) {
349349
auto schedule = scop->scheduleRoot();
350350

351351
// Unit test claims to use scop->globalParameterContext properly
352-
auto context = scop->makeContext(std::vector<int>{64, 64, 64});
352+
auto context = scop->makeContext<int>({{"M", 64}, {"N", 64}, {"K", 64}});
353353
auto& globalParameterContext =
354354
const_cast<isl::set&>(scop->globalParameterContext);
355355
globalParameterContext = globalParameterContext.intersect(context);

test/test_mapper_memory_promotion.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def fun(float(N, M) A, float(N, M) B) -> (C) {
275275
blockSpace = blockSpace.set_dim_id(isl::dim_type::param, 0, BX)
276276
.set_dim_id(isl::dim_type::param, 1, BY);
277277
isl::set blockZero =
278-
makeSpecializationSet(blockSpace, std::vector<int>{{0, 0}});
278+
isl::makeSpecializationSet<int>(blockSpace, {{BX, 0}, {BY, 0}});
279279

280280
// Must have groups for these tensors, in arbitrary order.
281281
unordered_set<string> names{"A", "B", "C"};

0 commit comments

Comments
 (0)