Skip to content

Commit 3191183

Browse files
[BugFix]dbo support torchair graph in decode (#1420)
### What this PR does / why we need it? DBO support torchair graph model in decode, make its possible to set `"torchair_graph_config": {"enabled": True } ` when using DBO mode . ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? add ut test case `tests/multicard/test_torchair_graph_mode.py` Signed-off-by: shikang-hangzhou <459956190@qq.com>
1 parent bc546a9 commit 3191183

File tree

3 files changed

+72
-35
lines changed

3 files changed

+72
-35
lines changed

tests/multicard/test_torchair_graph_mode.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,13 @@
3030

3131
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
3232
reason="torchair graph is not supported on v0")
33-
def test_e2e_deepseekv3_with_torchair(monkeypatch: pytest.MonkeyPatch):
33+
@pytest.mark.parametrize("VLLM_ASCEND_ENABLE_DBO", ["0", "1"])
34+
def test_e2e_deepseekv3_with_torchair(monkeypatch: pytest.MonkeyPatch,
35+
VLLM_ASCEND_ENABLE_DBO):
3436
with monkeypatch.context() as m:
3537
m.setenv("VLLM_USE_MODELSCOPE", "True")
3638
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
39+
m.setenv("VLLM_ASCEND_ENABLE_DBO", VLLM_ASCEND_ENABLE_DBO)
3740

3841
example_prompts = [
3942
"Hello, my name is",

vllm_ascend/models/deepseek_dbo.py

Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343
from vllm.model_executor.layers.layernorm import RMSNorm
4444
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
4545
ReplicatedLinear,
46-
RowParallelLinear)
46+
RowParallelLinear,
47+
UnquantizedLinearMethod)
4748
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4849
from vllm.model_executor.layers.quantization import QuantizationConfig
4950
from vllm.model_executor.layers.rotary_embedding import get_rope
@@ -75,31 +76,56 @@
7576
MultiStreamStepMetadata,
7677
make_multistream_metadata_ds)
7778
from vllm_ascend.ops.fused_moe import AscendFusedMoE
79+
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
7880
from vllm_ascend.utils import dispose_tensor
7981

8082
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
8183

8284

8385
class CustomDeepseekDBOMLP(CustomDeepseekV2MLP):
8486

87+
def __init__(
88+
self,
89+
hidden_size: int,
90+
intermediate_size: int,
91+
hidden_act: str,
92+
quant_config: Optional[QuantizationConfig] = None,
93+
reduce_results: bool = True,
94+
prefix: str = "",
95+
) -> None:
96+
super().__init__(hidden_size=hidden_size,
97+
intermediate_size=intermediate_size,
98+
hidden_act=hidden_act,
99+
quant_config=quant_config,
100+
prefix=prefix)
101+
self.is_dynamic_quant = not isinstance(
102+
self.gate_up_proj.quant_method,
103+
UnquantizedLinearMethod) and isinstance(
104+
self.gate_up_proj.quant_method.quant_method,
105+
AscendW8A8DynamicLinearMethod)
106+
85107
def _forward_ms_mlp(self, x):
86108
current_ms_metadata = get_multistream_comm_context()
87109
assert current_ms_metadata is not None
88110
gate_up, _ = self.gate_up_proj(x)
89-
x, dynamic_scale = self.act_fn(gate_up)
90-
x = torch_npu.npu_quant_matmul(
91-
x,
92-
self.down_proj.weight,
93-
self.down_proj.weight_scale,
94-
pertoken_scale=dynamic_scale,
95-
output_dtype=torch.bfloat16,
96-
)
97-
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
98-
current_ms_metadata.before_comm_event.record()
99-
with torch.npu.stream(current_ms_metadata.comm_stream):
100-
current_ms_metadata.before_comm_event.wait()
101-
x = tensor_model_parallel_all_reduce(x)
102-
current_ms_metadata.after_comm_event.record()
111+
if self.is_dynamic_quant:
112+
x, dynamic_scale = self.act_fn(gate_up)
113+
x = torch_npu.npu_quant_matmul(
114+
x,
115+
self.down_proj.weight,
116+
self.down_proj.weight_scale,
117+
pertoken_scale=dynamic_scale,
118+
output_dtype=torch.bfloat16,
119+
)
120+
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
121+
current_ms_metadata.before_comm_event.record()
122+
with torch.npu.stream(current_ms_metadata.comm_stream):
123+
current_ms_metadata.before_comm_event.wait()
124+
x = tensor_model_parallel_all_reduce(x)
125+
current_ms_metadata.after_comm_event.record()
126+
else:
127+
x = self.act_fn(gate_up)
128+
x, _ = self.down_proj(x)
103129
return x
104130

105131

@@ -796,6 +822,7 @@ def forward(
796822
attn_metadata: Optional[AttentionMetadata] = None,
797823
intermediate_tensors: Optional[IntermediateTensors] = None,
798824
inputs_embeds: Optional[torch.Tensor] = None,
825+
graph_enable: Optional[bool] = True
799826
) -> Union[torch.Tensor, IntermediateTensors]:
800827
if get_pp_group().is_first_rank:
801828
if inputs_embeds is not None:
@@ -809,8 +836,9 @@ def forward(
809836
residual = intermediate_tensors["residual"]
810837

811838
num_normal_layers = (self.first_k_dense_replace
812-
if VLLM_ASCEND_ENABLE_DBO and self.can_run_ms()
813-
else self.end_layer - self.start_layer)
839+
if VLLM_ASCEND_ENABLE_DBO and not graph_enable
840+
and self.can_run_ms() else self.end_layer -
841+
self.start_layer)
814842

