@@ -633,6 +633,7 @@ struct vk_flash_attn_push_constants {
633
633
uint32_t nev2;
634
634
uint32_t nev3;
635
635
uint32_t nem1;
636
+ uint32_t nem2;
636
637
637
638
uint32_t nb01;
638
639
uint32_t nb02;
@@ -643,7 +644,6 @@ struct vk_flash_attn_push_constants {
643
644
uint32_t nb21;
644
645
uint32_t nb22;
645
646
uint32_t nb23;
646
- uint32_t nb31;
647
647
648
648
float scale;
649
649
float max_bias;
@@ -658,6 +658,7 @@ struct vk_flash_attn_push_constants {
658
658
uint32_t split_kv;
659
659
uint32_t k_num;
660
660
};
661
+ static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
661
662
662
663
struct vk_op_push_constants {
663
664
uint32_t KX;
@@ -756,6 +757,14 @@ struct vk_op_rope_push_constants {
756
757
struct vk_op_soft_max_push_constants {
757
758
uint32_t KX;
758
759
uint32_t KY;
760
+ uint32_t ne00;
761
+ uint32_t ne01;
762
+ uint32_t ne02;
763
+ uint32_t ne12;
764
+ uint32_t ne13;
765
+ uint32_t nb11;
766
+ uint32_t nb12;
767
+ uint32_t nb13;
759
768
float scale;
760
769
float max_bias;
761
770
float m0;
@@ -6043,7 +6052,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6043
6052
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
6044
6053
6045
6054
const uint32_t nem1 = mask ? mask->ne[1] : 0;
6046
- const uint32_t nbm1 = mask ? mask->nb[1 ] : 0;
6055
+ const uint32_t nem2 = mask ? mask->ne[2 ] : 0;
6047
6056
6048
6057
const uint32_t D = neq0;
6049
6058
uint32_t N = neq1;
@@ -6206,7 +6215,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6206
6215
// Try to use split_k when KV is large enough to be worth the overhead
6207
6216
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
6208
6217
// Try to run two workgroups per SM.
6209
- split_k = ctx->device->shader_core_count * 2 / workgroups_y;
6218
+ split_k = ctx->device->shader_core_count * 2 / ( workgroups_y * workgroups_z) ;
6210
6219
if (split_k > 1) {
6211
6220
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
6212
6221
// of "align", so recompute split_k based on that.
@@ -6216,9 +6225,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6216
6225
}
6217
6226
}
6218
6227
6219
- // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
6220
- // and the per-row m and L values (ne1 rows).
6221
- const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
6228
+ // Reserve space for split_k temporaries. For each split x batch , we need to store the O matrix (D x ne1)
6229
+ // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
6230
+ const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
6222
6231
if (split_k_size > ctx->device->max_memory_allocation_size) {
6223
6232
GGML_ABORT("Requested preallocation size is too large");
6224
6233
}
@@ -6310,11 +6319,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6310
6319
(uint32_t)neq2, (uint32_t)neq3,
6311
6320
(uint32_t)nek2, (uint32_t)nek3,
6312
6321
(uint32_t)nev2, (uint32_t)nev3,
6313
- nem1,
6322
+ nem1, nem2,
6314
6323
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
6315
6324
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
6316
6325
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
6317
- nbm1,
6318
6326
scale, max_bias, logit_softcap,
6319
6327
mask != nullptr, n_head_log2, m0, m1,
6320
6328
gqa_ratio, split_kv, split_k };
@@ -6337,13 +6345,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6337
6345
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
6338
6346
6339
6347
ggml_vk_sync_buffers(subctx);
6340
- const std::array<uint32_t, 3 > pc2 = { D, (uint32_t)ne1, split_k };
6348
+ const std::array<uint32_t, 4 > pc2 = { D, (uint32_t)ne1, (uint32_t)ne3 , split_k };
6341
6349
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
6342
6350
{
6343
6351
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
6344
6352
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6345
6353
},
6346
- pc2, { (uint32_t)ne1, 1, 1 });
6354
+ pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
6347
6355
} else {
6348
6356
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
6349
6357
{
@@ -7669,7 +7677,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
7669
7677
const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
7670
7678
const uint32_t nrows_y = (uint32_t)src0->ne[1];
7671
7679
7672
- const uint32_t n_head_kv = nrows_x/nrows_y;
7680
+ const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
7681
+ const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
7682
+ const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
7683
+ const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
7684
+ const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
7685
+
7686
+ const uint32_t n_head_kv = src0->ne[2];
7673
7687
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
7674
7688
7675
7689
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -7678,6 +7692,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
7678
7692
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
7679
7693
ncols,
7680
7694
src1 != nullptr ? nrows_y : (uint32_t)0,
7695
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
7696
+ ne12, ne13,
7697
+ nb11, nb12, nb13,
7681
7698
scale, max_bias,
7682
7699
m0, m1,
7683
7700
n_head_log2,
@@ -10251,11 +10268,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10251
10268
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
10252
10269
return false;
10253
10270
}
10254
- // TODO: support broadcast
10255
- // ref: https://github.com/ggml-org/llama.cpp/pull/14435
10256
- if (op->src[0]->ne[3] != 1) {
10257
- return false;
10258
- }
10259
10271
// It's straightforward to support different K/V dequant, but would
10260
10272
// significantly increase the number of pipelines
10261
10273
if (op->src[1]->type != op->src[2]->type) {
@@ -10416,13 +10428,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10416
10428
case GGML_OP_DIAG_MASK_INF:
10417
10429
return true;
10418
10430
case GGML_OP_SOFT_MAX:
10419
- // TODO: support batching
10420
- if (op->src[0]->ne[3] != 1) {
10421
- return false;
10422
- }
10423
- // TODO: support broadcast
10424
- // ref: https://github.com/ggml-org/llama.cpp/pull/14435
10425
- return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
10426
10431
case GGML_OP_SOFT_MAX_BACK:
10427
10432
case GGML_OP_ARGSORT:
10428
10433
case GGML_OP_SUM:
0 commit comments