Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 04fa3ed

Browse files
Sven Verdoolaegeftynse
authored andcommitted
emitHalideExpr: use isl AST expression generator to print expression
The printed expression is derived from plugging in the inverse schedule. This inverse schedule may have multiple disjuncts, especially if the original schedule tree has disjunctive filters, such as is the case in #200. The expression can therefore not be assumed to consist of a single disjunct. Use the AST expression generator instead of manually trying to pick the affine expression apart. Ideally, this should use the AST build at the point where the user statement was created, but this requires more invasive changes. Closes #200
1 parent 4798154 commit 04fa3ed

File tree

4 files changed

+58
-19
lines changed

4 files changed

+58
-19
lines changed

src/core/polyhedral/codegen_cuda.cc

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -471,13 +471,11 @@ std::string toString(isl::aff subscript) {
471471
}
472472

473473
std::string toString(isl::pw_aff subscript) {
474-
isl::aff subscriptAff = isl::null<isl::aff>();
475-
subscript.foreach_piece([&](isl::set domain, isl::aff aff) {
476-
CHECK(!subscriptAff.get()) << "expected one piece";
477-
subscriptAff = aff;
478-
});
479-
480-
return toString(subscriptAff);
474+
// Use a temporary isl::ast_build to print the expression.
475+
// Ideally, this should use the build at the point
476+
// where the user statement was created.
477+
auto astBuild = isl::ast_build::from_context(subscript.domain());
478+
return astBuild.expr_from(subscript).to_C_str();
481479
}
482480

