Skip to content

Commit 919a26d

Browse files
authored
[Fix] fix the flash_attn_v1 when Br!=Bc (#68)
1 parent 582c219 commit 919a26d

File tree

2 files changed

+80
-59
lines changed

2 files changed

+80
-59
lines changed

docs/17_flash_attn/02_flash_attn_v1_part2.md

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@ printf("Max shared memory: %d, requested shared memory: %d \n", max_sram_size, s
3232
3333
如果请求的 sram_size 超过 max_sram_size,那么内核启动时将会失败,这时候我们需要调整 Bc、D、Br 的参数,找到平衡点,既能保证算法所需内存,又不会超过硬件限制。
3434
35+
这里为了简单起见,在代码中直接将 Bc 和 Br 写成了固定值。值得注意的是,这个 Br 和 Bc 的值是可以不一样的,并且一定有$Br \leq Bc$。
36+
37+
```cpp
38+
const int Bc = 32;
39+
const int Br = 16;
40+
```
41+
42+
至于为什么一定会有$Br \leq Bc$,则可以回到在 Flash Attention V1 的论文里,其计算方式为 $Bc=\lceil \frac{M}{4d} \rceil$,$Br= min(\lceil \frac{M}{4d} \rceil, d)$。其中$M$是设备每个 SM 所能使用的最大共享内存空间大小,$d$是每个向量的维度。$4d$表示的是 Q,K,V,S 使用共享内存的子块大小之和。这里会发现当$\lceil \frac{M}{4d} \rceil > d$时,$Br = d < \lceil \frac{M}{4d} \rceil = Bc$。当$\lceil \frac{M}{4d} \rceil \leq d$时,$Br = \lceil \frac{M}{4d} \rceil = Bc$。所以一定会有$Br \leq Bc$。
43+
44+
根据这个性质,当 Br 与 Bc 不相等时时,也可以只用简单的 if 语句就可以完成 Q 子块的加载,但设置 Bc 和 Br 的时候最好是相等的,可以提高 GPU 线程的利用率。
45+
3546
接下来,我们需要设置 CUDA 内核的执行维度(也就是 gridDim 和 blockDim):
3647

3748
```cpp
@@ -90,17 +101,18 @@ int lm_offset = (bx * gridDim.y * N) + (by * N);
90101

91102
```cpp
92103
extern __shared__ float sram[];
93-
int tile_size = Bc * d; // size of Qi, Kj, Vj
94-
float* Qi = sram;
95-
float* Kj = &sram[tile_size];
96-
float* Vj = &sram[tile_size * 2];
97-
float* S = &sram[tile_size * 3];
104+
const int KV_TILE_SIZE = Bc * d; // size of Kj, Vj
105+
const int Q_TILE_SIZE = Br * d; // size of Qi
106+
float *Qi = sram;
107+
float *Kj = &sram[Q_TILE_SIZE];
108+
float *Vj = &sram[Q_TILE_SIZE + KV_TILE_SIZE];
109+
float *S = &sram[Q_TILE_SIZE + KV_TILE_SIZE * 2];
98110
```
99111

100112
这里我们可以逐一拆解每个部分的作用:
101113

102-
- **Qi 区域**: 用来存储当前片(tile)中从 Q 张量 Q 载入的数据。在内核后续计算 $QK^T$ 时,线程需要利用共享内存里的 Qi 进行向量点乘操作。Qi 的大小为 `tile_size`,即 Bc 个向量,每个向量的维度为 d。Bc 一般对应块内线程数量,每个线程负责处理一个向量或一个向量的一部分。
103-
- **Kj 区域**: Kj 用于存储从全局内存中加载的键(Key)张量的一部分。外层循环中会把总数据分块,每次将一块键数据载入共享内存。同 Qi 一样,也有 tile_size 大小(Bc * d 的数据量)。
114+
- **Qi 区域**: 用来存储当前片(tile)中从 Q 张量 Q 载入的数据。在内核后续计算 $QK^T$ 时,线程需要利用共享内存里的 Qi 进行向量点乘操作。Qi 的大小为 `Q_TILE_SIZE`,即 Br 个向量,每个向量的维度为 d。Br 一般对应块内线程数量,每个线程负责处理一个向量或一个向量的一部分。
115+
- **Kj 区域**: Kj 用于存储从全局内存中加载的键(Key)张量的一部分。外层循环中会把总数据分块,每次将一块键数据载入共享内存。同 Qi 一样,也有 `KV_TILE_SIZE` 大小(Bc * d 的数据量)。
104116
- **Vj 区域**: Vj 与 Kj 类似,不过它存储的是值(Value)张量的一部分。运算中配合 softmax 后的注意力权重对每个线程所对应的值进行加权求和,最终生成输出。
105117
- **S 区域**: S 区域专门用来存储计算结果——也就是 $QK^T$ 相乘得到的分数 Matrix S。在执行 softmax 操作之前,每个线程对自己对应的输出行内的所有元素,将点乘结果保存到 S 里。
106118

@@ -115,12 +127,12 @@ float* S = &sram[tile_size * 3];
115127

116128
```cpp
117129
// 整个 K、V 张量被分成 Tc 个 tile
118-
// 每个 tile 大小为 tile_size(定义为 Bc * d)
130+
// 每个 KV tile 大小为 KV_TILE_SIZE(定义为 Bc * d)
119131
for (int j = 0; j < Tc; j++) {
120132
// Load Kj, Vj from HBM to SRAM
121133
for (int x = 0; x < d; x++) {
122-
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
123-
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
134+
Kj[(tx * d) + x] = K[qkv_offset + (KV_TILE_SIZE * j) + (tx * d) + x];
135+
Vj[(tx * d) + x] = V[qkv_offset + (KV_TILE_SIZE * j) + (tx * d) + x];
124136
}
125137
__syncthreads();
126138
```
@@ -137,7 +149,10 @@ for (int j = 0; j < Tc; j++) {
137149
138150
```cpp
139151
for (int i = 0; i < Tr; i++) {
140-
... // 内部代码
152+
// 这个就是处理Br和Bc不相等的情况
153+
if (tx < Br){
154+
... // 内部代码
155+
}
141156
}
142157
```
143158

@@ -147,7 +162,7 @@ for (int j = 0; j < Tc; j++) {
147162

148163
```cpp
149164
for (int x = 0; x < d; x++) {
150-
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
165+
Qi[(tx * d) + x] = Q[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x];
151166
}
152167
```
153168
@@ -250,8 +265,8 @@ for (int j = 0; j < Tc; j++) {
250265
for (int y = 0; y < Bc; y++) {
251266
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
252267
}
253-
O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \
254-
* ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \
268+
O[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x] = (1 / row_l_new) \
269+
* ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x]) \
255270
+ (__expf(row_m - row_m_new) * pv));
256271
}
257272
```

docs/17_flash_attn/flash_attn_v1.cu

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -60,73 +60,78 @@ __global__ void flash_attn_v1_kernel(const float *Q,
6060

6161
// Define SRAM for Q,K,V,S
6262
extern __shared__ float sram[];
63-
int tile_size = Bc * d; // size of Qi, Kj, Vj
63+
const int KV_TILE_SIZE = Bc * d; // size of Kj, Vj
64+
const int Q_TILE_SIZE = Br * d; // size of Qi
65+
// const int S_TILE_SIZE = Br * Bc; // size of Sij = softmax(Qi * Kj^T * softmax_scale)
6466
float *Qi = sram;
65-
float *Kj = &sram[tile_size];
66-
float *Vj = &sram[tile_size * 2];
67-
float *S = &sram[tile_size * 3];
67+
float *Kj = &sram[Q_TILE_SIZE];
68+
float *Vj = &sram[Q_TILE_SIZE + KV_TILE_SIZE];
69+
float *S = &sram[Q_TILE_SIZE + KV_TILE_SIZE * 2];
6870

6971
// outer loop
7072
for (int j = 0; j < Tc; j++)
7173
{
7274
// Load Kj, Vj from HBM to SRAM
7375
for (int x = 0; x < d; x++)
7476
{
75-
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
76-
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
77+
Kj[(tx * d) + x] = K[qkv_offset + (KV_TILE_SIZE * j) + (tx * d) + x];
78+
Vj[(tx * d) + x] = V[qkv_offset + (KV_TILE_SIZE * j) + (tx * d) + x];
7779
}
7880
__syncthreads();
7981

8082
for (int i = 0; i < Tr; i++)
8183
{
82-
// Load Qi to SRAM, l and m to registers
83-
for (int x = 0; x < d; x++)
84+
if (tx < Br)
8485
{
85-
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
86-
}
87-
float row_m_prev = m[lm_offset + (Br * i) + tx];
88-
float row_l_prev = l[lm_offset + (Br * i) + tx];
89-
90-
// S = QK^T, row_m = rowmax(S)
91-
float row_m = -INFINITY;
92-
for (int y = 0; y < Bc; y++)
93-
{
94-
float sum = 0;
86+
// Load Qi to SRAM, l and m to registers
9587
for (int x = 0; x < d; x++)
9688
{
97-
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
89+
Qi[(tx * d) + x] = Q[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x];
9890
}
99-
sum *= softmax_scale;
100-
S[(Bc * tx) + y] = sum;
101-
102-
if (sum > row_m)
103-
row_m = sum;
104-
}
91+
float row_m_prev = m[lm_offset + (Br * i) + tx];
92+
float row_l_prev = l[lm_offset + (Br * i) + tx];
10593

106-
// P = exp(S - row_m), row_l = rowsum(P)
107-
float row_l = 0;
108-
for (int y = 0; y < Bc; y++)
109-
{
110-
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
111-
row_l += S[(Bc * tx) + y];
112-
}
94+
// S = QK^T, row_m = rowmax(S)
95+
float row_m = -INFINITY;
96+
for (int y = 0; y < Bc; y++)
97+
{
98+
float sum = 0;
99+
for (int x = 0; x < d; x++)
100+
{
101+
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
102+
}
103+
sum *= softmax_scale;
104+
S[(Bc * tx) + y] = sum;
113105

114-
// Compute new m and l
115-
float row_m_new = max(row_m_prev, row_m);
116-
float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);
106+
if (sum > row_m)
107+
row_m = sum;
108+
}
117109

118-
// Write O, l, m to HBM
119-
for (int x = 0; x < d; x++)
120-
{
121-
float pv = 0; // Pij * Vj
110+
// P = exp(S - row_m), row_l = rowsum(P)
111+
float row_l = 0;
122112
for (int y = 0; y < Bc; y++)
123113
{
124-
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
114+
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
115+
row_l += S[(Bc * tx) + y];
116+
}
117+
118+
// Compute new m and l
119+
float row_m_new = max(row_m_prev, row_m);
120+
float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);
121+
122+
// Write O, l, m to HBM
123+
for (int x = 0; x < d; x++)
124+
{
125+
float pv = 0; // Pij * Vj
126+
for (int y = 0; y < Bc; y++)
127+
{
128+
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
129+
}
130+
O[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x] = (1 / row_l_new) * ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x]) + (__expf(row_m - row_m_new) * pv));
125131
}
126-
O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) * ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) + (__expf(row_m - row_m_new) * pv));
132+
m[lm_offset + (Br * i) + tx] = row_m_new;
133+
l[lm_offset + (Br * i) + tx] = row_l_new;
127134
}
128-
m[lm_offset + (Br * i) + tx] = row_m_new;
129-
l[lm_offset + (Br * i) + tx] = row_l_new;
130135
}
131136
__syncthreads();
132137
}
@@ -234,7 +239,8 @@ int main()
234239

235240
// split kv seq_len to Tc and Q seq_len to Tr
236241
const int Bc = 32;
237-
const int Br = 32;
242+
// const int Br = 32;
243+
const int Br = 16;
238244
const int Tc = ceil((float)N / Bc);
239245
const int Tr = ceil((float)N / Br);
240246

@@ -305,7 +311,7 @@ int main()
305311

306312
if (max_diff < 0.0001)
307313
{
308-
printf("Results are correct! ");
314+
printf("Results are correct! \n");
309315
}
310316
else
311317
{

0 commit comments

Comments
 (0)