Skip to content

Commit b726564

Browse files
jeffbolznvggerganov
authored andcommitted
vulkan: support softmax/FA batch and broadcast (#14449)
1 parent 3236670 commit b726564

File tree

7 files changed

+80
-44
lines changed

7 files changed

+80
-44
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,7 @@ struct vk_flash_attn_push_constants {
627627
uint32_t nev2;
628628
uint32_t nev3;
629629
uint32_t nem1;
630+
uint32_t nem2;
630631

631632
uint32_t nb01;
632633
uint32_t nb02;
@@ -637,7 +638,6 @@ struct vk_flash_attn_push_constants {
637638
uint32_t nb21;
638639
uint32_t nb22;
639640
uint32_t nb23;
640-
uint32_t nb31;
641641

642642
float scale;
643643
float max_bias;
@@ -652,6 +652,7 @@ struct vk_flash_attn_push_constants {
652652
uint32_t split_kv;
653653
uint32_t k_num;
654654
};
655+
static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
655656

656657
struct vk_op_push_constants {
657658
uint32_t KX;
@@ -743,6 +744,14 @@ struct vk_op_rope_push_constants {
743744
struct vk_op_soft_max_push_constants {
744745
uint32_t KX;
745746
uint32_t KY;
747+
uint32_t ne00;
748+
uint32_t ne01;
749+
uint32_t ne02;
750+
uint32_t ne12;
751+
uint32_t ne13;
752+
uint32_t nb11;
753+
uint32_t nb12;
754+
uint32_t nb13;
746755
float scale;
747756
float max_bias;
748757
float m0;
@@ -5977,7 +5986,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
59775986
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
59785987

59795988
const uint32_t nem1 = mask ? mask->ne[1] : 0;
5980-
const uint32_t nbm1 = mask ? mask->nb[1] : 0;
5989+
const uint32_t nem2 = mask ? mask->ne[2] : 0;
59815990

59825991
const uint32_t D = neq0;
59835992
uint32_t N = neq1;
@@ -6140,7 +6149,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
61406149
// Try to use split_k when KV is large enough to be worth the overhead
61416150
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
61426151
// Try to run two workgroups per SM.
6143-
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
6152+
split_k = ctx->device->shader_core_count * 2 / (workgroups_y * workgroups_z);
61446153
if (split_k > 1) {
61456154
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
61466155
// of "align", so recompute split_k based on that.
@@ -6150,9 +6159,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
61506159
}
61516160
}
61526161

6153-
// Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
6154-
// and the per-row m and L values (ne1 rows).
6155-
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
6162+
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
6163+
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
6164+
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
61566165
if (split_k_size > ctx->device->max_memory_allocation_size) {
61576166
GGML_ABORT("Requested preallocation size is too large");
61586167
}
@@ -6244,11 +6253,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
62446253
(uint32_t)neq2, (uint32_t)neq3,
62456254
(uint32_t)nek2, (uint32_t)nek3,
62466255
(uint32_t)nev2, (uint32_t)nev3,
6247-
nem1,
6256+
nem1, nem2,
62486257
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
62496258
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
62506259
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
6251-
nbm1,
62526260
scale, max_bias, logit_softcap,
62536261
mask != nullptr, n_head_log2, m0, m1,
62546262
gqa_ratio, split_kv, split_k };
@@ -6271,13 +6279,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
62716279
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
62726280

62736281
ggml_vk_sync_buffers(subctx);
6274-
const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
6282+
const std::array<uint32_t, 4> pc2 = { D, (uint32_t)ne1, (uint32_t)ne3, split_k };
62756283
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
62766284
{
62776285
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
62786286
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
62796287
},
6280-
pc2, { (uint32_t)ne1, 1, 1 });
6288+
pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
62816289
} else {
62826290
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
62836291
{
@@ -7562,7 +7570,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
75627570
const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
75637571
const uint32_t nrows_y = (uint32_t)src0->ne[1];
75647572

7565-
const uint32_t n_head_kv = nrows_x/nrows_y;
7573+
const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
7574+
const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
7575+
const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
7576+
const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
7577+
const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
7578+
7579+
const uint32_t n_head_kv = src0->ne[2];
75667580
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
75677581

75687582
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -7571,6 +7585,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
75717585
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
75727586
ncols,
75737587
src1 != nullptr ? nrows_y : (uint32_t)0,
7588+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
7589+
ne12, ne13,
7590+
nb11, nb12, nb13,
75747591
scale, max_bias,
75757592
m0, m1,
75767593
n_head_log2,
@@ -10066,11 +10083,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1006610083
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
1006710084
return false;
1006810085
}
10069-
// TODO: support broadcast
10070-
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
10071-
if (op->src[0]->ne[3] != 1) {
10072-
return false;
10073-
}
1007410086
// It's straightforward to support different K/V dequant, but would
1007510087
// significantly increase the number of pipelines
1007610088
if (op->src[1]->type != op->src[2]->type) {
@@ -10231,13 +10243,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1023110243
case GGML_OP_DIAG_MASK_INF:
1023210244
return true;
1023310245
case GGML_OP_SOFT_MAX:
10234-
// TODO: support batching
10235-
if (op->src[0]->ne[3] != 1) {
10236-
return false;
10237-
}
10238-
// TODO: support broadcast
10239-
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
10240-
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
1024110246
case GGML_OP_SOFT_MAX_BACK:
1024210247
case GGML_OP_ARGSORT:
1024310248
case GGML_OP_SUM:

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ void main() {
9999
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
100100
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
101101
#endif
102+
uint32_t m_offset = 0;
103+
if (p.nem2 != 1) {
104+
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
105+
}
102106

103107
[[dont_unroll]]
104108
for (uint32_t j = start_j; j < end_j; ++j) {
@@ -150,7 +154,7 @@ void main() {
150154
uint32_t c = (idx + tid) % Bc;
151155
uint32_t r = (idx + tid) / Bc;
152156
if (idx + tid < Bc * Br) {
153-
masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
157+
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
154158
}
155159
}
156160
barrier();
@@ -277,7 +281,7 @@ void main() {
277281
// If there is split_k, then the split_k resolve shader does the final
278282
// division by L. Store the intermediate O value and per-row m and L values.
279283
if (p.k_num > 1) {
280-
uint32_t o_offset = D * p.ne1 * split_k_index;
284+
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
281285

282286
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
283287
if (r < N) {
@@ -289,7 +293,7 @@ void main() {
289293
}
290294
}
291295

292-
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
296+
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
293297
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
294298
if (r < N) {
295299
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -311,7 +315,7 @@ void main() {
311315
}
312316
}
313317

314-
uint32_t o_offset = iq3*p.ne2*p.ne1;
318+
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
315319

316320
if (p.gqa_ratio > 1) {
317321
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ layout (push_constant) uniform parameter {
2424
uint32_t nev2;
2525
uint32_t nev3;
2626
uint32_t nem1;
27+
uint32_t nem2;
2728

2829
uint32_t nb01;
2930
uint32_t nb02;
@@ -34,7 +35,6 @@ layout (push_constant) uniform parameter {
3435
uint32_t nb21;
3536
uint32_t nb22;
3637
uint32_t nb23;
37-
uint32_t nb31;
3838

3939
float scale;
4040
float max_bias;

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ void main() {
123123
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
124124
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
125125
#endif
126+
uint32_t m_offset = 0;
127+
if (p.nem2 != 1) {
128+
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
129+
}
126130

127131
[[dont_unroll]]
128132
for (uint32_t j = start_j; j < end_j; ++j) {
@@ -181,7 +185,7 @@ void main() {
181185
uint32_t c = (idx + tid) % Bc;
182186
uint32_t r = (idx + tid) / Bc;
183187
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
184-
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
188+
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
185189
}
186190
}
187191
barrier();
@@ -300,7 +304,7 @@ void main() {
300304
// If there is split_k, then the split_k resolve shader does the final
301305
// division by L. Store the intermediate O value and per-row m and L values.
302306
if (p.k_num > 1) {
303-
uint32_t o_offset = D * p.ne1 * split_k_index;
307+
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
304308

305309
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
306310
if (tile_row(r) < N) {
@@ -312,7 +316,7 @@ void main() {
312316
}
313317
}
314318

315-
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
319+
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
316320
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
317321
if (tile_row(r) < N) {
318322
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -334,7 +338,7 @@ void main() {
334338
}
335339
}
336340

337-
uint32_t o_offset = iq3*p.ne2*p.ne1;
341+
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
338342

339343
if (p.gqa_ratio > 1) {
340344
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ void main() {
130130
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
131131
}
132132

133+
uint32_t m_offset = 0;
134+
if (p.nem2 != 1) {
135+
m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
136+
}
137+
133138
[[dont_unroll]]
134139
for (uint32_t j = start_j; j < end_j; ++j) {
135140

@@ -155,7 +160,7 @@ void main() {
155160

156161
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
157162

158-
coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
163+
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
159164

160165
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
161166
}
@@ -229,10 +234,10 @@ void main() {
229234
if (p.k_num > 1) {
230235
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
231236

232-
uint32_t o_offset = D * p.ne1 * split_k_index;
237+
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
233238
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
234239

235-
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
240+
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
236241
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
237242
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
238243
return;
@@ -250,7 +255,7 @@ void main() {
250255

251256
O = Ldiag*O;
252257

253-
uint32_t o_offset = iq3*p.ne2*p.ne1;
258+
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
254259

255260
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
256261
if (p.gqa_ratio > 1) {

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,22 @@ layout (binding = 1) writeonly buffer D {float data_d[];};
1212
layout (push_constant) uniform parameter {
1313
uint D;
1414
uint N;
15+
uint ne3;
1516
uint k_num;
1617
} p;
1718

1819
void main() {
1920
// Each workgroup handles a row
2021
const uint n = gl_WorkGroupID.x;
2122
const uint tid = gl_LocalInvocationID.x;
23+
const uint iq3 = gl_WorkGroupID.z;
2224

2325
uint D = p.D;
2426
uint N = p.N;
2527
uint k_num = p.k_num;
2628

27-
uint l_offset = D * N * k_num + n;
28-
uint m_offset = D * N * k_num + N + n;
29+
uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
30+
uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
2931
uint lm_stride = N * 2;
3032

3133
// Compute the max m value for the row
@@ -49,11 +51,11 @@ void main() {
4951
for (uint d = tid; d < D; d += BLOCK_SIZE) {
5052
float O = 0.0;
5153
[[unroll]] for (uint k = 0; k < k_num; ++k) {
52-
uint o_offset = D * N * k + D * n + d;
54+
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
5355
float m = data_a[m_offset + k * lm_stride];
5456
O += exp(m - m_max) * data_a[o_offset];
5557
}
5658
O *= L;
57-
data_d[D * n + d] = O;
59+
data_d[iq3 * D * N + D * n + d] = O;
5860
}
5961
}

0 commit comments

Comments
 (0)