Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 14d0ca9

Browse files
ftynseSven Verdoolaege
authored andcommitted
memory promotion: handle strides
The original implementation of memory promotion ignored strides in accesses to simplify the code. isl recently introduced support for extracting stride information from sets, making stride manipulation easy in TC. Introduce special handling for strided support into TensorReferenceGroup and related classes. An access is strided if the access function has the shape (a_i - offset_i) = 0 mod stride_i where stride_i is a constant and offset_i is some affine expression on the iteration domain. Use isl to compute offsets and strides in access relations. Use this information to promote to shared memory only those tensor elements that are actually read in case of strided accesses. This decreases the amount of shared memory used by a kernel with such accesses. This also prepares for the introduction of register promotion where accesses of each thread individually are strided with the stride equal to the number of threads. Note that references accessing disjoint sets of elements with strides are not grouped even if their non-strided footprints overlap, e.g. A[2*i] and A[2*i + 1] belong to different groups. This may decrease the benefit of coalesced reads when copying between global and shared memory. At the same time, it also decreases the required shared memory size making it to promote one of the references in case where a group with two references would not fit. The profitability of such grouping requires further exploration.
1 parent d1673af commit 14d0ca9

File tree

3 files changed

+189
-11
lines changed

3 files changed

+189
-11
lines changed

