Skip to content

Commit 812087b

Browse files
remove repeat init
Signed-off-by: shikang-hangzhou <459956190@qq.com>
1 parent 04e6169 commit 812087b

File tree

3 files changed

+46
-100
lines changed

3 files changed

+46
-100
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ jobs:
203203
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
204204
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
205205
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo
206+
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_w8a8_ep_dbo
206207
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
207208
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py --ignore=tests/multicard/test_w4a8_deepseek.py
208209
fi

tests/multicard/test_offline_inference_distributed.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,6 @@ def test_models_distributed_topk() -> None:
8585
vllm_model.generate(example_prompts, sampling_params)
8686

8787

88-
@pytest.mark.skip(
89-
reason=
90-
"deepseek dbo dose not consider the support on half precision float, will enable this ut after we actually support it"
91-
)
9288
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
9389
def test_models_distributed_DeepSeek_dbo():
9490
example_prompts = ["The president of the United States is"] * 41
@@ -109,6 +105,30 @@ def test_models_distributed_DeepSeek_dbo():
109105
vllm_model.generate(example_prompts, sampling_params)
110106

111107

108+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
109+
def test_models_distributed_DeepSeek_w8a8_ep_dbo():
110+
example_prompts = ["The president of the United States is"] * 100
111+
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
112+
with VllmRunner(
113+
snapshot_download("vllm-ascend/DeepSeek-V2-Lite-W8A8"),
114+
dtype="auto",
115+
quantization="ascend",
116+
tensor_parallel_size=4,
117+
enforce_eager=True,
118+
enable_expert_parallel=True,
119+
distributed_executor_backend="mp",
120+
additional_config={"ascend_scheduler_config": {
121+
"enabled": True,
122+
}}) as vllm_model:
123+
model_arch = 'DeepseekV2ForCausalLM'
124+
registed_models = ModelRegistry.models
125+
assert registed_models[
126+
model_arch].module_name == "vllm_ascend.models.deepseek_dbo"
127+
assert registed_models[
128+
model_arch].class_name == "CustomDeepseekDBOForCausalLM"
129+
vllm_model.generate(example_prompts, sampling_params)
130+
131+
112132
@pytest.mark.skip(reason="Due to OOM,waiting for 1311pr to merge in")
113133
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
114134
def test_models_distributed_DeepSeekV3_dbo():

vllm_ascend/models/deepseek_dbo.py

Lines changed: 21 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
ParallelLMHead, VocabParallelEmbedding)
5252
from vllm.model_executor.models.deepseek_v2 import \
5353
DeepseekV2ForCausalLM # noqa: E501
54-
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
5554
from vllm.model_executor.models.utils import (
5655
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers,
5756
maybe_prefix)
@@ -60,8 +59,10 @@
6059
import vllm_ascend.envs as envs_ascend
6160
from vllm_ascend.ascend_config import get_ascend_config
6261
from vllm_ascend.ascend_forward_context import FusedMoEState
63-
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLAAttention,
64-
CustomDeepseekV2MLP)
62+
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2DecoderLayer,
63+
CustomDeepseekV2MLAAttention,
64+
CustomDeepseekV2MLP,
65+
CustomDeepseekV2MoE)
6566
from vllm_ascend.multistream.base import MSEventKey
6667
from vllm_ascend.multistream.context import (
6768
advance_step_multistream_layer_context, get_multistream_comm_context,
@@ -126,7 +127,7 @@ def _forward_ms_mlp(self, x):
126127
return x
127128

128129

129-
class CustomDeepseekDBOMoE(nn.Module):
130+
class CustomDeepseekDBOMoE(CustomDeepseekV2MoE):
130131

131132
top_k: int
132133

@@ -136,45 +137,9 @@ def __init__(
136137
quant_config: Optional[QuantizationConfig] = None,
137138
prefix: str = "",
138139
):
139-
super().__init__()
140-
self.tp_size = get_tensor_model_parallel_world_size()
141-
self.routed_scaling_factor = config.routed_scaling_factor
142-
self.n_shared_experts = config.n_shared_experts
143-
self.routed_scaling_factor = config.routed_scaling_factor
144-
if self.tp_size > config.n_routed_experts:
145-
raise ValueError(
146-
f"Tensor parallel size {self.tp_size} is greater than "
147-
f"the number of experts {config.n_routed_experts}.")
148-
149-
if config.hidden_act != "silu":
150-
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
151-
"Only silu is supported for now.")
152-
153-
self.gate = ReplicatedLinear(config.hidden_size,
154-
config.n_routed_experts,
155-
bias=False,
156-
quant_config=None,
157-
prefix=f"{prefix}.gate")
158-
if config.topk_method == "noaux_tc":
159-
self.gate.e_score_correction_bias = nn.Parameter(
160-
torch.empty(config.n_routed_experts))
161-
else:
162-
self.gate.e_score_correction_bias = None
163-
164-
self.experts = AscendFusedMoE(
165-
num_experts=config.n_routed_experts,
166-
top_k=config.num_experts_per_tok,
167-
hidden_size=config.hidden_size,
168-
intermediate_size=config.moe_intermediate_size,
169-
reduce_results=False,
170-
renormalize=config.norm_topk_prob,
171-
quant_config=quant_config,
172-
use_grouped_topk=True,
173-
num_expert_group=config.n_group,
174-
topk_group=config.topk_group,
175-
prefix=f"{prefix}.experts",
176-
scoring_func=config.scoring_func,
177-
e_score_correction_bias=self.gate.e_score_correction_bias)
140+
super().__init__(config=config,
141+
quant_config=quant_config,
142+
prefix=prefix)
178143

