@@ -485,6 +485,28 @@ def fun(float(N,K) A, float(K,M) B, float(N,M) C) -> (O) {
485
485
EXPECT_TRUE (cDeclPos == std::string::npos)
486
486
<< " tensor C promoted to register but has no reuse" ;
487
487
}
488
+
489
+ void expectFourOElementsPromoted (const std::string& code) {
490
+ auto oDeclPos = code.find (" float32 _O_0[4][1];" );
491
+ EXPECT_TRUE (oDeclPos != std::string::npos)
492
+ << " expected O to be promoted to registers" ;
493
+
494
+ expectNoABCPromotion (code);
495
+
496
+ auto o00Pos = code.find (" _O_0[0][0]" );
497
+ auto o10Pos = code.find (" _O_0[1][0]" );
498
+ auto o20Pos = code.find (" _O_0[2][0]" );
499
+ auto o30Pos = code.find (" _O_0[3][0]" );
500
+
501
+ EXPECT_TRUE (o00Pos != std::string::npos)
502
+ << " expected constant subscripts in _O_0" ;
503
+ EXPECT_TRUE (o10Pos != std::string::npos)
504
+ << " expected constant subscripts in _O_0" ;
505
+ EXPECT_TRUE (o20Pos != std::string::npos)
506
+ << " expected constant subscripts in _O_0" ;
507
+ EXPECT_TRUE (o30Pos != std::string::npos)
508
+ << " expected constant subscripts in _O_0" ;
509
+ }
488
510
};
489
511
490
512
TEST_F (MatMulBias, RegisterPromotion) {
@@ -544,25 +566,7 @@ TEST_F(MatMulBias, RegistersAtRoot) {
544
566
545
567
// Expecting 4 elements because we map the loop i in O[i][j] to 8 threads
546
568
// after tiling by 32.
547
- auto oDeclPos = code.find (" float32 _O_0[4][1];" );
548
- EXPECT_TRUE (oDeclPos != std::string::npos)
549
- << " expected O to be promoted to registers" ;
550
-
551
- expectNoABCPromotion (code);
552
-
553
- auto o00Pos = code.find (" _O_0[0][0]" );
554
- auto o10Pos = code.find (" _O_0[1][0]" );
555
- auto o20Pos = code.find (" _O_0[2][0]" );
556
- auto o30Pos = code.find (" _O_0[3][0]" );
557
-
558
- EXPECT_TRUE (o00Pos != std::string::npos)
559
- << " expected constant subscripts in _O_0" ;
560
- EXPECT_TRUE (o10Pos != std::string::npos)
561
- << " expected constant subscripts in _O_0" ;
562
- EXPECT_TRUE (o20Pos != std::string::npos)
563
- << " expected constant subscripts in _O_0" ;
564
- EXPECT_TRUE (o30Pos != std::string::npos)
565
- << " expected constant subscripts in _O_0" ;
569
+ expectFourOElementsPromoted (code);
566
570
}
567
571
568
572
TEST_F (MatMulBias, RegistersAtRootNotEnoughUnroll) {
@@ -589,23 +593,24 @@ TEST_F(MatMulBias, RegistersBelowFirstBand) {
589
593
using namespace polyhedral ::detail;
590
594
591
595
// Disable automatic promotion to registers because we are going to call it
592
- // manually.
596
+ // manually. Use a large unroll size to unroll all loops below the first
597
+ // band and actually hit registers.
593
598
auto mappingOptions = CudaMappingOptions::makeNaiveMappingOptions ()
599
+ .unroll (512 )
594
600
.useSharedMemory (false )
595
601
.usePrivateMemory (false );
596
602
auto mscop = prepare ({{" N" , 42 }, {" M" , 56 }, {" K" , 37 }}, mappingOptions);
597
603
598
- auto nodes = ScheduleTree::collectDFSPostorder (
604
+ auto nodes = ScheduleTree::collectDFSPreorder (
599
605
mscop->scop ().scheduleRoot (), ScheduleTreeType::Band);
600
606
ASSERT_GT (nodes.size (), 0u );
601
607
auto node = nodes[0 ];
602
608
promoteToRegistersBelow (*mscop, node);
603
609
auto code = emitCode (mscop);
604
610
605
- auto oDeclPos = code.find (" float32 _O_0[1][1];" );
606
- EXPECT_TRUE (oDeclPos != std::string::npos)
607
- << " expected O to be promoted to registers" ;
608
- expectNoABCPromotion (code);
611
+ // Expecting 4 elements because we map the loop i in O[i][j] to 8 threads
612
+ // after tiling by 32.
613
+ expectFourOElementsPromoted (code);
609
614
}
610
615
611
616
class Strided : public TestMapper {
0 commit comments