Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 2aaa675

Browse files
Merge pull request #535 from facebookresearch/pr/constant
check that all accesses to registers have constant index expressions
2 parents fed4151 + 4603750 commit 2aaa675

File tree

2 files changed

+49
-21
lines changed

2 files changed

+49
-21
lines changed

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -396,10 +396,49 @@ struct LdgWrapper {
396396
std::ostream& out_;
397397
};
398398

399+
template <typename AFF>
400+
isl::ast_expr buildAccess(AFF access, const CodegenStatementContext& context) {
401+
return context.build().access_from(access);
402+
}
403+
404+
void emitAccess(isl::ast_expr access, const CodegenStatementContext& context) {
405+
context.ss << access.to_C_str();
406+
}
407+
399408
template <typename AFF>
400409
void emitAccess(AFF access, const CodegenStatementContext& context) {
410+
emitAccess(buildAccess(access, context), context);
411+
}
412+
413+
// Check that the given expression is an access with constant index expressions
414+
void checkConstantAccess(isl::ast_expr expr) {
415+
auto op = expr.as<isl::ast_expr_op>();
416+
auto access = op.as<isl::ast_op_access>();
417+
TC_CHECK(access);
418+
for (int i = 1; i < access.get_n_arg(); ++i) {
419+
auto arg = access.get_arg(i);
420+
TC_CHECK(arg.as<isl::ast_expr_int>())
421+
<< "expected constant subscript, got " << arg.to_C_str();
422+
}
423+
}
424+
425+
// Print an access to a(n array of) register(s), checking that
426+
// the index expressions are constant.
427+
void emitRegisterAccess(
428+
isl::pw_multi_aff access,
429+
const CodegenStatementContext& context) {
430+
auto expr = buildAccess(access, context);
431+
checkConstantAccess(expr);
432+
emitAccess(expr, context);
433+
}
434+
435+
// Print an access to global memory, wrapping the access in an "__ldg()"
436+
// call if the accessed tensor is known to be read-only.
437+
void emitGlobalAccess(
438+
isl::multi_pw_aff access,
439+
const CodegenStatementContext& context) {
401440
LdgWrapper ldgWrapper(context, access.get_tuple_id(isl::dim_type::out));
402-
context.ss << context.build().access_from(access).to_C_str();
441+
emitAccess(access, context);
403442
}
404443
} // namespace
405444

@@ -414,9 +453,9 @@ void emitCopyStmt(const CodegenStatementContext& context) {
414453
if (isRead) {
415454
emitAccess(isl::multi_pw_aff(promoted), context);
416455
context.ss << " = ";
417-
emitAccess(isl::multi_pw_aff(original), context);
456+
emitGlobalAccess(isl::multi_pw_aff(original), context);
418457
} else {
419-
emitAccess(isl::multi_pw_aff(original), context);
458+
emitGlobalAccess(isl::multi_pw_aff(original), context);
420459
context.ss << " = ";
421460
emitAccess(isl::multi_pw_aff(promoted), context);
422461
}
@@ -625,7 +664,8 @@ void emitMappedTensorAccess(
625664
return;
626665
}
627666

628-
auto tensorId = context.scop().promotedDecl(promotionInfo.groupId).tensorId;
667+
auto decl = context.scop().promotedDecl(promotionInfo.groupId);
668+
auto tensorId = decl.tensorId;
629669

630670
// Here and below in comments: D = domain, O = original tensor, P = promoted
631671
// tensor, S = partial schedule, A = AST loops;
@@ -651,7 +691,11 @@ void emitMappedTensorAccess(
651691
auto astToPromoted =
652692
isl::pw_multi_aff(promotion).pullback(astToScheduledOriginal);
653693

654-
emitAccess(astToPromoted, context);
694+
if (decl.kind == Scop::PromotedDecl::Kind::Register) {
695+
emitRegisterAccess(astToPromoted, context);
696+
} else {
697+
emitAccess(astToPromoted, context);
698+
}
655699
}
656700

657701
} // namespace detail

test/test_cuda_mapper_memory_promotion.cc

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -485,19 +485,6 @@ def fun(float(N,K) A, float(K,M) B, float(N,M) C) -> (O) {
485485
EXPECT_TRUE(cDeclPos == std::string::npos)
486486
<< "tensor C promoted to register but has no reuse";
487487
}
488-
489-
void expectNoSymbolicSubscript(const std::string& code) {
490-
// We don't know the exact name of the iterator, but it starts with c.
491-
auto oWithIteratorPos = code.find("_O_0[c");
492-
auto oWithThreadPos = code.find("_O_0[t1");
493-
494-
EXPECT_TRUE(oWithIteratorPos == std::string::npos)
495-
<< "accessing local arrays with iterators in subscripts makes "
496-
<< "these arrays placed in local memory instead of registers";
497-
EXPECT_TRUE(oWithThreadPos == std::string::npos)
498-
<< "expected per-thread groups to be computed, i.e. thread "
499-
<< "identifiers should not appear in the subscripts";
500-
}
501488
};
502489

503490
TEST_F(MatMulBias, RegisterPromotion) {
@@ -562,7 +549,6 @@ TEST_F(MatMulBias, RegistersAtRoot) {
562549
<< "expected O to be promoted to registers";
563550

564551
expectNoABCPromotion(code);
565-
expectNoSymbolicSubscript(code);
566552

567553
auto o00Pos = code.find("_O_0[0][0]");
568554
auto o10Pos = code.find("_O_0[1][0]");
@@ -597,7 +583,6 @@ TEST_F(MatMulBias, RegistersAtRootNotEnoughUnroll) {
597583
<< "not expected O to be promoted to registers";
598584

599585
expectNoABCPromotion(code);
600-
expectNoSymbolicSubscript(code);
601586
}
602587

603588
TEST_F(MatMulBias, RegistersBelowFirstBand) {
@@ -621,7 +606,6 @@ TEST_F(MatMulBias, RegistersBelowFirstBand) {
621606
EXPECT_TRUE(oDeclPos != std::string::npos)
622607
<< "expected O to be promoted to registers";
623608
expectNoABCPromotion(code);
624-
expectNoSymbolicSubscript(code);
625609
}
626610

627611
class Strided : public TestMapper {

0 commit comments

Comments
 (0)