Skip to content

Commit 9e46a00

Browse files
authored
[0.9.1][Config] Add extra checking to torchair_graph_config. (#1675)
This PR adds validation checking to torchair_graph_config for better reliability. Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 57664f0 commit 9e46a00

File tree

4 files changed

+95
-22
lines changed

4 files changed

+95
-22
lines changed

tests/singlecard/test_ascend_config.py

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,12 @@ def test_run_with_ascend_config():
5454
# torchair graph only works with deepseek. The e2e test should be added
5555
# in multicard test with deepseek models.
5656
"enabled": False,
57-
"use_cached_graph": True,
58-
"graph_batch_sizes": [1, 2, 4, 8],
57+
"use_cached_graph": False,
58+
"graph_batch_sizes": [],
5959
"graph_batch_sizes_init": False,
60-
"enable_multistream_moe": True,
61-
"enable_multistream_mla": True,
60+
"enable_multistream_moe": False,
61+
"enable_multistream_mla": False,
62+
"enable_view_optimize": False,
6263
},
6364
"ascend_scheduler_config": {
6465
"enabled": True,
@@ -73,13 +74,12 @@ def test_run_with_ascend_config():
7374
ascend_config = get_ascend_config()
7475

7576
assert not ascend_config.torchair_graph_config.enabled
76-
assert ascend_config.torchair_graph_config.use_cached_graph
77-
assert ascend_config.torchair_graph_config.graph_batch_sizes == [
78-
1, 2, 4, 8
79-
]
77+
assert not ascend_config.torchair_graph_config.use_cached_graph
78+
assert ascend_config.torchair_graph_config.graph_batch_sizes == []
8079
assert not ascend_config.torchair_graph_config.graph_batch_sizes_init
81-
assert ascend_config.torchair_graph_config.enable_multistream_mla
82-
assert ascend_config.torchair_graph_config.enable_multistream_moe
80+
assert not ascend_config.torchair_graph_config.enable_multistream_mla
81+
assert not ascend_config.torchair_graph_config.enable_multistream_moe
82+
assert not ascend_config.torchair_graph_config.enable_view_optimize
8383
assert ascend_config.ascend_scheduler_config.enabled
8484
assert ascend_config.ascend_scheduler_config.enable_chunked_prefill
8585

@@ -142,6 +142,58 @@ def test_ascend_config_load_error():
142142
additional_config=input_additional_config_fake_3):
143143
pass
144144

145+
# use_cached_graph should not be enabled without torchair graph mode
146+
with pytest.raises(RuntimeError):
147+
input_additional_config_fake_4 = {
148+
"torchair_graph_config": {
149+
"enabled": False,
150+
"use_cached_graph": True,
151+
},
152+
}
153+
with VllmRunner("facebook/opt-125m",
154+
enforce_eager=True,
155+
additional_config=input_additional_config_fake_4):
156+
pass
157+
158+
# graph_batch_sizes_init should not be enabled without torchair graph mode
159+
with pytest.raises(RuntimeError):
160+
input_additional_config_fake_5 = {
161+
"torchair_graph_config": {
162+
"enabled": False,
163+
"graph_batch_sizes_init": True,
164+
},
165+
}
166+
with VllmRunner("facebook/opt-125m",
167+
enforce_eager=True,
168+
additional_config=input_additional_config_fake_5):
169+
pass
170+
171+
# enable_multistream_mla should not be enabled without torchair graph mode
172+
with pytest.raises(RuntimeError):
173+
input_additional_config_fake_6 = {
174+
"torchair_graph_config": {
175+
"enabled": False,
176+
"enable_multistream_mla": True,
177+
},
178+
}
179+
with VllmRunner("facebook/opt-125m",
180+
enforce_eager=True,
181+
additional_config=input_additional_config_fake_6):
182+
pass
183+
184+
# enable_multistream_moe should not be enabled without torchair graph mode
185+
with pytest.raises(RuntimeError):
186+
input_additional_config_fake_7 = {
187+
"torchair_graph_config": {
188+
"enabled": False,
189+
"enable_multistream_moe": True,
190+
},
191+
}
192+
with VllmRunner("facebook/opt-125m",
193+
enforce_eager=True,
194+
additional_config=input_additional_config_fake_7):
195+
pass
196+
145197

