From b32ab4c1727712aa87898b1e07612d6983f5a97a Mon Sep 17 00:00:00 2001 From: Xinyu Chen Date: Mon, 23 Jun 2025 11:35:16 +0800 Subject: [PATCH 1/3] DP: Optimize combine with ReduceScatter --- vllm/forward_context.py | 5 ++++- vllm/model_executor/layers/fused_moe/layer.py | 14 +++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 20494e86c62..9fda0ebfcbd 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -35,6 +35,7 @@ class DPMetadata: cu_tokens_across_dp_cpu: torch.Tensor hidden_states_across_dp: Optional[torch.Tensor] = None router_logits_across_dp: Optional[torch.Tensor] = None + hidden_states: Optional[torch.Tensor] = None @dataclass @@ -132,9 +133,11 @@ def set_forward_context(attn_metadata: Any, dtype=dtype) router_logits_across_dp = torch.empty( (batchsize * dp_size, num_experts), device=device, dtype=dtype) + hidden_states = torch.empty((batchsize, hidden_size),\ + device=device, dtype=dtype) dp_metadata = DPMetadata(cu_tokens_across_dp_cpu, hidden_states_across_dp, - router_logits_across_dp) + router_logits_across_dp, hidden_states) else: dp_metadata = DPMetadata(cu_tokens_across_dp_cpu) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a9dc52c7cca..6b0f0854230 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -974,12 +974,16 @@ def forward_impl(self, hidden_states: torch.Tensor, if final_hidden_states.ndim == 3: final_hidden_states = final_hidden_states.view( -1, final_hidden_states.size(2)) - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] - all_hidden_states = get_dp_group().all_reduce(final_hidden_states) - final_hidden_states = all_hidden_states[start:end, :] + import habana_frameworks.torch as htorch + htorch.core.mark_step() + local_hidden_states = get_forward_context( + ).dp_metadata.hidden_states + torch.distributed.reduce_scatter_tensor( + local_hidden_states, + final_hidden_states, + group=get_dp_group().device_group) + final_hidden_states = local_hidden_states if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) From 60d1df5145933a807a0212e164eda2b50d800a10 Mon Sep 17 00:00:00 2001 From: Xinyu Chen Date: Mon, 23 Jun 2025 11:53:23 +0800 Subject: [PATCH 2/3] DP: AllGather topk weights and ids after select_experts on each rank --- vllm/forward_context.py | 30 ++++++++----------- .../layers/fused_moe/fused_moe.py | 11 ++++--- vllm/model_executor/layers/fused_moe/layer.py | 26 +++++++++++----- .../model_executor/layers/quantization/fp8.py | 22 ++++++++++++-- 4 files changed, 57 insertions(+), 32 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 9fda0ebfcbd..d46fb634ccb 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -34,7 +34,8 @@ class DPMetadata: cu_tokens_across_dp_cpu: torch.Tensor hidden_states_across_dp: Optional[torch.Tensor] = None - router_logits_across_dp: Optional[torch.Tensor] = None + topk_ids_across_dp: Optional[torch.Tensor] = None + topk_weights_across_dp: Optional[torch.Tensor] = None hidden_states: Optional[torch.Tensor] = None @@ -106,19 +107,11 @@ def set_forward_context(attn_metadata: Any, cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) if current_platform.is_hpu(): - num_expert_names = [ - "moe_num_experts", # Dbrx - "num_experts", # Jamba - "n_routed_experts", # DeepSeek - "num_local_experts", # Mixtral - ] - num_experts = 0 - for name in num_expert_names: - num_experts = getattr(vllm_config.model_config.hf_text_config, - name, 0) - if num_experts > 0: - break - assert num_experts > 0, \ + num_experts_per_tok = 0 + num_experts_per_tok = getattr( + vllm_config.model_config.hf_text_config, "num_experts_per_tok", + 0) + assert num_experts_per_tok > 0, \ "No expert found in the model config.\ Please check the model config." @@ -131,13 +124,16 @@ def set_forward_context(attn_metadata: Any, (request_batch_size * dp_size, padded_seq_length, hidden_size), device=device, dtype=dtype) - router_logits_across_dp = torch.empty( - (batchsize * dp_size, num_experts), device=device, dtype=dtype) + topk_ids_across_dp = torch.empty((batchsize * dp_size,\ + num_experts_per_tok), device=device, dtype=torch.int64) + topk_weights_across_dp = torch.empty((batchsize * dp_size,\ + num_experts_per_tok), device=device, dtype=dtype) hidden_states = torch.empty((batchsize, hidden_size),\ device=device, dtype=dtype) dp_metadata = DPMetadata(cu_tokens_across_dp_cpu, hidden_states_across_dp, - router_logits_across_dp, hidden_states) + topk_ids_across_dp, + topk_weights_across_dp, hidden_states) else: dp_metadata = DPMetadata(cu_tokens_across_dp_cpu) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 8e34a35ebef..e6bf6a79bf0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -9,6 +9,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm.distributed import get_dp_group from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm, deep_gemm_moe_fp8) @@ -858,8 +859,9 @@ def fused_topk( topk: int, renormalize: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") + if not (get_dp_group().world_size > 1 and current_platform.is_hpu()): + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") M, _ = hidden_states.shape @@ -899,8 +901,9 @@ def grouped_topk( e_score_correction_bias: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") + if not (get_dp_group().world_size > 1 and current_platform.is_hpu()): + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") gating_output = gating_output.float() if e_score_correction_bias is not None: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6b0f0854230..94eb937f839 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -295,13 +295,28 @@ def forward_hpu( topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32) topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) topk_weights /= topk_weights.sum(dim=-1, keepdim=True) - topk_weights = topk_weights.to(x.dtype) + topk_ids = topk_ids.to(torch.int64) + topk_weights = topk_weights.to(x.dtype) + if layer.dp_size > 1: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + + topk_ids_across_dp = get_forward_context( + ).dp_metadata.topk_ids_across_dp + topk_ids = layer.multicast_fn(topk_ids, cu_tokens_across_dp_cpu, + topk_ids_across_dp) + + topk_weights_across_dp = get_forward_context( + ).dp_metadata.topk_weights_across_dp + topk_weights = layer.multicast_fn(topk_weights, + cu_tokens_across_dp_cpu, + topk_weights_across_dp) topk_ids = topk_ids.view(*x.shape[:-1], -1) topk_weights = topk_weights.view(*x.shape[:-1], -1) return layer.moe_op( x, - topk_ids.to(torch.int64), - topk_weights.to(x.dtype), + topk_ids, + topk_weights, permuted_weights=True, activation=activation, ).view(*input_shape) @@ -941,15 +956,10 @@ def forward_impl(self, hidden_states: torch.Tensor, ).dp_metadata.cu_tokens_across_dp_cpu hidden_states_across_dp = get_forward_context( ).dp_metadata.hidden_states_across_dp - router_logits_across_dp = get_forward_context( - ).dp_metadata.router_logits_across_dp hidden_states = self.multicast_fn(hidden_states, cu_tokens_across_dp_cpu, hidden_states_across_dp) - router_logits = self.multicast_fn(router_logits, - cu_tokens_across_dp_cpu, - router_logits_across_dp) # Matrix multiply. final_hidden_states = self.quant_method.apply( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 3e6e1c8b783..bf5ba39a0f1 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -11,6 +11,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) @@ -951,13 +952,28 @@ def forward_hpu( topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32) topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) topk_weights /= topk_weights.sum(dim=-1, keepdim=True) - topk_weights = topk_weights.to(x.dtype) + topk_ids = topk_ids.to(torch.int64) + topk_weights = topk_weights.to(x.dtype) + if layer.dp_size > 1: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + + topk_ids_across_dp = get_forward_context( + ).dp_metadata.topk_ids_across_dp + topk_ids = layer.multicast_fn(topk_ids, cu_tokens_across_dp_cpu, + topk_ids_across_dp) + + topk_weights_across_dp = get_forward_context( + ).dp_metadata.topk_weights_across_dp + topk_weights = layer.multicast_fn(topk_weights, + cu_tokens_across_dp_cpu, + topk_weights_across_dp) topk_ids = topk_ids.view(*x.shape[:-1], -1) topk_weights = topk_weights.view(*x.shape[:-1], -1) output = layer.moe_op( x, - topk_ids.to(torch.int64), - topk_weights.to(x.dtype), + topk_ids, + topk_weights, permuted_weights=True, activation=activation, ) From 2cc162dfdb54c25f6f445bd10684daa04fc24c47 Mon Sep 17 00:00:00 2001 From: Xinyu Chen Date: Mon, 23 Jun 2025 12:55:56 +0800 Subject: [PATCH 3/3] DP: Optimize dummy run --- vllm/worker/hpu_model_runner.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index b3b6bb4749d..c593f5af640 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -2481,7 +2481,8 @@ def _dummy_run(self, max_num_batched_tokens: int) -> None: num_patches=UNSET_NUM_PATCHES, is_lora_profile_run=True, num_iters=1, - align_worker=True) + align_worker=True, + is_dummy_run=True) return def _remove_duplicate_submodules(self): @@ -2509,9 +2510,10 @@ def warmup_scenario(self, temperature=0, num_patches=None, num_iters=3, - align_worker=False) -> None: + align_worker=False, + is_dummy_run=False) -> None: phase = 'prompt' if is_prompt else 'decode' - use_graphs = self._use_graphs(num_patches) + use_graphs = is_dummy_run or self._use_graphs(num_patches) scenario_name = ("warmup_" f"{phase}_" f"bs{batch_size}_" @@ -2569,7 +2571,8 @@ def warmup_scenario(self, temperature=temperature, ctx=ctx) for i, b in enumerate(blocks) ] - torch.hpu.synchronize() + if not is_dummy_run: + torch.hpu.synchronize() profiler = None if is_pt_profiler_run and self.is_driver_worker: profiler = setup_profiler() @@ -2597,7 +2600,8 @@ def warmup_scenario(self, kv_caches, intermediate_tensors=intermediate_tensors, warmup_mode=True, - ctx_blocks=ctx) + ctx_blocks=ctx, + is_dummy_run=is_dummy_run) else: # decode with multi-step inputs = dataclasses.replace(inputs, is_first_multi_step=True, @@ -2617,13 +2621,15 @@ def warmup_scenario(self, num_steps=2, seqs=seqs, ctx_blocks=ctx) - torch.hpu.synchronize() + if not is_dummy_run: + torch.hpu.synchronize() if profiler: profiler.step() if profiler: profiler.stop() self.profiler.end() - gc.collect() + if not is_dummy_run: + gc.collect() def remove_all_loras(self): if not self.lora_manager: @@ -3235,7 +3241,8 @@ def execute_model( warmup_mode=False, previous_hidden_states: Optional[torch.Tensor] = None, seqs=None, - ctx_blocks: int = 1 + ctx_blocks: int = 1, + is_dummy_run: bool = False, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: use_delayed_sampling = self.use_delayed_sampling and not warmup_mode assert not (use_delayed_sampling and num_steps != 1), \ @@ -3489,7 +3496,7 @@ def try_revert_dummy_output_tokens(): **execute_model_kwargs, selected_token_indices=sampling_metadata. selected_token_indices) - if warmup_mode: + if warmup_mode and not is_dummy_run: torch.hpu.synchronize() import torch.distributed as dist if dist.is_initialized(): @@ -3515,6 +3522,11 @@ def try_revert_dummy_output_tokens(): LoraMask.setLoraMask( lora_logits_mask.index_select( 0, sampling_metadata.selected_token_indices)) + + if is_dummy_run: + fake_output = self._delayed_sampler_outputs(model_input) + return [fake_output] + if not get_pp_group().is_last_rank: return hidden_states