Skip to content

Commit ca927c7

Browse files
authored
[Model] Support weight absorption for DeepSeek-v2 (#3115)
This PR supports the weight absorption for DeepSeek-v2/v3 models. At this moment, absorption is enabled by default, while for the first prefill, computing without weight absorption is usually more efficient. The logic that switches between weight absorption and normal computation is still in progress.
1 parent 2cd8da8 commit ca927c7

File tree

31 files changed

+200
-114
lines changed

31 files changed

+200
-114
lines changed

python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py

Lines changed: 106 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,91 @@
11
"""A pass that rewrites KV cache creation functions in IRModule."""
22

33
import json
4-
from typing import Any, Dict
4+
from typing import Any, Dict, Literal, Tuple
55

66
import tvm
77
from tvm import IRModule, relax
88
from tvm.relax.frontend.nn.llm import kv_cache
99
from tvm.relax.frontend.nn.llm.kv_cache import RopeMode
1010

1111

12-
def extract_creation_args(func: relax.Function) -> Dict[str, Any]:
12+
def extract_creation_args(func: relax.Function) -> Tuple[Literal["mha", "mla"], Dict[str, Any]]:
1313
"""Extract the KV cache creation args from the given generic creation func."""
1414
assert isinstance(func.body, relax.SeqExpr)
1515
assert len(func.body.blocks) == 1
1616
assert isinstance(func.body.blocks[0], relax.DataflowBlock)
1717
assert isinstance(func.body.blocks[0].bindings[0], relax.VarBinding)
1818
assert isinstance(func.body.blocks[0].bindings[0].value, relax.Call)
1919
assert func.body.blocks[0].bindings[0].value.op == tvm.ir.Op.get("relax.call_pure_packed")
20-
args = func.body.blocks[0].bindings[0].value.args
21-
assert isinstance(args[0], relax.ExternFunc)
22-
assert args[0].global_symbol == "mlc.create_paged_kv_cache_generic"
23-
24-
assert len(args) == 15
25-
assert isinstance(args[1], relax.ShapeExpr)
26-
assert len(args[1].values) == 5
27-
assert isinstance(args[2], relax.ShapeExpr)
28-
for i in range(3, 14):
29-
if i in [10, 11]:
30-
continue
31-
assert isinstance(args[i], relax.PrimValue)
32-
assert isinstance(args[i].value, (tvm.tir.IntImm, tvm.tir.FloatImm))
33-
assert isinstance(args[10], relax.StringImm)
34-
assert isinstance(args[11], (relax.Constant, relax.PrimValue))
35-
assert isinstance(args[14], relax.DataTypeImm)
36-
37-
return {
38-
"max_batch_size": args[1].values[0],
39-
"max_total_seq_len": args[1].values[1],
40-
"prefill_chunk_size": args[1].values[2],
41-
"page_size": args[1].values[3],
42-
"support_sliding_window": args[1].values[4],
43-
"layer_partition": args[2],
44-
"num_hidden_layers": args[3].value.value,
45-
"num_attention_heads": args[4].value.value,
46-
"num_key_value_heads": args[5].value.value,
47-
"head_dim": args[6].value.value,
48-
"rope_mode": args[7].value.value,
49-
"rope_scale": args[8].value.value,
50-
"rope_theta": args[9].value.value,
51-
"rope_scaling": json.loads(args[10].value),
52-
"rope_ext_factors": args[11],
53-
"rotary_dim": args[12].value.value,
54-
"enable_disaggregation": bool(args[13].value.value),
55-
"dtype": args[14].value,
56-
}
20+
call_args = func.body.blocks[0].bindings[0].value.args
21+
assert isinstance(call_args[0], relax.ExternFunc)
22+
assert call_args[0].global_symbol == "mlc.create_paged_kv_cache_generic"
23+
assert isinstance(call_args[1], relax.StringImm)
24+
25+
args = call_args[1:]
26+
if args[0].value == "mha":
27+
assert len(args) == 15
28+
assert isinstance(args[1], relax.ShapeExpr)
29+
assert len(args[1].values) == 5
30+
assert isinstance(args[2], relax.ShapeExpr)
31+
for i in range(3, 14):
32+
if i in [10, 11]:
33+
continue
34+
assert isinstance(args[i], relax.PrimValue)
35+
assert isinstance(args[i].value, (tvm.tir.IntImm, tvm.tir.FloatImm))
36+
assert isinstance(args[10], relax.StringImm)
37+
assert isinstance(args[11], (relax.Constant, relax.PrimValue))
38+
assert isinstance(args[14], relax.DataTypeImm)
39+
40+
return "mha", {
41+
"max_batch_size": args[1].values[0],
42+
"max_total_seq_len": args[1].values[1],
43+
"prefill_chunk_size": args[1].values[2],
44+
"page_size": args[1].values[3],
45+
"support_sliding_window": args[1].values[4],
46+
"layer_partition": args[2],
47+
"num_hidden_layers": args[3].value.value,
48+
"num_attention_heads": args[4].value.value,
49+
"num_key_value_heads": args[5].value.value,
50+
"head_dim": args[6].value.value,
51+
"rope_mode": args[7].value.value,
52+
"rope_scale": args[8].value.value,
53+
"rope_theta": args[9].value.value,
54+
"rope_scaling": json.loads(args[10].value),
55+
"rope_ext_factors": args[11],
56+
"rotary_dim": args[12].value.value,
57+
"enable_disaggregation": bool(args[13].value.value),
58+
"dtype": args[14].value,
59+
}
60+
if call_args[1].value == "mla":
61+
assert len(args) == 12
62+
assert isinstance(args[1], relax.ShapeExpr)
63+
assert len(args[1].values) == 5
64+
assert isinstance(args[2], relax.ShapeExpr)
65+
for i in range(3, 11):
66+
assert isinstance(args[i], relax.PrimValue)
67+
assert isinstance(args[i].value, tvm.tir.IntImm)
68+
assert isinstance(args[11], relax.DataTypeImm)
69+
70+
return "mla", {
71+
"max_batch_size": args[1].values[0],
72+
"max_total_seq_len": args[1].values[1],
73+
"prefill_chunk_size": args[1].values[2],
74+
"page_size": args[1].values[3],
75+
"support_sliding_window": args[1].values[4],
76+
"layer_partition": args[2],
77+
"num_hidden_layers": args[3].value.value,
78+
"num_attention_heads": args[4].value.value,
79+
"num_key_value_heads": args[5].value.value,
80+
"qk_nope_head_dim": args[6].value.value,
81+
"qk_rope_head_dim": args[7].value.value,
82+
"v_head_dim": args[8].value.value,
83+
"kv_lora_rank": args[9].value.value,
84+
"enable_disaggregation": bool(args[10].value.value),
85+
"dtype": args[11].value,
86+
}
87+
88+
raise ValueError("Cannot reach here")
5789

5890

5991
@tvm.transform.module_pass(opt_level=0, name="DispatchKVCacheCreation")
@@ -100,24 +132,38 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR
100132
if mod.attrs is not None:
101133
new_mod = new_mod.with_attrs(mod.attrs)
102134

103-
kwargs = extract_creation_args(creation_func)
104-
self.attach_kv_cache_metadata(kwargs)
135+
kv_cache_kind, kwargs = extract_creation_args(creation_func)
136+
self.attach_kv_cache_metadata(kv_cache_kind, kwargs)
105137

106138
bb = relax.BlockBuilder(new_mod)
107-
self.create_tir_paged_kv_cache(bb, kwargs)
108-
self.create_flashinfer_paged_kv_cache(bb, kwargs)
139+
self.create_tir_paged_kv_cache(bb, kv_cache_kind, kwargs)
140+
self.create_flashinfer_paged_kv_cache(bb, kv_cache_kind, kwargs)
109141
return bb.finalize()
110142

111-
def attach_kv_cache_metadata(self, kwargs: Dict[str, Any]):
143+
def attach_kv_cache_metadata(
144+
self, kv_cache_kind: Literal["mha", "mla"], kwargs: Dict[str, Any]
145+
):
112146
"""Attach the KV cache metadata to model metadata."""
113-
self.metadata["kv_cache"] = {
114-
"num_hidden_layers": kwargs["num_hidden_layers"],
115-
"num_attention_heads": kwargs["num_attention_heads"],
116-
"num_key_value_heads": kwargs["num_key_value_heads"],
117-
"head_dim": kwargs["head_dim"],
118-
}
119-
120-
def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]) -> None:
147+
if kv_cache_kind == "mha":
148+
self.metadata["kv_cache"] = {
149+
"num_hidden_layers": kwargs["num_hidden_layers"],
150+
"num_attention_heads": kwargs["num_attention_heads"],
151+
"num_key_value_heads": kwargs["num_key_value_heads"],
152+
"head_dim": kwargs["head_dim"],
153+
}
154+
elif kv_cache_kind == "mla":
155+
self.metadata["kv_cache"] = {
156+
"num_hidden_layers": kwargs["num_hidden_layers"],
157+
"num_attention_heads": kwargs["num_attention_heads"],
158+
"num_key_value_heads": 1,
159+
"head_dim": kwargs["kv_lora_rank"] + kwargs["qk_rope_head_dim"],
160+
}
161+
else:
162+
raise ValueError("Cannot reach here.")
163+
164+
def create_tir_paged_kv_cache(
165+
self, bb: relax.BlockBuilder, kv_cache_kind: Literal["mha", "mla"], kwargs: Dict[str, Any]
166+
) -> None:
121167
"""Create the TIR-based PagedKVCache"""
122168
max_batch_size = relax.Var(
123169
"max_batch_size_", relax.ShapeStructInfo([kwargs["max_batch_size"]])
@@ -143,16 +189,22 @@ def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, An
143189
support_sliding_window,
144190
],
145191
):
146-
cache = kv_cache.TIRPagedKVCache(target=self.target, **kwargs)
192+
if kv_cache_kind == "mha":
193+
cache = kv_cache.TIRPagedKVCache(target=self.target, **kwargs)
194+
elif kv_cache_kind == "mla":
195+
cache = kv_cache.TIRPagedKVCache.create_mla_kv_cache(target=self.target, **kwargs)
196+
else:
197+
raise ValueError("Cannot reach here")
147198
bb.emit_func_output(cache._expr) # pylint: disable=protected-access
148199