tc/core/polyhedral/memory_promotion.cc

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,50 @@ namespace polyhedral {
3131
using detail::ScheduleTree;
3232

3333
namespace {
34+
// Remove strides specified by "strides" and "offsets" from the range of
35+
// "relation". In particular, relation has a shape
36+
//
37+
// D -> O: o_i = offset_i + stride_i * f(D)
38+
//
39+
// transform it into
40+
//
41+
// D -> O: o_i = f(D)
42+
//
43+
// by subtracting "offsets" and by dividing the result by "strides".
44+
isl::map removeRangeStrides(
45+
isl::map relation,
46+
isl::multi_val strides,
47+
isl::multi_aff offsets) {
48+
CHECK_EQ(strides.size(), offsets.size());
49+
50+
auto space = relation.get_space();
51+
auto stridesMA = isl::multi_aff::identity(space.range().map_from_set());
52+
stridesMA = stridesMA / strides;
53+
54+
return relation.sum(isl::map(offsets.neg())).apply_range(isl::map(stridesMA));
55+
}
56+
57+
// Compute a box approximation of the range of the given relation,
58+
// including the lower bounds, the box sizes, and the strides.
59+
// If the range has strides, remove them first.
3460
ScopedFootprint outputRanges(isl::map access) {
61+
auto ctx = access.get_ctx();
62+
int nSubscripts = access.dim(isl::dim_type::out);
63+
64+
auto strides = isl::val_list(ctx, nSubscripts);
65+
auto strideOffsets = isl::aff_list(ctx, nSubscripts);
66+
for (int i = 0; i < nSubscripts; ++i) {
67+
auto si = access.get_range_stride_info(i);
68+
strides = strides.add(si.get_stride());
69+
strideOffsets = strideOffsets.add(si.get_offset());
70+
}
71+
3572
ScopedFootprint footprint;
73+
footprint.strideValues = isl::multi_val(access.get_space().range(), strides);
74+
footprint.strideOffsets = isl::multi_aff(access.get_space(), strideOffsets);
3675

37-
// TODO: also compute strides
76+
access = removeRangeStrides(
77+
access, footprint.strideValues, footprint.strideOffsets);
3878

3979
footprint.box = access.get_range_simple_fixed_box_hull();
4080
return footprint;
@@ -84,10 +124,16 @@ isl::set TensorReferenceGroup::approximateFootprint() const {
84124
auto lspace = isl::local_space(accessed.get_space().range());
85125

86126
for (size_t i = 0; i < approximation.dim(); ++i) {
87-
auto dimLowerBound = approximation.lowerBound(i);
127+
auto offset = approximation.lowerBound(i);
128+
auto stride = approximation.stride(i);
129+
auto strideOffset = approximation.strideOffset(i);
130+
auto size = approximation.size(i);
88131
auto rhs = isl::aff(lspace, isl::dim_type::set, i);
89-
isl::map partial = (isl::aff_map(dimLowerBound) <= rhs) &
90-
(isl::aff_map(dimLowerBound + approximation.size(i)) > rhs);
132+
auto lowerBound = offset * stride + strideOffset;
133+
auto upperBound = (offset + size) * stride + strideOffset;
134+
auto partial =
135+
(isl::aff_map(lowerBound) <= rhs) & (isl::aff_map(upperBound) > rhs);
136+
91137
accessed = accessed & partial;
92138
}
93139
return accessed.range();
@@ -304,7 +350,9 @@ TensorGroups TensorReferenceGroup::accessedBySubtree(
304350

305351
// Compute the relation between schedule dimensions, original and promoted array
306352
// subscripts, in the space
307-
// [S -> O] -> P
353+
// [S -> O] -> O.
354+
// The caller is in charge of updating the tuple of the target space with the
355+
// group identifier.
308356
// The mapping depends on the original schedule dimensions because the same
309357
// elements of the promoted array get assigned different values of the original
310358
// array in different outer loop iterations; it's impossible to project out the
@@ -314,10 +362,20 @@ isl::multi_aff TensorReferenceGroup::promotion() const {
314362
isl::map map = scopedAccesses();
315363
auto accessSpace = map.get_space();
316364

317-
// lower bounds space is S -> P; which we transform into [S -> O] -> P
318-
auto lowerBounds = approximation.lowerBounds().pullback(
319-
isl::multi_aff::domain_map(accessSpace));
320-
auto promotion = isl::multi_aff::range_map(accessSpace) - lowerBounds;
365+
// Construct a projection multi-aff in [S -> O] -> S
366+
// for further precomposition.
367+
auto originalSpaceInserter = isl::multi_aff::domain_map(accessSpace);
368+
369+
// Lower bounds and offsets space is S -> O; transform into [S -> O] -> O.
370+
auto lowerBounds =
371+
approximation.lowerBounds().pullback(originalSpaceInserter);
372+
auto offsets = approximation.strideOffsets.pullback(originalSpaceInserter);
373+
374+
// Create promotion starting by identity in [S -> O] -> O.
375+
auto original = isl::multi_aff::range_map(accessSpace);
376+
auto promotion =
377+
(original - offsets) / approximation.strideValues - lowerBounds;
378+
321379
return promotion;
322380
}
323381

tc/core/polyhedral/memory_promotion.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ enum class AccessType : short { Read, Write };
3333
// Rectangular overapproximation of a tensor elements accessed through a single
3434
// reference.
3535
// Each dimension is overapproximated by a lower bound, an affine function of
36-
// parameters and schedule dimensions visible around the scope, and by a
37-
// constant size.
36+
// parameters and schedule dimensions visible around the scope, by a
37+
// constant size, and by a pair offset/stride for strided accesses. If the
38+
// access is not strided, then "offset" is a zero expression and "stride" is 1.
39+
// The lowerBound and the size are computed after removing the potential stride.
3840
// The scope is defined by a specific position in a schedule tree (const
3941
// ScheduleTree*), the user is responsible for maintaining the correspondance
4042
// between schedule tree positions and footprints.
@@ -48,7 +50,17 @@ struct ScopedFootprint {
4850
isl::aff lowerBound(size_t pos) const {
4951
return box.get_offset().get_aff(pos);
5052
}
53+
isl::val stride(size_t pos) const {
54+
return strideValues.get_val(pos);
55+
}
56+
isl::aff strideOffset(size_t pos) const {
57+
return strideOffsets.get_aff(pos);
58+
}
59+
5160
isl::fixed_box box;
61+
isl::multi_val strideValues;
62+
isl::multi_aff strideOffsets;
63+
5264
isl::multi_aff lowerBounds() const;
5365
};
5466

test/test_cuda_mapper_memory_promotion.cc

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,114 @@ TEST_F(MatMulBias, RegisterPromotionSharedPreference) {
504504
<< "tensor A promoted to register but has elements accessed by multiple threads";
505505
}
506506

507+
class Strided : public TestMapper {
508+
public:
509+
std::unique_ptr<MappedScop> makeScopAndCheck(
510+
const std::string& tc,
511+
const std::unordered_map<std::string, size_t>& sizes,
512+
long groupISize,
513+
long groupIStride,
514+
long groupIConstOffset) {
515+
auto options = CudaMappingOptions::makeNaiveMappingOptions()
516+
.tile(32, 32)
517+
.mapToThreads(21, 21)
518+
.useSharedMemory(false);
519+
auto mscop = makeMappedScop(tc, options, sizes);
520+
auto& scop = mscop->scop();
521+
auto ctx = scop.domain().get_ctx();
522+
523+
auto groups = TensorReferenceGroup::accessedBySubtree(
524+
scop.scheduleRoot()->child({0, 0, 0}), scop);
525+
EXPECT_EQ(groups.size(), 2u) << "expected groups for both tensors";
526+
527+
for (const auto& g : groups) {
528+
auto name = g.first.get_name();
529+
if (name != "I") {
530+
continue;
531+
}
532+
533+
const auto& perTensorGroups = g.second;
534+
// One cannot use ASSERT_EQ in a function that returns something, because
535+
// it would trigger an immediate return without value. Use EXPECT_EQ and
536+
// return nullptr manually.
537+
EXPECT_EQ(perTensorGroups.size(), 1u) << "expected one group for I";
538+
if (perTensorGroups.size() != 1u) {
539+
return nullptr;
540+
}
541+
const auto& oneGroup = perTensorGroups[0];
542+
543+
EXPECT_EQ(oneGroup->references.size(), 1u)
544+
<< "expected one reference in the group for I";
545+
if (oneGroup->references.size() != 1u) {
546+
return nullptr;
547+
}
548+
const auto& ref = oneGroup->references[0];
549+
550+
EXPECT_EQ(oneGroup->approximation.dim(), 2u)
551+
<< "could not compute approximation for " << ref->scopedAccess;
552+
553+
EXPECT_EQ(oneGroup->approximation.size(1), isl::val(ctx, groupISize))
554+
<< "expected strides to be removed";
555+
556+
isl::val stride = isl::val(ctx, groupIStride);
557+
EXPECT_EQ(oneGroup->approximation.stride(1), stride);
558+
559+
auto expectedOffset =
560+
isl::aff::zero_on_domain(ref->scopedAccess.domain().get_space()) +
561+
groupIConstOffset;
562+
// Convert to pw_aff because it has is_equal whereas a simple aff only has
563+
// is_plain_equal that fails here.
564+
EXPECT_TRUE(
565+
isl::pw_aff(oneGroup->approximation.strideOffset(1).mod(stride))
566+
.is_equal(isl::pw_aff(expectedOffset.mod(stride))))
567+
<< oneGroup->approximation.strideOffset(1) << "\n"
568+
<< expectedOffset;
569+
}
570+
return mscop;
571+
}
572+
};
573+
574+
// Check that strides are effectively handled in memory promotion. In
575+
// particular, check that array elements that are jumped over
576+
// by the main computation are not copied into shared memory.
577+
TEST_F(Strided, Stride2) {
578+
std::string tc = R"TC(
579+
def strided(float(N,M) I) -> (O) {
580+
O(i, j) = I(j, 2 * i + 1)
581+
}
582+
)TC";
583+
584+
// Expect the promoted size to be 32x32, with stride 2 and offset -1 along
585+
// the second dimension.
586+
auto mscop = makeScopAndCheck(tc, {{"N", 42}, {"M", 420}}, 32, 2, -1);
587+
ASSERT_TRUE(mscop.get() != nullptr);
588+
auto& scop = mscop->scop();
589+
590+
// Additionally check that copies look fine.
591+
scop.promoteEverythingAt({0, 0, 0});
592+
auto code = std::get<0>(mscop->codegen("strided"));
593+
EXPECT_TRUE(
594+
code.find("_I_0[c2][c3] = I[32 * b1 + c2][64 * b0 + 2 * c3 + 1]") !=
595+
std::string::npos)
596+
<< "expected strided accesses to global array in copies";
597+
EXPECT_TRUE(code.find("= _I_0[c3][c2]") != std::string::npos)
598+
<< "expected non-strided access to promoted array in main computation";
599+
EXPECT_TRUE(code.find("= _I_0[c3][2 * c2") == std::string::npos)
600+
<< "did not expect strided access to promoted array in main computation";
601+
}
602+
603+
TEST_F(Strided, Stride5) {
604+
std::string tc = R"TC(
605+
def strided(float(N,M) I) -> (O) {
606+
O(i, j) = I(j, 5 * i)
607+
}
608+
)TC";
609+
610+
// Expect the promoted size to be 32x32, with stride 5 and offset 0 along
611+
// the second dimension.
612+
makeScopAndCheck(tc, {{"N", 42}, {"M", 420}}, 32, 5, 0);
613+
}
614+
507615
int main(int argc, char** argv) {
508616
::testing::InitGoogleTest(&argc, argv);
509617
::gflags::ParseCommandLineFlags(&argc, &argv, true);

0 commit comments

Comments
 (0)