483481
isl::pw_aff makeAffFromMappedExpr(

test/test_mapper.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def fun(float(N, M) A, float(N, M) B) -> (C) {
186186
float32 (*B)[M] = reinterpret_cast<float32 (*)[M]>(pB);
187187
for (int c1 = 16 * b1; c1 < M; c1 += 4096) {
188188
if (M >= t1 + c1 + 1) {
189-
C[t0 + 16*b0][t1 + c1] = (A[t0 + 16*b0][t1 + c1] + B[t0 + 16*b0][t1 + c1]);
189+
C[t0 + 16 * b0][t1 + c1] = (A[t0 + 16 * b0][t1 + c1] + B[t0 + 16 * b0][t1 + c1]);
190190
}
191191
}
192192
}
@@ -442,7 +442,7 @@ TEST_F(PolyhedralMapperTest, Unroll1D) {
442442
auto mscop = MappedScop::makeWithOuterBlockInnerThreadStrategy(
443443
std::move(scop), mappingOptions);
444444
auto code = std::get<0>(mscop->codegen(specializedName));
445-
std::string expected("C[64*b0 + c2][t0 + 64*b1]");
445+
std::string expected("C[64 * b0 + c2][t0 + 64 * b1]");
446446
ASSERT_TRUE(code.find(expected) != std::string::npos) << code;
447447
}
448448

@@ -461,7 +461,7 @@ TEST_F(PolyhedralMapperTest, Unroll2D) {
461461
auto mscop = MappedScop::makeWithOuterBlockInnerThreadStrategy(
462462
std::move(scop), mappingOptions);
463463
auto code = std::get<0>(mscop->codegen(specializedName));
464-
std::string expected("C[32 + t1 + 64*b0][32 + t0 + 64*b1]");
464+
std::string expected("C[t1 + 64 * b0 + 32][t0 + 64 * b1 + 32]");
465465
ASSERT_TRUE(code.find(expected) != std::string::npos);
466466
}
467467

test/test_mapper_memory_promotion.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,13 @@ TEST_F(Sum4D, CodeOuterBand) {
118118
"__shared__ float32 _C_0[16][16][16][16];"};
119119

120120
auto copyA =
121-
"_A_0[c4][c5][c6][c7] = A[16*b0 + c4][16*b1 + c5][c2 + c6][c3 + c7];";
121+
"_A_0[c4][c5][c6][c7] = A[16 * b0 + c4][16 * b1 + c5][c2 + c6][c3 + c7];";
122122
auto copyB =
123-
"_B_0[c4][c5][c6][c7] = B[16*b0 + c4][16*b1 + c5][c2 + c6][c3 + c7];";
123+
"_B_0[c4][c5][c6][c7] = B[16 * b0 + c4][16 * b1 + c5][c2 + c6][c3 + c7];";
124124
auto compute =
125125
"_C_0[c4][c5][c6][t0] = (_A_0[c4][c5][c6][t0] + _B_0[c4][c5][c6][t0]);";
126126
auto copyC =
127-
"C[16*b0 + c4][16*b1 + c5][c2 + c6][c3 + c7] = _C_0[c4][c5][c6][c7];";
127+
"C[16 * b0 + c4][16 * b1 + c5][c2 + c6][c3 + c7] = _C_0[c4][c5][c6][c7];";
128128
auto sync = "__syncthreads()";
129129

130130
auto code = emitCode({256, 128, 192, 224}, {16, 16, 16, 16}, {0, 0, 0, 0});
@@ -160,13 +160,13 @@ TEST_F(Sum4D, CodeBeforeThreadMapping) {
160160
"__shared__ float32 _B_0[16][16][16][1];",
161161
"__shared__ float32 _C_0[16][16][16][1];"};
162162
auto copyA =
163-
"_A_0[c4][c5][c6][0] = A[16*b0 + c4][16*b1 + c5][c2 + c6][t0 + c3];";
163+
"_A_0[c4][c5][c6][0] = A[16 * b0 + c4][16 * b1 + c5][c2 + c6][t0 + c3];";
164164
auto copyB =
165-
"_B_0[c4][c5][c6][0] = B[16*b0 + c4][16*b1 + c5][c2 + c6][t0 + c3];";
165+
"_B_0[c4][c5][c6][0] = B[16 * b0 + c4][16 * b1 + c5][c2 + c6][t0 + c3];";
166166
auto compute =
167167
"_C_0[c4][c5][c6][0] = (_A_0[c4][c5][c6][0] + _B_0[c4][c5][c6][0]);";
168168
auto copyC =
169-
"C[16*b0 + c4][16*b1 + c5][c2 + c6][t0 + c3] = _C_0[c4][c5][c6][0];";
169+
"C[16 * b0 + c4][16 * b1 + c5][c2 + c6][t0 + c3] = _C_0[c4][c5][c6][0];";
170170
auto sync = "__syncthreads()";
171171

172172
auto code =
@@ -204,12 +204,12 @@ TEST_F(Sum4D, CodeInnerBand) {
204204
"__shared__ float32 _A_0[1][1][1][1];",
205205
"__shared__ float32 _B_0[1][1][1][1];"};
206206
auto copyA =
207-
"_A_0[0][0][0][0] = A[16*b0 + c4][16*b1 + c5][c2 + c6][t0 + c3];";
207+
"_A_0[0][0][0][0] = A[16 * b0 + c4][16 * b1 + c5][c2 + c6][t0 + c3];";
208208
auto copyB =
209-
"_B_0[0][0][0][0] = B[16*b0 + c4][16*b1 + c5][c2 + c6][t0 + c3];";
209+
"_B_0[0][0][0][0] = B[16 * b0 + c4][16 * b1 + c5][c2 + c6][t0 + c3];";
210210
auto compute = "_C_0[0][0][0][0] = (_A_0[0][0][0][0] + _B_0[0][0][0][0]);";
211211
auto copyC =
212-
"C[16*b0 + c4][16*b1 + c5][c2 + c6][t0 + c3] = _C_0[0][0][0][0];";
212+
"C[16 * b0 + c4][16 * b1 + c5][c2 + c6][t0 + c3] = _C_0[0][0][0][0];";
213213
auto sync = "__syncthreads()";
214214

215215
auto code =

test/test_tc_mapper_bugs.cc

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,47 @@ TEST(LayerNorm, ReferenceBelongsToTwoGroups) {
697697
atCompl.compile("layernorm", inputs, options);
698698
}
699699

700+
// This case was observed when running the autotuner on example_MLP_model
701+
// (#200). It calls code generation on a schedule tree containing a
702+
// disjunctive filter, which results in expression with more than one disjunct
703+
// that was not handed properly.
704+
// TODO: the disjunctive filter in the schedule is unexpected and its origin
705+
// should be identified and explained.
706+
TEST(TMM_128_1024_1000, DisjunctiveFilter) {
707+
at::Tensor I = at::CUDA(at::kFloat).rand({128, 1024});
708+
at::Tensor W = at::CUDA(at::kFloat).rand({1000, 1024});
709+
std::vector<at::Tensor> inputs = {I, W};
710+
std::vector<at::Tensor> outputs;
711+
712+
auto TC = std::string(R"TC(
713+
def tmm_naive(float(B, X) I, float(Y, X) W) -> (O) {
714+
O(b, y) +=! I(b, rx) * W(y, rx)
715+
}
716+
)TC");
717+
auto options =
718+
tc::MappingOptions::makeNaiveMappingOptions()
719+
.outerScheduleFusionStrategy(tc::FusionStrategy::Preserve3Coincident)
720+
.outerScheduleAllowSkewing(false)
721+
.outerSchedulePositiveOrthant(true)
722+
.intraTileScheduleFusionStrategy(tc::FusionStrategy::Min)
723+
.intraTileScheduleAllowSkewing(false)
724+
.intraTileSchedulePositiveOrthant(true)
725+
.tile(1, 32, 63)
726+
.mapToThreads(2, 32)
727+
.mapToBlocks(64, 128, 1024)
728+
.unroll(128)
729+
.tileImperfectlyNested(false)
730+
.useSharedMemory(false)
731+
.usePrivateMemory(false)
732+
.unrollCopyShared(false)
733+
.matchLibraryCalls(true);
734+
735+
tc::ATenCompilationUnit<tc::CudaTcExecutor> atCompl;
736+
atCompl.define(TC);
737+
// Expecting this to compile without dying.
738+
atCompl.compile("tmm_naive", inputs, options);
739+
}
740+
700741
TEST(Halide2Isl, MinInUpperBound) {
701742
at::Tensor mat1 = at::CUDA(at::kFloat).rand({1, 100, 184, 184});
702743
at::Tensor mat1_pad = at::CUDA(at::kFloat).rand({1, 100, 186, 186});

0 commit comments

Comments
 (0)