Skip to content

[QST] confusion about various shapes encountered #2534

@rehsuuuuuuuu

Description

@rehsuuuuuuuu

What is your question?
Hi, sorry if this is a newbee question. While I am exploring cutlass and building my own project I encountered different shapes that I cant match them in a correct way. I will try to list them out with my own understanding. Any helps will be appreciated.

ThreadBlockShape/TileShape ---> These two should be equivalent, as they are both used in the CollectiveMainloop/CollectiveEpilogue builder, refering to the size per CTA (how you tile the global memory I suppose). In my case, I used 128x128x128

MMA atom shape ---> Like 16x8x64 in cute::MMA_Atom<cute::SM120::BLOCKSCALED::SM120_16x8x64. This one is also fairly straightforward, which is the size per tensor core operation.

The confusion comes from, when I look at PTX doc, https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tcgen05-matrix-shape, the shape that is supported on, say, .kind::mxf4nvf4 with dense/1 cta group is 128xNx64, where N is multiple of 8 up to 256.

On the other hand, when i print out the tiledMMA info from the kernel I am using, I got

cute::TiledMMA<cute::MMA_Atom<cute::SM120::BLOCKSCALED::SM120_16x8x64_TN_VS<cutlass::float_e2m1_t,cutlass::float_e2m1_t, float, cutlass::float_ue4m3_t, 16> >
cute::Layout<cute::tuple<cute::C<4>, cute::C<2>, cute::C<1> >, cute::tuple<cute::C<1>, cute::C<4>, cute::C<0> > >, 
cute::tuple<cute::C<128>, cute::Layout<cute::tuple<cute::C<8>, cute::C<2>, cute::C<2> >, cute::tuple<cute::C<1>, cute::C<16>, cute::C<8> > >, cute::C<64> >

from here the shape of the tiledMMA should be 128x32x64(with permutation on the N dimension).

How are these shapes related?

Also some other questions:

  1. In the AtomLayout (4,2,1) above, we are adding 7 additional warps to expand the Atom shape, does each warp comes from different CTAs or within the same CTA?
  2. Before I pass my SFA to the kernel (from pytorch), I have to permute my tensor memory so to accommodate for the interleaved layout. What I did is to (M;K)--->(M//128,4,32; K//4,4) ---> permute(0,1,2,3,4)->(0,3,2,1,4). It works but seems to contradict the case when M/N >128 provided in https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tcgen05-mma-scale-factor-b-layout-4x, which is column major for (32x16) blocks, my permutation is row major for (32x16 blocks). But the results run with my permutation seems to be correct.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions