Skip to content

Commit 5559443

Browse files
[V0.9.1][BugFix] Fix load weight error and add new e2e case (#1651)
### What this PR does / why we need it? 1. Quant parameters has been modified but DBO file didn't match it. 2. remove useless init code, mostly reuse v2 init code. 3. add DBO e2e case and remove case skip. ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? ‘tests/multicard/test_offline_inference_distributed.py’ Signed-off-by: shikang-hangzhou <459956190@qq.com>
1 parent ffd1d9a commit 5559443

File tree

5 files changed

+51
-105
lines changed

5 files changed

+51
-105
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ jobs:
203203
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
204204
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
205205
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo
206+
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_w8a8_ep_dbo
206207
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
207208
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py --ignore=tests/multicard/test_w4a8_deepseek.py
208209
fi

tests/multicard/test_offline_inference_distributed.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,30 @@ def test_models_distributed_DeepSeek_dbo():
109109
vllm_model.generate(example_prompts, sampling_params)
110110

111111

112+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
113+
def test_models_distributed_DeepSeek_w8a8_ep_dbo():
114+
example_prompts = ["The president of the United States is"] * 100
115+
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
116+
with VllmRunner(
117+
snapshot_download("vllm-ascend/DeepSeek-V2-Lite-W8A8"),
118+
dtype="auto",
119+
quantization="ascend",
120+
tensor_parallel_size=4,
121+
enforce_eager=True,
122+
enable_expert_parallel=True,
123+
distributed_executor_backend="mp",
124+
additional_config={"ascend_scheduler_config": {
125+
"enabled": True,
126+
}}) as vllm_model:
127+
model_arch = 'DeepseekV2ForCausalLM'
128+
registed_models = ModelRegistry.models
129+
assert registed_models[
130+
model_arch].module_name == "vllm_ascend.models.deepseek_dbo"
131+
assert registed_models[
132+
model_arch].class_name == "CustomDeepseekDBOForCausalLM"
133+
vllm_model.generate(example_prompts, sampling_params)
134+
135+
112136
@pytest.mark.skip(reason="Due to OOM,waiting for 1311pr to merge in")
113137
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
114138
def test_models_distributed_DeepSeekV3_dbo():

vllm_ascend/models/deepseek_dbo.py

Lines changed: 20 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -33,35 +33,32 @@
3333
from torch import nn
3434
from transformers import PretrainedConfig
3535
from vllm.attention import AttentionMetadata
36-
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
37-
get_current_vllm_config)
36+
from vllm.config import CacheConfig, ModelConfig, VllmConfig
3837
from vllm.distributed import (get_ep_group, get_pp_group,
3938
get_tensor_model_parallel_rank,
4039
get_tensor_model_parallel_world_size,
4140
get_tp_group, tensor_model_parallel_all_reduce)
4241
from vllm.distributed.parallel_state import get_dp_group
4342
from vllm.forward_context import get_forward_context
4443
from vllm.model_executor.layers.layernorm import RMSNorm
45-
from vllm.model_executor.layers.linear import (ReplicatedLinear,
46-
UnquantizedLinearMethod)
44+
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
4745
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4846
from vllm.model_executor.layers.quantization import QuantizationConfig
4947
from vllm.model_executor.layers.sampler import get_sampler
5048
from vllm.model_executor.layers.vocab_parallel_embedding import (
5149
ParallelLMHead, VocabParallelEmbedding)
5250
from vllm.model_executor.models.deepseek_v2 import \
5351
DeepseekV2ForCausalLM # noqa: E501
54-
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
5552
from vllm.model_executor.models.utils import (
5653
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers,
5754
maybe_prefix)
5855
from vllm.sequence import IntermediateTensors
5956

6057
import vllm_ascend.envs as envs_ascend
61-
from vllm_ascend.ascend_config import get_ascend_config
6258
from vllm_ascend.ascend_forward_context import FusedMoEState
63-
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLAAttention,
64-
CustomDeepseekV2MLP)
59+
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2DecoderLayer,
60+
CustomDeepseekV2MLP,
61+
CustomDeepseekV2MoE)
6562
from vllm_ascend.multistream.base import MSEventKey
6663
from vllm_ascend.multistream.context import (
6764
advance_step_multistream_layer_context, get_multistream_comm_context,
@@ -71,7 +68,6 @@
7168
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
7269
MultiStreamStepMetadata,
7370
make_multistream_metadata_ds)
74-
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7571
from vllm_ascend.quantization.w8a8_dynamic import (
7672
AscendW8A8DynamicLinearMethod, apply_mlp)
7773
from vllm_ascend.utils import dispose_tensor
@@ -126,7 +122,7 @@ def _forward_ms_mlp(self, x):
126122
return x
127123

