Skip to content

[V0.9.1][BugFix] Fix load weight error and add new e2e case #1651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/vllm_ascend_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ jobs:
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_w8a8_ep_dbo
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py --ignore=tests/multicard/test_w4a8_deepseek.py
fi
Expand Down
24 changes: 24 additions & 0 deletions tests/multicard/test_offline_inference_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,30 @@ def test_models_distributed_DeepSeek_dbo():
vllm_model.generate(example_prompts, sampling_params)


@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
def test_models_distributed_DeepSeek_w8a8_ep_dbo():
example_prompts = ["The president of the United States is"] * 100
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
with VllmRunner(
snapshot_download("vllm-ascend/DeepSeek-V2-Lite-W8A8"),
dtype="auto",
quantization="ascend",
tensor_parallel_size=4,
enforce_eager=True,
enable_expert_parallel=True,
distributed_executor_backend="mp",
additional_config={"ascend_scheduler_config": {
"enabled": True,
}}) as vllm_model:
model_arch = 'DeepseekV2ForCausalLM'
registed_models = ModelRegistry.models
assert registed_models[
model_arch].module_name == "vllm_ascend.models.deepseek_dbo"
assert registed_models[
model_arch].class_name == "CustomDeepseekDBOForCausalLM"
vllm_model.generate(example_prompts, sampling_params)


@pytest.mark.skip(reason="Due to OOM,waiting for 1311pr to merge in")
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
def test_models_distributed_DeepSeekV3_dbo():
Expand Down
120 changes: 20 additions & 100 deletions vllm_ascend/models/deepseek_dbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,35 +33,32 @@
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ReplicatedLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.deepseek_v2 import \
DeepseekV2ForCausalLM # noqa: E501
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
from vllm.model_executor.models.utils import (
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from vllm.sequence import IntermediateTensors

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLAAttention,
CustomDeepseekV2MLP)
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2DecoderLayer,
CustomDeepseekV2MLP,
CustomDeepseekV2MoE)
from vllm_ascend.multistream.base import MSEventKey
from vllm_ascend.multistream.context import (
advance_step_multistream_layer_context, get_multistream_comm_context,
Expand All @@ -71,7 +68,6 @@
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
MultiStreamStepMetadata,
make_multistream_metadata_ds)
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.w8a8_dynamic import (
AscendW8A8DynamicLinearMethod, apply_mlp)
from vllm_ascend.utils import dispose_tensor
Expand Down Expand Up @@ -126,7 +122,7 @@ def _forward_ms_mlp(self, x):
return x


class CustomDeepseekDBOMoE(nn.Module):
class CustomDeepseekDBOMoE(CustomDeepseekV2MoE):

top_k: int

Expand All @@ -136,45 +132,9 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.routed_scaling_factor = config.routed_scaling_factor
if self.tp_size > config.n_routed_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.n_routed_experts}.")

if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")

self.gate = ReplicatedLinear(config.hidden_size,
config.n_routed_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate")
if config.topk_method == "noaux_tc":
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts))
else:
self.gate.e_score_correction_bias = None

self.experts = AscendFusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias)
super().__init__(config=config,
quant_config=quant_config,
prefix=prefix)

if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
Expand All @@ -189,19 +149,6 @@ def __init__(
)
CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok

self.dp_size = get_dp_group().world_size

self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.kv_consumer = None
transfer_config = get_current_vllm_config().kv_transfer_config
if transfer_config is not None:
self.kv_consumer = transfer_config.kv_role = "kv_consumer"
self.params_dtype = torch.get_default_dtype()

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -254,7 +201,7 @@ def _forward_ms_op_gate(
return router_logits


class CustomDeepseekDBODecoderLayer(DeepseekV2DecoderLayer):
class CustomDeepseekDBODecoderLayer(CustomDeepseekV2DecoderLayer):

def __init__(
self,
Expand All @@ -264,43 +211,19 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx
# TODO: enable mla in vllm-ascend
attn_cls = CustomDeepseekV2MLAAttention
self.self_attn = attn_cls(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=config.q_lora_rank
if hasattr(config, "q_lora_rank") else None,
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
super().__init__(config=config,
prefix=prefix,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config)
self.tp_size = get_tensor_model_parallel_world_size()
self.dp_size = get_dp_group().world_size
self.tp_group = get_tp_group().device_group
self.global_num_experts = config.n_routed_experts

if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
and self.layer_idx >= config.first_k_dense_replace
and self.layer_idx % config.moe_layer_freq == 0):
self.mlp = CustomDeepseekDBOMoE(
config=config,
quant_config=quant_config,
Expand All @@ -314,11 +237,6 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor

def forward(
self,
Expand Down Expand Up @@ -926,7 +844,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
Expand Down
7 changes: 4 additions & 3 deletions vllm_ascend/multistream/ms_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ def model_input_split_v1_mla_attn(
attn_metadata.query_lens):
return [attn_metadata]

query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ),
dtype=int)
query_start_loc_cpu: Any = np.zeros(shape=(len(attn_metadata.query_lens) +
1, ),
dtype=int)
np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:])
if attn_metadata.num_prefills > 0:
prefill_query_start_loc = np.zeros(
prefill_query_start_loc: Any = np.zeros(
shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int)
np.cumsum(attn_metadata.prefill.query_lens,
out=prefill_query_start_loc[1:])
Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import weakref
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -1160,7 +1160,7 @@ def _calc_spec_decode_metadata(

# Compute the logits indices.
# [4, 1, 3, 1, 2]
num_sampled_tokens = num_draft_tokens + 1
num_sampled_tokens: Any = num_draft_tokens + 1
# Step 1. [4, 5, 8, 9, 11]
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
total_num_sampled_tokens = cu_num_sampled_tokens[-1]
Expand Down