Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 5f456b3

Browse files
committed
test_cuda_mapper_memory_promotion: fix RegistersBelowFirstBand test
The original implementation took the first band in DFS postorder instead of preorder thus selecting the last band instead of the first band. Use DFS preorder and update the test accordingly.
1 parent d384ed3 commit 5f456b3

File tree

1 file changed

+30
-25
lines changed

1 file changed

+30
-25
lines changed

test/test_cuda_mapper_memory_promotion.cc

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,28 @@ def fun(float(N,K) A, float(K,M) B, float(N,M) C) -> (O) {
485485
EXPECT_TRUE(cDeclPos == std::string::npos)
486486
<< "tensor C promoted to register but has no reuse";
487487
}
488+
489+
void expectFourOElementsPromoted(const std::string& code) {
490+
auto oDeclPos = code.find("float32 _O_0[4][1];");
491+
EXPECT_TRUE(oDeclPos != std::string::npos)
492+
<< "expected O to be promoted to registers";
493+
494+
expectNoABCPromotion(code);
495+
496+
auto o00Pos = code.find("_O_0[0][0]");
497+
auto o10Pos = code.find("_O_0[1][0]");
498+
auto o20Pos = code.find("_O_0[2][0]");
499+
auto o30Pos = code.find("_O_0[3][0]");
500+
501+
EXPECT_TRUE(o00Pos != std::string::npos)
502+
<< "expected constant subscripts in _O_0";
503+
EXPECT_TRUE(o10Pos != std::string::npos)
504+
<< "expected constant subscripts in _O_0";
505+
EXPECT_TRUE(o20Pos != std::string::npos)
506+
<< "expected constant subscripts in _O_0";
507+
EXPECT_TRUE(o30Pos != std::string::npos)
508+
<< "expected constant subscripts in _O_0";
509+
}
488510
};
489511

490512
TEST_F(MatMulBias, RegisterPromotion) {
@@ -544,25 +566,7 @@ TEST_F(MatMulBias, RegistersAtRoot) {
544566

545567
// Expecting 4 elements because we map the loop i in O[i][j] to 8 threads
546568
// after tiling by 32.
547-
auto oDeclPos = code.find("float32 _O_0[4][1];");
548-
EXPECT_TRUE(oDeclPos != std::string::npos)
549-
<< "expected O to be promoted to registers";
550-
551-
expectNoABCPromotion(code);
552-
553-
auto o00Pos = code.find("_O_0[0][0]");
554-
auto o10Pos = code.find("_O_0[1][0]");
555-
auto o20Pos = code.find("_O_0[2][0]");
556-
auto o30Pos = code.find("_O_0[3][0]");
557-
558-
EXPECT_TRUE(o00Pos != std::string::npos)
559-
<< "expected constant subscripts in _O_0";
560-
EXPECT_TRUE(o10Pos != std::string::npos)
561-
<< "expected constant subscripts in _O_0";
562-
EXPECT_TRUE(o20Pos != std::string::npos)
563-
<< "expected constant subscripts in _O_0";
564-
EXPECT_TRUE(o30Pos != std::string::npos)
565-
<< "expected constant subscripts in _O_0";
569+
expectFourOElementsPromoted(code);
566570
}
567571

568572
TEST_F(MatMulBias, RegistersAtRootNotEnoughUnroll) {
@@ -589,23 +593,24 @@ TEST_F(MatMulBias, RegistersBelowFirstBand) {
589593
using namespace polyhedral::detail;
590594

591595
// Disable automatic promotion to registers because we are going to call it
592-
// manually.
596+
// manually. Use a large unroll size to unroll all loops below the first
597+
// band and actually hit registers.
593598
auto mappingOptions = CudaMappingOptions::makeNaiveMappingOptions()
599+
.unroll(512)
594600
.useSharedMemory(false)
595601
.usePrivateMemory(false);
596602
auto mscop = prepare({{"N", 42}, {"M", 56}, {"K", 37}}, mappingOptions);
597603

598-
auto nodes = ScheduleTree::collectDFSPostorder(
604+
auto nodes = ScheduleTree::collectDFSPreorder(
599605
mscop->scop().scheduleRoot(), ScheduleTreeType::Band);
600606
ASSERT_GT(nodes.size(), 0u);
601607
auto node = nodes[0];
602608
promoteToRegistersBelow(*mscop, node);
603609
auto code = emitCode(mscop);
604610

605-
auto oDeclPos = code.find("float32 _O_0[1][1];");
606-
EXPECT_TRUE(oDeclPos != std::string::npos)
607-
<< "expected O to be promoted to registers";
608-
expectNoABCPromotion(code);
611+
// Expecting 4 elements because we map the loop i in O[i][j] to 8 threads
612+
// after tiling by 32.
613+
expectFourOElementsPromoted(code);
609614
}
610615

611616
class Strided : public TestMapper {

0 commit comments

Comments
 (0)