Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit aaf96f2

Browse files
Merge pull request #226 from facebookresearch/pr/fix-179
emitHalideExpr: print parentheses around compound nested expressions
2 parents 17b2bad + 07d1c00 commit aaf96f2

File tree

4 files changed

+75
-49
lines changed

4 files changed

+75
-49
lines changed

include/tc/core/polyhedral/cuda/codegen.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,6 @@ struct CodegenStatementContext;
3131

3232
namespace detail {
3333

34-
void emitDirectSubscripts(
35-
isl::pw_multi_aff subscripts,
36-
const CodegenStatementContext& context);
37-
38-
std::string toString(isl::pw_aff subscript);
39-
4034
isl::pw_aff makeAffFromMappedExpr(
4135
const Halide::Expr& expr,
4236
const CodegenStatementContext& context);

src/core/polyhedral/cuda/codegen.cc

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -359,28 +359,33 @@ void emitReductionInit(
359359
context.ss << ";" << endl;
360360
}
361361

362-
void emitCopyStmt(const CodegenStatementContext& context) {
363-
using detail::emitDirectSubscripts;
362+
namespace {
363+
template <typename AFF>
364+
void emitAccess(AFF access, const CodegenStatementContext& context) {
365+
// Use a temporary isl::ast_build to print the expression.
366+
// Ideally, this should use the build at the point
367+
// where the user statement was created.
368+
auto astBuild = isl::ast_build::from_context(access.domain());
369+
context.ss << astBuild.access_from(access).to_C_str();
370+
}
371+
} // namespace
364372

373+
void emitCopyStmt(const CodegenStatementContext& context) {
365374
auto stmtId = context.statementId();
366375

367376
auto iteratorMap = context.iteratorMap();
368377
auto promoted = iteratorMap.range_factor_range();
369378
auto original = iteratorMap.range_factor_domain().range_factor_range();
370379
auto isRead = stmtId.get_name() == kReadIdName;
371-
auto originalName = original.get_tuple_id(isl::dim_type::out).get_name();
372-
auto promotedName = promoted.get_tuple_id(isl::dim_type::out).get_name();
373380

374381
if (isRead) {
375-
context.ss << promotedName;
376-
emitDirectSubscripts(promoted, context);
377-
context.ss << " = " << originalName;
378-
emitDirectSubscripts(original, context);
382+
emitAccess(isl::multi_pw_aff(promoted), context);
383+
context.ss << " = ";
384+
emitAccess(isl::multi_pw_aff(original), context);
379385
} else {
380-
context.ss << originalName;
381-
emitDirectSubscripts(original, context);
382-
context.ss << " = " << promotedName;
383-
emitDirectSubscripts(promoted, context);
386+
emitAccess(isl::multi_pw_aff(original), context);
387+
context.ss << " = ";
388+
emitAccess(isl::multi_pw_aff(promoted), context);
384389
}
385390
context.ss << ";" << std::endl;
386391
}
@@ -447,14 +452,6 @@ void AstPrinter::emitAst(isl::ast_node node) {
447452

448453
namespace detail {
449454

450-
std::string toString(isl::pw_aff subscript) {
451-
// Use a temporary isl::ast_build to print the expression.
452-
// Ideally, this should use the build at the point
453-
// where the user statement was created.
454-
auto astBuild = isl::ast_build::from_context(subscript.domain());
455-
return astBuild.expr_from(subscript).to_C_str();
456-
}
457-
458455
isl::pw_aff makeAffFromMappedExpr(
459456
const Halide::Expr& expr,
460457
const CodegenStatementContext& context) {
@@ -498,18 +495,35 @@ isl::multi_aff makeMultiAffAccess(
498495
return ma;
499496
}
500497

498+
namespace {
499+
bool is_identifier_or_nonnegative_integer(isl::ast_expr expr) {
500+
if (isl_ast_expr_get_type(expr.get()) == isl_ast_expr_id)
501+
return true;
502+
if (isl_ast_expr_get_type(expr.get()) != isl_ast_expr_int)
503+
return false;
504+
return isl::manage(isl_ast_expr_get_val(expr.get())).is_nonneg();
505+
}
506+
} // namespace
507+
501508
void emitHalideExpr(
502509
const Halide::Expr& e,
503510
const CodegenStatementContext& context,
504511
const map<string, string>& substitutions) {
505512
class EmitHalide : public Halide::Internal::IRPrinter {
506513
using Halide::Internal::IRPrinter::visit;
507514
void visit(const Halide::Internal::Variable* op) {
508-
// This is probably needlessly indirect, given that we just have
509-
// a name to look up somewhere.
510515
auto pwAff = tc::polyhedral::detail::makeAffFromMappedExpr(
511516
Halide::Expr(op), context);
512-
context.ss << tc::polyhedral::detail::toString(pwAff);
517+
// Use a temporary isl::ast_build to print the expression.
518+
// Ideally, this should use the build at the point
519+
// where the user statement was created.
520+
auto astBuild = isl::ast_build::from_context(pwAff.domain());
521+
auto expr = astBuild.expr_from(pwAff);
522+
auto s = expr.to_C_str();
523+
if (!is_identifier_or_nonnegative_integer(expr)) {
524+
s = "(" + s + ")";
525+
}
526+
context.ss << s;
513527
}
514528
void visit(const Halide::Internal::Call* op) {
515529
if (substitutions.count(op->name)) {
@@ -613,19 +627,7 @@ void emitMappedTensorAccess(
613627
auto astToPromoted =
614628
isl::pw_multi_aff(promotion).pullback(astToScheduledOriginal);
615629

616-
auto astBuild = isl::ast_build::from_context(astToPromoted.domain());
617-
context.ss << astBuild.access_from(astToPromoted).to_C_str();
618-
}
619-
620-
void emitDirectSubscripts(
621-
isl::pw_multi_aff subscripts,
622-
const CodegenStatementContext& context) {
623-
auto mpa = isl::multi_pw_aff(subscripts); // this conversion is safe
624-
for (auto pa : isl::MPA(mpa)) {
625-
context.ss << "[";
626-
context.ss << toString(pa.pa);
627-
context.ss << "]";
628-
}
630+
emitAccess(astToPromoted, context);
629631
}
630632

631633
} // namespace detail

test/test_mapper.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def fun(float(N, M) A, float(N, M) B) -> (C) {
186186
float32 (*B)[M] = reinterpret_cast<float32 (*)[M]>(pB);
187187
for (int c1 = 16 * b1; c1 < M; c1 += 4096) {
188188
if (M >= t1 + c1 + 1) {
189-
C[t0 + 16 * b0][t1 + c1] = (A[t0 + 16 * b0][t1 + c1] + B[t0 + 16 * b0][t1 + c1]);
189+
C[(t0 + 16 * b0)][(t1 + c1)] = (A[(t0 + 16 * b0)][(t1 + c1)] + B[(t0 + 16 * b0)][(t1 + c1)]);
190190
}
191191
}
192192
}
@@ -318,9 +318,9 @@ constexpr auto kExpectedMatmul_64_64_64 =
318318
for (int c1 = 0; c1 <= 63; c1 += 16) {
319319
for (int c2 = t1; c2 <= 15; c2 += 8) {
320320
for (int c3 = 0; c3 <= 15; c3 += 1) {
321-
O[c0 + c2][c1 + c3] = 0.000000f;
321+
O[(c0 + c2)][(c1 + c3)] = 0.000000f;
322322
for (int c4 = t0; c4 <= 63; c4 += 32) {
323-
O[c0 + c2][c1 + c3] = (O[c0 + c2][c1 + c3] + (A[c0 + c2][c4]*B[c4][c1 + c3]));
323+
O[(c0 + c2)][(c1 + c3)] = (O[(c0 + c2)][(c1 + c3)] + (A[(c0 + c2)][c4]*B[c4][(c1 + c3)]));
324324
}
325325
}
326326
}
@@ -443,7 +443,7 @@ TEST_F(PolyhedralMapperTest, Unroll1D) {
443443
auto mscop = MappedScop::makeWithOuterBlockInnerThreadStrategy(
444444
std::move(scop), mappingOptions);
445445
auto code = std::get<0>(mscop->codegen(specializedName));
446-
std::string expected("C[64 * b0 + c2][t0 + 64 * b1]");
446+
std::string expected("C[(64 * b0 + c2)][(t0 + 64 * b1)]");
447447
ASSERT_TRUE(code.find(expected) != std::string::npos) << code;
448448
}
449449

@@ -462,7 +462,7 @@ TEST_F(PolyhedralMapperTest, Unroll2D) {
462462
auto mscop = MappedScop::makeWithOuterBlockInnerThreadStrategy(
463463
std::move(scop), mappingOptions);
464464
auto code = std::get<0>(mscop->codegen(specializedName));
465-
std::string expected("C[t1 + 64 * b0 + 32][t0 + 64 * b1 + 32]");
465+
std::string expected("C[(t1 + 64 * b0 + 32)][(t0 + 64 * b1 + 32)]");
466466
ASSERT_TRUE(code.find(expected) != std::string::npos);
467467
}
468468

@@ -712,7 +712,7 @@ TEST_F(PolyhedralMapperTest, ReductionMM1D) {
712712
auto code = codegenMapped(kTcMM, mappingOptions);
713713
using tc::code::cuda::kCUBReductionName;
714714
EXPECT_TRUE(code.find(kCUBReductionName) != std::string::npos);
715-
EXPECT_TRUE(code.find("C[c0 + c3][t0 + c1] = (C") != std::string::npos);
715+
EXPECT_TRUE(code.find("C[(c0 + c3)][(t0 + c1)] = (C") != std::string::npos);
716716
}
717717

718718
/*
@@ -730,7 +730,7 @@ TEST_F(PolyhedralMapperTest, ReductionMM2D) {
730730
auto code = codegenMapped(kTcMM, mappingOptions);
731731
using tc::code::cuda::kCUBReductionName;
732732
EXPECT_TRUE(code.find(kCUBReductionName) != std::string::npos);
733-
EXPECT_TRUE(code.find("C[t1 + c0][t0 + c1] = (C") != std::string::npos);
733+
EXPECT_TRUE(code.find("C[(t1 + c0)][(t0 + c1)] = (C") != std::string::npos);
734734
}
735735

736736
int main(int argc, char** argv) {

test/test_tc_mapper_bugs.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,36 @@ TEST(Halide2Isl, MinInUpperBound) {
761761
atCompl.compile("graph2", inputs, options);
762762
}
763763

764+
// Check that nested expressions are properly formatted.
765+
// In particular, as soon as the tensor size X is larger than the tile size,
766+
// the expression for "xp" is a sum of multiple loop iterators
767+
// in the generated code. Parentheses need to be placed around
768+
// these expressions to ensure the end result is, say, "-(c1 + c3)"
769+
// rather than "-c1 + c3".
770+
// The actual convolution is one where the output is equal to the input.
771+
TEST(Convolution, NestedExpressions) {
772+
auto convolution = "convolution";
773+
auto TC = std::string(R"TC(
774+
def convolution(float(X) A, float(Xp) K) -> (B) {
775+
B(x) +=! A(xp) * K(X - 1 + x - xp) where xp in 0:X
776+
}
777+
)TC");
778+
int X = 33;
779+
at::Tensor A = at::CUDA(at::kFloat).zeros({X});
780+
at::Tensor K = at::CUDA(at::kFloat).zeros({2 * X - 1});
781+
A[10] = 1;
782+
K[X - 1] = 1;
783+
std::vector<at::Tensor> inputs = {A, K};
784+
std::vector<at::Tensor> outputs;
785+
tc::ATenCompilationUnit<tc::CudaTcExecutor> cu;
786+
cu.define(TC);
787+
auto options = tc::CudaMappingOptions::makeNaiveCudaMappingOptions();
788+
auto handle = cu.compile(convolution, inputs, options);
789+
cu.run(convolution, inputs, outputs, handle);
790+
auto B = outputs[0];
791+
CHECK_EQ(at::Scalar(B[10]).toFloat(), 1);
792+
}
793+
764794
int main(int argc, char** argv) {
765795
::testing::InitGoogleTest(&argc, argv);
766796
::gflags::ParseCommandLineFlags(&argc, &argv, true);

0 commit comments

Comments
 (0)