@@ -695,8 +695,26 @@ def layernorm(float(T, B, C) I) -> (O, mean, centered, var) {
695
695
// (#200). It calls code generation on a schedule tree containing a
696
696
// disjunctive filter, which results in an expression with more than
697
697
// one disjunct that was not handled properly.
698
- // TODO: the disjunctive filter in the schedule is unexpected and its origin
699
- // should be identified and explained.
698
+ //
699
+ // A disjunctive filter is introduced by the full/partial tile separation
700
+ // of the reduction detection. The domain is of the form
701
+ // { S_1[O_s1_b, O_s1_y, O_s1_r_x] :
702
+ // 0 <= O_s1_b <= 127 and 0 <= O_s1_y <= 999 and 0 <= O_s1_r_x <= 1023 }.
703
+ // An outer tiling with size (1, 32, 63) is performed, which means
704
+ // in particular that O_s1_r_x is tiled in blocks of size 63.
705
+ // Within these blocks a full thread mapping tile is of size 2,
706
+ // which means that the final element of the even blocks and
707
+ // the initial element of the odd blocks do not belong to a full tile.
708
+ // In particular, O_s1_r_x = 62 (final element of block 0) and
709
+ // O_s1_r_x = 63 (initial element of block 1) do not belong to a full tile.
710
+ // The constraints
711
+ // -61 + O_s1_r_x <= 126*floor((63 + O_s1_r_x)/126) <= 62 + O_s1_r_x
712
+ // perform this filtering.
713
+ // In the O_s1_y direction, tiles are of size 32, resulting
714
+ // in the constraint O_s1_y <= 991 on full tiles.
715
+ // The partial tiles are described by the complement
716
+ // O_s1_y >= 992 or (63 + O_s1_r_x) mod 126 = 0 or (-62 + O_s1_r_x) mod 126 = 0
717
+ // which is disjunctive.
700
718
TEST (TMM_128_1024_1000, DisjunctiveFilter) {
701
719
at::Tensor I = at::CUDA (at::kFloat ).rand ({128 , 1024 });
702
720
at::Tensor W = at::CUDA (at::kFloat ).rand ({1000 , 1024 });
0 commit comments