Skip to content

Commit 0b84bd5

Browse files
committed
Merge remote-tracking branch 'origin/compilade/refactor-kv-cache' into GraniteFour
* origin/compilade/refactor-kv-cache: model : use ggml_swiglu_split for Mamba model : remove unnecessary prefix for tensor loading constants jamba : remove redundant nullptr initializations vulkan: optimize flash attention split_k_reduce (ggml-org#14554)
2 parents 12c50f1 + f7c7a92 commit 0b84bd5

File tree

3 files changed

+49
-46
lines changed

3 files changed

+49
-46
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2706,7 +2706,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
27062706
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
27072707

27082708
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2709-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
2709+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
27102710
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
27112711

27122712
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -6252,13 +6252,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
62526252
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
62536253

62546254
// Try to use split_k when KV is large enough to be worth the overhead
6255-
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
6255+
if (workgroups_x == 1 && shader_core_count > 0) {
62566256
// Try to run two workgroups per SM.
62576257
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
62586258
if (split_k > 1) {
62596259
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
62606260
// of "align", so recompute split_k based on that.
6261-
split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
6261+
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
62626262
split_k = CEIL_DIV(KV, split_kv);
62636263
workgroups_x = split_k;
62646264
}
@@ -6392,7 +6392,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
63926392
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
63936393
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
63946394
},
6395-
pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
6395+
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
63966396
} else {
63976397
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
63986398
{

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
#extension GL_EXT_control_flow_attributes : enable
44

5-
#define BLOCK_SIZE 32
5+
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
66

7-
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
7+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
88

99
layout (binding = 0) readonly buffer A {float data_a[];};
1010
layout (binding = 1) writeonly buffer D {float data_d[];};
@@ -16,6 +16,8 @@ layout (push_constant) uniform parameter {
1616
uint k_num;
1717
} p;
1818

19+
shared float tmpsh[BLOCK_SIZE];
20+
1921
void main() {
2022
// Each workgroup handles a row
2123
const uint n = gl_WorkGroupID.x;
@@ -32,23 +34,51 @@ void main() {
3234

3335
// Compute the max m value for the row
3436
float m_max = -1.0/0.0;
35-
[[unroll]] for (uint k = 0; k < k_num; ++k) {
36-
float m = data_a[m_offset + k * lm_stride];
37+
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
38+
float m = data_a[m_offset + (k + tid) * lm_stride];
3739
m_max = max(m_max, m);
3840
}
3941

42+
// reduce across the workgroup
43+
tmpsh[tid] = m_max;
44+
barrier();
45+
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
46+
if (tid < s) {
47+
m_max = max(m_max, tmpsh[tid + s]);
48+
tmpsh[tid] = m_max;
49+
}
50+
barrier();
51+
}
52+
m_max = tmpsh[0];
53+
54+
barrier();
55+
4056
// Compute L based on m_max
4157
float L = 0;
42-
[[unroll]] for (uint k = 0; k < k_num; ++k) {
43-
float l = data_a[l_offset + k * lm_stride];
44-
float m = data_a[m_offset + k * lm_stride];
58+
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
59+
float l = data_a[l_offset + (k + tid) * lm_stride];
60+
float m = data_a[m_offset + (k + tid) * lm_stride];
4561
L += exp(m - m_max) * l;
4662
}
4763

64+
// reduce across the workgroup
65+
tmpsh[tid] = L;
66+
barrier();
67+
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
68+
if (tid < s) {
69+
L += tmpsh[tid + s];
70+
tmpsh[tid] = L;
71+
}
72+
barrier();
73+
}
74+
L = tmpsh[0];
75+
4876
L = 1.0 / L;
4977

78+
// D dimension is split across workgroups in the y dimension
79+
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
5080
// Scale and sum the O contributions based on m_max and store the result to memory
51-
for (uint d = tid; d < D; d += BLOCK_SIZE) {
81+
if (d < D) {
5282
float O = 0.0;
5383
[[unroll]] for (uint k = 0; k < k_num; ++k) {
5484
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;

src/llama-model.cpp

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3267,10 +3267,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
32673267
{
32683268
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
32693269

3270-
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
3270+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
32713271
// if output is NULL, init from the input tok embed, duplicated to allow offloading
32723272
if (output == NULL) {
3273-
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
3273+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
32743274
}
32753275
}
32763276

@@ -3313,10 +3313,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33133313
{
33143314
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
33153315

3316-
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
3316+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
33173317
// if output is NULL, init from the input tok embed, duplicated to allow offloading
33183318
if (output == NULL) {
3319-
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
3319+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
33203320
}
33213321
}
33223322

@@ -3352,56 +3352,29 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33523352

33533353
// out_proj
33543354
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
3355-
3356-
layer.wq = nullptr;
3357-
layer.wk = nullptr;
3358-
layer.wv = nullptr;
3359-
layer.wo = nullptr;
3360-
33613355
} else {
33623356
// Attention layers
33633357

33643358
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
33653359
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
33663360
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
33673361
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
3368-
3369-
layer.ssm_in = nullptr;
3370-
layer.ssm_conv1d = nullptr;
3371-
layer.ssm_conv1d_b = nullptr;
3372-
layer.ssm_x = nullptr;
3373-
layer.ssm_dt_norm = nullptr;
3374-
layer.ssm_dt = nullptr;
3375-
layer.ssm_dt_b = nullptr;
3376-
layer.ssm_b_norm = nullptr;
3377-
layer.ssm_c_norm = nullptr;
3378-
layer.ssm_a = nullptr;
3379-
layer.ssm_d = nullptr;
3380-
layer.ssm_out = nullptr;
33813362
}
33823363

33833364
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
33843365

3385-
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
3366+
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED);
33863367

33873368
if (layer.ffn_gate_inp) {
33883369
// MoE
33893370
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
33903371
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0);
33913372
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
3392-
3393-
layer.ffn_gate = nullptr;
3394-
layer.ffn_down = nullptr;
3395-
layer.ffn_up = nullptr;
33963373
} else {
33973374
// FFN (no MoE)
33983375
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
33993376
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
34003377
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
3401-
3402-
layer.ffn_gate_exps = nullptr;
3403-
layer.ffn_down_exps = nullptr;
3404-
layer.ffn_up_exps = nullptr;
34053378
}
34063379
}
34073380
} break;
@@ -10228,7 +10201,7 @@ struct llm_graph_context_mamba : public virtual llm_graph_context {
1022810201
// TODO: skip computing output earlier for unused tokens
1022910202

1023010203
y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d));
10231-
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
10204+
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
1023210205

1023310206
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
1023410207
cur = build_lora_mm(layer.ssm_out, y);
@@ -10352,7 +10325,7 @@ struct llm_graph_context_mamba : public virtual llm_graph_context {
1035210325
// TODO: skip computing output earlier for unused tokens
1035310326

1035410327
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
10355-
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
10328+
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
1035610329

1035710330
// grouped RMS norm
1035810331
y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);

0 commit comments

Comments
 (0)