-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Closed
Labels
Description
What is your question?
A toy example (I am a newbee and there might be some "brainless" atom choice):
using ELM = cutlass::half_t;
using bM = decltype(Int<128>{});
using bN = decltype(Int<128>{});
using bK = decltype(Int<16>{});
TiledMMA tmma =
make_tiled_mma(SM70_8x8x4_F32F16F16F32_NT{}, Layout<Shape<_2, _2, _2>>{},
Tile<_32, _32, _16>{});
auto thr_mma = tmma.get_thread_slice(0);
auto sA = make_tensor(make_smem_ptr((ELM *)(0)), Layout<Shape<bM, bK>>{}); // Let's assume A is somehow copied to this sA
auto sB = make_tensor(make_smem_ptr((ELM *)(0)), Layout<Shape<bN, bK>>{}); // Let's assume the same as well
Tensor tSrA = thr_mma.partition_fragment_A(sA);
Tensor tSrB = thr_mma.partition_fragment_B(sB);
Tensor acc = partition_fragment_C(tmma, Shape<bM, bN>{});
auto cp_atom = Copy_Atom<SM75_U32x4_LDSM_N, ELM>{};
auto smem_tiled_cp_A = make_tiled_copy_A(cp_atom, tmma);
auto smem_thr_cp_A = smem_tiled_cp_A.get_thread_slice(0);
Tensor tSsA = smem_thr_cp_A.partition_S(sA);
auto smem_tiled_cp_B = make_tiled_copy_B(cp_atom, tmma);
auto smem_thr_cp_B = smem_tiled_cp_B.get_thread_slice(0);
Tensor tSsB = smem_thr_cp_B.partition_S(sB);
Tensor tSrA_copy_view = smem_thr_cp_A.retile_D(tSrA);
Tensor tSrB_copy_view = smem_thr_cp_A.retile_D(tSrB);
printf("\n");
cute::print(layout<>(tSrA));
printf("\n");
cute::print(layout<>(tSsA));
printf("\n");
cute::print(layout<>(tSrA_copy_view));
printf("\n");
and stdout would give me this:
(_4,_8,_2):(_1,_4,_32)
(((_2,_4),_2),_4,_1):(((_1,_128),_1024),_32,_0)
((_8,_2),_4,_1):((_1,_32),_8,_0)
And then many examples will launch a pipeline iterating the K-mode like this:
cute::copy(smem_tiled_cp_A, tSsA(_, _, _0{}), tSrA_copy_view(_, _, _0{}));
cute::copy(smem_tiled_cp_B, tSsB(_, _, _0{}), tSrB_copy_view(_, _, _0{}));
for (int i = 0; i < size<2>(tSrA); ++i) {
if (i < size<2>(tSrA) - 1) { // prefetch
cute::copy(smem_tiled_copy_A, tSsA(_, _, i + 1), tSrA_copy_view(_, _, i + 1));
cute::copy(smem_tiled_copy_B, tSsB(_, _, i + 1), tSrB_copy_view(_, _, i + 1));
}
cute::gemm(tmma, tSrA(_, _, i), tSrB(_, _, i), acc);
}
The question is, the K-mode of tSsA and tSrA_copy_view is 1, but that of tSrA is 2. It seems a single copy from smem to register is sufficient for 2 gemms in this case, so isn't that tSsA(_, _, i + 1) and tSrA_copy_view(_, _, i + 1) will go out of bounds when i == 0?
Hope anyone could guide me through this, thanks!