Skip to content

[QST] How does make_tiled_copy_A determine the source address for copying? #2232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ezioliao opened this issue Apr 10, 2025 · 3 comments
Closed

Comments

@ezioliao
Copy link

Take the following code as an example:

  // s2r_copy_atom construction
  using s2r_copy_op = SM75_U32x4_LDSM_N; 
  using s2r_copy_traits = Copy_Traits<s2r_copy_op>;
  using s2r_copy_atom = Copy_Atom<s2r_copy_traits, T>; // T = cute::half_t
  using S2RCopyAtomA = s2r_copy_atom;

  // tiled_mma construction
  using mma_op = SM80_16x8x16_F16F16F16F16_TN;
  using mma_traits = MMA_Traits<mma_op>;
  using mma_atom = MMA_Atom<mma_traits>;
  using MMA_EU_RepeatT = decltype(make_layout(make_shape(
      Int<1>{}, Int<1>{}, Int<1>{})));
  using MMA_P_T = Tile<Int<16>, Int<8>, Int<16>>;
  using TiledMMA = decltype(make_tiled_mma(mma_atom{}, MMA_EU_RepeatT{}, MMA_P_T{}));

  // tAsA construction for each thread
  TiledMMA tiled_mma;
  auto s2r_tiled_copy_a = make_tiled_copy_A(S2RCopyAtomA{}, tiled_mma);
  auto s2r_thr_copy_a = s2r_tiled_copy_a.get_slice(idx);
  auto tAsA = s2r_thr_copy_a.partition_S(sA);  // ? (CPY, CPY_M, CPY_K, kStage)
  auto tCrA_view = s2r_thr_copy_a.retile_D(tCrA);  // ? (CPY, CPY_M, CPY_K)

In this implementation, SM80_16x8x16_F16F16F16F16_TN is used as the tiled_mma, while SM75_U32x4_LDSM_N serves as the copy operation. The tiled_mma is then used to construct s2r_tiled_copy_a.

As we know, in the copy function for matrix A, each thread handles source addressing to perform 16x16 matrix copying. For SM80_16x8x16_F16F16F16F16_TN and SM75_U32x4_LDSM_N,the source address layout for the copy operation is as follows:

Image

I'm confused about how make_tiled_copy_A assigns source addresses to each thread. Specifically, what determines the thread-to-address mapping pattern shown in the upper diagram versus the alternative pattern below?

Image

From what I understand, mma_atom and Copy_Atom doesn't seem to provide the relevant information:

MMA_Atom
  ThrID:      _32:_1
  Shape_MNK:  (_16,_8,_16)
  LayoutA_TV: ((_4,_8),(_2,_2,_2)):((_32,_1),(_16,_8,_128))
  LayoutB_TV: ((_4,_8),(_2,_2)):((_16,_1),(_8,_64))
  LayoutC_TV: ((_4,_8),(_2,_2)):((_32,_1),(_16,_8))

Copy_Atom
  ThrID:        _32:_1
  ValLayoutSrc: (_32,_8):(_8,_1)
  ValLayoutDst: (_32,(_2,_4)):(_2,(_1,_64))
  ValLayoutRef: (_32,(_2,_4)):(_2,(_1,_64))
  ValueType:    16b
@ccecka
Copy link

ccecka commented Apr 10, 2025

The MMA specifies the thread-to-coordinate mapping and the data layout specifies the coordinate-to-address mapping.

In your example, the MMA is

print_latex(tiled_mma);

Image
and the CPY is

print_latex(s2r_tiled_copy_a);

Image
Note how the RHS (the destination) of the CPY is the same Thr-Val layout as the A-matrix in the MMA. That is, this CPY will produce exactly the registers that MMA expects for its A-matrix. The LHS (the source) of the CPY is transformed according to the intrinsic pattern of SM75_U32x4_LDSM_N. This is the pattern that will be used to read from SMEM.

There are no addresses yet. These are partitioning patterns specifying the mapping of thr-vals to coordinates. When these patterns are applied to tensors (which are coordinate-to-address mappings), then we get the addresses

  auto s2r_thr_copy_a = s2r_tiled_copy_a.get_slice(idx);
  Tensor tAsA = s2r_thr_copy_a.partition_S(sA);  // (CPY, CPY_M, CPY_K, kStage)
  Tensor tCrA_view = s2r_thr_copy_a.retile_D(tCrA);  // (CPY, CPY_M, CPY_K)

@ezioliao
Copy link
Author

ezioliao commented Apr 12, 2025

@ccecka Thanks for your reply! I just have one remaining question about this:

The LHS (the source) of the CPY is transformed according to the intrinsic pattern of SM75_U32x4_LDSM_N

The Thr-Val layout of LHS of this CPY seems ((16, 2),8):((1, 128),16) (if A is M-major). I can't deduce this layout from the information SM75_U32x4_LDSM_N provided.

In other words, without using print_latex , how would I know what the src Thr-Val layout should look like? I believe this layout was derived from the information provided by copy_atom and mma_atom using layout algebra, but I don't know the specific derivation process.

MMA_Atom
  ThrID:      _32:_1
  Shape_MNK:  (_16,_8,_16)
  LayoutA_TV: ((_4,_8),(_2,_2,_2)):((_32,_1),(_16,_8,_128))
  LayoutB_TV: ((_4,_8),(_2,_2)):((_16,_1),(_8,_64))
  LayoutC_TV: ((_4,_8),(_2,_2)):((_32,_1),(_16,_8))

Copy_Atom
  ThrID:        _32:_1
  ValLayoutSrc: (_32,_8):(_8,_1)
  ValLayoutDst: (_32,(_2,_4)):(_2,(_1,_64))
  ValLayoutRef: (_32,(_2,_4)):(_2,(_1,_64))
  ValueType:    16b

@ezioliao ezioliao reopened this Apr 12, 2025
@ccecka
Copy link

ccecka commented Apr 27, 2025

The complete partitioning patterns can't be derived from the Atoms, only the TiledMMAs and TiledCopys (this is why we call print_latex on the Tiled instead of the Atom). I don't think it's useful to get into the representation that is used internally, but we used to track and propagate these patterns by hand in excel and now we use CuTe Layouts to do so.

@ezioliao ezioliao closed this as completed May 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants