Skip to content

Commit 5a1689f

Browse files
authored
[Fix] Fix update_aclgraph_sizes when running MoE models (#913)
### What this PR does / why we need it? Fix update_aclgraph_sizes when running MoE models. --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 3442fbd commit 5a1689f

File tree

5 files changed

+47
-35
lines changed

5 files changed

+47
-35
lines changed

vllm_ascend/distributed/parallel_state.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,17 @@ def get_etp_group() -> GroupCoordinator:
2222

2323

2424
def init_ascend_model_parallel(
25-
tensor_model_parallel_size: int = 1,
26-
pipeline_model_parallel_size: int = 1,
25+
expert_parallel_size: int = 1,
2726
expert_tensor_parallel_size: int = 1,
27+
world_size: Optional[int] = None,
2828
backend: Optional[str] = None,
2929
):
3030
assert torch.distributed.is_initialized()
31-
world_size: int = torch.distributed.get_world_size()
31+
world_size = world_size or torch.distributed.get_world_size()
3232
backend = backend or torch.distributed.get_backend(
3333
get_world_group().device_group)
34-
num_expert_parallel_groups: int = expert_tensor_parallel_size
35-
num_expert_tensor_parallel_groups: int = (world_size //
36-
expert_tensor_parallel_size)
34+
num_expert_parallel_groups = expert_tensor_parallel_size
35+
num_expert_tensor_parallel_groups = expert_parallel_size
3736

3837
global _EP
3938
group_ranks = []

vllm_ascend/platform.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,26 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
119119
from vllm.config import CompilationLevel # noqa: E402
120120
compilation_config = vllm_config.compilation_config
121121
model_config = vllm_config.model_config
122+
additional_config = vllm_config.additional_config
123+
parallel_config = vllm_config.parallel_config
124+
cache_config = vllm_config.cache_config
125+
126+
if parallel_config:
127+
# Default value for expert tensor parallel size
128+
parallel_config.expert_tensor_parallel_size = parallel_config.tensor_parallel_size
129+
130+
# NOTE: When enable_expert_parallel is True, we follow vLLM convention:
131+
# ep_size = world_size, which means expert_tensor_parallel_size must be 1
132+
if (additional_config
133+
and "expert_tensor_parallel_size" in additional_config
134+
and not parallel_config.enable_expert_parallel):
135+
parallel_config.expert_tensor_parallel_size = int(
136+
additional_config["expert_tensor_parallel_size"])
137+
138+
# Calculate expert parallel size based on world size
139+
parallel_config.expert_parallel_size = (
140+
parallel_config.world_size //
141+
parallel_config.expert_tensor_parallel_size)
122142

123143
if model_config is None:
124144
logger.warning("Model config is missing. This may indicate "
@@ -127,9 +147,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
127147
else:
128148
enforce_eager = getattr(model_config, "enforce_eager", False)
129149

130-
if vllm_config.additional_config is not None:
131-
enable_graph_mode = vllm_config.additional_config.get(
132-
"enable_graph_mode", False)
150+
if additional_config is not None:
151+
enable_graph_mode = additional_config.get("enable_graph_mode",
152+
False)
133153
if enable_graph_mode:
134154
if enforce_eager:
135155
raise RuntimeError(
@@ -139,7 +159,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
139159
logger.warning(
140160
"NPU graph mode is still experimental and not supported for V1 without mla currently, "
141161
"it has been disabled automatically.")
142-
vllm_config.additional_config["enable_graph_mode"] = False
162+
additional_config["enable_graph_mode"] = False
143163
if model_config:
144164
model_type = model_config.hf_config.model_type
145165
if "deepseek" not in model_type:
@@ -178,7 +198,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
178198
["vllm.unified_ascend_attention_with_output"])
179199
update_aclgraph_sizes(vllm_config)
180200

181-
parallel_config = vllm_config.parallel_config
182201
if parallel_config and parallel_config.worker_cls == "auto":
183202
if envs.VLLM_USE_V1:
184203
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
@@ -190,7 +209,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
190209
else:
191210
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
192211

193-
cache_config = vllm_config.cache_config
194212
if cache_config:
195213
if cache_config.block_size is None:
196214
cache_config.block_size = 128
@@ -202,11 +220,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
202220

203221
if envs.VLLM_USE_V1:
204222
# Activate custom ops for v1.
205-
vllm_config.compilation_config.custom_ops = ["all"]
223+
compilation_config.custom_ops = ["all"]
206224
# If ascend_scheduler_config exists in additional_config,
207225
# extents original scheduler_config to use AscendScheduler.
208226

209-
additional_config = vllm_config.additional_config
210227
if additional_config and additional_config.get(
211228
"ascend_scheduler_config", None) is not None:
212229
additional_scheduler_config = additional_config.get(

vllm_ascend/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,16 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
126126
original_sizes, compilation_config.cudagraph_capture_sizes = \
127127
compilation_config.cudagraph_capture_sizes, None
128128

129-
# Calculate parallel configuration factor (increases with DP or TP)
130-
# TODO(Yizhou): This is a temporary solution, need to be improved
131-
# in the future, taking into account the other parallel configurations.
129+
# Calculate parallel configuration factor
132130
num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
133131
parallel_config = vllm_config.parallel_config
132+
133+
# TODO: Find out whether we need to take into account the pp_size
134134
parallel_factor = 1 + sum(size > 1 for size in [
135-
parallel_config.data_parallel_size,
136-
parallel_config.tensor_parallel_size
135+
parallel_config.data_parallel_size_local,
136+
parallel_config.tensor_parallel_size,
137+
parallel_config.expert_parallel_size,
138+
parallel_config.expert_tensor_parallel_size,
137139
])
138140

139141
# Calculate maximum supported batch sizes considering model architecture

vllm_ascend/worker/worker.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -534,21 +534,18 @@ def _init_worker_distributed_environment(
534534
backend: str = "hccl") -> None:
535535
"""Initialize the distributed environment."""
536536
parallel_config = self.parallel_config
537-
additional_config = self.vllm_config.additional_config
538537
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
539538
init_distributed_environment(parallel_config.world_size, rank,
540539
distributed_init_method, local_rank,
541540
backend)
542541
ensure_model_parallel_initialized(
543542
parallel_config.tensor_parallel_size,
544543
parallel_config.pipeline_parallel_size)
545-
expert_tensor_parallel_size = 1
546-
if additional_config:
547-
expert_tensor_parallel_size = additional_config.get(
548-
"expert_tensor_parallel_size", 1)
549-
init_ascend_model_parallel(parallel_config.tensor_parallel_size,
550-
parallel_config.pipeline_parallel_size,
551-
expert_tensor_parallel_size)
544+
init_ascend_model_parallel(
545+
parallel_config.expert_parallel_size,
546+
parallel_config.expert_tensor_parallel_size,
547+
parallel_config.world_size,
548+
)
552549
ensure_kv_transfer_initialized(vllm_config)
553550

554551

vllm_ascend/worker/worker_v1.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,6 @@ def execute_dummy_batch(self) -> None:
234234

235235
def _init_worker_distributed_environment(self) -> None:
236236
"""Initialize the distributed environment."""
237-
additional_config = self.vllm_config.additional_config
238237
parallel_config = self.vllm_config.parallel_config
239238
set_custom_all_reduce(
240239
not self.parallel_config.disable_custom_all_reduce)
@@ -244,13 +243,11 @@ def _init_worker_distributed_environment(self) -> None:
244243
ensure_model_parallel_initialized(
245244
self.parallel_config.tensor_parallel_size,
246245
self.parallel_config.pipeline_parallel_size)
247-
expert_tensor_parallel_size = 1
248-
if additional_config is not None and "expert_tensor_parallel_size" in additional_config:
249-
expert_tensor_parallel_size = int(
250-
additional_config["expert_tensor_parallel_size"])
251-
init_ascend_model_parallel(parallel_config.tensor_parallel_size,
252-
parallel_config.pipeline_parallel_size,
253-
expert_tensor_parallel_size)
246+
init_ascend_model_parallel(
247+
parallel_config.expert_parallel_size,
248+
parallel_config.expert_tensor_parallel_size,
249+
parallel_config.world_size,
250+
)
254251
ensure_kv_transfer_initialized(self.vllm_config)
255252

256253
def _init_profiler(self):

0 commit comments

Comments
 (0)