26
26
# """Inference-only DeepseekV2/DeepseekV3 model."""
27
27
28
28
import os
29
- from typing import Any , Dict , Optional , Union
29
+ from typing import Any , Dict , List , Optional , Union
30
30
31
31
import torch
32
32
import torch .distributed as dist
33
33
from torch import nn
34
34
from transformers import PretrainedConfig
35
- from vllm .attention import Attention
35
+ from vllm .attention import Attention , AttentionMetadata
36
36
from vllm .config import (CacheConfig , ModelConfig , VllmConfig ,
37
37
get_current_vllm_config )
38
38
from vllm .distributed import (get_dp_group , get_pp_group ,
64
64
from vllm .sequence import IntermediateTensors
65
65
66
66
from vllm_ascend .ops .fused_moe import AscendFusedMoE
67
- from vllm_ascend .utils import VLLM_ENABLE_GRAPH_MODE
68
67
69
68
70
69
class CustomDeepseekV2MoE (nn .Module ):
@@ -133,7 +132,7 @@ def __init__(
133
132
vllm_config = get_current_vllm_config ()
134
133
self .dp_size = get_dp_group ().world_size
135
134
batch_size = vllm_config .scheduler_config .max_num_seqs
136
- self .enable_mc2 = int (os .environ .get ("VLLM_ENABLE_MC2" , 0 )) == 1
135
+ self .enable_mc2 = int (os .environ .get ("VLLM_ENABLE_MC2" , '0' )) == 1
137
136
138
137
params_dtype = torch .get_default_dtype ()
139
138
self .final_hidden_states = torch .zeros (
@@ -309,38 +308,36 @@ def __init__(
309
308
310
309
self .prefix = prefix
311
310
self .debug_layer_idx = int (self .prefix .split ("." )[- 2 ])
312
- if VLLM_ENABLE_GRAPH_MODE == "1" :
313
- self .forward = self .forward_torchair
314
- else :
315
- self .forward = self .forward_eager # type: ignore
311
+ self .enable_graph_mode = False
312
+ additional_config = get_current_vllm_config ().additional_config
313
+ if additional_config :
314
+ self .enable_graph_mode = additional_config .get (
315
+ "enable_graph_mode" , False )
316
316
317
- def forward_torchair (self ,
318
- positions : torch .Tensor ,
319
- hidden_states : torch .Tensor ,
320
- kv_cache : torch .Tensor = None ,
321
- attn_metadata = None ):
317
+ def forward (
318
+ self ,
319
+ positions : torch .Tensor ,
320
+ hidden_states : torch .Tensor ,
321
+ kv_cache : Optional [torch .Tensor ] = None ,
322
+ attn_metadata : Optional [AttentionMetadata ] = None ) -> torch .Tensor :
322
323
if self .q_lora_rank is not None :
323
324
ckq = self .q_a_proj (hidden_states )[0 ]
324
325
hidden_states_or_q_c = self .q_a_layernorm (ckq )
325
326
else :
326
327
hidden_states_or_q_c = hidden_states
327
- return self .mla_attn (hidden_states_or_q_c , hidden_states , None ,
328
- kv_cache , attn_metadata )
329
-
330
- def forward_eager (self , positions : torch .Tensor ,
331
- hidden_states : torch .Tensor ):
332
- if self .q_lora_rank is not None :
333
- ckq = self .q_a_proj (hidden_states )[0 ]
334
- hidden_states_or_q_c = self .q_a_layernorm (ckq )
328
+ if self .enable_graph_mode :
329
+ return self .mla_attn .impl .forward (self .mla_attn ,
330
+ hidden_states_or_q_c ,
331
+ hidden_states , None , kv_cache ,
332
+ attn_metadata )
335
333
else :
336
- hidden_states_or_q_c = hidden_states
337
- kv_c , k_pe = self .kv_a_proj_with_mqa (hidden_states )[0 ].split (
338
- [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
339
- kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
340
- return self .mla_attn (hidden_states_or_q_c ,
341
- kv_c_normed ,
342
- k_pe ,
343
- output_shape = hidden_states .shape )
334
+ kv_c , k_pe = self .kv_a_proj_with_mqa (hidden_states )[0 ].split (
335
+ [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
336
+ kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
337
+ return self .mla_attn (hidden_states_or_q_c ,
338
+ kv_c_normed ,
339
+ k_pe ,
340
+ output_shape = hidden_states .shape )
344
341
345
342
346
343
class CustomDeepseekV2DecoderLayer (DeepseekV2DecoderLayer ):
@@ -408,6 +405,54 @@ def __init__(
408
405
eps = config .rms_norm_eps )
409
406
self .routed_scaling_factor = config .routed_scaling_factor
410
407
408
+ def forward (
409
+ self ,
410
+ positions : torch .Tensor ,
411
+ hidden_states : torch .Tensor ,
412
+ residual : Optional [torch .Tensor ],
413
+ kv_cache : Optional [torch .Tensor ] = None ,
414
+ attn_metadata : Optional [AttentionMetadata ] = None ,
415
+ ) -> torch .Tensor :
416
+ # Self Attention
417
+ if residual is None :
418
+ residual = hidden_states
419
+ hidden_states = self .input_layernorm (hidden_states )
420
+ else :
421
+ hidden_states , residual = self .input_layernorm (
422
+ hidden_states , residual )
423
+ hidden_states = self .self_attn (
424
+ positions = positions ,
425
+ hidden_states = hidden_states ,
426
+ kv_cache = kv_cache ,
427
+ attn_metadata = attn_metadata ,
428
+ )
429
+
430
+ if hidden_states .dtype == torch .float16 :
431
+ # Fix FP16 overflow
432
+ # We scale both hidden_states and residual before
433
+ # rmsnorm, and rmsnorm result would not affect by scale.
434
+ hidden_states *= 1. / self .routed_scaling_factor
435
+ if self .layer_idx == 0 :
436
+ # The residual is shared by all layers, we only scale it on
437
+ # first layer.
438
+ residual *= 1. / self .routed_scaling_factor
439
+
440
+ # Fully Connected
441
+ hidden_states , residual = self .post_attention_layernorm (
442
+ hidden_states , residual )
443
+ hidden_states = self .mlp (hidden_states )
444
+
445
+ if isinstance (self .mlp ,
446
+ DeepseekV2MLP ) and hidden_states .dtype == torch .float16 :
447
+ # Fix FP16 overflow
448
+ # Scaling the DeepseekV2MLP output, it is the input of
449
+ # input_layernorm of next decoder layer.
450
+ # The scaling of DeepseekV2MOE output would be done in the forward
451
+ # of DeepseekV2MOE
452
+ hidden_states *= 1. / self .routed_scaling_factor
453
+
454
+ return hidden_states , residual
455
+
411
456
412
457
class CustomDeepseekV2Model (nn .Module ):
413
458
@@ -459,7 +504,9 @@ def forward(
459
504
self ,
460
505
input_ids : torch .Tensor ,
461
506
positions : torch .Tensor ,
462
- intermediate_tensors : Optional [IntermediateTensors ],
507
+ kv_caches : Optional [List [torch .Tensor ]] = None ,
508
+ attn_metadata : Optional [AttentionMetadata ] = None ,
509
+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
463
510
inputs_embeds : Optional [torch .Tensor ] = None ,
464
511
) -> Union [torch .Tensor , IntermediateTensors ]:
465
512
if get_pp_group ().is_first_rank :
@@ -473,8 +520,13 @@ def forward(
473
520
hidden_states = intermediate_tensors ["hidden_states" ]
474
521
residual = intermediate_tensors ["residual" ]
475
522
476
- for layer in self .layers [self .start_layer :self .end_layer ]:
477
- hidden_states , residual = layer (positions , hidden_states , residual )
523
+ for i in range (self .start_layer , self .end_layer ):
524
+ layer = self .layers [i ]
525
+ hidden_states , residual = layer (
526
+ positions , hidden_states , residual ,
527
+ kv_caches [i -
528
+ self .start_layer ] if kv_caches is not None else None ,
529
+ attn_metadata )
478
530
479
531
if not get_pp_group ().is_last_rank :
480
532
return IntermediateTensors ({
@@ -514,6 +566,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
514
566
self .make_empty_intermediate_tensors = (
515
567
self .model .make_empty_intermediate_tensors )
516
568
569
+ def forward (
570
+ self ,
571
+ input_ids : torch .Tensor ,
572
+ positions : torch .Tensor ,
573
+ kv_caches : Optional [List [torch .Tensor ]] = None ,
574
+ attn_metadata : Optional [AttentionMetadata ] = None ,
575
+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
576
+ inputs_embeds : Optional [torch .Tensor ] = None ,
577
+ ) -> Union [torch .Tensor , IntermediateTensors ]:
578
+ hidden_states = self .model (input_ids , positions , kv_caches ,
579
+ attn_metadata , intermediate_tensors ,
580
+ inputs_embeds )
581
+ return hidden_states
582
+
517
583
518
584
class CustomDeepseekV3ForCausalLM (CustomDeepseekV2ForCausalLM ):
519
585
pass
0 commit comments