128124

129-
class CustomDeepseekDBOMoE(nn.Module):
125+
class CustomDeepseekDBOMoE(CustomDeepseekV2MoE):
130126

131127
top_k: int
132128

@@ -136,45 +132,9 @@ def __init__(
136132
quant_config: Optional[QuantizationConfig] = None,
137133
prefix: str = "",
138134
):
139-
super().__init__()
140-
self.tp_size = get_tensor_model_parallel_world_size()
141-
self.routed_scaling_factor = config.routed_scaling_factor
142-
self.n_shared_experts = config.n_shared_experts
143-
self.routed_scaling_factor = config.routed_scaling_factor
144-
if self.tp_size > config.n_routed_experts:
145-
raise ValueError(
146-
f"Tensor parallel size {self.tp_size} is greater than "
147-
f"the number of experts {config.n_routed_experts}.")
148-
149-
if config.hidden_act != "silu":
150-
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
151-
"Only silu is supported for now.")
152-
153-
self.gate = ReplicatedLinear(config.hidden_size,
154-
config.n_routed_experts,
155-
bias=False,
156-
quant_config=None,
157-
prefix=f"{prefix}.gate")
158-
if config.topk_method == "noaux_tc":
159-
self.gate.e_score_correction_bias = nn.Parameter(
160-
torch.empty(config.n_routed_experts))
161-
else:
162-
self.gate.e_score_correction_bias = None
163-
164-
self.experts = AscendFusedMoE(
165-
num_experts=config.n_routed_experts,
166-
top_k=config.num_experts_per_tok,
167-
hidden_size=config.hidden_size,
168-
intermediate_size=config.moe_intermediate_size,
169-
reduce_results=False,
170-
renormalize=config.norm_topk_prob,
171-
quant_config=quant_config,
172-
use_grouped_topk=True,
173-
num_expert_group=config.n_group,
174-
topk_group=config.topk_group,
175-
prefix=f"{prefix}.experts",
176-
scoring_func=config.scoring_func,
177-
e_score_correction_bias=self.gate.e_score_correction_bias)
135+
super().__init__(config=config,
136+
quant_config=quant_config,
137+
prefix=prefix)
178138

