Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit a106e7e

Browse files
author
Sven Verdoolaege
committed
TensorReferenceGroup::approximateFootprint: return relative footprint
The footprint is computed relative to the outer schedule dimensions and makes little sense when those outer schedule dimensions are projected out. In particular, the test case assumed an implicit connection with the block identifiers, but this will be removed in a subsequent commit. It's better to consider the relation with the tile (i.e., the outer schedule dimensions).
1 parent 2e8a117 commit a106e7e

File tree

3 files changed

+11
-13
lines changed

3 files changed

+11
-13
lines changed

tc/core/polyhedral/memory_promotion.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ std::unique_ptr<TensorReferenceGroup> TensorReferenceGroup::makeSingleton(
115115
return group;
116116
}
117117

118-
isl::set TensorReferenceGroup::approximateFootprint() const {
118+
isl::map TensorReferenceGroup::approximateFootprint() const {
119119
auto scopedDomain = scopedAccesses().domain();
120120
auto space = approximation.box.get_space();
121121
auto accessed = isl::map::universe(space).intersect_domain(scopedDomain);
@@ -134,7 +134,7 @@ isl::set TensorReferenceGroup::approximateFootprint() const {
134134

135135
accessed = accessed & partial;
136136
}
137-
return accessed.range();
137+
return accessed;
138138
}
139139

140140
isl::multi_aff ScopedFootprint::lowerBounds() const {
@@ -517,9 +517,8 @@ ScheduleTree* insertCopiesUnder(
517517
isl::set::universe(promotionSpace.domain().unwrap().domain());
518518
auto arrayId =
519519
promotionSpace.domain().unwrap().get_tuple_id(isl::dim_type::out);
520-
auto approximatedRead = scheduleUniverse.product(
521-
group.approximateFootprint().set_tuple_id(arrayId).intersect(
522-
tensorElements));
520+
auto approximatedRead =
521+
group.approximateFootprint().intersect_range(tensorElements).wrap();
523522
approximatedRead = approximatedRead.product(promotedFootprint);
524523
auto readExtension = extension.intersect_range(approximatedRead)
525524
.set_tuple_id(isl::dim_type::out, readId);

tc/core/polyhedral/memory_promotion.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ class TensorReferenceGroup {
142142
}
143143

144144
// Rectangular overapproximation of the set of tensor elements accessed below
145-
// the scoping point.
146-
isl::set approximateFootprint() const;
145+
// and relative to the scoping point.
146+
isl::map approximateFootprint() const;
147147

148148
isl::multi_aff promotion() const;
149149
isl::set promotedFootprint() const;

test/test_cuda_mapper_memory_promotion.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,10 @@ def fun(float(N, M) A, float(N, M) B) -> (C) {
273273

274274
EXPECT_EQ(groups.size(), 3u);
275275

276-
USING_MAPPING_SHORT_NAMES(BX, BY, BZ, TX, TY, TZ);
277-
isl::space blockSpace = isl::space(ctx, 0);
278-
isl::set blockZero =
279-
isl::makeSpecializationSet<int>(blockSpace, {{BX, 0}, {BY, 0}});
276+
isl::space tileSpace = isl::space(ctx, 0).unnamed_set_from_params(2);
277+
// Work around missing isl_set_from_multi_aff.
278+
auto tileZero =
279+
isl::map(isl::multi_aff::zero(tileSpace.from_range())).range();
280280

281281
// Must have groups for these tensors, in arbitrary order.
282282
unordered_set<string> names{"A", "B", "C"};
@@ -305,8 +305,7 @@ def fun(float(N, M) A, float(N, M) B) -> (C) {
305305
EXPECT_EQ(
306306
oneGroup->approximation.size(1),
307307
isl::val(ctx, std::min(tile2, problemSize2)));
308-
auto footprint =
309-
oneGroup->approximateFootprint().intersect_params(blockZero);
308+
auto footprint = tileZero.apply(oneGroup->approximateFootprint());
310309
size_t np = npoints(footprint);
311310
EXPECT_EQ(
312311
np, std::min(tile1, problemSize1) * std::min(tile2, problemSize2));

0 commit comments

Comments
 (0)