Skip to content

[QST] Why it won't OOB in tiled_copy pipeline #2018

@ZhZhang711

Description

@ZhZhang711

What is your question?
A toy example (I am a newbee and there might be some "brainless" atom choice):

  using ELM = cutlass::half_t;
  using bM = decltype(Int<128>{});
  using bN = decltype(Int<128>{});
  using bK = decltype(Int<16>{});
  TiledMMA tmma =
      make_tiled_mma(SM70_8x8x4_F32F16F16F32_NT{}, Layout<Shape<_2, _2, _2>>{},
                     Tile<_32, _32, _16>{});
  auto thr_mma = tmma.get_thread_slice(0);

  auto sA = make_tensor(make_smem_ptr((ELM *)(0)), Layout<Shape<bM, bK>>{});  // Let's assume A is somehow copied to this sA
  auto sB = make_tensor(make_smem_ptr((ELM *)(0)), Layout<Shape<bN, bK>>{});  // Let's assume the same as well

  Tensor tSrA = thr_mma.partition_fragment_A(sA);
  Tensor tSrB = thr_mma.partition_fragment_B(sB);
  Tensor acc = partition_fragment_C(tmma, Shape<bM, bN>{});

  auto cp_atom = Copy_Atom<SM75_U32x4_LDSM_N, ELM>{};
  auto smem_tiled_cp_A = make_tiled_copy_A(cp_atom, tmma);
  auto smem_thr_cp_A = smem_tiled_cp_A.get_thread_slice(0);
  Tensor tSsA = smem_thr_cp_A.partition_S(sA);
  auto smem_tiled_cp_B = make_tiled_copy_B(cp_atom, tmma);
  auto smem_thr_cp_B = smem_tiled_cp_B.get_thread_slice(0);
  Tensor tSsB = smem_thr_cp_B.partition_S(sB);

  Tensor tSrA_copy_view = smem_thr_cp_A.retile_D(tSrA);
  Tensor tSrB_copy_view = smem_thr_cp_A.retile_D(tSrB);

  printf("\n");
  cute::print(layout<>(tSrA));
  printf("\n");
  cute::print(layout<>(tSsA));
  printf("\n");
  cute::print(layout<>(tSrA_copy_view));
  printf("\n");

and stdout would give me this:

(_4,_8,_2):(_1,_4,_32)
(((_2,_4),_2),_4,_1):(((_1,_128),_1024),_32,_0)
((_8,_2),_4,_1):((_1,_32),_8,_0)

And then many examples will launch a pipeline iterating the K-mode like this:

  cute::copy(smem_tiled_cp_A, tSsA(_, _, _0{}), tSrA_copy_view(_, _, _0{}));
  cute::copy(smem_tiled_cp_B, tSsB(_, _, _0{}), tSrB_copy_view(_, _, _0{}));

  for (int i = 0; i < size<2>(tSrA); ++i) {
    if (i < size<2>(tSrA) - 1) {  // prefetch
      cute::copy(smem_tiled_copy_A, tSsA(_, _, i + 1), tSrA_copy_view(_, _, i + 1));
      cute::copy(smem_tiled_copy_B, tSsB(_, _, i + 1), tSrB_copy_view(_, _, i + 1));
    }
    cute::gemm(tmma, tSrA(_, _, i), tSrB(_, _, i), acc);
  }

The question is, the K-mode of tSsA and tSrA_copy_view is 1, but that of tSrA is 2. It seems a single copy from smem to register is sufficient for 2 gemms in this case, so isn't that tSsA(_, _, i + 1) and tSrA_copy_view(_, _, i + 1) will go out of bounds when i == 0?
Hope anyone could guide me through this, thanks!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions