Skip to content

[Fix] fix the flash_attn_v1 when Br!=Bc #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions docs/17_flash_attn/02_flash_attn_v1_part2.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ printf("Max shared memory: %d, requested shared memory: %d \n", max_sram_size, s

如果请求的 sram_size 超过 max_sram_size,那么内核启动时将会失败,这时候我们需要调整 Bc、D、Br 的参数,找到平衡点,既能保证算法所需内存,又不会超过硬件限制。

这里为了简单起见,在代码中直接将 Bc 和 Br 写成了固定值。值得注意的是,这个 Br 和 Bc 的值是可以不一样的,并且一定有$Br \leq Bc$。

```cpp
const int Bc = 32;
const int Br = 16;
```

至于为什么一定会有$Br \leq Bc$,则可以回到在 Flash Attention V1 的论文里,其计算方式为 $Bc=\lceil \frac{M}{4d} \rceil$,$Br= min(\lceil \frac{M}{4d} \rceil, d)$。其中$M$是设备每个 SM 所能使用的最大共享内存空间大小,$d$是每个向量的维度。$4d$表示的是 Q,K,V,S 使用共享内存的子块大小之和。这里会发现当$\lceil \frac{M}{4d} \rceil > d$时,$Br = d < \lceil \frac{M}{4d} \rceil = Bc$。当$\lceil \frac{M}{4d} \rceil \leq d$时,$Br = \lceil \frac{M}{4d} \rceil = Bc$。所以一定会有$Br \leq Bc$。

根据这个性质,当 Br 与 Bc 不相等时时,也可以只用简单的 if 语句就可以完成 Q 子块的加载,但设置 Bc 和 Br 的时候最好是相等的,可以提高 GPU 线程的利用率。

接下来,我们需要设置 CUDA 内核的执行维度(也就是 gridDim 和 blockDim):

