Skip to content

[cherry-pick] Backport multistream MLA fixes and TP communication optimizations #1474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: v0.9.1-dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 72 additions & 43 deletions tests/multicard/test_torchair_graph_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Run `pytest tests/multicard/test_torchair_graph_mode.py`.
"""
import os
from typing import Dict

import pytest

Expand All @@ -28,6 +29,55 @@
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"


def _deepseek_torchair_test_fixture(
additional_config: Dict,
*,
tensor_parallel_size=4,
):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

# torchair is only work without chunked-prefill now
kwargs = {
"ascend_scheduler_config": {
"enabled": True,
},
"refresh": True,
}
additional_config.update(**kwargs)

with VllmRunner(
"vllm-ascend/DeepSeek-V3-Pruning",
dtype="half",
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend="mp",
enforce_eager=False,
additional_config=additional_config,
) as vllm_model:
# use greedy sampler to make sure the generated results are fix
vllm_output = vllm_model.generate_greedy(example_prompts, 5)

# NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of
# DeepSeek-V3 with 2 hidden layers, thus the golden results seems
# inaccurate. This will only change if accuracy improves with the
# official weights of DeepSeek-V3.
golden_results = [
'Hello, my name is下载早点向前很有่อง',
'The president of the United States isSender)## physiological Albany',
'The capital of France is Rocky转角 hospitalizedinterval sparked',
'The future of AI is её asegο BIOS一扫',
]

assert len(golden_results) == len(vllm_output)
for i in range(len(vllm_output)):
assert golden_results[i] == vllm_output[i][1]
print(f"Generated text: {vllm_output[i][1]!r}")


@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="torchair graph is not supported on v0")
@pytest.mark.parametrize("VLLM_ASCEND_ENABLE_DBO", ["0", "1"])
Expand All @@ -38,46 +88,25 @@ def test_e2e_deepseekv3_with_torchair(monkeypatch: pytest.MonkeyPatch,
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
m.setenv("VLLM_ASCEND_ENABLE_DBO", VLLM_ASCEND_ENABLE_DBO)

example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
dtype = "half"
max_tokens = 5
# torchair is only work without chunked-prefill now
with VllmRunner(
"vllm-ascend/DeepSeek-V3-Pruning",
dtype=dtype,
tensor_parallel_size=4,
distributed_executor_backend="mp",
additional_config={
"torchair_graph_config": {
"enabled": True,
},
"ascend_scheduler_config": {
"enabled": True,
},
"refresh": True,
},
enforce_eager=False,
) as vllm_model:
# use greedy sampler to make sure the generated results are fix
vllm_output = vllm_model.generate_greedy(example_prompts,
max_tokens)
# NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of
# DeepSeek-V3 with 2 hidden layers, thus the golden results seems
# inaccurate. This will only change if accuracy improves with the
# official weights of DeepSeek-V3.
golden_results = [
'Hello, my name is下载早点向前很有่อง',
'The president of the United States isSender)## physiological Albany',
'The capital of France is Rocky转角 hospitalizedinterval sparked',
'The future of AI is её asegο BIOS一扫',
]

assert len(golden_results) == len(vllm_output)
for i in range(len(vllm_output)):
assert golden_results[i] == vllm_output[i][1]
print(f"Generated text: {vllm_output[i][1]!r}")
additional_config = {
"torchair_graph_config": {
"enabled": True,
},
}
_deepseek_torchair_test_fixture(additional_config)


@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="torchair graph is not supported on v0")
def test_e2e_deepseekv3_with_torchair_ms_mla(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_MODELSCOPE", "True")
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")

additional_config = {
"torchair_graph_config": {
"enabled": True,
"enable_multistream_mla": True,
},
}
_deepseek_torchair_test_fixture(additional_config)
24 changes: 13 additions & 11 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
MLAAttentionImpl)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
Expand Down Expand Up @@ -584,12 +585,11 @@ def __init__(
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.tp_size = get_tensor_model_parallel_world_size()

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla

# Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config().speculative_config
Expand All @@ -604,7 +604,7 @@ def _v_up_proj_and_o_proj(self, x):
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return self.o_proj(x)[0]
return self.o_proj(x, is_prefill=False)[0]

# Return `ql_nope`, `q_pe`
def _q_proj_and_k_up_proj(self, x):
Expand Down Expand Up @@ -869,12 +869,12 @@ def _forward_prefill(

current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self.o_proj(attn_output)[0]
return self.o_proj(attn_output, is_prefill=True)[0]
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self.o_proj(attn_output)[0]
return self.o_proj(attn_output, is_prefill=True)[0]

def exec_kv(
self,
Expand All @@ -883,6 +883,7 @@ def exec_kv(
sin: torch.Tensor,
kv_cache: Tuple,
slots: torch.Tensor,
enable_multistream_mla: bool = False,
):

B = hidden_states.shape[0]
Expand All @@ -894,7 +895,7 @@ def exec_kv(
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
with npu_stream_switch("mla_secondary",
0,
enabled=self.enable_multistream_mla):
enabled=enable_multistream_mla):
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
Expand Down Expand Up @@ -1066,6 +1067,7 @@ def forward(
kv_cache: Tuple[torch.Tensor],
attn_metadata: M,
output: Optional[torch.Tensor] = None,
enable_multistream_mla: bool = False,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
Expand Down Expand Up @@ -1127,22 +1129,22 @@ def forward(
# KvRmsNormRopeCache and SingleRope.
npu_wait_tensor(decode_hs_or_q_c,
cos,
enabled=self.enable_multistream_mla)
enabled=enable_multistream_mla)
npu_wait_tensor(decode_hs_or_q_c,
sin,
enabled=self.enable_multistream_mla)
enabled=enable_multistream_mla)
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
if self.running_in_graph:
decode_k_pe, decode_k_nope = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
attn_metadata.slot_mapping, enable_multistream_mla)
with npu_stream_switch("mla_secondary",
0,
enabled=self.enable_multistream_mla):
enabled=enable_multistream_mla):
npu_wait_tensor(decode_q_pe,
decode_k_pe,
enabled=self.enable_multistream_mla)
enabled=enable_multistream_mla)
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
else:
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
Expand Down
Loading