Skip to content

[BUG] examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu GridDim miscalculated #2493

@JJXiangJiaoJun

Description

@JJXiangJiaoJun

Bug Report

Describe the bug
kernel launch configuration of examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu, GridDim was miscalculated

04_mma_tma_2sm_sm100.cu#L576

Steps/Code to reproduce bug
./04_mma_tma_2sm_sm100 1024 1024 256

Image

e.g.

  • This Example default Gemm_M = 512, Gemm_N = 1024, Gemm_K = 256, our 2sm-umma MMATileShape is 256x256x16, cluster shape is 4x4x1.
  • So as we use 2sm-umma instruction, CTATileShape is 128x256x16, where CTATileShapeM == MMATileShape / 2. we need to launch with a GridDim = (ceil_div(Gemm_M, CTATileShapeM), ceil_div(Gemm_N, CTATileShapeN), 1) == (4, 4, 1) without consideration of roundup by ClusterShape.
  • But in this Example we use MMATileShape to calculate GridDim, such as GridDim = (ceil_div(Gemm_M, MMATileShapeM), ceil_div(Gemm_N, MMATileShapeM), 1) == (2, 4, 1). Howerver, in this problem size, Example round_up GridDim with ClusterShape which result in GridDim = (round_up(2, 4), round_up(4, 4)) == (4, 4), we got right GridDim by mistake!
  • with problems (Gemm_M / CTATileShapeM) > 4, GridDim is wrong. like Gemm_M = 1024, Gemm_N = 1024, Gemm_K = 256, we should lauch with GridDim (8, 4, 1), by in the example it will be (4, 4, 1)

I create a PR to fix this problem: fix: examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu GridDim miscalculated

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions