@@ -60,73 +60,78 @@ __global__ void flash_attn_v1_kernel(const float *Q,
60
60
61
61
// Define SRAM for Q,K,V,S
62
62
extern __shared__ float sram[];
63
- int tile_size = Bc * d; // size of Qi, Kj, Vj
63
+ const int KV_TILE_SIZE = Bc * d; // size of Kj, Vj
64
+ const int Q_TILE_SIZE = Br * d; // size of Qi
65
+ // const int S_TILE_SIZE = Br * Bc; // size of Sij = softmax(Qi * Kj^T * softmax_scale)
64
66
float *Qi = sram;
65
- float *Kj = &sram[tile_size ];
66
- float *Vj = &sram[tile_size * 2 ];
67
- float *S = &sram[tile_size * 3 ];
67
+ float *Kj = &sram[Q_TILE_SIZE ];
68
+ float *Vj = &sram[Q_TILE_SIZE + KV_TILE_SIZE ];
69
+ float *S = &sram[Q_TILE_SIZE + KV_TILE_SIZE * 2 ];
68
70
69
71
// outer loop
70
72
for (int j = 0 ; j < Tc; j++)
71
73
{
72
74
// Load Kj, Vj from HBM to SRAM
73
75
for (int x = 0 ; x < d; x++)
74
76
{
75
- Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
76
- Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
77
+ Kj[(tx * d) + x] = K[qkv_offset + (KV_TILE_SIZE * j) + (tx * d) + x];
78
+ Vj[(tx * d) + x] = V[qkv_offset + (KV_TILE_SIZE * j) + (tx * d) + x];
77
79
}
78
80
__syncthreads ();
79
81
80
82
for (int i = 0 ; i < Tr; i++)
81
83
{
82
- // Load Qi to SRAM, l and m to registers
83
- for (int x = 0 ; x < d; x++)
84
+ if (tx < Br)
84
85
{
85
- Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
86
- }
87
- float row_m_prev = m[lm_offset + (Br * i) + tx];
88
- float row_l_prev = l[lm_offset + (Br * i) + tx];
89
-
90
- // S = QK^T, row_m = rowmax(S)
91
- float row_m = -INFINITY;
92
- for (int y = 0 ; y < Bc; y++)
93
- {
94
- float sum = 0 ;
86
+ // Load Qi to SRAM, l and m to registers
95
87
for (int x = 0 ; x < d; x++)
96
88
{
97
- sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
89
+ Qi[(tx * d) + x] = Q[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x];
98
90
}
99
- sum *= softmax_scale;
100
- S[(Bc * tx) + y] = sum;
101
-
102
- if (sum > row_m)
103
- row_m = sum;
104
- }
91
+ float row_m_prev = m[lm_offset + (Br * i) + tx];
92
+ float row_l_prev = l[lm_offset + (Br * i) + tx];
105
93
106
- // P = exp(S - row_m), row_l = rowsum(P)
107
- float row_l = 0 ;
108
- for (int y = 0 ; y < Bc; y++)
109
- {
110
- S[(Bc * tx) + y] = __expf (S[(Bc * tx) + y] - row_m);
111
- row_l += S[(Bc * tx) + y];
112
- }
94
+ // S = QK^T, row_m = rowmax(S)
95
+ float row_m = -INFINITY;
96
+ for (int y = 0 ; y < Bc; y++)
97
+ {
98
+ float sum = 0 ;
99
+ for (int x = 0 ; x < d; x++)
100
+ {
101
+ sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
102
+ }
103
+ sum *= softmax_scale;
104
+ S[(Bc * tx) + y] = sum;
113
105
114
- // Compute new m and l
115
- float row_m_new = max (row_m_prev, row_m) ;
116
- float row_l_new = ( __expf (row_m_prev - row_m_new) * row_l_prev) + ( __expf (row_m - row_m_new) * row_l);
106
+ if (sum > row_m)
107
+ row_m = sum ;
108
+ }
117
109
118
- // Write O, l, m to HBM
119
- for (int x = 0 ; x < d; x++)
120
- {
121
- float pv = 0 ; // Pij * Vj
110
+ // P = exp(S - row_m), row_l = rowsum(P)
111
+ float row_l = 0 ;
122
112
for (int y = 0 ; y < Bc; y++)
123
113
{
124
- pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
114
+ S[(Bc * tx) + y] = __expf (S[(Bc * tx) + y] - row_m);
115
+ row_l += S[(Bc * tx) + y];
116
+ }
117
+
118
+ // Compute new m and l
119
+ float row_m_new = max (row_m_prev, row_m);
120
+ float row_l_new = (__expf (row_m_prev - row_m_new) * row_l_prev) + (__expf (row_m - row_m_new) * row_l);
121
+
122
+ // Write O, l, m to HBM
123
+ for (int x = 0 ; x < d; x++)
124
+ {
125
+ float pv = 0 ; // Pij * Vj
126
+ for (int y = 0 ; y < Bc; y++)
127
+ {
128
+ pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
129
+ }
130
+ O[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x] = (1 / row_l_new) * ((row_l_prev * __expf (row_m_prev - row_m_new) * O[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x]) + (__expf (row_m - row_m_new) * pv));
125
131
}
126
- O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) * ((row_l_prev * __expf (row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) + (__expf (row_m - row_m_new) * pv));
132
+ m[lm_offset + (Br * i) + tx] = row_m_new;
133
+ l[lm_offset + (Br * i) + tx] = row_l_new;
127
134
}
128
- m[lm_offset + (Br * i) + tx] = row_m_new;
129
- l[lm_offset + (Br * i) + tx] = row_l_new;
130
135
}
131
136
__syncthreads ();
132
137
}
@@ -234,7 +239,8 @@ int main()
234
239
235
240
// split kv seq_len to Tc and Q seq_len to Tr
236
241
const int Bc = 32 ;
237
- const int Br = 32 ;
242
+ // const int Br = 32;
243
+ const int Br = 16 ;
238
244
const int Tc = ceil ((float )N / Bc);
239
245
const int Tr = ceil ((float )N / Br);
240
246
@@ -305,7 +311,7 @@ int main()
305
311
306
312
if (max_diff < 0.0001 )
307
313
{
308
- printf (" Results are correct! " );
314
+ printf (" Results are correct! \n " );
309
315
}
310
316
else
311
317
{
0 commit comments