179144
if config.n_shared_experts is not None:
180145
intermediate_size = (config.moe_intermediate_size *
@@ -189,19 +154,6 @@ def __init__(
189154
)
190155
CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok
191156

192-
self.dp_size = get_dp_group().world_size
193-
194-
self.tp_group = get_tp_group().device_group
195-
self.tp_rank = get_tp_group().rank_in_group
196-
self.kv_consumer = None
197-
transfer_config = get_current_vllm_config().kv_transfer_config
198-
if transfer_config is not None:
199-
self.kv_consumer = transfer_config.kv_role = "kv_consumer"
200-
self.params_dtype = torch.get_default_dtype()
201-
202-
ascend_config = get_ascend_config()
203-
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
204-
205157
def forward(
206158
self,
207159
hidden_states: torch.Tensor,
@@ -254,7 +206,7 @@ def _forward_ms_op_gate(
254206
return router_logits
255207

256208

257-
class CustomDeepseekDBODecoderLayer(DeepseekV2DecoderLayer):
209+
class CustomDeepseekDBODecoderLayer(CustomDeepseekV2DecoderLayer):
258210

259211
def __init__(
260212
self,
@@ -264,43 +216,19 @@ def __init__(
264216
cache_config: Optional[CacheConfig] = None,
265217
quant_config: Optional[QuantizationConfig] = None,
266218
) -> None:
267-
nn.Module.__init__(self)
268-
self.hidden_size = config.hidden_size
269-
rope_theta = getattr(config, "rope_theta", 10000)
270-
rope_scaling = getattr(config, "rope_scaling", None)
271-
max_position_embeddings = getattr(config, "max_position_embeddings",
272-
8192)
273-
# DecoderLayers are created with `make_layers` which passes the prefix
274-
# with the layer's index.
275-
layer_idx = int(prefix.split(sep='.')[-1])
276-
self.layer_idx = layer_idx
277-
# TODO: enable mla in vllm-ascend
278-
attn_cls = CustomDeepseekV2MLAAttention
279-
self.self_attn = attn_cls(
280-
config=config,
281-
hidden_size=self.hidden_size,
282-
num_heads=config.num_attention_heads,
283-
qk_nope_head_dim=config.qk_nope_head_dim,
284-
qk_rope_head_dim=config.qk_rope_head_dim,
285-
v_head_dim=config.v_head_dim,
286-
q_lora_rank=config.q_lora_rank
287-
if hasattr(config, "q_lora_rank") else None,
288-
kv_lora_rank=config.kv_lora_rank,
289-
rope_theta=rope_theta,
290-
rope_scaling=rope_scaling,
291-
max_position_embeddings=max_position_embeddings,
292-
cache_config=cache_config,
293-
quant_config=quant_config,
294-
prefix=f"{prefix}.self_attn",
295-
)
219+
super().__init__(config=config,
220+
prefix=prefix,
221+
model_config=model_config,
222+
cache_config=cache_config,
223+
quant_config=quant_config)
296224
self.tp_size = get_tensor_model_parallel_world_size()
297225
self.dp_size = get_dp_group().world_size
298226
self.tp_group = get_tp_group().device_group
299227
self.global_num_experts = config.n_routed_experts
300228

301229
if (config.n_routed_experts is not None
302-
and layer_idx >= config.first_k_dense_replace
303-
and layer_idx % config.moe_layer_freq == 0):
230+
and self.layer_idx >= config.first_k_dense_replace
231+
and self.layer_idx % config.moe_layer_freq == 0):
304232
self.mlp = CustomDeepseekDBOMoE(
305233
config=config,
306234
quant_config=quant_config,
@@ -314,11 +242,6 @@ def __init__(
314242
quant_config=quant_config,
315243
prefix=f"{prefix}.mlp",
316244
)
317-
self.input_layernorm = RMSNorm(config.hidden_size,
318-
eps=config.rms_norm_eps)
319-
self.post_attention_layernorm = RMSNorm(config.hidden_size,
320-
eps=config.rms_norm_eps)
321-
self.routed_scaling_factor = config.routed_scaling_factor
322245

323246
def forward(
324247
self,
@@ -921,12 +844,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
921844
self.config = config
922845
self.quant_config = quant_config
923846
self.model = CustomDeepseekDBOModel(vllm_config=vllm_config,
924-
prefix=maybe_prefix(
925-
prefix, "model"))
847+
prefix=maybe_prefix(
848+
prefix, "model"))
926849
if get_pp_group().is_last_rank:
927850
self.lm_head = ParallelLMHead(config.vocab_size,
928851
config.hidden_size,
929-
quant_config=quant_config)
852+
quant_config=quant_config,
853+
prefix=maybe_prefix(
854+
prefix, "lm_head"))
930855
else:
931856
self.lm_head = PPMissingLayer()
932857
self.logits_processor = LogitsProcessor(config.vocab_size)

0 commit comments

Comments
 (0)