Skip to content

Commit 92aa467

Browse files
remove repeat init
Signed-off-by: shikang-hangzhou <459956190@qq.com>
1 parent 04e6169 commit 92aa467

File tree

5 files changed

+51
-105
lines changed

5 files changed

+51
-105
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ jobs:
203203
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
204204
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
205205
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo
206+
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_w8a8_ep_dbo
206207
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
207208
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py --ignore=tests/multicard/test_w4a8_deepseek.py
208209
fi

tests/multicard/test_offline_inference_distributed.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,30 @@ def test_models_distributed_DeepSeek_dbo():
109109
vllm_model.generate(example_prompts, sampling_params)
110110

111111

112+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
113+
def test_models_distributed_DeepSeek_w8a8_ep_dbo():
114+
example_prompts = ["The president of the United States is"] * 100
115+
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
116+
with VllmRunner(
117+
snapshot_download("vllm-ascend/DeepSeek-V2-Lite-W8A8"),
118+
dtype="auto",
119+
quantization="ascend",
120+
tensor_parallel_size=4,
121+
enforce_eager=True,
122+
enable_expert_parallel=True,
123+
distributed_executor_backend="mp",
124+
additional_config={"ascend_scheduler_config": {
125+
"enabled": True,
126+
}}) as vllm_model:
127+
model_arch = 'DeepseekV2ForCausalLM'
128+
registed_models = ModelRegistry.models
129+
assert registed_models[
130+
model_arch].module_name == "vllm_ascend.models.deepseek_dbo"
131+
assert registed_models[
132+
model_arch].class_name == "CustomDeepseekDBOForCausalLM"
133+
vllm_model.generate(example_prompts, sampling_params)
134+
135+
112136
@pytest.mark.skip(reason="Due to OOM,waiting for 1311pr to merge in")
113137
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
114138
def test_models_distributed_DeepSeekV3_dbo():

vllm_ascend/models/deepseek_dbo.py

Lines changed: 20 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -33,35 +33,32 @@
3333
from torch import nn
3434
from transformers import PretrainedConfig
3535
from vllm.attention import AttentionMetadata
36-
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
37-
get_current_vllm_config)
36+
from vllm.config import CacheConfig, ModelConfig, VllmConfig
3837
from vllm.distributed import (get_ep_group, get_pp_group,
3938
get_tensor_model_parallel_rank,
4039
get_tensor_model_parallel_world_size,
4140
get_tp_group, tensor_model_parallel_all_reduce)
4241
from vllm.distributed.parallel_state import get_dp_group
4342
from vllm.forward_context import get_forward_context
4443
from vllm.model_executor.layers.layernorm import RMSNorm
45-
from vllm.model_executor.layers.linear import (ReplicatedLinear,
46-
UnquantizedLinearMethod)
44+
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
4745
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4846
from vllm.model_executor.layers.quantization import QuantizationConfig
4947
from vllm.model_executor.layers.sampler import get_sampler
5048
from vllm.model_executor.layers.vocab_parallel_embedding import (
5149
ParallelLMHead, VocabParallelEmbedding)
5250
from vllm.model_executor.models.deepseek_v2 import \
5351
DeepseekV2ForCausalLM # noqa: E501
54-
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
5552
from vllm.model_executor.models.utils import (
5653
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers,
5754
maybe_prefix)
5855
from vllm.sequence import IntermediateTensors
5956

6057
import vllm_ascend.envs as envs_ascend
61-
from vllm_ascend.ascend_config import get_ascend_config
6258
from vllm_ascend.ascend_forward_context import FusedMoEState
63-
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLAAttention,
64-
CustomDeepseekV2MLP)
59+
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2DecoderLayer,
60+
CustomDeepseekV2MLP,
61+
CustomDeepseekV2MoE)
6562
from vllm_ascend.multistream.base import MSEventKey
6663
from vllm_ascend.multistream.context import (
6764
advance_step_multistream_layer_context, get_multistream_comm_context,
@@ -71,7 +68,6 @@
7168
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
7269
MultiStreamStepMetadata,
7370
make_multistream_metadata_ds)
74-
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7571
from vllm_ascend.quantization.w8a8_dynamic import (
7672
AscendW8A8DynamicLinearMethod, apply_mlp)
7773
from vllm_ascend.utils import dispose_tensor
@@ -126,7 +122,7 @@ def _forward_ms_mlp(self, x):
126122
return x
127123

