Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 7595c66

Browse files
author
Sven Verdoolaege
committed
TMM_128_1024_1000.DisjunctiveFilter: explain disjunctive filter
Closes #503
1 parent c53955c commit 7595c66

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

test/cuda/test_tc_mapper_bugs.cc

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -695,8 +695,26 @@ def layernorm(float(T, B, C) I) -> (O, mean, centered, var) {
695695
// (#200). It calls code generation on a schedule tree containing a
696696
// disjunctive filter, which results in an expression with more than
697697
// one disjunct that was not handled properly.
698-
// TODO: the disjunctive filter in the schedule is unexpected and its origin
699-
// should be identified and explained.
698+
//
699+
// A disjunctive filter is introduced by the full/partial tile separation
700+
// of the reduction detection. The domain is of the form
701+
// { S_1[O_s1_b, O_s1_y, O_s1_r_x] :
702+
// 0 <= O_s1_b <= 127 and 0 <= O_s1_y <= 999 and 0 <= O_s1_r_x <= 1023 }.
703+
// An outer tiling with size (1, 32, 63) is performed, which means
704+
// in particular that O_s1_r_x is tiled in blocks of size 63.
705+
// Within these blocks a full thread mapping tile is of size 2,
706+
// which means that the final element of the even blocks and
707+
// the initial element of the odd blocks do not belong to a full tile.
708+
// In particular, O_s1_r_x = 62 (final element of block 0) and
709+
// O_s1_r_x = 63 (initial element of block 1) do not belong to a full tile.
710+
// The constraints
711+
// -61 + O_s1_r_x <= 126*floor((63 + O_s1_r_x)/126) <= 62 + O_s1_r_x
712+
// perform this filtering.
713+
// In the O_s1_y direction, tiles are of size 32, resulting
714+
// in the constraint O_s1_y <= 991 on full tiles.
715+
// The partial tiles are described by the complement
716+
// O_s1_y >= 992 or (63 + O_s1_r_x) mod 126 = 0 or (-62 + O_s1_r_x) mod 126 = 0
717+
// which is disjunctive.
700718
TEST(TMM_128_1024_1000, DisjunctiveFilter) {
701719
at::Tensor I = at::CUDA(at::kFloat).rand({128, 1024});
702720
at::Tensor W = at::CUDA(at::kFloat).rand({1000, 1024});

0 commit comments

Comments
 (0)