Skip to content

Commit f04c676

Browse files
authored
[Bugfix] fix env variable in dbo (#1284)
### What this PR does / why we need it? Fix env variable in dbo to enable dbo in DeepSeek-V3 model. Besides, we have fixed an known issue in deepseek-dbo. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? This patch can be tested with newly added e2e tests: [tests/multicard/test_offline_inference_distributed.py](https://github.com/vllm-project/vllm-ascend/pull/1285/files#diff-7cd2e6b1bda6b8ad1bedb3276971fe7064aeae4dc0efd41c301c4ede2158c57e). It can be verified with pytest. --------- Signed-off-by: zhuohuan <zxdu1997@gmail.com>
1 parent 21fb68a commit f04c676

File tree

4 files changed

+41
-7
lines changed

4 files changed

+41
-7
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,8 @@ jobs:
361361
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
362362
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
363363
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
364+
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo
365+
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
364366
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py --ignore=tests/e2e/multicard/test_offline_inference_distributed.py
365367
366368
- name: Run vllm-project/vllm-ascend test on V0 engine

tests/e2e/multicard/test_offline_inference_distributed.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from modelscope import snapshot_download # type: ignore
2727
from vllm import SamplingParams
28+
from vllm.model_executor.models.registry import ModelRegistry
2829

2930
from tests.conftest import VllmRunner
3031

@@ -94,6 +95,32 @@ def test_models_distributed_DeepSeek_dbo():
9495
tensor_parallel_size=4,
9596
distributed_executor_backend="mp",
9697
) as vllm_model:
98+
model_arch = 'DeepseekV2ForCausalLM'
99+
registed_models = ModelRegistry.models
100+
assert registed_models[
101+
model_arch].module_name == "vllm_ascend.models.deepseek_dbo"
102+
assert registed_models[
103+
model_arch].class_name == "CustomDeepseekDBOForCausalLM"
104+
vllm_model.generate(example_prompts, sampling_params)
105+
106+
107+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
108+
def test_models_distributed_DeepSeekV3_dbo():
109+
example_prompts = ["The president of the United States is"] * 41
110+
dtype = "half"
111+
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
112+
with VllmRunner(
113+
"vllm-ascend/DeepSeek-V3-Pruning",
114+
dtype=dtype,
115+
tensor_parallel_size=4,
116+
distributed_executor_backend="mp",
117+
) as vllm_model:
118+
model_arch = 'DeepseekV3ForCausalLM'
119+
registed_models = ModelRegistry.models
120+
assert registed_models[
121+
model_arch].module_name == "vllm_ascend.models.deepseek_dbo"
122+
assert registed_models[
123+
model_arch].class_name == "CustomDeepseekDBOForCausalLM"
97124
vllm_model.generate(example_prompts, sampling_params)
98125

99126

vllm_ascend/models/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,19 @@ def register_model():
3535
ModelRegistry.register_model(
3636
"DeepseekV2ForCausalLM",
3737
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
38+
39+
ModelRegistry.register_model(
40+
"DeepseekV3ForCausalLM",
41+
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
42+
3843
else:
3944
ModelRegistry.register_model(
4045
"DeepseekV2ForCausalLM",
4146
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
4247

43-
ModelRegistry.register_model(
44-
"DeepseekV3ForCausalLM",
45-
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
48+
ModelRegistry.register_model(
49+
"DeepseekV3ForCausalLM",
50+
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
4651

4752
ModelRegistry.register_model(
4853
"Qwen3MoeForCausalLM",

vllm_ascend/models/deepseek_dbo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ def _forward_ms_layer(
641641

642642
if self.mlp.tp_size > 1:
643643
num_token, _ = hidden_states[i].shape
644-
padded_num_tokens = (self.mlp.tp_size - num_token %
644+
padded_num_tokens = (self.mlp.tp_size - num_tokens[i] %
645645
self.mlp.tp_size) % self.mlp.tp_size
646646
if padded_num_tokens > 0:
647647
hidden_states[i] = nn.functional.pad(
@@ -851,16 +851,16 @@ def forward(
851851
if VLLM_ASCEND_ENABLE_DBO and self.can_run_ms()
852852
else self.end_layer - self.start_layer)
853853

854-
for i in range(self.start_layer, self.start_layer + num_normal_layers):
854+
moe_start_layer = self.start_layer + num_normal_layers
855+
for i in range(self.start_layer, min(moe_start_layer, self.end_layer)):
855856
layer = self.layers[i]
856857
hidden_states, residual = layer(
857858
positions, hidden_states, residual,
858859
kv_caches[i -
859860
self.start_layer] if kv_caches is not None else None,
860861
attn_metadata)
861862

862-
moe_start_layer = self.start_layer + num_normal_layers
863-
if moe_start_layer != self.end_layer:
863+
if moe_start_layer < self.end_layer:
864864
# if we enable multistream/dbo, process sparse layers here
865865
hidden_states, residual = self._forward_ms_layers(
866866
positions=positions,

0 commit comments

Comments
 (0)