@@ -61,6 +61,13 @@ struct PolyhedralMapperTest : public ::testing::Test {
61
61
joinBandsIterative (scop->scheduleRoot ()->child ({0 }), true );
62
62
return scop;
63
63
}
64
+ std::unique_ptr<Scop> PrepareAndJoinBandsMatMul () {
65
+ auto scop = Prepare (makeMatmulTc ());
66
+ scop = Scop::makeScheduled (*scop, SchedulerOptions ().view );
67
+ auto root = scop->scheduleRoot ();
68
+ bandSplit (root, root->child ({0 }), 2 );
69
+ return scop;
70
+ }
64
71
65
72
std::unique_ptr<MappedScop> makeUnmapped (std::string tc) {
66
73
return MappedScop::makeOneBlockOneThread (Prepare (tc));
@@ -511,7 +518,7 @@ constexpr auto kExpectedMatmul_64_64_64 =
511
518
)CUDA" ;
512
519
513
520
TEST_F (PolyhedralMapperTest, MergedContexts) {
514
- auto scop = PrepareAndJoinBands ( makeMatmulTc () );
521
+ auto scop = PrepareAndJoinBandsMatMul ( );
515
522
516
523
// Unit test claims to use scop->globalParameterContext properly
517
524
auto context = scop->makeContext <int >({{" M" , 64 }, {" N" , 64 }, {" K" , 64 }});
@@ -526,16 +533,16 @@ TEST_F(PolyhedralMapperTest, MergedContexts) {
526
533
}
527
534
528
535
TEST_F (PolyhedralMapperTest, Match1) {
529
- auto scop = PrepareAndJoinBands ( makeMatmulTc () );
536
+ auto scop = PrepareAndJoinBandsMatMul ( );
530
537
auto schedule = scop->scheduleRoot ();
531
538
532
539
auto mscop = TileAndMapThreads (std::move (scop), {16 , 16 }, {32ul , 8ul });
533
540
auto f = match (
534
- sequence (
541
+ band ( sequence (
535
542
filter ([](isl::union_set f) {
536
543
return f.get_space ().dim (isl::dim_type::param) == 3 ;
537
544
}),
538
- filter (band ( ))),
545
+ filter ())),
539
546
schedule);
540
547
EXPECT_EQ (1u , f.size ());
541
548
}
@@ -553,37 +560,31 @@ def fun(float(M, N) I) -> (O) {
553
560
}
554
561
555
562
TEST_F (PolyhedralMapperTest, MatmulTC) {
556
- string tc = R"TC(
557
- def fun(float(M, K) A, float(K, N) B) -> (C) {
558
- C(m, n) +=! A(m, r_k) * B(r_k, n)
559
- }
560
- )TC" ;
561
-
562
- auto scop = PrepareAndJoinBands (tc);
563
+ auto scop = PrepareAndJoinBandsMatMul ();
563
564
auto tileOptions = TileOptions::ShiftPointLoops | TileOptions::ScaleTileLoops;
564
565
TileAndCheckStructuralEquality (*scop, tileOptions, {3ul , 4ul });
565
566
}
566
567
567
568
TEST_F (PolyhedralMapperTest, MatmulShiftScale) {
568
- auto scop = PrepareAndJoinBands ( makeMatmulTc () );
569
+ auto scop = PrepareAndJoinBandsMatMul ( );
569
570
auto tileOptions = TileOptions::ShiftPointLoops | TileOptions::ScaleTileLoops;
570
571
TileAndCheckStructuralEquality (*scop, tileOptions, {3ul , 4ul });
571
572
}
572
573
573
574
TEST_F (PolyhedralMapperTest, MatmulShift) {
574
- auto scop = PrepareAndJoinBands ( makeMatmulTc () );
575
+ auto scop = PrepareAndJoinBandsMatMul ( );
575
576
auto tileOptions = TileOptions::ShiftPointLoops;
576
577
TileAndCheckStructuralEquality (*scop, tileOptions, {3ul , 4ul });
577
578
}
578
579
579
580
TEST_F (PolyhedralMapperTest, MatmulScale) {
580
- auto scop = PrepareAndJoinBands ( makeMatmulTc () );
581
+ auto scop = PrepareAndJoinBandsMatMul ( );
581
582
auto tileOptions = TileOptions::ScaleTileLoops;
582
583
TileAndCheckStructuralEquality (*scop, tileOptions, {3ul , 4ul });
583
584
}
584
585
585
586
TEST_F (PolyhedralMapperTest, MatmulNoshiftNoscale) {
586
- auto scop = PrepareAndJoinBands ( makeMatmulTc () );
587
+ auto scop = PrepareAndJoinBandsMatMul ( );
587
588
auto tileOptions = TileOptions ();
588
589
TileAndCheckStructuralEquality (*scop, tileOptions, {3ul , 4ul });
589
590
}
0 commit comments