149200
def create_flashinfer_paged_kv_cache(
150-
self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]
201+
self, bb: relax.BlockBuilder, kv_cache_kind: Literal["mha", "mla"], kwargs: Dict[str, Any]
151202
) -> None:
152203
"""Create the FlashInfer-based PagedKVCache"""
153204
# Filter the cases which FlashInfer does not support.
154205
if ( # pylint: disable=too-many-boolean-expressions
155206
not self.flashinfer
207+
or kv_cache_kind != "mha"
156208
or str(kwargs["dtype"]) != "float16"
157209
or kwargs["head_dim"] != 128
158210
or (

python/mlc_llm/model/baichuan/baichuan_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
277277
page_size: tir.Var,
278278
support_sliding_window: tir.Var,
279279
) -> PagedKVCache:
280-
return PagedKVCache.create_generic(
280+
return PagedKVCache.create_generic_mha(
281281
max_batch_size=max_batch_size,
282282
max_total_seq_len=max_total_seq_len,
283283
prefill_chunk_size=prefill_chunk_size,

python/mlc_llm/model/chatglm3/chatglm3_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
353353
page_size: tir.Var,
354354
support_sliding_window: tir.Var,
355355
) -> PagedKVCache:
356-
return PagedKVCache.create_generic(
356+
return PagedKVCache.create_generic_mha(
357357
max_batch_size=max_batch_size,
358358
max_total_seq_len=max_total_seq_len,
359359
prefill_chunk_size=prefill_chunk_size,

python/mlc_llm/model/cohere/cohere_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
322322
page_size: tir.Var,
323323
support_sliding_window: tir.Var,
324324
) -> PagedKVCache:
325-
return PagedKVCache.create_generic(
325+
return PagedKVCache.create_generic_mha(
326326
max_batch_size=max_batch_size,
327327
max_total_seq_len=max_total_seq_len,
328328
prefill_chunk_size=prefill_chunk_size,

python/mlc_llm/model/deepseek/deepseek_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
428428
page_size: tir.Var,
429429
support_sliding_window: tir.Var,
430430
) -> PagedKVCache:
431-
return PagedKVCache.create_generic(
431+
return PagedKVCache.create_generic_mha(
432432
max_batch_size=max_batch_size,
433433
max_total_seq_len=max_total_seq_len,
434434
prefill_chunk_size=prefill_chunk_size,

python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from tvm.script import tir as T
1414

1515
from mlc_llm import op as op_ext
16-
from mlc_llm.nn import PagedKVCache, RopeMode
16+
from mlc_llm.nn import PagedKVCache
1717
from mlc_llm.nn.expert import MixtralExperts
1818
from mlc_llm.support import logging
1919
from mlc_llm.support import tensor_parallel as tp
@@ -282,8 +282,6 @@ def forward_absorb(
282282
) # (b, s, 1, kv_lora_rank), (b, s, 1, qk_rope_head_dim)
283283

284284
compressed_kv = self.kv_a_layernorm(compressed_kv)
285-
k_nope = compressed_kv # (b, s, 1, kv_lora_rank)
286-
value_states = compressed_kv # (b, s, 1, kv_lora_rank)
287285

288286
q_pe, k_pe = self.rotary_emb(q_pe, k_pe, query_positions)
289287

@@ -303,28 +301,9 @@ def f_concat_nope_pe(var_nope: te.Tensor, var_pe: te.Tensor):
303301
query_states = op.tensor_expr_op(
304302
concat_nope_pe(num_heads=self.num_heads), "concat_q", [q_nope, q_pe]
305303
) # (b, s, num_heads, kv_lora_rank + qk_rope_head_dim)
306-
key_states = op.tensor_expr_op(
307-
concat_nope_pe(num_heads=1), "concat_k", [k_nope, k_pe]
308-
) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)
309-
value_states = op.pad(
310-
value_states, [0, 0, 0, 0, 0, 0, 0, self.qk_rope_head_dim]
311-
) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)
312304

313-
qkv = op.concat(
314-
[query_states, key_states, value_states], dim=2
315-
) # (b, s, num_heads + 2, kv_lora_rank + qk_rope_head_dim)
316-
output, _ = op.split(
317-
paged_kv_cache.attention_with_fused_qkv(
318-
layer_id,
319-
qkv,
320-
self.num_heads,
321-
self.softmax_scale
322-
* math.sqrt(
323-
self.kv_lora_rank + self.qk_rope_head_dim
324-
), # This is to cancel out the 1/sqrt(d) in normal attention
325-
),
326-
indices_or_sections=[self.kv_lora_rank],
327-
axis=-1,
305+
output = paged_kv_cache.mla_absorbed(
306+
layer_id, query_states, compressed_kv, k_pe, self.softmax_scale
328307
) # (b, s, num_heads, kv_lora_rank)
329308
output = (
330309
op.matmul(
@@ -645,7 +624,9 @@ def __init__(self, config: DeepseekV2Config):
645624
self.num_attention_heads = config.num_attention_heads
646625
self.num_key_value_heads = config.num_key_value_heads
647626
self.kv_lora_rank = config.kv_lora_rank
627+
self.qk_nope_head_dim = config.qk_nope_head_dim
648628
self.qk_rope_head_dim = config.qk_rope_head_dim
629+
self.v_head_dim = config.v_head_dim
649630
self.rms_norm_eps = config.rms_norm_eps
650631
self.rope_theta = config.rope_theta
651632
self.vocab_size = config.vocab_size
@@ -724,19 +705,19 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
724705
page_size: tir.Var,
725706
support_sliding_window: tir.Var,
726707
) -> PagedKVCache:
727-
return PagedKVCache.create_generic(
708+
return PagedKVCache.create_generic_mla(
728709
max_batch_size=max_batch_size,
729710
max_total_seq_len=max_total_seq_len,
730711
prefill_chunk_size=prefill_chunk_size,
731712
page_size=page_size,
732713
support_sliding_window=support_sliding_window,
733714
num_hidden_layers=self.num_hidden_layers,
734715
num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,
735-
num_key_value_heads=1,
736-
head_dim=self.kv_lora_rank + self.qk_rope_head_dim,
737-
rope_mode=RopeMode.NONE,
738-
rope_scale=1,
739-
rope_theta=self.rope_theta,
716+
num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,
717+
qk_nope_head_dim=self.qk_nope_head_dim,
718+
qk_rope_head_dim=self.qk_rope_head_dim,
719+
v_head_dim=self.v_head_dim,
720+
kv_lora_rank=self.kv_lora_rank,
740721
dtype=self.dtype,
741722
)
742723

python/mlc_llm/model/eagle/eagle_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
164164
page_size: tir.Var,
165165
support_sliding_window: tir.Var,
166166
) -> PagedKVCache:
167-
return PagedKVCache.create_generic(
167+
return PagedKVCache.create_generic_mha(
168168
max_batch_size=max_batch_size,
169169
max_total_seq_len=max_total_seq_len,
170170
prefill_chunk_size=prefill_chunk_size,

python/mlc_llm/model/gemma/gemma_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
306306
page_size: tir.Var,
307307
support_sliding_window: tir.Var,
308308
) -> PagedKVCache:
309-
return PagedKVCache.create_generic(
309+
return PagedKVCache.create_generic_mha(
310310
max_batch_size=max_batch_size,
311311
max_total_seq_len=max_total_seq_len,
312312
prefill_chunk_size=prefill_chunk_size,

python/mlc_llm/model/gpt2/gpt2_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
297297
page_size: tir.Var,
298298
support_sliding_window: tir.Var,
299299
) -> PagedKVCache:
300-
return PagedKVCache.create_generic(
300+
return PagedKVCache.create_generic_mha(
301301
max_batch_size=max_batch_size,
302302
max_total_seq_len=max_total_seq_len,
303303
prefill_chunk_size=prefill_chunk_size,

python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
264264
page_size: tir.Var,
265265
support_sliding_window: tir.Var,
266266
) -> PagedKVCache:
267-
return PagedKVCache.create_generic(
267+
return PagedKVCache.create_generic_mha(
268268
max_batch_size=max_batch_size,
269269
max_total_seq_len=max_total_seq_len,
270270
prefill_chunk_size=prefill_chunk_size,

0 commit comments

Comments
 (0)