Skip to content

Commit 5956ef0

Browse files
author
weijinqian_v1
committed
handle code clean
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
2 parents adf3f74 + 2351977 commit 5956ef0

File tree

9 files changed

+307
-36
lines changed

9 files changed

+307
-36
lines changed

tests/singlecard/test_ascend_config.py

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,12 @@ def test_run_with_ascend_config():
5454
# torchair graph only works with deepseek. The e2e test should be added
5555
# in multicard test with deepseek models.
5656
"enabled": False,
57-
"use_cached_graph": True,
58-
"graph_batch_sizes": [1, 2, 4, 8],
57+
"use_cached_graph": False,
58+
"graph_batch_sizes": [],
5959
"graph_batch_sizes_init": False,
60-
"enable_multistream_moe": True,
61-
"enable_multistream_mla": True,
60+
"enable_multistream_moe": False,
61+
"enable_multistream_mla": False,
62+
"enable_view_optimize": False,
6263
},
6364
"ascend_scheduler_config": {
6465
"enabled": True,
@@ -73,13 +74,12 @@ def test_run_with_ascend_config():
7374
ascend_config = get_ascend_config()
7475

7576
assert not ascend_config.torchair_graph_config.enabled
76-
assert ascend_config.torchair_graph_config.use_cached_graph
77-
assert ascend_config.torchair_graph_config.graph_batch_sizes == [
78-
1, 2, 4, 8
79-
]
77+
assert not ascend_config.torchair_graph_config.use_cached_graph
78+
assert ascend_config.torchair_graph_config.graph_batch_sizes == []
8079
assert not ascend_config.torchair_graph_config.graph_batch_sizes_init
81-
assert ascend_config.torchair_graph_config.enable_multistream_mla
82-
assert ascend_config.torchair_graph_config.enable_multistream_moe
80+
assert not ascend_config.torchair_graph_config.enable_multistream_mla
81+
assert not ascend_config.torchair_graph_config.enable_multistream_moe
82+
assert not ascend_config.torchair_graph_config.enable_view_optimize
8383
assert ascend_config.ascend_scheduler_config.enabled
8484
assert ascend_config.ascend_scheduler_config.enable_chunked_prefill
8585

@@ -142,6 +142,58 @@ def test_ascend_config_load_error():
142142
additional_config=input_additional_config_fake_3):
143143
pass
144144

145+
# use_cached_graph should not be enabled without torchair graph mode
146+
with pytest.raises(RuntimeError):
147+
input_additional_config_fake_4 = {
148+
"torchair_graph_config": {
149+
"enabled": False,
150+
"use_cached_graph": True,
151+
},
152+
}
153+
with VllmRunner("facebook/opt-125m",
154+
enforce_eager=True,
155+
additional_config=input_additional_config_fake_4):
156+
pass
157+
158+
# graph_batch_sizes_init should not be enabled without torchair graph mode
159+
with pytest.raises(RuntimeError):
160+
input_additional_config_fake_5 = {
161+
"torchair_graph_config": {
162+
"enabled": False,
163+
"graph_batch_sizes_init": True,
164+
},
165+
}
166+
with VllmRunner("facebook/opt-125m",
167+
enforce_eager=True,
168+
additional_config=input_additional_config_fake_5):
169+
pass
170+
171+
# enable_multistream_mla should not be enabled without torchair graph mode
172+
with pytest.raises(RuntimeError):
173+
input_additional_config_fake_6 = {
174+
"torchair_graph_config": {
175+
"enabled": False,
176+
"enable_multistream_mla": True,
177+
},
178+
}
179+
with VllmRunner("facebook/opt-125m",
180+
enforce_eager=True,
181+
additional_config=input_additional_config_fake_6):
182+
pass
183+
184+
# enable_multistream_moe should not be enabled without torchair graph mode
185+
with pytest.raises(RuntimeError):
186+
input_additional_config_fake_7 = {
187+
"torchair_graph_config": {
188+
"enabled": False,
189+
"enable_multistream_moe": True,
190+
},
191+
}
192+
with VllmRunner("facebook/opt-125m",
193+
enforce_eager=True,
194+
additional_config=input_additional_config_fake_7):
195+
pass
196+
145197

146198
@_clean_up_ascend_config
147199
def test_check_ascend_config_v0():
@@ -168,9 +220,7 @@ def test_ascend_config_refresh():
168220
input_additional_config = {
169221
"torchair_graph_config": {
170222
"enabled": False,
171-
"use_cached_graph": True,
172-
"graph_batch_sizes": [1, 2, 4, 8],
173-
"graph_batch_sizes_init": False,
223+
"enable_view_optimize": False
174224
},
175225
"refresh": True,
176226
}
@@ -180,9 +230,4 @@ def test_ascend_config_refresh():
180230
additional_config=input_additional_config):
181231
ascend_config = get_ascend_config()
182232

