Skip to content

Commit 87ebaef

Browse files
authored
[perf]: support dual-batch overlap(dbo) for deepseek (#941)
### What this PR does / why we need it? Based on the design of dual-batch overlap proposed by Deepseek team and also the implementation of fused moe in VLLM project, we implement the multi-stream(also known as dual-batch) overlap for deepseek+mla on Ascend NPU. We split the input batch of model into two microbatches and then overlap the comp/comm ops in attention and moe layers using two streams to improve the performance. Our approach can be easily extended when adding dispatch/combine communications for moe layer. Compared with the previously proposed [draft](#842), we use one stream for computation ops and the other for communication ops, separately. In out opinions, it is beneficial for arranging the order of executing different ops and thus avoiding the contention of computation/communication resources. ref: [overlap for llama](https://github.com/vllm-project/vllm/pull/15787/files) ref: [dbo in sglang](https://github.com/sgl-project/sglang/pull/4068/files#diff-b4937569fc71f6ad215181b633b2f89c7183a2b4ac39e41fc22635599a9be7de) ### Does this PR introduce _any_ user-facing change? Adding an env variable "VLLM_ENABLE_DBO". Users can enable dbo by setting "VLLM_ASCEND_ENABLE_DBO=1" See /examples/offline_dualbatch_overlap_npu.py for more info. ### How was this patch tested? This patch can be tested with vllm-0.9.0 using its online service with benchmark tests. We have decoupled the func of dbo from vllm and it should be able to run without any modification to the code of vllm(some modifications is better to implement in vllm though). Any advice/discussion is welcome. ### Performance Benchmark We have ran the benchmark_serving script of vllm to test the performance after using dual-batch overlap. `python -m vllm.entrypoints.openai.api_server \ --model=DeepSeek-R1-W8A8 \ --trust-remote-code \ --distributed-executor-backend=mp \ -tp=16 \ --port 8006 \ --max-num-seqs 390 \ --max-model-len 32768 \ --max-num-batched-tokens 65536 \ --block-size 128 \ --compilation_config 0 \ --gpu-memory-utilization 0.90 \ --disable-log-requests \ --additional-config '{"expert_tensor_parallel_size":1,"enable_inter_dp_scheduling":true,"init_torchair_graph_batch_sizes":true,"trace_recompiles":true,"ascend_scheduler_config":{},"enable_graph_mode":false}'` and run benchmark with the parameters of : `--dataset-name random --random-input-len 4096 --random-output-len 1 --num-prompts 200 --max-concurrency 8 --request-rate 5 --metric-percentiles 90` 1. test with the version using allgather+allreduce in Ascend 910B (tp16 ep16 + deepseek r1 w8a8) 2. test with the version using alltoall: prefill qps: 0.90 -> 1.01 Mean TTFT:8226->7432ms The overlap approach when using alltoall communication can be further optimized by overlapping micro-batch1's moe comp with micro-batch2's dispatch a2a comm --------- Signed-off-by: zhuohuan <zxdu1997@gmail.com>
1 parent 3640c60 commit 87ebaef

File tree

14 files changed

+1896
-11
lines changed

14 files changed

+1896
-11
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import os
2+
import time
3+
4+
from vllm import LLM, SamplingParams
5+
6+
# enable dual-batch overlap for vllm ascend
7+
os.environ["VLLM_ASCEND_ENABLE_DBO"] = "1"
8+
os.environ["VLLM_USE_V1"] = "1"
9+
10+
# Sample prompts.
11+
prompts = ["The president of the United States is"] * 41
12+
# Create a sampling params object.
13+
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
14+
15+
16+
def main():
17+
# Create an LLM.
18+
llm = LLM(model="deepseek-ai/DeepSeek-V3-Lite-base-latest-w8a8-dynamic",
19+
enforce_eager=True,
20+
tensor_parallel_size=2,
21+
max_model_len=4096,
22+
trust_remote_code=True,
23+
additional_config={
24+
"torchair_graph_config": {
25+
"enabled": False
26+
},
27+
"ascend_scheduler_config": {
28+
"enabled": True
29+
},
30+
"expert_tensor_parallel_size": 1
31+
})
32+
33+
# Generate texts from the prompts. The output is a list of RequestOutput
34+
# objects that contain the prompt, generated text, and other information.
35+
outputs = llm.generate(prompts, sampling_params)
36+
37+
# Print the outputs.
38+
print("-" * 50)
39+
for output in outputs:
40+
prompt = output.prompt
41+
generated_text = output.outputs[0].text
42+
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
43+
print("-" * 50)
44+
45+
# Add a buffer to wait for profiler in the background process
46+
# (in case MP is on) to finish writing profiling output.
47+
time.sleep(10)
48+
49+
50+
if __name__ == "__main__":
51+
main()

tests/multicard/test_offline_inference_distributed.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,17 @@ def test_models_distributed_topk() -> None:
8181
distributed_executor_backend="mp",
8282
) as vllm_model:
8383
vllm_model.generate(example_prompts, sampling_params)
84+
85+
86+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
87+
def test_models_distributed_DeepSeek_dbo():
88+
example_prompts = ["The president of the United States is"] * 41
89+
dtype = "half"
90+
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
91+
with VllmRunner(
92+
"deepseek-ai/DeepSeek-V2-Lite",
93+
dtype=dtype,
94+
tensor_parallel_size=4,
95+
distributed_executor_backend="mp",
96+
) as vllm_model:
97+
vllm_model.generate(example_prompts, sampling_params)

