Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 4603750

Browse files
author
Sven Verdoolaege
committed
check that all accesses to registers have constant index expressions
The index expressions of accesses to arrays of registers should be constant since registers are not addressable. Failing to use constant index expressions would result in accesses to "local" memory, which is much slower than registers. Some MatMulBias tests checked for particular patterns that should not appear when the index expressions are constant, but they did not check that they are actually constants. Replace these checks by generic checks during emitCudaKernel that explicitly check for constant index expressions and that get applied to all accesses to registers.
1 parent e5d9904 commit 4603750

File tree

2 files changed

+32
-20
lines changed

2 files changed

+32
-20
lines changed

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -410,10 +410,33 @@ void emitAccess(AFF access, const CodegenStatementContext& context) {
410410
emitAccess(buildAccess(access, context), context);
411411
}
412412

413+
// Check that the given expression is an access with constant index expressions
414+
void checkConstantAccess(isl::ast_expr expr) {
415+
auto op = expr.as<isl::ast_expr_op>();
416+
auto access = op.as<isl::ast_op_access>();
417+
TC_CHECK(access);
418+
for (int i = 1; i < access.get_n_arg(); ++i) {
419+
auto arg = access.get_arg(i);
420+
TC_CHECK(arg.as<isl::ast_expr_int>())
421+
<< "expected constant subscript, got " << arg.to_C_str();
422+
}
423+
}
424+
425+
// Print an access to a(n array of) register(s), checking that
426+
// the index expressions are constant.
427+
void emitRegisterAccess(
428+
isl::pw_multi_aff access,
429+
const CodegenStatementContext& context) {
430+
auto expr = buildAccess(access, context);
431+
checkConstantAccess(expr);
432+
emitAccess(expr, context);
433+
}
434+
413435
// Print an access to global memory, wrapping the access in an "__ldg()"
414436
// call if the accessed tensor is known to be read-only.
415-
template <typename AFF>
416-
void emitGlobalAccess(AFF access, const CodegenStatementContext& context) {
437+
void emitGlobalAccess(
438+
isl::multi_pw_aff access,
439+
const CodegenStatementContext& context) {
417440
LdgWrapper ldgWrapper(context, access.get_tuple_id(isl::dim_type::out));
418441
emitAccess(access, context);
419442
}
@@ -641,7 +664,8 @@ void emitMappedTensorAccess(
641664
return;
642665
}
643666

644-
auto tensorId = context.scop().promotedDecl(promotionInfo.groupId).tensorId;
667+
auto decl = context.scop().promotedDecl(promotionInfo.groupId);
668+
auto tensorId = decl.tensorId;
645669

646670
// Here and below in comments: D = domain, O = original tensor, P = promoted
647671
// tensor, S = partial schedule, A = AST loops;
@@ -667,7 +691,11 @@ void emitMappedTensorAccess(
667691
auto astToPromoted =
668692
isl::pw_multi_aff(promotion).pullback(astToScheduledOriginal);
669693

670-
emitAccess(astToPromoted, context);
694+
if (decl.kind == Scop::PromotedDecl::Kind::Register) {
695+
emitRegisterAccess(astToPromoted, context);
696+
} else {
697+
emitAccess(astToPromoted, context);
698+
}
671699
}
672700

673701
} // namespace detail

test/test_cuda_mapper_memory_promotion.cc

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -485,19 +485,6 @@ def fun(float(N,K) A, float(K,M) B, float(N,M) C) -> (O) {
485485
EXPECT_TRUE(cDeclPos == std::string::npos)
486486
<< "tensor C promoted to register but has no reuse";
487487
}
488-
489-
void expectNoSymbolicSubscript(const std::string& code) {
490-
// We don't know the exact name of the iterator, but it starts with c.
491-
auto oWithIteratorPos = code.find("_O_0[c");
492-
auto oWithThreadPos = code.find("_O_0[t1");
493-
494-
EXPECT_TRUE(oWithIteratorPos == std::string::npos)
495-
<< "accessing local arrays with iterators in subscripts makes "
496-
<< "these arrays placed in local memory instead of registers";
497-
EXPECT_TRUE(oWithThreadPos == std::string::npos)
498-
<< "expected per-thread groups to be computed, i.e. thread "
499-
<< "identifiers should not appear in the subscripts";
500-
}
501488
};
502489

503490
TEST_F(MatMulBias, RegisterPromotion) {
@@ -562,7 +549,6 @@ TEST_F(MatMulBias, RegistersAtRoot) {
562549
<< "expected O to be promoted to registers";
563550

564551
expectNoABCPromotion(code);
565-
expectNoSymbolicSubscript(code);
566552

567553
auto o00Pos = code.find("_O_0[0][0]");
568554
auto o10Pos = code.find("_O_0[1][0]");
@@ -597,7 +583,6 @@ TEST_F(MatMulBias, RegistersAtRootNotEnoughUnroll) {
597583
<< "not expected O to be promoted to registers";
598584

599585
expectNoABCPromotion(code);
600-
expectNoSymbolicSubscript(code);
601586
}
602587

603588
TEST_F(MatMulBias, RegistersBelowFirstBand) {
@@ -621,7 +606,6 @@ TEST_F(MatMulBias, RegistersBelowFirstBand) {
621606
EXPECT_TRUE(oDeclPos != std::string::npos)
622607
<< "expected O to be promoted to registers";
623608
expectNoABCPromotion(code);
624-
expectNoSymbolicSubscript(code);
625609
}
626610

627611
class Strided : public TestMapper {

0 commit comments

Comments
 (0)