Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 3abaacc

Browse files
authored
Merge pull request #472 from facebookresearch/strided-access
Support strided access in shared memory promotion
2 parents 3ed7f4a + 14d0ca9 commit 3abaacc

File tree

6 files changed

+213
-19
lines changed

6 files changed

+213
-19
lines changed

tc/core/polyhedral/memory_promotion.cc

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,50 @@ namespace polyhedral {
3131
using detail::ScheduleTree;
3232

3333
namespace {
34+
// Remove strides specified by "strides" and "offsets" from the range of
35+
// "relation". In particular, relation has a shape
36+
//
37+
// D -> O: o_i = offset_i + stride_i * f(D)
38+
//
39+
// transform it into
40+
//
41+
// D -> O: o_i = f(D)
42+
//
43+
// by subtracting "offsets" and by dividing the result by "strides".
44+
isl::map removeRangeStrides(
45+
isl::map relation,
46+
isl::multi_val strides,
47+
isl::multi_aff offsets) {
48+
CHECK_EQ(strides.size(), offsets.size());
49+
50+
auto space = relation.get_space();
51+
auto stridesMA = isl::multi_aff::identity(space.range().map_from_set());
52+
stridesMA = stridesMA / strides;
53+
54+
return relation.sum(isl::map(offsets.neg())).apply_range(isl::map(stridesMA));
55+
}
56+
57+
// Compute a box approximation of the range of the given relation,
58+
// including the lower bounds, the box sizes, and the strides.
59+
// If the range has strides, remove them first.
3460
ScopedFootprint outputRanges(isl::map access) {
61+
auto ctx = access.get_ctx();
62+
int nSubscripts = access.dim(isl::dim_type::out);
63+
64+
auto strides = isl::val_list(ctx, nSubscripts);
65+
auto strideOffsets = isl::aff_list(ctx, nSubscripts);
66+
for (int i = 0; i < nSubscripts; ++i) {
67+
auto si = access.get_range_stride_info(i);
68+
strides = strides.add(si.get_stride());
69+
strideOffsets = strideOffsets.add(si.get_offset());
70+
}
71+
3572
ScopedFootprint footprint;
73+
footprint.strideValues = isl::multi_val(access.get_space().range(), strides);
74+
footprint.strideOffsets = isl::multi_aff(access.get_space(), strideOffsets);
3675

37-
// TODO: also compute strides
76+
access = removeRangeStrides(
77+
access, footprint.strideValues, footprint.strideOffsets);
3878

3979
footprint.box = access.get_range_simple_fixed_box_hull();
4080
return footprint;
@@ -77,16 +117,23 @@ std::unique_ptr<TensorReferenceGroup> TensorReferenceGroup::makeSingleton(
77117
return group;
78118
}
79119

80-
isl::set ScopedFootprint::footprint(isl::set domain) const {
81-
auto space = box.get_space();
82-
auto accessed = isl::map::universe(space).intersect_domain(domain);
120+
isl::set TensorReferenceGroup::approximateFootprint() const {
121+
auto scopedDomain = scopedAccesses().domain();
122+
auto space = approximation.box.get_space();
123+
auto accessed = isl::map::universe(space).intersect_domain(scopedDomain);
83124
auto lspace = isl::local_space(accessed.get_space().range());
84125

85-
for (size_t i = 0; i < dim(); ++i) {
86-
auto dimLowerBound = lowerBound(i);
126+
for (size_t i = 0; i < approximation.dim(); ++i) {
127+
auto offset = approximation.lowerBound(i);
128+
auto stride = approximation.stride(i);
129+
auto strideOffset = approximation.strideOffset(i);
130+
auto size = approximation.size(i);
87131
auto rhs = isl::aff(lspace, isl::dim_type::set, i);
88-
isl::map partial = (isl::aff_map(dimLowerBound) <= rhs) &
89-
(isl::aff_map(dimLowerBound + size(i)) > rhs);
132+
auto lowerBound = offset * stride + strideOffset;
133+
auto upperBound = (offset + size) * stride + strideOffset;
134+
auto partial =
135+
(isl::aff_map(lowerBound) <= rhs) & (isl::aff_map(upperBound) > rhs);
136+
90137
accessed = accessed & partial;
91138
}
92139
return accessed.range();
@@ -303,7 +350,9 @@ TensorGroups TensorReferenceGroup::accessedBySubtree(
303350

304351
// Compute the relation between schedule dimensions, original and promoted array
305352
// subscripts, in the space
306-
// [S -> O] -> P
353+
// [S -> O] -> O.
354+
// The caller is in charge of updating the tuple of the target space with the
355+
// group identifier.
307356
// The mapping depends on the original schedule dimensions because the same
308357
// elements of the promoted array get assigned different values of the original
309358
// array in different outer loop iterations; it's impossible to project out the
@@ -313,10 +362,20 @@ isl::multi_aff TensorReferenceGroup::promotion() const {
313362
isl::map map = scopedAccesses();
314363
auto accessSpace = map.get_space();
315364

316-
// lower bounds space is S -> P; which we transform into [S -> O] -> P
317-
auto lowerBounds = approximation.lowerBounds().pullback(
318-
isl::multi_aff::domain_map(accessSpace));
319-
auto promotion = isl::multi_aff::range_map(accessSpace) - lowerBounds;
365+
// Construct a projection multi-aff in [S -> O] -> S
366+
// for further precomposition.
367+
auto originalSpaceInserter = isl::multi_aff::domain_map(accessSpace);
368+
369+
// Lower bounds and offsets space is S -> O; transform into [S -> O] -> O.
370+
auto lowerBounds =
371+
approximation.lowerBounds().pullback(originalSpaceInserter);
372+
auto offsets = approximation.strideOffsets.pullback(originalSpaceInserter);
373+
374+
// Create promotion starting by identity in [S -> O] -> O.
375+
auto original = isl::multi_aff::range_map(accessSpace);
376+
auto promotion =
377+
(original - offsets) / approximation.strideValues - lowerBounds;
378+
320379
return promotion;
321380
}
322381

tc/core/polyhedral/memory_promotion.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ enum class AccessType : short { Read, Write };
3333
// Rectangular overapproximation of a tensor elements accessed through a single
3434
// reference.
3535
// Each dimension is overapproximated by a lower bound, an affine function of
36-
// parameters and schedule dimensions visible around the scope, and by a
37-
// constant size.
36+
// parameters and schedule dimensions visible around the scope, by a
37+
// constant size, and by a pair offset/stride for strided accesses. If the
38+
// access is not strided, then "offset" is a zero expression and "stride" is 1.
39+
// The lowerBound and the size are computed after removing the potential stride.
3840
// The scope is defined by a specific position in a schedule tree (const
3941
// ScheduleTree*), the user is responsible for maintaining the correspondance
4042
// between schedule tree positions and footprints.
@@ -48,8 +50,17 @@ struct ScopedFootprint {
4850
isl::aff lowerBound(size_t pos) const {
4951
return box.get_offset().get_aff(pos);
5052
}
53+
isl::val stride(size_t pos) const {
54+
return strideValues.get_val(pos);
55+
}
56+
isl::aff strideOffset(size_t pos) const {
57+
return strideOffsets.get_aff(pos);
58+
}
59+
5160
isl::fixed_box box;
52-
isl::set footprint(isl::set domain) const;
61+
isl::multi_val strideValues;
62+
isl::multi_aff strideOffsets;
63+
5364
isl::multi_aff lowerBounds() const;
5465
};
5566

@@ -131,9 +142,7 @@ class TensorReferenceGroup {
131142

132143
// Rectangular overapproximation of the set of tensor elements accessed below
133144
// the scoping point.
134-
isl::set approximateFootprint() const {
135-
return approximation.footprint(scopedAccesses().domain());
136-
}
145+
isl::set approximateFootprint() const;
137146

138147
isl::multi_aff promotion() const;
139148
isl::set promotedFootprint() const;

tc/core/polyhedral/schedule_tree_matcher-inl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ inline ScheduleTreeMatcher context(Args... children) {
6161
return ScheduleTreeMatcher(detail::ScheduleTreeType::Context, children...);
6262
}
6363

64+
template <typename... Args>
65+
inline ScheduleTreeMatcher threadSpecific(Args... children) {
66+
return ScheduleTreeMatcher(
67+
detail::ScheduleTreeType::ThreadSpecificMarker, children...);
68+
}
69+
6470
template <typename... Args>
6571
inline ScheduleTreeMatcher filter(
6672
std::function<bool(isl::union_set)> propertyMatcher,

tc/external/detail/islpp-inl.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,13 @@ inline isl::map operator<=(isl::aff_map A, isl::aff B) {
185185
return A < B + 1;
186186
}
187187

188+
///////////////////////////////////////////////////////////////////////////////
189+
// Operations on isl::multi_aff
190+
///////////////////////////////////////////////////////////////////////////////
191+
inline isl::multi_aff operator/(isl::multi_aff left, isl::multi_val right) {
192+
return left.scale_down(right);
193+
}
194+
188195
///////////////////////////////////////////////////////////////////////////////
189196
// Operations on isl::set and isl::union_set
190197
///////////////////////////////////////////////////////////////////////////////

tc/external/detail/islpp.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,11 @@ isl::map operator<=(isl::aff_map A, isl::aff B);
176176
isl::map operator>(isl::aff_map A, isl::aff B);
177177
isl::map operator<(isl::aff_map A, isl::aff B);
178178

179+
///////////////////////////////////////////////////////////////////////////////
180+
// Operations on isl::multi_aff
181+
///////////////////////////////////////////////////////////////////////////////
182+
isl::multi_aff operator/(isl::multi_aff left, isl::multi_val right);
183+
179184
///////////////////////////////////////////////////////////////////////////////
180185
// Operations on isl::set and isl::union_set
181186
///////////////////////////////////////////////////////////////////////////////

test/test_cuda_mapper_memory_promotion.cc

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,114 @@ TEST_F(MatMulBias, RegisterPromotionSharedPreference) {
504504
<< "tensor A promoted to register but has elements accessed by multiple threads";
505505
}
506506

507+
class Strided : public TestMapper {
508+
public:
509+
std::unique_ptr<MappedScop> makeScopAndCheck(
510+
const std::string& tc,
511+
const std::unordered_map<std::string, size_t>& sizes,
512+
long groupISize,
513+
long groupIStride,
514+
long groupIConstOffset) {
515+
auto options = CudaMappingOptions::makeNaiveMappingOptions()
516+
.tile(32, 32)
517+
.mapToThreads(21, 21)
518+
.useSharedMemory(false);
519+
auto mscop = makeMappedScop(tc, options, sizes);
520+
auto& scop = mscop->scop();
521+
auto ctx = scop.domain().get_ctx();
522+
523+
auto groups = TensorReferenceGroup::accessedBySubtree(
524+
scop.scheduleRoot()->child({0, 0, 0}), scop);
525+
EXPECT_EQ(groups.size(), 2u) << "expected groups for both tensors";
526+
527+
for (const auto& g : groups) {
528+
auto name = g.first.get_name();
529+
if (name != "I") {
530+
continue;
531+
}
532+
533+
const auto& perTensorGroups = g.second;
534+
// One cannot use ASSERT_EQ in a function that returns something, because
535+
// it would trigger an immediate return without value. Use EXPECT_EQ and
536+
// return nullptr manually.
537+
EXPECT_EQ(perTensorGroups.size(), 1u) << "expected one group for I";
538+
if (perTensorGroups.size() != 1u) {
539+
return nullptr;
540+
}
541+
const auto& oneGroup = perTensorGroups[0];
542+
543+
EXPECT_EQ(oneGroup->references.size(), 1u)
544+
<< "expected one reference in the group for I";
545+
if (oneGroup->references.size() != 1u) {
546+
return nullptr;
547+
}
548+
const auto& ref = oneGroup->references[0];
549+
550+
EXPECT_EQ(oneGroup->approximation.dim(), 2u)
551+
<< "could not compute approximation for " << ref->scopedAccess;
552+
553+
EXPECT_EQ(oneGroup->approximation.size(1), isl::val(ctx, groupISize))
554+
<< "expected strides to be removed";
555+
556+
isl::val stride = isl::val(ctx, groupIStride);
557+
EXPECT_EQ(oneGroup->approximation.stride(1), stride);
558+
559+
auto expectedOffset =
560+
isl::aff::zero_on_domain(ref->scopedAccess.domain().get_space()) +
561+
groupIConstOffset;
562+
// Convert to pw_aff because it has is_equal whereas a simple aff only has
563+
// is_plain_equal that fails here.
564+
EXPECT_TRUE(
565+
isl::pw_aff(oneGroup->approximation.strideOffset(1).mod(stride))
566+
.is_equal(isl::pw_aff(expectedOffset.mod(stride))))
567+
<< oneGroup->approximation.strideOffset(1) << "\n"
568+
<< expectedOffset;
569+
}
570+
return mscop;
571+
}
572+
};
573+
574+
// Check that strides are effectively handled in memory promotion. In
575+
// particular, check that array elements that are jumped over
576+
// by the main computation are not copied into shared memory.
577+
TEST_F(Strided, Stride2) {
578+
std::string tc = R"TC(
579+
def strided(float(N,M) I) -> (O) {
580+
O(i, j) = I(j, 2 * i + 1)
581+
}
582+
)TC";
583+
584+
// Expect the promoted size to be 32x32, with stride 2 and offset -1 along
585+
// the second dimension.
586+
auto mscop = makeScopAndCheck(tc, {{"N", 42}, {"M", 420}}, 32, 2, -1);
587+
ASSERT_TRUE(mscop.get() != nullptr);
588+
auto& scop = mscop->scop();
589+
590+
// Additionally check that copies look fine.
591+
scop.promoteEverythingAt({0, 0, 0});
592+
auto code = std::get<0>(mscop->codegen("strided"));
593+
EXPECT_TRUE(
594+
code.find("_I_0[c2][c3] = I[32 * b1 + c2][64 * b0 + 2 * c3 + 1]") !=
595+
std::string::npos)
596+
<< "expected strided accesses to global array in copies";
597+
EXPECT_TRUE(code.find("= _I_0[c3][c2]") != std::string::npos)
598+
<< "expected non-strided access to promoted array in main computation";
599+
EXPECT_TRUE(code.find("= _I_0[c3][2 * c2") == std::string::npos)
600+
<< "did not expect strided access to promoted array in main computation";
601+
}
602+
603+
TEST_F(Strided, Stride5) {
604+
std::string tc = R"TC(
605+
def strided(float(N,M) I) -> (O) {
606+
O(i, j) = I(j, 5 * i)
607+
}
608+
)TC";
609+
610+
// Expect the promoted size to be 32x32, with stride 5 and offset 0 along
611+
// the second dimension.
612+
makeScopAndCheck(tc, {{"N", 42}, {"M", 420}}, 32, 5, 0);
613+
}
614+
507615
int main(int argc, char** argv) {
508616
::testing::InitGoogleTest(&argc, argv);
509617
::gflags::ParseCommandLineFlags(&argc, &argv, true);

0 commit comments

Comments
 (0)