128124

129-
class CustomDeepseekDBOMoE(nn.Module):
125+
class CustomDeepseekDBOMoE(CustomDeepseekV2MoE):
130126

131127
top_k: int
132128

@@ -136,45 +132,9 @@ def __init__(
136132
quant_config: Optional[QuantizationConfig] = None,
137133
prefix: str = "",
138134
):
139-
super().__init__()
140-
self.tp_size = get_tensor_model_parallel_world_size()
141-
self.routed_scaling_factor = config.routed_scaling_factor
142-
self.n_shared_experts = config.n_shared_experts
143-
self.routed_scaling_factor = config.routed_scaling_factor
144-
if self.tp_size > config.n_routed_experts:
145-
raise ValueError(
146-
f"Tensor parallel size {self.tp_size} is greater than "
147-
f"the number of experts {config.n_routed_experts}.")
148-
149-
if config.hidden_act != "silu":
150-
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
151-
"Only silu is supported for now.")
152-
153-
self.gate = ReplicatedLinear(config.hidden_size,
154-
config.n_routed_experts,
155-
bias=False,
156-
quant_config=None,
157-
prefix=f"{prefix}.gate")
158-
if config.topk_method == "noaux_tc":
159-
self.gate.e_score_correction_bias = nn.Parameter(
160-
torch.empty(config.n_routed_experts))
161-
else:
162-
self.gate.e_score_correction_bias = None
163-
164-
self.experts = AscendFusedMoE(
165-
num_experts=config.n_routed_experts,
166-
top_k=config.num_experts_per_tok,
167-
hidden_size=config.hidden_size,
168-
intermediate_size=config.moe_intermediate_size,
169-
reduce_results=False,
170-
renormalize=config.norm_topk_prob,
171-
quant_config=quant_config,
172-
use_grouped_topk=True,
173-
num_expert_group=config.n_group,
174-
topk_group=config.topk_group,
175-
prefix=f"{prefix}.experts",
176-
scoring_func=config.scoring_func,
177-
e_score_correction_bias=self.gate.e_score_correction_bias)
135+
super().__init__(config=config,
136+
quant_config=quant_config,
137+
prefix=prefix)
178138

