Skip to content

Commit b4c169f

Browse files
committed
Initial commit with all but the MLA graph code done
1 parent 833e2b7 commit b4c169f

File tree

15 files changed

+201
-24
lines changed

15 files changed

+201
-24
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,6 +1346,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
13461346
params.flash_attn = true;
13471347
}
13481348
).set_env("LLAMA_ARG_FLASH_ATTN"));
1349+
add_opt(common_arg(
1350+
{"-mla", "--mla-attn"},
1351+
string_format("enable Multi-head Latent Attention (default: %s)", params.mla_attn ? "enabled" : "disabled"),
1352+
[](common_params & params) {
1353+
params.mla_attn = true;
1354+
}
1355+
).set_env("LLAMA_ARG_MLA_ATTN"));
13491356
add_opt(common_arg(
13501357
{"-p", "--prompt"}, "PROMPT",
13511358
"prompt to start generation with; for system message, use -sys",

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
10981098
cparams.cb_eval_user_data = params.cb_eval_user_data;
10991099
cparams.offload_kqv = !params.no_kv_offload;
11001100
cparams.flash_attn = params.flash_attn;
1101+
cparams.mla_attn = params.mla_attn;
11011102
cparams.no_perf = params.no_perf;
11021103

11031104
if (params.reranking) {

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ struct common_params {
319319
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
320320
bool cont_batching = true; // insert new sequences for decoding on-the-fly
321321
bool flash_attn = false; // flash attention
322+
bool mla_attn = false; // MLA attention for deepseek2
322323
bool no_perf = false; // disable performance metrics
323324
bool ctx_shift = true; // context shift on inifinite text generation
324325

convert_hf_to_gguf.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ def prepare_tensors(self):
330330
gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED,
331331
gguf.MODEL_TENSOR.POSNET_NORM1,
332332
gguf.MODEL_TENSOR.POSNET_NORM2,
333+
gguf.MODEL_TENSOR.ATTN_K_B,
333334
)
334335
)
335336
or not new_name.endswith(".weight")
@@ -4414,6 +4415,27 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
44144415
return []
44154416

44164417
return [(self.map_tensor_name(name), data_torch)]
4418+
if name.endswith("kv_b_proj.weight"):
4419+
name_kb = name.replace("kv_b_proj", "k_b_proj")
4420+
name_vb = name.replace("kv_b_proj", "v_b_proj")
4421+
4422+
n_head_kv = self.hparams["num_key_value_heads"]
4423+
v_head_dim = self.hparams["v_head_dim"]
4424+
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
4425+
4426+
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
4427+
4428+
kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
4429+
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
4430+
k_b = k_b.transpose(1, 2)
4431+
k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim)
4432+
v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1])
4433+
4434+
return [
4435+
(self.map_tensor_name(name), data_torch),
4436+
(self.map_tensor_name(name_kb), k_b),
4437+
(self.map_tensor_name(name_vb), v_b)
4438+
]
44174439

44184440
def prepare_tensors(self):
44194441
super().prepare_tensors()

examples/server/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ The project is under active development, and we are [looking for feedback and co
4646
| `-ub, --ubatch-size N` | physical maximum batch size (default: 512)<br/>(env: LLAMA_ARG_UBATCH) |
4747
| `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) |
4848
| `-fa, --flash-attn` | enable Flash Attention (default: disabled)<br/>(env: LLAMA_ARG_FLASH_ATTN) |
49+
| `-mla, --mla-attn` | enable Multi-head Latent Attention (default: disabled)<br/>(env: LLAMA_ARG_MLA_ATTN) |
4950
| `--no-perf` | disable internal libllama performance timings (default: false)<br/>(env: LLAMA_ARG_NO_PERF) |
5051
| `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) |
5152
| `--no-escape` | do not process escape sequences |

gguf-py/gguf/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,8 @@ class MODEL_TENSOR(IntEnum):
377377
ATTN_Q_B = auto()
378378
ATTN_KV_A_MQA = auto()
379379
ATTN_KV_B = auto()
380+
ATTN_K_B = auto()
381+
ATTN_V_B = auto()
380382
ATTN_Q_A_NORM = auto()
381383
ATTN_KV_A_NORM = auto()
382384
FFN_SUB_NORM = auto()
@@ -581,6 +583,8 @@ class MODEL_TENSOR(IntEnum):
581583
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
582584
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
583585
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
586+
MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b",
587+
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
584588
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
585589
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
586590
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
@@ -1451,6 +1455,8 @@ class MODEL_TENSOR(IntEnum):
14511455
MODEL_TENSOR.ATTN_Q_B,
14521456
MODEL_TENSOR.ATTN_KV_A_MQA,
14531457
MODEL_TENSOR.ATTN_KV_B,
1458+
MODEL_TENSOR.ATTN_K_B,
1459+
MODEL_TENSOR.ATTN_V_B,
14541460
MODEL_TENSOR.ATTN_Q_A_NORM,
14551461
MODEL_TENSOR.ATTN_KV_A_NORM,
14561462
MODEL_TENSOR.ATTN_OUT,

gguf-py/gguf/tensor_mapping.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,14 @@ class TensorNameMap:
656656
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
657657
),
658658

659+
MODEL_TENSOR.ATTN_K_B: (
660+
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2 (MLA specific)
661+
),
662+
663+
MODEL_TENSOR.ATTN_V_B: (
664+
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2 (MLA specific)
665+
),
666+
659667
MODEL_TENSOR.ATTN_Q_A_NORM: (
660668
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
661669
),

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ extern "C" {
355355
bool embeddings; // if true, extract embeddings (together with logits)
356356
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
357357
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
358+
bool mla_attn; // MLA attention for deepseek2
358359
bool no_perf; // whether to measure performance timings
359360

360361
// Abort callback

src/llama-arch.cpp

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
10301030
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
10311031
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
10321032
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
1033+
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
1034+
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
10331035
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
10341036
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
10351037
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
@@ -1471,23 +1473,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
14711473
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
14721474
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
14731475
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1474-
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1475-
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1476-
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1477-
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1478-
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1479-
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1480-
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1481-
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1482-
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1483-
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1484-
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1485-
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1486-
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1487-
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1488-
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1489-
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1490-
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1476+
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1477+
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
14911478
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
14921479
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
14931480
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},

src/llama-arch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,8 @@ enum llm_tensor {
299299
LLM_TENSOR_ATTN_Q_B,
300300
LLM_TENSOR_ATTN_KV_A_MQA,
301301
LLM_TENSOR_ATTN_KV_B,
302+
LLM_TENSOR_ATTN_K_B,
303+
LLM_TENSOR_ATTN_V_B,
302304
LLM_TENSOR_ATTN_Q_A_NORM,
303305
LLM_TENSOR_ATTN_KV_A_NORM,
304306
LLM_TENSOR_ATTN_SUB_NORM,

0 commit comments

Comments
 (0)