815843
moe_start_layer = self.start_layer + num_normal_layers
816844
for i in range(self.start_layer, min(moe_start_layer, self.end_layer)):
@@ -847,15 +875,13 @@ def can_run_ms(self):
847875
return False
848876
return True
849877

850-
def _forward_ms_layers(
851-
self,
852-
positions: torch.Tensor,
853-
hidden_states: torch.Tensor,
854-
residual: torch.Tensor,
855-
moe_start_layer: int,
856-
kv_caches: Optional[List[torch.Tensor]] = None,
857-
is_prefill: bool = False,
858-
):
878+
def _forward_ms_layers(self,
879+
positions: torch.Tensor,
880+
hidden_states: torch.Tensor,
881+
residual: torch.Tensor,
882+
moe_start_layer: int,
883+
kv_caches: Optional[List[torch.Tensor]] = None,
884+
is_prefill: bool = False):
859885

860886
if moe_start_layer == self.end_layer:
861887
return hidden_states, residual
@@ -917,8 +943,9 @@ def forward(
917943
attn_metadata: Optional[AttentionMetadata] = None,
918944
intermediate_tensors: Optional[IntermediateTensors] = None,
919945
inputs_embeds: Optional[torch.Tensor] = None,
946+
graph_enable: Optional[bool] = True
920947
) -> Union[torch.Tensor, IntermediateTensors]:
921948
hidden_states = self.model(input_ids, positions, kv_caches,
922949
attn_metadata, intermediate_tensors,
923-
inputs_embeds)
950+
inputs_embeds, graph_enable)
924951
return hidden_states

vllm_ascend/worker/model_runner_v1.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,8 @@ def _process_reqs(
10371037
if self.torchair_graph_enabled:
10381038
model_kwargs["kv_caches"] = self.kv_caches
10391039
model_kwargs["attn_metadata"] = attn_metadata
1040+
if envs_ascend.VLLM_ASCEND_ENABLE_DBO and with_prefill:
1041+
model_kwargs["graph_enable"] = False # type: ignore
10401042
if self.torchair_graph_enabled and not with_prefill:
10411043
compiled_model = self._get_torchair_lazy_compiled_model(
10421044
padded_num_tokens_across_dp)
@@ -1045,17 +1047,15 @@ def _process_reqs(
10451047
positions=positions,
10461048
intermediate_tensors=intermediate_tensors,
10471049
inputs_embeds=inputs_embeds,
1048-
**model_kwargs,
1049-
)
1050+
**model_kwargs)
10501051
else:
10511052
assert self.model is not None
10521053
hidden_states = self.model(
10531054
input_ids=input_ids,
10541055
positions=positions,
10551056
intermediate_tensors=intermediate_tensors,
10561057
inputs_embeds=inputs_embeds,
1057-
**model_kwargs,
1058-
)
1058+
**model_kwargs)
10591059

10601060
self.maybe_wait_for_kv_save()
10611061
finished_sending, finished_recving = self.get_finished_kv_transfer(
@@ -1586,6 +1586,7 @@ def _dummy_run(
15861586
num_tokens_across_dp=num_tokens_across_dp,
15871587
with_prefill=with_prefill,
15881588
in_profile_run=self.in_profile_run):
1589+
model_kwargs = {}
15891590
if self.torchair_graph_enabled and not with_prefill:
15901591
# Only mark static while compiling
15911592
if is_torchair_compile:
@@ -1603,20 +1604,26 @@ def _dummy_run(
16031604
torch._dynamo.mark_static(kv[1])
16041605
compiled_model = self._get_torchair_lazy_compiled_model(
16051606
num_tokens)
1607+
model_kwargs["kv_caches"] = self.kv_caches
1608+
model_kwargs["attn_metadata"] = attn_metadata
1609+
if envs_ascend.VLLM_ASCEND_ENABLE_DBO:
1610+
model_kwargs["graph_enable"] = True # type: ignore
16061611
hidden_states = compiled_model(
16071612
input_ids=input_ids,
16081613
positions=positions,
16091614
intermediate_tensors=intermediate_tensors,
16101615
inputs_embeds=None,
1611-
kv_caches=self.kv_caches,
1612-
attn_metadata=attn_metadata,
1616+
**model_kwargs,
16131617
)
16141618
else:
1619+
if envs_ascend.VLLM_ASCEND_ENABLE_DBO:
1620+
model_kwargs["graph_enable"] = False # type: ignore
16151621
hidden_states = model(
16161622
input_ids=input_ids,
16171623
positions=positions,
16181624
intermediate_tensors=intermediate_tensors,
1619-
inputs_embeds=inputs_embeds)
1625+
inputs_embeds=inputs_embeds,
1626+
**model_kwargs)
16201627
return hidden_states
16211628

16221629
@contextmanager

0 commit comments

Comments
 (0)