183-
assert not ascend_config.torchair_graph_config.enabled
184-
assert ascend_config.torchair_graph_config.use_cached_graph
185-
assert ascend_config.torchair_graph_config.graph_batch_sizes == [
186-
1, 2, 4, 8
187-
]
188-
assert not ascend_config.torchair_graph_config.graph_batch_sizes_init
233+
assert not ascend_config.torchair_graph_config.enable_view_optimize

vllm_ascend/ascend_config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,31 @@ def __init__(self, torchair_graph_config):
7070
raise ValueError(
7171
"graph_batch_sizes_init is only valid when graph_batch_sizes is empty"
7272
)
73+
if not self.enabled:
74+
if self.use_cached_graph:
75+
raise RuntimeError(
76+
"use_cached_graph is valid only when Torchair graph mode is enabled"
77+
)
78+
if self.graph_batch_sizes:
79+
raise RuntimeError(
80+
"graph_batch_sizes is valid only when Torchair graph mode is enabled"
81+
)
82+
if self.graph_batch_sizes_init:
83+
raise RuntimeError(
84+
"graph_batch_sizes_init is valid only when Torchair graph mode is enabled"
85+
)
86+
if self.enable_multistream_mla:
87+
raise RuntimeError(
88+
"enable_multistream_mla is valid only when Torchair graph mode is enabled"
89+
)
90+
if self.enable_multistream_moe:
91+
raise RuntimeError(
92+
"enable_multistream_moe is valid only when Torchair graph mode is enabled"
93+
)
94+
if self.enable_kv_nz:
95+
raise RuntimeError(
96+
"enable_kv_nz is valid only when Torchair graph mode is enabled"
97+
)
7398

7499

75100
class AscendSchedulerConfig:

vllm_ascend/attention/mla_v1.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -352,13 +352,13 @@ def build_torchair_graph_dummy(
352352
else:
353353
attn_state = AscendAttentionState.DecodeOnly
354354
num_decode_tokens = 1
355-
sin = torch.ones(num_reqs,
355+
sin = torch.ones(num_tokens,
356356
1,
357357
1,
358358
self.rope_dim,
359359
dtype=self.runner.dtype,
360360
device=device)
361-
cos = torch.ones(num_reqs,
361+
cos = torch.ones(num_tokens,
362362
1,
363363
1,
364364
self.rope_dim,
@@ -547,15 +547,13 @@ def build(
547547
actual_seq_q_lens = query_start_loc[1:].tolist(
548548
) + self.runner.actual_seq_q_lens[num_reqs:num_reqs +
549549
num_reqs_pad_size]
550-
cos = self.cos_cache[
551-
input_positions].unsqueeze( # type: ignore
552-
1).unsqueeze(2)
553-
sin = self.sin_cache[
554-
input_positions].unsqueeze( # type: ignore
555-
1).unsqueeze(2)
556550
else:
557551
seq_lens_list = seq_lens.tolist()
558-
cos, sin = None, None
552+
553+
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
554+
1).unsqueeze(2)
555+
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
556+
1).unsqueeze(2)
559557
mc2_mask = self.generate_activate_mask(
560558
num_actual_tokens, num_reqs + num_reqs_pad_size)
561559

vllm_ascend/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def register_model():
1212
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1313
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
1414
from .moe_block import AscendSparseMoeBlock # noqa: F401
15+
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
1516

1617
ModelRegistry.register_model(
1718
"DeepSeekMTPModel",

vllm_ascend/models/deepseek_v2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ def __init__(
236236
ascend_config = get_ascend_config()
237237
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
238238
self.enable_multistream_moe = \
239-
ascend_config.torchair_graph_config.enable_multistream_moe
239+
ascend_config.torchair_graph_config.enable_multistream_moe and \
240+
self.torchair_graph_enabled
240241

241242
self.gate = ReplicatedLinear(config.hidden_size,
242243
config.n_routed_experts,
@@ -462,7 +463,8 @@ def __init__(
462463
ascend_config = get_ascend_config()
463464
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
464465
self.enable_multistream_mla = \
465-
ascend_config.torchair_graph_config.enable_multistream_mla
466+
ascend_config.torchair_graph_config.enable_multistream_mla and \
467+
self.torchair_graph_enabled
466468

467469
def forward(
468470
self,

vllm_ascend/models/qwen3.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from collections.abc import Iterable
2+
from typing import Optional, Union
3+
4+
import torch
5+
from torch import nn
6+
from transformers import Qwen3Config
7+
from vllm.compilation.decorators import support_torch_compile
8+
from vllm.config import CacheConfig, VllmConfig
9+
from vllm.distributed import get_pp_group
10+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
11+
from vllm.model_executor.layers.quantization import QuantizationConfig
12+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
13+
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
14+
from vllm.model_executor.models.qwen2 import Qwen2Model
15+
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
16+
from vllm.model_executor.models.utils import (AutoWeightsLoader,
17+
PPMissingLayer, maybe_prefix)
18+
from vllm.model_executor.sampling_metadata import SamplingMetadata
19+
from vllm.sequence import IntermediateTensors
20+
21+
from vllm_ascend.ops.layernorm import AddRMSNormQuant
22+
23+
24+
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
25+
26+
def __init__(
27+
self,
28+
config: Qwen3Config,
29+
cache_config: Optional[CacheConfig] = None,
30+
quant_config: Optional[QuantizationConfig] = None,
31+
prefix: str = "",
32+
) -> None:
33+
super().__init__(config=config,
34+
cache_config=cache_config,
35+
quant_config=quant_config,
36+
prefix=prefix)
37+
if quant_config is None:
38+
return
39+
40+
from vllm_ascend.quantization.quant_config import AscendQuantConfig
41+
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
42+
43+
assert isinstance(quant_config, AscendQuantConfig), \
44+
"Expected quant_config to be an instance of AscendQuantConfig"
45+
46+
if isinstance(self.self_attn.qkv_proj.quant_method,
47+
AscendW8A8LinearMethod):
48+
self.input_layernorm = AddRMSNormQuant(
49+
config.hidden_size,
50+
layer=self.self_attn.qkv_proj,
51+
eps=config.rms_norm_eps)
52+
if isinstance(self.mlp.gate_up_proj.quant_method,
53+
AscendW8A8LinearMethod):
54+
self.post_attention_layernorm = AddRMSNormQuant(
55+
config.hidden_size,
56+
layer=self.mlp.gate_up_proj,
57+
eps=config.rms_norm_eps)
58+
59+
60+
ALL_DECODER_LAYER_TYPES = {
61+
"attention": CustomQwen3DecoderLayer,
62+
}
63+
64+
65+
@support_torch_compile(
66+
dynamic_arg_dims={
67+
"input_ids": 0,
68+
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
69+
# otherwise (seq_len, ).
70+
"positions": -1,
71+
"intermediate_tensors": 0,
72+
"inputs_embeds": 0,
73+
})
74+
class CustomQwen3Model(Qwen2Model):
75+
76+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
77+
super().__init__(vllm_config=vllm_config,
78+
prefix=prefix,
79+
decoder_layer_type=CustomQwen3DecoderLayer)
80+
81+
82+
class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
83+
# add `CustomQwen3Model` to init self.model
84+
packed_modules_mapping = {
85+
"qkv_proj": [
86+
"q_proj",
87+
"k_proj",
88+
"v_proj",
89+
],
90+
"gate_up_proj": [
91+
"gate_proj",
92+
"up_proj",
93+
],
94+
}
95+
96+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
97+
super().__init__()
98+
config = vllm_config.model_config.hf_config
99+
quant_config = vllm_config.quant_config
100+
lora_config = vllm_config.lora_config
101+
102+
self.config = config
103+
self.lora_config = lora_config
104+
105+
self.quant_config = quant_config
106+
self.model = CustomQwen3Model(vllm_config=vllm_config,
107+
prefix=maybe_prefix(prefix, "model"))
108+
109+
if get_pp_group().is_last_rank:
110+
if config.tie_word_embeddings:
111+
self.lm_head = self.model.embed_tokens
112+
else:
113+
self.lm_head = ParallelLMHead(config.vocab_size,
114+
config.hidden_size,
115+
quant_config=quant_config,
116+
prefix=maybe_prefix(
117+
prefix, "lm_head"))
118+
else:
119+
self.lm_head = PPMissingLayer()
120+
121+
self.logits_processor = LogitsProcessor(config.vocab_size)
122+
123+
self.make_empty_intermediate_tensors = (
124+
self.model.make_empty_intermediate_tensors)
125+
126+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
127+
return self.model.get_input_embeddings(input_ids)
128+
129+
def forward(
130+
self,
131+
input_ids: torch.Tensor,
132+
positions: torch.Tensor,
133+
intermediate_tensors: Optional[IntermediateTensors] = None,
134+
inputs_embeds: Optional[torch.Tensor] = None,
135+
) -> Union[torch.Tensor, IntermediateTensors]:
136+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
137+
inputs_embeds)
138+
return hidden_states
139+
140+
def compute_logits(
141+
self,
142+
hidden_states: torch.Tensor,
143+
sampling_metadata: SamplingMetadata,
144+
) -> Optional[torch.Tensor]:
145+
logits = self.logits_processor(self.lm_head, hidden_states,
146+
sampling_metadata)
147+
return logits
148+
149+
def load_weights(self, weights: Iterable[tuple[str,
150+
torch.Tensor]]) -> set[str]:
151+
loader = AutoWeightsLoader(
152+
self,
153+
skip_prefixes=(["lm_head."]
154+
if self.config.tie_word_embeddings else None),
155+
)
156+
return loader.load_weights(weights)

vllm_ascend/ops/fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1110,7 +1110,8 @@ def __init__(
11101110

11111111
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
11121112
self.enable_multistream_moe = \
1113-
ascend_config.torchair_graph_config.enable_multistream_moe
1113+
ascend_config.torchair_graph_config.enable_multistream_moe and \
1114+
self.torchair_graph_enabled
11141115

11151116
if self.scoring_func != "softmax" and not self.use_grouped_topk:
11161117
raise ValueError("Only softmax scoring function is supported for "

0 commit comments

Comments
 (0)