-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Description
What is your question?
Hi, sorry if this is a newbee question. While I am exploring cutlass and building my own project I encountered different shapes that I cant match them in a correct way. I will try to list them out with my own understanding. Any helps will be appreciated.
ThreadBlockShape/TileShape ---> These two should be equivalent, as they are both used in the CollectiveMainloop/CollectiveEpilogue builder, refering to the size per CTA (how you tile the global memory I suppose). In my case, I used 128x128x128
MMA atom shape ---> Like 16x8x64 in cute::MMA_Atom<cute::SM120::BLOCKSCALED::SM120_16x8x64. This one is also fairly straightforward, which is the size per tensor core operation.
The confusion comes from, when I look at PTX doc, https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tcgen05-matrix-shape, the shape that is supported on, say, .kind::mxf4nvf4 with dense/1 cta group is 128xNx64, where N is multiple of 8 up to 256.
On the other hand, when i print out the tiledMMA info from the kernel I am using, I got
cute::TiledMMA<cute::MMA_Atom<cute::SM120::BLOCKSCALED::SM120_16x8x64_TN_VS<cutlass::float_e2m1_t,cutlass::float_e2m1_t, float, cutlass::float_ue4m3_t, 16> >
cute::Layout<cute::tuple<cute::C<4>, cute::C<2>, cute::C<1> >, cute::tuple<cute::C<1>, cute::C<4>, cute::C<0> > >,
cute::tuple<cute::C<128>, cute::Layout<cute::tuple<cute::C<8>, cute::C<2>, cute::C<2> >, cute::tuple<cute::C<1>, cute::C<16>, cute::C<8> > >, cute::C<64> >
from here the shape of the tiledMMA should be 128x32x64(with permutation on the N dimension).
How are these shapes related?
Also some other questions:
- In the AtomLayout (4,2,1) above, we are adding 7 additional warps to expand the Atom shape, does each warp comes from different CTAs or within the same CTA?
- Before I pass my SFA to the kernel (from pytorch), I have to permute my tensor memory so to accommodate for the interleaved layout. What I did is to (M;K)--->(M//128,4,32; K//4,4) ---> permute(0,1,2,3,4)->(0,3,2,1,4). It works but seems to contradict the case when M/N >128 provided in https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tcgen05-mma-scale-factor-b-layout-4x, which is column major for (32x16) blocks, my permutation is row major for (32x16 blocks). But the results run with my permutation seems to be correct.