Skip to content

Commit f9da13f

Browse files
authored
[Model] Initial weight absorption impl for Deepseek-v2/v3 (#3092)
This PR introduces the initial weight absorption implementation for the Deepseek-v2/v3 models. This implementation will be further enhanced with new KV cache attention interfaces for MLA.
1 parent 00f5303 commit f9da13f

File tree

3 files changed

+154
-3
lines changed

3 files changed

+154
-3
lines changed

python/mlc_llm/model/deepseek_v2/deepseek_v2_loader.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,43 @@ def combine_expert_gate_up(*hf_params, dtype):
117117
),
118118
)
119119

120+
# map MLA kv_b_proj weight
121+
attn = f"model.layers.{i}.self_attn"
122+
mapping.add_mapping(
123+
f"{attn}.w_uk",
124+
[f"{attn}.kv_b_proj.weight"],
125+
functools.partial(
126+
lambda kv_b_proj, dtype: np.split(
127+
kv_b_proj.reshape(
128+
model_config.num_key_value_heads,
129+
model_config.qk_nope_head_dim + model_config.v_head_dim,
130+
model_config.kv_lora_rank,
131+
),
132+
indices_or_sections=[model_config.qk_nope_head_dim],
133+
axis=1,
134+
)[0]
135+
.transpose(0, 2, 1)
136+
.astype(dtype),
137+
dtype=mlc_param.dtype,
138+
),
139+
)
140+
mapping.add_mapping(
141+
f"{attn}.w_uv",
142+
[f"{attn}.kv_b_proj.weight"],
143+
functools.partial(
144+
lambda kv_b_proj, dtype: np.split(
145+
kv_b_proj.reshape(
146+
model_config.num_key_value_heads,
147+
model_config.qk_nope_head_dim + model_config.v_head_dim,
148+
model_config.kv_lora_rank,
149+
),
150+
indices_or_sections=[model_config.qk_nope_head_dim],
151+
axis=1,
152+
)[1].astype(dtype),
153+
dtype=mlc_param.dtype,
154+
),
155+
)
156+
120157
for mlc_name, mlc_param in named_parameters.items():
121158
if mlc_name not in mapping.param_map:
122159
mapping.add_mapping(

python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py

Lines changed: 114 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ def __init__(self, config: DeepseekV2Config):
220220
self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
221221
bias=False,
222222
)
223+
self.w_uk = nn.Parameter((self.num_heads, config.kv_lora_rank, self.qk_nope_head_dim))
224+
self.w_uv = nn.Parameter((self.num_heads, self.v_head_dim, config.kv_lora_rank))
223225

224226
self.o_proj = nn.Linear(
225227
self.num_heads * self.v_head_dim,
@@ -241,6 +243,106 @@ def forward(
241243
paged_kv_cache: PagedKVCache,
242244
layer_id: int,
243245
query_positions: Tensor,
246+
):
247+
return self.forward_absorb(hidden_states, paged_kv_cache, layer_id, query_positions)
248+
249+
def forward_absorb(
250+
self,
251+
hidden_states: Tensor,
252+
paged_kv_cache: PagedKVCache,
253+
layer_id: int,
254+
query_positions: Tensor,
255+
):
256+
b, s, _ = hidden_states.shape
257+
258+
if self.q_lora_rank is None:
259+
q = self.q_proj(hidden_states)
260+
else:
261+
q = self.q_b_proj(
262+
self.q_a_layernorm(self.q_a_proj(hidden_states))
263+
) # (b, s, num_heads * q_head_dim)
264+
q = op.reshape(q, (b, s, self.num_heads, self.q_head_dim)) # (b, s, num_heads, q_head_dim)
265+
q_nope, q_pe = op.split(
266+
q, [self.qk_nope_head_dim], axis=-1
267+
) # (b, s, num_heads, qk_nope_head_dim), (b, s, num_heads, qk_rope_head_dim)
268+
q_nope = (
269+
op.matmul(
270+
q_nope.reshape(b * s, self.num_heads, self.qk_nope_head_dim).permute_dims(1, 0, 2),
271+
self.w_uk.permute_dims(0, 2, 1),
272+
)
273+
.permute_dims(1, 0, 2)
274+
.reshape(b, s, self.num_heads, self.kv_lora_rank)
275+
) # (b, s, num_heads, kv_lora_rank)
276+
277+
compressed_kv = self.kv_a_proj_with_mqa(hidden_states).reshape(
278+
b, s, 1, self.kv_lora_rank + self.qk_rope_head_dim
279+
) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)
280+
compressed_kv, k_pe = op.split(
281+
compressed_kv, [self.config.kv_lora_rank], axis=-1
282+
) # (b, s, 1, kv_lora_rank), (b, s, 1, qk_rope_head_dim)
283+
284+
compressed_kv = self.kv_a_layernorm(compressed_kv)
285+
k_nope = compressed_kv # (b, s, 1, kv_lora_rank)
286+
value_states = compressed_kv # (b, s, 1, kv_lora_rank)
287+
288+
q_pe, k_pe = self.rotary_emb(q_pe, k_pe, query_positions)
289+
290+
def concat_nope_pe(num_heads: int):
291+
def f_concat_nope_pe(var_nope: te.Tensor, var_pe: te.Tensor):
292+
return te.compute(
293+
(b, s, num_heads, self.kv_lora_rank + self.qk_rope_head_dim),
294+
lambda _b, _s, _h, _d: te.if_then_else(
295+
_d < self.kv_lora_rank,
296+
var_nope[_b, _s, _h, _d],
297+
var_pe[_b, _s, _h, _d - self.kv_lora_rank],
298+
),
299+
)
300+
301+
return f_concat_nope_pe
302+
303+
query_states = op.tensor_expr_op(
304+
concat_nope_pe(num_heads=self.num_heads), "concat_q", [q_nope, q_pe]
305+
) # (b, s, num_heads, kv_lora_rank + qk_rope_head_dim)
306+
key_states = op.tensor_expr_op(
307+
concat_nope_pe(num_heads=1), "concat_k", [k_nope, k_pe]
308+
) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)
309+
value_states = op.pad(
310+
value_states, [0, 0, 0, 0, 0, 0, 0, self.qk_rope_head_dim]
311+
) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)
312+
313+
qkv = op.concat(
314+
[query_states, key_states, value_states], dim=2
315+
) # (b, s, num_heads + 2, kv_lora_rank + qk_rope_head_dim)
316+
output, _ = op.split(
317+
paged_kv_cache.attention_with_fused_qkv(
318+
layer_id,
319+
qkv,
320+
self.num_heads,
321+
self.softmax_scale
322+
* math.sqrt(
323+
self.kv_lora_rank + self.qk_rope_head_dim
324+
), # This is to cancel out the 1/sqrt(d) in normal attention
325+
),
326+
indices_or_sections=[self.kv_lora_rank],
327+
axis=-1,
328+
) # (b, s, num_heads, kv_lora_rank)
329+
output = (
330+
op.matmul(
331+
output.reshape(b * s, self.num_heads, self.kv_lora_rank).permute_dims(1, 0, 2),
332+
self.w_uv.permute_dims(0, 2, 1),
333+
)
334+
.permute_dims(1, 0, 2)
335+
.reshape(b, s, self.num_heads * self.v_head_dim)
336+
)
337+
338+
return self.o_proj(output)
339+
340+
def forward_normal(
341+
self,
342+
hidden_states: Tensor,
343+
paged_kv_cache: PagedKVCache,
344+
layer_id: int,
345+
query_positions: Tensor,
244346
):
245347
b, s, _ = hidden_states.shape
246348

