Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 1c7f3f8

Browse files
author
Sven Verdoolaege
committed
[RFC] use templated isl types isPromotableToRegistersBelow
Templated isl types require the user to specify the domain and range universes of isl objects, allowing the compiler to check whether it makes sense to combine pairs of objects. This RFC only converts isPromotableToRegistersBelow and some related functions to illustrate the effect. The isPromotableToRegistersBelow was already applying operations correctly, so the code itself did not require any changes. However, one variable was reused to store different types of intermediate result and this one had to be split up into several variables because they now have different types.
1 parent 0eac520 commit 1c7f3f8

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -370,23 +370,25 @@ bool accessSubscriptsAreUnrolledLoops(
370370
* thread associated to a given pair of tensor element and outer schedule
371371
* iteration.
372372
*/
373+
template <typename Outer>
373374
bool isPromotableToRegistersBelow(
374375
const TensorReferenceGroup& group,
375376
const detail::ScheduleTree* root,
376377
const detail::ScheduleTree* scope,
377-
isl::multi_union_pw_aff outer,
378-
isl::multi_union_pw_aff thread) {
378+
isl::MultiUnionPwAff<Statement, Outer> outer,
379+
isl::MultiUnionPwAff<Statement, Thread> thread) {
379380
if (!accessSubscriptsAreUnrolledLoops(
380-
group, root, scope, outer.flat_range_product(thread))) {
381+
group, root, scope, outer.range_product(thread))) {
381382
return false;
382383
}
383384

384385
auto originalAccesses = group.originalAccesses();
385-
auto map = isl::union_map::from(outer);
386-
map = map.range_product(originalAccesses);
387-
map = map.apply_domain(isl::union_map::from(thread));
386+
auto outerMap = isl::UnionMap<Statement, Outer>::from(outer);
387+
auto pair = outerMap.range_product(originalAccesses);
388+
auto threadMap = isl::UnionMap<Statement, Thread>::from(thread);
389+
auto threadToPair = pair.apply_domain(threadMap);
388390

389-
return map.is_injective();
391+
return threadToPair.is_injective();
390392
}
391393

392394
/*

0 commit comments

Comments
 (0)