vllm_ascend/attention/mla_v1.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
from vllm_ascend.ascend_config import get_ascend_config
1515
from vllm_ascend.attention.attention_v1 import AscendAttentionState
16+
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
17+
from vllm_ascend.multistream.context import get_multistream_comm_context
18+
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
1619
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
1720

1821
if TYPE_CHECKING:
@@ -117,6 +120,7 @@ class AscendMLAMetadata:
117120

118121
with_prefill_across_dp: bool = False
119122

123+
query_lens: Optional[list[int]] = None
120124
# The dimension of the attention heads
121125
head_dim: Optional[int] = None
122126
attn_mask: torch.Tensor = None
@@ -135,6 +139,17 @@ def __post_init__(self):
135139
# f"Only {supported_head_sizes} are supported for head_dim,",
136140
# f"received {self.head_dim}.")
137141

142+
def split_metadata_for_multistream(
143+
self,
144+
ms_split_config: MSAttentionMetadataSplitConfig,
145+
) -> list["AscendMLAMetadata"]:
146+
"""Split metadata for multi-stream with AscendMLAMetadata"""
147+
return model_input_split_v1_mla_attn(
148+
ms_split_config=ms_split_config,
149+
attn_metadata=self,
150+
_metadata_cls=AscendMLAMetadata,
151+
)
152+
138153

139154
M = TypeVar("M", bound=AscendMLAMetadata)
140155