@@ -450,6 +552,14 @@ def _set(layer, hint):
450552
self.self_attn.kv_b_proj.weight,
451553
tp.ShardSingleDim("_shard_kv_b_weight", dim=0),
452554
)
555+
_set(
556+
self.self_attn.w_uk,
557+
tp.ShardSingleDim("_shard_kv_b_weight_w_uk", dim=0),
558+
)
559+
_set(
560+
self.self_attn.w_uv,
561+
tp.ShardSingleDim("_shard_kv_b_weight_w_uv", dim=0),
562+
)
453563
_set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1))
454564

455565
if isinstance(self.mlp, DeepseekV2MoE):
@@ -517,7 +627,6 @@ def __init__(self, config: DeepseekV2Config):
517627

518628
def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):
519629
hidden_states = inputs
520-
print(f"inputs.shape = {inputs.shape}")
521630
query_positions = paged_kv_cache.get_query_positions(inputs.shape[0] * inputs.shape[1])
522631
for layer_id, layer in enumerate(self.layers):
523632
hidden_states = layer(hidden_states, paged_kv_cache, layer_id, query_positions)
@@ -535,6 +644,8 @@ def __init__(self, config: DeepseekV2Config):
535644
self.intermediate_size = config.intermediate_size
536645
self.num_attention_heads = config.num_attention_heads
537646
self.num_key_value_heads = config.num_key_value_heads
647+
self.kv_lora_rank = config.kv_lora_rank
648+
self.qk_rope_head_dim = config.qk_rope_head_dim
538649
self.rms_norm_eps = config.rms_norm_eps
539650
self.rope_theta = config.rope_theta
540651
self.vocab_size = config.vocab_size
@@ -621,8 +732,8 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
621732
support_sliding_window=support_sliding_window,
622733
num_hidden_layers=self.num_hidden_layers,
623734
num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,
624-
num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,
625-
head_dim=256,
735+
num_key_value_heads=1,
736+
head_dim=self.kv_lora_rank + self.qk_rope_head_dim,
626737
rope_mode=RopeMode.NONE,
627738
rope_scale=1,
628739
rope_theta=self.rope_theta,

tests/python/integration/test_model_compile.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ def test_model_compile(): # pylint: disable=too-many-locals
114114
if not target.startswith("cuda") and quant == "q4f16_ft":
115115
# FasterTransformer only works with cuda
116116
continue
117+
if "deepseek_v2" in model and "32" in quant:
118+
# Skip f32 for deepseek v2 model for now.
119+
continue
117120
log_file = os.path.join(tmp_dir, f"lib{idx}.log")
118121
cmd = [
119122
sys.executable,

0 commit comments

Comments
 (0)