146198
@_clean_up_ascend_config
147199
def test_check_ascend_config_v0():
@@ -168,9 +220,7 @@ def test_ascend_config_refresh():
168220
input_additional_config = {
169221
"torchair_graph_config": {
170222
"enabled": False,
171-
"use_cached_graph": True,
172-
"graph_batch_sizes": [1, 2, 4, 8],
173-
"graph_batch_sizes_init": False,
223+
"enable_view_optimize": False
174224
},
175225
"refresh": True,
176226
}
@@ -180,9 +230,4 @@ def test_ascend_config_refresh():
180230
additional_config=input_additional_config):
181231
ascend_config = get_ascend_config()
182232

183-
assert not ascend_config.torchair_graph_config.enabled
184-
assert ascend_config.torchair_graph_config.use_cached_graph
185-
assert ascend_config.torchair_graph_config.graph_batch_sizes == [
186-
1, 2, 4, 8
187-
]
188-
assert not ascend_config.torchair_graph_config.graph_batch_sizes_init
233+
assert not ascend_config.torchair_graph_config.enable_view_optimize

vllm_ascend/ascend_config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,31 @@ def __init__(self, torchair_graph_config):
7070
raise ValueError(
7171
"graph_batch_sizes_init is only valid when graph_batch_sizes is empty"
7272
)
73+
if not self.enabled:
74+
if self.use_cached_graph:
75+
raise RuntimeError(
76+
"use_cached_graph is valid only when Torchair graph mode is enabled"
77+
)
78+
if self.graph_batch_sizes:
79+
raise RuntimeError(
80+
"graph_batch_sizes is valid only when Torchair graph mode is enabled"
81+
)
82+
if self.graph_batch_sizes_init:
83+
raise RuntimeError(
84+
"graph_batch_sizes_init is valid only when Torchair graph mode is enabled"
85+
)
86+
if self.enable_multistream_mla:
87+
raise RuntimeError(
88+
"enable_multistream_mla is valid only when Torchair graph mode is enabled"
89+
)
90+
if self.enable_multistream_moe:
91+
raise RuntimeError(
92+
"enable_multistream_moe is valid only when Torchair graph mode is enabled"
93+
)
94+
if self.enable_kv_nz:
95+
raise RuntimeError(
96+
"enable_kv_nz is valid only when Torchair graph mode is enabled"
97+
)
7398

7499

75100
class AscendSchedulerConfig:

vllm_ascend/models/deepseek_v2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ def __init__(
236236
ascend_config = get_ascend_config()
237237
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
238238
self.enable_multistream_moe = \
239-
ascend_config.torchair_graph_config.enable_multistream_moe
239+
ascend_config.torchair_graph_config.enable_multistream_moe and \
240+
self.torchair_graph_enabled
240241

241242
self.gate = ReplicatedLinear(config.hidden_size,
242243
config.n_routed_experts,
@@ -462,7 +463,8 @@ def __init__(
462463
ascend_config = get_ascend_config()
463464
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
464465
self.enable_multistream_mla = \
465-
ascend_config.torchair_graph_config.enable_multistream_mla
466+
ascend_config.torchair_graph_config.enable_multistream_mla and \
467+
self.torchair_graph_enabled
466468

467469
def forward(
468470
self,

vllm_ascend/ops/fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,8 @@ def __init__(
10901090

10911091
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
10921092
self.enable_multistream_moe = \
1093-
ascend_config.torchair_graph_config.enable_multistream_moe
1093+
ascend_config.torchair_graph_config.enable_multistream_moe and \
1094+
self.torchair_graph_enabled
10941095

10951096
if self.scoring_func != "softmax" and not self.use_grouped_topk:
10961097
raise ValueError("Only softmax scoring function is supported for "

0 commit comments

Comments
 (0)