-
Notifications
You must be signed in to change notification settings - Fork 16
Open
Description
这一部分有关三个 Layout 的代码我一直没有看明白,在这篇知乎中看到一样的内容后,我发现这里的代码实现可能有 bug。
按照定义:
- LayoutTile 是每个 Block 有 (LayoutTile::m, LayoutTile::n) 个 float
- LayoutBlock 是每个 Block 有 (LayoutBlock::m, LayoutBlock::n) 个 thread
- LayoutThread 是每个 thread 中,每个 submatrix 有 (LayoutThread::m, LayoutThread::n) 个 float, 因为用的 float4,可以理解为 4*4。
那么此处 gemm_use_tile.cu 第10行和第11行 中对于m 和 n 的定义就有问题了,应该如下:
unsigned m= threadIdx.x* LayoutTile::m/LayoutBlock::m+ LayoutTile::m* blockIdx.x;
unsigned n= threadIdx.y* LayoutTile::n/LayoutBlock::n+ LayoutTile::n* blockIdx.y
同样的, gemm_use_tile.cu 第19行和第20行 中,iterationA 和 iterationB 应该分别指的是每个 thread 有多少个 (4,4) 的 subMatrix,这里应该是 2*2 = 4 个,那么 gemm_use_tile.cu 第21行和第22行 intervalA 和 intervalB 的定义就有问题了,按照后续代码,intervalA 和 intervalB 指的分别应该是每个 subMatrix 有多大,也就是 (LayoutThread::m, LayoutThread::n)
Metadata
Metadata
Assignees
Labels
No labels