@@ -386,6 +401,7 @@ def build(
386401

387402
return self.metadata_cls( # type: ignore
388403
num_actual_tokens=num_actual_tokens,
404+
query_lens=query_lens.tolist(),
389405
slot_mapping=slot_mapping,
390406
head_dim=self.runner.model_config.get_head_size(),
391407
num_decodes=self._num_decodes,
@@ -585,7 +601,15 @@ def _forward_prefill(
585601
)
586602
attn_output = attn_output.reshape(
587603
[num_tokens, self.num_heads * self.v_head_dim])
588-
return self.o_proj(attn_output)[0]
604+
605+
current_ms_metadata = get_multistream_comm_context()
606+
if current_ms_metadata is None:
607+
return self.o_proj(attn_output)[0]
608+
else:
609+
current_ms_metadata.before_comm_event.record()
610+
with torch.npu.stream(current_ms_metadata.comm_stream):
611+
current_ms_metadata.before_comm_event.wait()
612+
return self.o_proj(attn_output)[0]
589613

590614
def exec_kv(
591615
self,
@@ -685,7 +709,14 @@ def _forward_decode(
685709
context_lens=attn_metadata.decode.seq_lens, # type:ignore
686710
mla_vheadsize=self.kv_lora_rank,
687711
out=attn_output)
688-
return self._v_up_proj_and_o_proj(attn_output)
712+
current_ms_metadata = get_multistream_comm_context()
713+
if current_ms_metadata is None:
714+
return self._v_up_proj_and_o_proj(attn_output)
715+
else:
716+
current_ms_metadata.before_comm_event.record()
717+
with torch.npu.stream(current_ms_metadata.comm_stream):
718+
current_ms_metadata.before_comm_event.wait()
719+
return self._v_up_proj_and_o_proj(attn_output)
689720

690721
def forward(
691722
self,
@@ -811,16 +842,38 @@ def forward(
811842
key_cache=kv_cache,
812843
slot_indices=attn_metadata.slot_mapping.flatten())
813844
if has_prefill:
814-
output[num_decode_tokens:] = self._forward_prefill(
815-
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
816-
attn_metadata)
845+
# FIX: aicore move should be also placed on the comm stream in dbo,
846+
# otherwise it may affect the accuracy
847+
# TODO: use an elegant way to overlap
848+
output_prefill = self._forward_prefill(prefill_q,
849+
prefill_k_c_normed,
850+
prefill_k_pe, kv_cache,
851+
attn_metadata)
852+
current_ms_metadata = get_multistream_comm_context()
853+
if current_ms_metadata is not None:
854+
with torch.npu.stream(current_ms_metadata.comm_stream):
855+
output[num_decode_tokens:] = output_prefill
856+
current_ms_metadata.after_comm_event.record()
857+
else:
858+
output[num_decode_tokens:] = output_prefill
859+
817860
if has_decode:
818861
if self.running_in_graph:
819862
return self._forward_decode(decode_ql_nope, decode_q_pe,
820863
decode_k_nope, decode_k_pe,
821864
kv_cache, attn_metadata)
822865
else:
823-
output[:num_decode_tokens] = self._forward_decode(
824-
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
825-
kv_cache, attn_metadata)
866+
output_decode = self._forward_decode(decode_ql_nope,
867+
decode_q_pe,
868+
decode_k_nope,
869+
decode_k_pe, kv_cache,
870+
attn_metadata)
871+
current_ms_metadata = get_multistream_comm_context()
872+
if current_ms_metadata is not None:
873+
with torch.npu.stream(current_ms_metadata.comm_stream):
874+
output[:num_decode_tokens] = output_decode
875+
current_ms_metadata.after_comm_event.record()
876+
else:
877+
output[:num_decode_tokens] = output_decode
878+
826879
return output_padded

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@
107107
# Whether to enable the trace recompiles from pytorch.
108108
"VLLM_ASCEND_TRACE_RECOMPILES":
109109
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
110+
"VLLM_ASCEND_ENABLE_DBO":
111+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
110112
# Whether to enable the model execute time observe profile. Disable it when
111113
# running vllm ascend in production environment.
112114
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":

vllm_ascend/models/__init__.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from vllm import ModelRegistry
22

3+
import vllm_ascend.envs as envs
4+
35

46
def register_model():
7+
from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401
58
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
69
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
710
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
@@ -22,9 +25,14 @@ def register_model():
2225
"vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration"
2326
)
2427

25-
ModelRegistry.register_model(
26-
"DeepseekV2ForCausalLM",
27-
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
28+
if envs.VLLM_ASCEND_ENABLE_DBO:
29+
ModelRegistry.register_model(
30+
"DeepseekV2ForCausalLM",
31+
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
32+
else:
33+
ModelRegistry.register_model(
34+
"DeepseekV2ForCausalLM",
35+
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
2836

2937
ModelRegistry.register_model(
3038
"DeepseekV3ForCausalLM",

0 commit comments

Comments
 (0)