Skip to content

Commit d785e78

Browse files
authored
[V1] Make V1 engine backward compatible (#637)
### What this PR does / why we need it? Enforce eager mode in the V1 engine ahead of the upcoming CANN and torch_npu releases. ### Does this PR introduce _any_ user-facing change? After this change, users will no longer need to manually set enforce_eager=True. ### How was this patch tested? Test it with regular offline inference examples. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent bd70ce8 commit d785e78

File tree

4 files changed

+44
-47
lines changed

4 files changed

+44
-47
lines changed

tests/multicard/test_offline_inference_distributed.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def test_models_distributed(model: str,
4747
dtype=dtype,
4848
tensor_parallel_size=4,
4949
distributed_executor_backend=distributed_executor_backend,
50-
enforce_eager=True,
5150
) as vllm_model:
5251
vllm_model.generate_greedy(example_prompts, max_tokens)
5352

tests/ops/test_fused_moe.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
import pytest
2424
import torch
25-
from vllm.config import VllmConfig, set_current_vllm_config
2625
from vllm.model_executor.layers.activation import SiluAndMul
2726

2827
from vllm_ascend.ops.fused_moe import fused_experts
@@ -68,36 +67,31 @@ def test_fused_experts(
6867
dtype: torch.dtype,
6968
device: str,
7069
):
71-
vllm_config = VllmConfig()
72-
with set_current_vllm_config(vllm_config):
73-
a = torch.randn((m, k), device=device, dtype=dtype) / 10
74-
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
75-
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
70+
a = torch.randn((m, k), device=device, dtype=dtype) / 10
71+
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
72+
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
7673

77-
score = torch.randn((m, e), device=device, dtype=dtype)
74+
score = torch.randn((m, e), device=device, dtype=dtype)
7875

79-
if ep_size > 1:
80-
local_e = e // ep_size
81-
e_ids = torch.randint(0,
82-
e, (local_e, ),
83-
device=device,
84-
dtype=torch.int32)
85-
e_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
86-
e_map[e_ids] = torch.arange(local_e,
87-
device=device,
88-
dtype=torch.int32)
89-
w1 = w1[e_ids]
90-
w2 = w2[e_ids]
91-
else:
92-
e_map = None
76+
if ep_size > 1:
77+
local_e = e // ep_size
78+
e_ids = torch.randint(0,
79+
e, (local_e, ),
80+
device=device,
81+
dtype=torch.int32)
82+
e_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
83+
e_map[e_ids] = torch.arange(local_e, device=device, dtype=torch.int32)
84+
w1 = w1[e_ids]
85+
w2 = w2[e_ids]
86+
else:
87+
e_map = None
9388

94-
score = torch.softmax(score, dim=-1, dtype=dtype)
95-
topk_weights, topk_ids = torch.topk(score, topk)
96-
topk_ids = topk_ids.to(torch.int32)
89+
score = torch.softmax(score, dim=-1, dtype=dtype)
90+
topk_weights, topk_ids = torch.topk(score, topk)
91+
topk_ids = topk_ids.to(torch.int32)
9792

98-
output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map)
99-
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
100-
e_map)
101-
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
102-
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
93+
output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map)
94+
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map)
95+
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
96+
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
10397
torch.npu.empty_cache()

tests/singlecard/test_offline_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_models(model: str, dtype: str, max_tokens: int) -> None:
5252
with VllmRunner(model,
5353
max_model_len=8192,
5454
dtype=dtype,
55-
enforce_eager=True,
55+
enforce_eager=False,
5656
gpu_memory_utilization=0.7) as vllm_model:
5757
vllm_model.generate_greedy(example_prompts, max_tokens)
5858

vllm_ascend/platform.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -115,29 +115,33 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
115115
from vllm.config import CompilationLevel # noqa: E402
116116
compilation_config = vllm_config.compilation_config
117117

118-
enforce_eager_flag = False
119-
# Check whether the eager mode is configured
120-
try:
121-
enforce_eager_flag = vllm_config.model_config.enforce_eager
122-
except Exception:
123-
logger.warning(
124-
"There is currently no enforce_eager mode configured, the default value of enforce_eager=False is used"
125-
)
126-
127-
if enforce_eager_flag or compilation_config.level == CompilationLevel.NO_COMPILATION:
128-
logger.warning(
129-
"Compilation level PIECEWISE is not enable on NPU now, current compilation level to NO_COMPILATION"
130-
)
118+
if vllm_config.model_config is None:
119+
logger.warning("Model config is missing. This may indicate "
120+
"that we are running a test case")
121+
enforce_eager = False
122+
else:
123+
enforce_eager = getattr(vllm_config.model_config, "enforce_eager",
124+
False)
125+
126+
# TODO(Yizhou): Override the value of enforce_eager to True before
127+
# the CANN and torch_npu support NPU compilation.
128+
enforce_eager = True
129+
logger.warning(
130+
"NPU compilation support pending. Will be available in future CANN and "
131+
"torch_npu releases. Using default: enforce_eager=True")
132+
133+
if enforce_eager or compilation_config.level == CompilationLevel.NO_COMPILATION:
134+
logger.info("Compilation disabled, using eager mode by default")
131135
compilation_config.level = CompilationLevel.NO_COMPILATION
132136
elif compilation_config.level != CompilationLevel.PIECEWISE:
133137
logger.warning(
134-
"Compilation level %s is not enable on NPU now, forcing compilation level to NO_COMPILATION",
138+
"NPU does not support %s compilation level. Setting level to NO_COMPILATION",
135139
compilation_config.level)
136140
compilation_config.level = CompilationLevel.NO_COMPILATION
137141
else:
138142
logger.info(
139-
"Compilation level PIECEWISE is enable on NPU now, But use_inductor is no support, only use npu_graph now"
140-
)
143+
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
144+
"using only ACL Graph mode")
141145
compilation_config.use_inductor = False
142146
compilation_config.splitting_ops.extend(
143147
["vllm.unified_ascend_attention_with_output"])

0 commit comments

Comments
 (0)