diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 97d316e755..d9856bef3c 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -203,6 +203,7 @@ jobs: VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8 VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_w8a8_ep_dbo VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py --ignore=tests/multicard/test_w4a8_deepseek.py fi diff --git a/tests/multicard/test_offline_inference_distributed.py b/tests/multicard/test_offline_inference_distributed.py index df69ff8e3d..d4af282efe 100644 --- a/tests/multicard/test_offline_inference_distributed.py +++ b/tests/multicard/test_offline_inference_distributed.py @@ -109,6 +109,30 @@ def test_models_distributed_DeepSeek_dbo(): vllm_model.generate(example_prompts, sampling_params) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"}) +def test_models_distributed_DeepSeek_w8a8_ep_dbo(): + example_prompts = ["The president of the United States is"] * 100 + sampling_params = SamplingParams(max_tokens=100, temperature=0.0) + with VllmRunner( + snapshot_download("vllm-ascend/DeepSeek-V2-Lite-W8A8"), + dtype="auto", + quantization="ascend", + tensor_parallel_size=4, + enforce_eager=True, + enable_expert_parallel=True, + distributed_executor_backend="mp", + additional_config={"ascend_scheduler_config": { + "enabled": True, + }}) as vllm_model: + model_arch = 'DeepseekV2ForCausalLM' + registed_models = ModelRegistry.models + assert registed_models[ + model_arch].module_name == "vllm_ascend.models.deepseek_dbo" + assert registed_models[ + model_arch].class_name == "CustomDeepseekDBOForCausalLM" + vllm_model.generate(example_prompts, sampling_params) + + @pytest.mark.skip(reason="Due to OOM,waiting for 1311pr to merge in") @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"}) def test_models_distributed_DeepSeekV3_dbo(): diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 000bd39ed5..7d375c3519 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -33,8 +33,7 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import AttentionMetadata -from vllm.config import (CacheConfig, ModelConfig, VllmConfig, - get_current_vllm_config) +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_ep_group, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -42,8 +41,7 @@ from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ReplicatedLinear, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import get_sampler @@ -51,17 +49,16 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.models.deepseek_v2 import \ DeepseekV2ForCausalLM # noqa: E501 -from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer from vllm.model_executor.models.utils import ( PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.sequence import IntermediateTensors import vllm_ascend.envs as envs_ascend -from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState -from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLAAttention, - CustomDeepseekV2MLP) +from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2DecoderLayer, + CustomDeepseekV2MLP, + CustomDeepseekV2MoE) from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.context import ( advance_step_multistream_layer_context, get_multistream_comm_context, @@ -71,7 +68,6 @@ from vllm_ascend.multistream.metadata import (MultiStreamConfig, MultiStreamStepMetadata, make_multistream_metadata_ds) -from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.quantization.w8a8_dynamic import ( AscendW8A8DynamicLinearMethod, apply_mlp) from vllm_ascend.utils import dispose_tensor @@ -126,7 +122,7 @@ def _forward_ms_mlp(self, x): return x -class CustomDeepseekDBOMoE(nn.Module): +class CustomDeepseekDBOMoE(CustomDeepseekV2MoE): top_k: int @@ -136,45 +132,9 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): - super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - self.routed_scaling_factor = config.routed_scaling_factor - self.n_shared_experts = config.n_shared_experts - self.routed_scaling_factor = config.routed_scaling_factor - if self.tp_size > config.n_routed_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.n_routed_experts}.") - - if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") - - self.gate = ReplicatedLinear(config.hidden_size, - config.n_routed_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") - if config.topk_method == "noaux_tc": - self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts)) - else: - self.gate.e_score_correction_bias = None - - self.experts = AscendFusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias) + super().__init__(config=config, + quant_config=quant_config, + prefix=prefix) if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * @@ -189,19 +149,6 @@ def __init__( ) CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok - self.dp_size = get_dp_group().world_size - - self.tp_group = get_tp_group().device_group - self.tp_rank = get_tp_group().rank_in_group - self.kv_consumer = None - transfer_config = get_current_vllm_config().kv_transfer_config - if transfer_config is not None: - self.kv_consumer = transfer_config.kv_role = "kv_consumer" - self.params_dtype = torch.get_default_dtype() - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - def forward( self, hidden_states: torch.Tensor, @@ -254,7 +201,7 @@ def _forward_ms_op_gate( return router_logits -class CustomDeepseekDBODecoderLayer(DeepseekV2DecoderLayer): +class CustomDeepseekDBODecoderLayer(CustomDeepseekV2DecoderLayer): def __init__( self, @@ -264,43 +211,19 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: - nn.Module.__init__(self) - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - # DecoderLayers are created with `make_layers` which passes the prefix - # with the layer's index. - layer_idx = int(prefix.split(sep='.')[-1]) - self.layer_idx = layer_idx - # TODO: enable mla in vllm-ascend - attn_cls = CustomDeepseekV2MLAAttention - self.self_attn = attn_cls( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=config.v_head_dim, - q_lora_rank=config.q_lora_rank - if hasattr(config, "q_lora_rank") else None, - kv_lora_rank=config.kv_lora_rank, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) + super().__init__(config=config, + prefix=prefix, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config) self.tp_size = get_tensor_model_parallel_world_size() self.dp_size = get_dp_group().world_size self.tp_group = get_tp_group().device_group self.global_num_experts = config.n_routed_experts if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): + and self.layer_idx >= config.first_k_dense_replace + and self.layer_idx % config.moe_layer_freq == 0): self.mlp = CustomDeepseekDBOMoE( config=config, quant_config=quant_config, @@ -314,11 +237,6 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.routed_scaling_factor = config.routed_scaling_factor def forward( self, @@ -926,7 +844,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index fd32a18abb..b80446e324 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -72,11 +72,12 @@ def model_input_split_v1_mla_attn( attn_metadata.query_lens): return [attn_metadata] - query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ), - dtype=int) + query_start_loc_cpu: Any = np.zeros(shape=(len(attn_metadata.query_lens) + + 1, ), + dtype=int) np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:]) if attn_metadata.num_prefills > 0: - prefill_query_start_loc = np.zeros( + prefill_query_start_loc: Any = np.zeros( shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int) np.cumsum(attn_metadata.prefill.query_lens, out=prefill_query_start_loc[1:]) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 4b2feefa90..e3b81a076b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -26,7 +26,7 @@ import weakref from contextlib import contextmanager, nullcontext from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import numpy as np import numpy.typing as npt @@ -1160,7 +1160,7 @@ def _calc_spec_decode_metadata( # Compute the logits indices. # [4, 1, 3, 1, 2] - num_sampled_tokens = num_draft_tokens + 1 + num_sampled_tokens: Any = num_draft_tokens + 1 # Step 1. [4, 5, 8, 9, 11] cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32) total_num_sampled_tokens = cu_num_sampled_tokens[-1]