@@ -627,6 +627,7 @@ struct vk_flash_attn_push_constants {
627
627
uint32_t nev2;
628
628
uint32_t nev3;
629
629
uint32_t nem1;
630
+ uint32_t nem2;
630
631
631
632
uint32_t nb01;
632
633
uint32_t nb02;
@@ -637,7 +638,6 @@ struct vk_flash_attn_push_constants {
637
638
uint32_t nb21;
638
639
uint32_t nb22;
639
640
uint32_t nb23;
640
- uint32_t nb31;
641
641
642
642
float scale;
643
643
float max_bias;
@@ -652,6 +652,7 @@ struct vk_flash_attn_push_constants {
652
652
uint32_t split_kv;
653
653
uint32_t k_num;
654
654
};
655
+ static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
655
656
656
657
struct vk_op_push_constants {
657
658
uint32_t KX;
@@ -743,6 +744,14 @@ struct vk_op_rope_push_constants {
743
744
struct vk_op_soft_max_push_constants {
744
745
uint32_t KX;
745
746
uint32_t KY;
747
+ uint32_t ne00;
748
+ uint32_t ne01;
749
+ uint32_t ne02;
750
+ uint32_t ne12;
751
+ uint32_t ne13;
752
+ uint32_t nb11;
753
+ uint32_t nb12;
754
+ uint32_t nb13;
746
755
float scale;
747
756
float max_bias;
748
757
float m0;
@@ -5977,7 +5986,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5977
5986
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
5978
5987
5979
5988
const uint32_t nem1 = mask ? mask->ne[1] : 0;
5980
- const uint32_t nbm1 = mask ? mask->nb[1 ] : 0;
5989
+ const uint32_t nem2 = mask ? mask->ne[2 ] : 0;
5981
5990
5982
5991
const uint32_t D = neq0;
5983
5992
uint32_t N = neq1;
@@ -6140,7 +6149,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6140
6149
// Try to use split_k when KV is large enough to be worth the overhead
6141
6150
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
6142
6151
// Try to run two workgroups per SM.
6143
- split_k = ctx->device->shader_core_count * 2 / workgroups_y;
6152
+ split_k = ctx->device->shader_core_count * 2 / ( workgroups_y * workgroups_z) ;
6144
6153
if (split_k > 1) {
6145
6154
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
6146
6155
// of "align", so recompute split_k based on that.
@@ -6150,9 +6159,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6150
6159
}
6151
6160
}
6152
6161
6153
- // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
6154
- // and the per-row m and L values (ne1 rows).
6155
- const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
6162
+ // Reserve space for split_k temporaries. For each split x batch , we need to store the O matrix (D x ne1)
6163
+ // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
6164
+ const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
6156
6165
if (split_k_size > ctx->device->max_memory_allocation_size) {
6157
6166
GGML_ABORT("Requested preallocation size is too large");
6158
6167
}
@@ -6244,11 +6253,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6244
6253
(uint32_t)neq2, (uint32_t)neq3,
6245
6254
(uint32_t)nek2, (uint32_t)nek3,
6246
6255
(uint32_t)nev2, (uint32_t)nev3,
6247
- nem1,
6256
+ nem1, nem2,
6248
6257
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
6249
6258
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
6250
6259
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
6251
- nbm1,
6252
6260
scale, max_bias, logit_softcap,
6253
6261
mask != nullptr, n_head_log2, m0, m1,
6254
6262
gqa_ratio, split_kv, split_k };
@@ -6271,13 +6279,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6271
6279
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
6272
6280
6273
6281
ggml_vk_sync_buffers(subctx);
6274
- const std::array<uint32_t, 3 > pc2 = { D, (uint32_t)ne1, split_k };
6282
+ const std::array<uint32_t, 4 > pc2 = { D, (uint32_t)ne1, (uint32_t)ne3 , split_k };
6275
6283
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
6276
6284
{
6277
6285
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
6278
6286
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6279
6287
},
6280
- pc2, { (uint32_t)ne1, 1, 1 });
6288
+ pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
6281
6289
} else {
6282
6290
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
6283
6291
{
@@ -7562,7 +7570,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
7562
7570
const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
7563
7571
const uint32_t nrows_y = (uint32_t)src0->ne[1];
7564
7572
7565
- const uint32_t n_head_kv = nrows_x/nrows_y;
7573
+ const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
7574
+ const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
7575
+ const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
7576
+ const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
7577
+ const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
7578
+
7579
+ const uint32_t n_head_kv = src0->ne[2];
7566
7580
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
7567
7581
7568
7582
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -7571,6 +7585,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
7571
7585
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
7572
7586
ncols,
7573
7587
src1 != nullptr ? nrows_y : (uint32_t)0,
7588
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
7589
+ ne12, ne13,
7590
+ nb11, nb12, nb13,
7574
7591
scale, max_bias,
7575
7592
m0, m1,
7576
7593
n_head_log2,
@@ -10066,11 +10083,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10066
10083
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
10067
10084
return false;
10068
10085
}
10069
- // TODO: support broadcast
10070
- // ref: https://github.com/ggml-org/llama.cpp/pull/14435
10071
- if (op->src[0]->ne[3] != 1) {
10072
- return false;
10073
- }
10074
10086
// It's straightforward to support different K/V dequant, but would
10075
10087
// significantly increase the number of pipelines
10076
10088
if (op->src[1]->type != op->src[2]->type) {
@@ -10231,13 +10243,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10231
10243
case GGML_OP_DIAG_MASK_INF:
10232
10244
return true;
10233
10245
case GGML_OP_SOFT_MAX:
10234
- // TODO: support batching
10235
- if (op->src[0]->ne[3] != 1) {
10236
- return false;
10237
- }
10238
- // TODO: support broadcast
10239
- // ref: https://github.com/ggml-org/llama.cpp/pull/14435
10240
- return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
10241
10246
case GGML_OP_SOFT_MAX_BACK:
10242
10247
case GGML_OP_ARGSORT:
10243
10248
case GGML_OP_SUM:
0 commit comments