@@ -753,6 +753,39 @@ def perforatedConvolution(float(N, C, H, W) input, float(M, C, KH, KW) weights,
753
753
Prepare (tc);
754
754
}
755
755
756
+ TEST_F (PolyhedralMapperTest, ModulusConstantRHS) {
757
+ string tc = R"TC(
758
+ def fun(float(N) a) -> (b) { b(i) = a(i % 2) where i in 0:N }
759
+ )TC" ;
760
+ // This triggers tc2halide conversion and should not throw.
761
+ auto scop = Prepare (tc);
762
+ for (auto r : scop->reads .range ().get_set_list ()) {
763
+ // skip irrelevant reads, if any
764
+ if (r.get_tuple_name () != std::string (" a" )) {
765
+ continue ;
766
+ }
767
+ // std::cout << "Got stride: " << r.get_stride() << std::endl;
768
+ // EXPECT_EQ(r.get_stride(), 2);
769
+ }
770
+ }
771
+
772
+ TEST_F (PolyhedralMapperTest, ModulusVariableRHS) {
773
+ string tc = R"TC(
774
+ def local_sparse_convolution(float(N, C, H, W) I, float(O, KC, KH, KW) W1) -> (O1) {
775
+ O1(n, o, h, w) +=! I(n, kc % c, h + kh, w + kw) * W1(o, kc, kh, kw) where c in 1:C
776
+ }
777
+ )TC" ;
778
+ // This triggers tc2halide conversion and should not throw.
779
+ auto scop = Prepare (tc);
780
+ for (auto r : scop->reads .range ().get_set_list ()) {
781
+ // skip irrelevant reads, if any
782
+ if (r.get_tuple_name () != std::string (" I" )) {
783
+ continue ;
784
+ }
785
+ EXPECT_TRUE (r.plain_is_universe ());
786
+ }
787
+ }
788
+
756
789
int main (int argc, char ** argv) {
757
790
::testing::InitGoogleTest (&argc, argv);
758
791
::gflags::ParseCommandLineFlags (&argc, &argv, true );
0 commit comments