```cpp
Expand Down Expand Up @@ -90,17 +101,18 @@ int lm_offset = (bx * gridDim.y * N) + (by * N);

```cpp
extern __shared__ float sram[];
int tile_size = Bc * d; // size of Qi, Kj, Vj
float* Qi = sram;
float* Kj = &sram[tile_size];
float* Vj = &sram[tile_size * 2];
float* S = &sram[tile_size * 3];
const int KV_TILE_SIZE = Bc * d; // size of Kj, Vj
const int Q_TILE_SIZE = Br * d; // size of Qi
float *Qi = sram;
float *Kj = &sram[Q_TILE_SIZE];
float *Vj = &sram[Q_TILE_SIZE + KV_TILE_SIZE];
float *S = &sram[Q_TILE_SIZE + KV_TILE_SIZE * 2];
```

这里我们可以逐一拆解每个部分的作用:

- **Qi 区域**: 用来存储当前片(tile)中从 Q 张量 Q 载入的数据。在内核后续计算 $QK^T$ 时,线程需要利用共享内存里的 Qi 进行向量点乘操作。Qi 的大小为 `tile_size`,即 Bc 个向量,每个向量的维度为 d。Bc 一般对应块内线程数量,每个线程负责处理一个向量或一个向量的一部分。
- **Kj 区域**: Kj 用于存储从全局内存中加载的键(Key)张量的一部分。外层循环中会把总数据分块,每次将一块键数据载入共享内存。同 Qi 一样,也有 tile_size 大小(Bc * d 的数据量)。
- **Qi 区域**: 用来存储当前片(tile)中从 Q 张量 Q 载入的数据。在内核后续计算 $QK^T$ 时,线程需要利用共享内存里的 Qi 进行向量点乘操作。Qi 的大小为 `Q_TILE_SIZE`,即 Br 个向量,每个向量的维度为 d。Br 一般对应块内线程数量,每个线程负责处理一个向量或一个向量的一部分。
- **Kj 区域**: Kj 用于存储从全局内存中加载的键(Key)张量的一部分。外层循环中会把总数据分块,每次将一块键数据载入共享内存。同 Qi 一样,也有 `KV_TILE_SIZE` 大小(Bc * d 的数据量)。
- **Vj 区域**: Vj 与 Kj 类似,不过它存储的是值(Value)张量的一部分。运算中配合 softmax 后的注意力权重对每个线程所对应的值进行加权求和,最终生成输出。
- **S 区域**: S 区域专门用来存储计算结果——也就是 $QK^T$ 相乘得到的分数 Matrix S。在执行 softmax 操作之前,每个线程对自己对应的输出行内的所有元素,将点乘结果保存到 S 里。

Expand All @@ -115,12 +127,12 @@ float* S = &sram[tile_size * 3];

```cpp
// 整个 K、V 张量被分成 Tc 个 tile
// 每个 tile 大小为 tile_size(定义为 Bc * d)
// 每个 KV tile 大小为 KV_TILE_SIZE(定义为 Bc * d)
for (int j = 0; j < Tc; j++) {
// Load Kj, Vj from HBM to SRAM
for (int x = 0; x < d; x++) {
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
Kj[(tx * d) + x] = K[qkv_offset + (KV_TILE_SIZE * j) + (tx * d) + x];
Vj[(tx * d) + x] = V[qkv_offset + (KV_TILE_SIZE * j) + (tx * d) + x];
}
__syncthreads();
```
Expand All @@ -137,7 +149,10 @@ for (int j = 0; j < Tc; j++) {

```cpp
for (int i = 0; i < Tr; i++) {
... // 内部代码
// 这个就是处理Br和Bc不相等的情况
if (tx < Br){
... // 内部代码
}
}
```

Expand All @@ -147,7 +162,7 @@ for (int j = 0; j < Tc; j++) {

```cpp
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
Qi[(tx * d) + x] = Q[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x];
}
```

Expand Down Expand Up @@ -250,8 +265,8 @@ for (int j = 0; j < Tc; j++) {
for (int y = 0; y < Bc; y++) {
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
}
O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \
* ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \
O[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x] = (1 / row_l_new) \
* ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x]) \
+ (__expf(row_m - row_m_new) * pv));
}
```
Expand Down
96 changes: 51 additions & 45 deletions docs/17_flash_attn/flash_attn_v1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,73 +60,78 @@ __global__ void flash_attn_v1_kernel(const float *Q,

// Define SRAM for Q,K,V,S
extern __shared__ float sram[];
int tile_size = Bc * d; // size of Qi, Kj, Vj
const int KV_TILE_SIZE = Bc * d; // size of Kj, Vj
const int Q_TILE_SIZE = Br * d; // size of Qi
// const int S_TILE_SIZE = Br * Bc; // size of Sij = softmax(Qi * Kj^T * softmax_scale)
float *Qi = sram;
float *Kj = &sram[tile_size];
float *Vj = &sram[tile_size * 2];
float *S = &sram[tile_size * 3];
float *Kj = &sram[Q_TILE_SIZE];
float *Vj = &sram[Q_TILE_SIZE + KV_TILE_SIZE];
float *S = &sram[Q_TILE_SIZE + KV_TILE_SIZE * 2];

// outer loop
for (int j = 0; j < Tc; j++)
{
// Load Kj, Vj from HBM to SRAM
for (int x = 0; x < d; x++)
{
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
Kj[(tx * d) + x] = K[qkv_offset + (KV_TILE_SIZE * j) + (tx * d) + x];
Vj[(tx * d) + x] = V[qkv_offset + (KV_TILE_SIZE * j) + (tx * d) + x];
}
__syncthreads();

for (int i = 0; i < Tr; i++)
{
// Load Qi to SRAM, l and m to registers
for (int x = 0; x < d; x++)
if (tx < Br)
{
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
}
float row_m_prev = m[lm_offset + (Br * i) + tx];
float row_l_prev = l[lm_offset + (Br * i) + tx];

// S = QK^T, row_m = rowmax(S)
float row_m = -INFINITY;
for (int y = 0; y < Bc; y++)
{
float sum = 0;
// Load Qi to SRAM, l and m to registers
for (int x = 0; x < d; x++)
{
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
Qi[(tx * d) + x] = Q[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x];
}
sum *= softmax_scale;
S[(Bc * tx) + y] = sum;

if (sum > row_m)
row_m = sum;
}
float row_m_prev = m[lm_offset + (Br * i) + tx];
float row_l_prev = l[lm_offset + (Br * i) + tx];

// P = exp(S - row_m), row_l = rowsum(P)
float row_l = 0;
for (int y = 0; y < Bc; y++)
{
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
row_l += S[(Bc * tx) + y];
}
// S = QK^T, row_m = rowmax(S)
float row_m = -INFINITY;
for (int y = 0; y < Bc; y++)
{
float sum = 0;
for (int x = 0; x < d; x++)
{
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
}
sum *= softmax_scale;
S[(Bc * tx) + y] = sum;

// Compute new m and l
float row_m_new = max(row_m_prev, row_m);
float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);
if (sum > row_m)
row_m = sum;
}

// Write O, l, m to HBM
for (int x = 0; x < d; x++)
{
float pv = 0; // Pij * Vj
// P = exp(S - row_m), row_l = rowsum(P)
float row_l = 0;
for (int y = 0; y < Bc; y++)
{
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
row_l += S[(Bc * tx) + y];
}

// Compute new m and l
float row_m_new = max(row_m_prev, row_m);
float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);

// Write O, l, m to HBM
for (int x = 0; x < d; x++)
{
float pv = 0; // Pij * Vj
for (int y = 0; y < Bc; y++)
{
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
}
O[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x] = (1 / row_l_new) * ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x]) + (__expf(row_m - row_m_new) * pv));
}
O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) * ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) + (__expf(row_m - row_m_new) * pv));
m[lm_offset + (Br * i) + tx] = row_m_new;
l[lm_offset + (Br * i) + tx] = row_l_new;
}
m[lm_offset + (Br * i) + tx] = row_m_new;
l[lm_offset + (Br * i) + tx] = row_l_new;
}
__syncthreads();
}
Expand Down Expand Up @@ -234,7 +239,8 @@ int main()

// split kv seq_len to Tc and Q seq_len to Tr
const int Bc = 32;
const int Br = 32;
// const int Br = 32;
const int Br = 16;
const int Tc = ceil((float)N / Bc);
const int Tr = ceil((float)N / Br);

Expand Down Expand Up @@ -305,7 +311,7 @@ int main()

if (max_diff < 0.0001)
{
printf("Results are correct! ");
printf("Results are correct! \n");
}
else
{
Expand Down