179139
if config.n_shared_experts is not None:
180140
intermediate_size = (config.moe_intermediate_size *
@@ -189,19 +149,6 @@ def __init__(
189149
)
190150
CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok
191151

192-
self.dp_size = get_dp_group().world_size
193-
194-
self.tp_group = get_tp_group().device_group
195-
self.tp_rank = get_tp_group().rank_in_group
196-
self.kv_consumer = None
197-
transfer_config = get_current_vllm_config().kv_transfer_config
198-
if transfer_config is not None:
199-
self.kv_consumer = transfer_config.kv_role = "kv_consumer"
200-
self.params_dtype = torch.get_default_dtype()
201-
202-
ascend_config = get_ascend_config()
203-
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
204-
205152
def forward(
206153
self,
207154
hidden_states: torch.Tensor,
@@ -254,7 +201,7 @@ def _forward_ms_op_gate(
254201
return router_logits
255202

256203

257-
class CustomDeepseekDBODecoderLayer(DeepseekV2DecoderLayer):
204+
class CustomDeepseekDBODecoderLayer(CustomDeepseekV2DecoderLayer):
258205

259206
def __init__(
260207
self,
@@ -264,43 +211,19 @@ def __init__(
264211
cache_config: Optional[CacheConfig] = None,
265212
quant_config: Optional[QuantizationConfig] = None,
266213
) -> None:
267-
nn.Module.__init__(self)
268-
self.hidden_size = config.hidden_size
269-
rope_theta = getattr(config, "rope_theta", 10000)
270-
rope_scaling = getattr(config, "rope_scaling", None)
271-
max_position_embeddings = getattr(config, "max_position_embeddings",
272-
8192)
273-
# DecoderLayers are created with `make_layers` which passes the prefix
274-
# with the layer's index.
275-
layer_idx = int(prefix.split(sep='.')[-1])
276-
self.layer_idx = layer_idx
277-
# TODO: enable mla in vllm-ascend
278-
attn_cls = CustomDeepseekV2MLAAttention
279-
self.self_attn = attn_cls(
280-
config=config,
281-
hidden_size=self.hidden_size,
282-
num_heads=config.num_attention_heads,
283-
qk_nope_head_dim=config.qk_nope_head_dim,
284-
qk_rope_head_dim=config.qk_rope_head_dim,
285-
v_head_dim=config.v_head_dim,
286-
q_lora_rank=config.q_lora_rank
287-
if hasattr(config, "q_lora_rank") else None,
288-
kv_lora_rank=config.kv_lora_rank,
289-
rope_theta=rope_theta,
290-
rope_scaling=rope_scaling,
291-
max_position_embeddings=max_position_embeddings,
292-
cache_config=cache_config,
293-
quant_config=quant_config,
294-
prefix=f"{prefix}.self_attn",
295-
)
214+
super().__init__(config=config,
215+
prefix=prefix,
216+
model_config=model_config,
217+
cache_config=cache_config,
218+
quant_config=quant_config)
296219
self.tp_size = get_tensor_model_parallel_world_size()
297220
self.dp_size = get_dp_group().world_size
298221
self.tp_group = get_tp_group().device_group
299222
self.global_num_experts = config.n_routed_experts
300223

301224
if (config.n_routed_experts is not None
302-
and layer_idx >= config.first_k_dense_replace
303-
and layer_idx % config.moe_layer_freq == 0):
225+
and self.layer_idx >= config.first_k_dense_replace
226+
and self.layer_idx % config.moe_layer_freq == 0):
304227
self.mlp = CustomDeepseekDBOMoE(
305228
config=config,
306229
quant_config=quant_config,
@@ -314,11 +237,6 @@ def __init__(
314237
quant_config=quant_config,
315238
prefix=f"{prefix}.mlp",
316239
)
317-
self.input_layernorm = RMSNorm(config.hidden_size,
318-
eps=config.rms_norm_eps)
319-
self.post_attention_layernorm = RMSNorm(config.hidden_size,
320-
eps=config.rms_norm_eps)
321-
self.routed_scaling_factor = config.routed_scaling_factor
322240

323241
def forward(
324242
self,
@@ -926,7 +844,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
926844
if get_pp_group().is_last_rank:
927845
self.lm_head = ParallelLMHead(config.vocab_size,
928846
config.hidden_size,
929-
quant_config=quant_config)
847+
quant_config=quant_config,
848+
prefix=maybe_prefix(
849+
prefix, "lm_head"))
930850
else:
931851
self.lm_head = PPMissingLayer()
932852
self.logits_processor = LogitsProcessor(config.vocab_size)

vllm_ascend/multistream/ms_split.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,12 @@ def model_input_split_v1_mla_attn(
7272
attn_metadata.query_lens):
7373
return [attn_metadata]
7474

75-
query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ),
76-
dtype=int)
75+
query_start_loc_cpu: Any = np.zeros(shape=(len(attn_metadata.query_lens) +
76+
1, ),
77+
dtype=int)
7778
np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:])
7879
if attn_metadata.num_prefills > 0:
79-
prefill_query_start_loc = np.zeros(
80+
prefill_query_start_loc: Any = np.zeros(
8081
shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int)
8182
np.cumsum(attn_metadata.prefill.query_lens,
8283
out=prefill_query_start_loc[1:])

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import weakref
2727
from contextlib import contextmanager, nullcontext
2828
from dataclasses import dataclass
29-
from typing import TYPE_CHECKING, Dict, List, Optional, Union
29+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
3030

3131
import numpy as np
3232
import numpy.typing as npt
@@ -1160,7 +1160,7 @@ def _calc_spec_decode_metadata(
11601160

11611161
# Compute the logits indices.
11621162
# [4, 1, 3, 1, 2]
1163-
num_sampled_tokens = num_draft_tokens + 1
1163+
num_sampled_tokens: Any = num_draft_tokens + 1
11641164
# Step 1. [4, 5, 8, 9, 11]
11651165
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
11661166
total_num_sampled_tokens = cu_num_sampled_tokens[-1]

0 commit comments

Comments
 (0)