179139
if config.n_shared_experts is not None:
180140
intermediate_size = (config.moe_intermediate_size *
@@ -189,19 +149,6 @@ def __init__(
189149
)
190150
CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok
191151

192-
self.dp_size = get_dp_group().world_size
193-
194-
self.tp_group = get_tp_group().device_group
195-
self.tp_rank = get_tp_group().rank_in_group
196-
self.kv_consumer = None
197-
transfer_config = get_current_vllm_config().kv_transfer_config
198-
if transfer_config is not None:
199-
self.kv_consumer = transfer_config.kv_role = "kv_consumer"
200-
self.params_dtype = torch.get_default_dtype()
201-
202-
ascend_config = get_ascend_config()
203-
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
204-
205152
def forward(
206153
self,
207154
hidden_states: torch.Tensor,
@@ -254,7 +201,7 @@ def _forward_ms_op_gate(
254201
return router_logits
255202

256203

257-
class CustomDeepseekDBODecoderLayer(DeepseekV2DecoderLayer):
204+
class CustomDeepseekDBODecoderLayer(CustomDeepseekV2DecoderLayer):
258205

259206
def __init__(
260207
self,
@@ -264,43 +211,19 @@ def __init__(
264211
cache_config: Optional[CacheConfig] = None,
265212
quant_config: Optional[QuantizationConfig] = None,
266213
) -> None:
267-
nn.Module.__init__(self)
268-
self.hidden_size = config.hidden_size
269-
rope_theta = getattr(config, "rope_theta", 10000)
270-
rope_scaling = getattr(config, "rope_scaling", None)
271-
max_position_embeddings = getattr(config, "max_position_embeddings",
272-
8192)
273-
# DecoderLayers are created with `make_layers` which passes the prefix
274-
# with the layer's index.
275-
layer_idx = int(prefix.split(sep='.')[-1])
276-
self.layer_idx = layer_idx
277-
# TODO: enable mla in vllm-ascend
278-
attn_cls = CustomDeepseekV2MLAAttention
279-
self.self_attn = attn_cls(
280-
config=config,
281-
hidden_size=self.hidden_size,
282-
num_heads=config.num_attention_heads,
283-
qk_nope_head_dim=config.qk_nope_head_dim,
284-
qk_rope_head_dim=config.qk_rope_head_dim,
285-
v_head_dim=config.v_head_dim,
286-
q_lora_rank=config.q_lora_rank
287-
if hasattr(config, "q_lora_rank") else None,
288-
kv_lora_rank=config.kv_lora_rank,
289-
rope_theta=rope_theta,
290-
rope_scaling=rope_scaling,
291-
max_position_embeddings=max_position_embeddings,
292-
cache_config=cache_config,
293-
quant_config=quant_config,
294-
prefix=f"{prefix}.self_attn",
295-
)
214+
super().__init__(config=config,
215+
prefix=prefix,
216+
model_config=model_config,
217+
cache_config=cache_config,
218+
quant_config=quant_config)
296219
self.tp_size = get_tensor_model_parallel_world_size()
297220
self.dp_size = get_dp_group().world_size
298221
self.tp_group = get_tp_group().device_group
299222
self.global_num_experts = config.n_routed_experts
300223

301224
if (config.n_routed_experts is not None
302-
and layer_idx >= config.first_k_dense_replace
303-
and layer_idx % config.moe_layer_freq == 0):
225+
and self.layer_idx >= config.first_k_dense_replace
226+
and self.layer_idx % config.moe_layer_freq == 0):
304227
self.mlp = CustomDeepseekDBOMoE(
305228
config=config,
306229
quant_config=quant_config,
@@ -314,11 +237,6 @@ def __init__(
314237
quant_config=quant_config,
315238
prefix=f"{prefix}.mlp",
316239
)
317-
self.input_layernorm = RMSNorm(config.hidden_size,
318-
eps=config.rms_norm_eps)
319-
self.post_attention_layernorm = RMSNorm(config.hidden_size,
320-
eps=config.rms_norm_eps)
321-
self.routed_scaling_factor = config.routed_scaling_factor
322240

323241
def forward(
324242
self,
@@ -926,7 +844,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
926844
if get_pp_group().is_last_rank:
927845
self.lm_head = ParallelLMHead(config.vocab_size,
928846
config.hidden_size,
929-
quant_config=quant_config)
847+
quant_config=quant_config,
848+
prefix=maybe_prefix(
849+
prefix, "lm_head"))
930850
else:
931851
self.lm_head = PPMissingLayer()
932852
self.logits_processor = LogitsProcessor(config.vocab_size)

vllm_ascend/multistream/ms_split.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,12 @@ def model_input_split_v1_mla_attn(
7272
attn_metadata.query_lens):
7373
return [attn_metadata]
7474

75-
query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ),
76-
dtype=int)
75+
query_start_loc_cpu: Any = np.zeros(shape=(len(attn_metadata.query_lens) +
76+
1, ),
77+
dtype=int)
7778
np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:])
7879
if attn_metadata.num_prefills > 0:
79-
prefill_query_start_loc = np.zeros(
80+
prefill_query_start_loc: Any = np.zeros(
8081
shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int)
8182
np.cumsum(attn_metadata.prefill.query_lens,
8283
out=prefill_query_start_loc[1:])

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import weakref
2727
from contextlib import contextmanager, nullcontext
2828
from dataclasses import dataclass
29-
from typing import TYPE_CHECKING, Dict, List, Optional, Union
29+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
3030

3131
import numpy as np
3232
import numpy.typing as npt
@@ -1160,7 +1160,7 @@ def _calc_spec_decode_metadata(
11601160

11611161
# Compute the logits indices.
11621162
# [4, 1, 3, 1, 2]
1163-
num_sampled_tokens = num_draft_tokens + 1
1163+
num_sampled_tokens: Any = num_draft_tokens + 1
11641164
# Step 1. [4, 5, 8, 9, 11]
11651165
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
11661166
total_num_sampled_tokens = cu_num_sampled_tokens[-1]

0 commit comments

Comments
 (0)