Skip to content

Support Data Parallel MOE on HPU #1022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
66c9d53
Support DP Attention + TP/EP MOE for v0
xinyu-intel Mar 19, 2025
26156f0
Enable DP in HPU worker/runner
xinyu-intel Mar 19, 2025
6051b0a
DP Awared Padding
xinyu-intel Apr 8, 2025
c3ef5aa
Fix native_multicast for 3d input
xinyu-intel Apr 8, 2025
915c389
Optimize DP communication with allgather
xinyu-intel Apr 10, 2025
23b5942
Merge branch 'habana_main' into dev/xinyu/dpmoe-pr
madamczyk-intel May 7, 2025
452187c
Merge remote-tracking branch 'ssh/habana_main' into dev/xinyu/dpmoe-pr
xinyu-intel May 8, 2025
fbf39c3
Merge branch 'habana_main' into dev/xinyu/dpmoe-pr
xinyu-intel May 8, 2025
39eb1eb
Merge remote-tracking branch 'ssh/habana_main' into HEAD
xinyu-intel May 22, 2025
cbb0cb6
Merge remote-tracking branch 'ssh/habana_main' into dev/xinyu/dpmoe-pr
xinyu-intel Jun 4, 2025
0ef3aba
Merge remote-tracking branch 'ssh/habana_main' into dev/xinyu/dpmoe-pr
xinyu-intel Jun 5, 2025
ccd1832
Merge branch 'habana_main' into dev/xinyu/dpmoe-pr
michalkuligowski Jun 5, 2025
2fd2726
Merge remote-tracking branch 'ssh/habana_main' into dev/xinyu/dpmoe-pr
xinyu-intel Jun 12, 2025
a762460
Add some docstrings
xinyu-intel Jun 12, 2025
9514c6c
Merge remote-tracking branch 'ssh/master_next' into dev/xinyu/dpmoe-pr
xinyu-intel Jun 16, 2025
99fcdf7
Merge remote-tracking branch 'ssh/habana_main' into dev/xinyu/dpmoe-pr
xinyu-intel Jun 17, 2025
98cd804
ignore mypy for _dummy_run
xinyu-intel Jun 17, 2025
b474177
Fix granitemoe
xinyu-intel Jun 17, 2025
f182288
Fix dummy_run and barrier
xinyu-intel Jun 17, 2025
e03d775
Merge remote-tracking branch 'ssh/habana_main' into dev/xinyu/dpmoe-pr
xinyu-intel Jun 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,14 @@ def __init__(
self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
)

# Data Parallel Group
self.need_to_sync_across_dp = self.parallel_config.data_parallel_size > 1 # noqa
if self.need_to_sync_across_dp:
self.dp_group = self.parallel_config.stateless_init_dp_group()
# Data Parallel Ranks should execute the dummy batch if no real batch
# is scheduled.
self.should_execute_dummy_batch = False

logger.info(
"Initializing a V0 LLM engine (v%s) with config: %s, "
"use_cached_outputs=%s, ",
Expand Down Expand Up @@ -897,17 +905,31 @@ def get_num_unfinished_requests(self) -> int:
return sum(scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler)

def has_unfinished_requests(self) -> bool:
def has_unfinished_requests(self,
virtual_engine: Optional[int] = None) -> bool:
"""Returns True if there are unfinished requests."""
return any(scheduler.has_unfinished_seqs()
for scheduler in self.scheduler)
if virtual_engine is not None:
schedulers = [self.scheduler[virtual_engine]]
else:
schedulers = self.scheduler
has_unfinished = any(scheduler.has_unfinished_seqs()
for scheduler in schedulers)
if not self.need_to_sync_across_dp:
return has_unfinished
aggregated_has_unfinished = ParallelConfig.\
has_unfinished_dp(self.dp_group, has_unfinished)
if not has_unfinished and aggregated_has_unfinished:
# current rank has no unfinished seqs, but other ranks do,
# so we should execute a dummy batch to sync across ranks
self.should_execute_dummy_batch = True
return aggregated_has_unfinished

def has_unfinished_requests_for_virtual_engine(
self, virtual_engine: int) -> bool:
"""
Returns True if there are unfinished requests for the virtual engine.
"""
return self.scheduler[virtual_engine].has_unfinished_seqs()
return self.has_unfinished_requests(virtual_engine)

def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices."""
Expand Down Expand Up @@ -1308,6 +1330,22 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"Pipeline parallelism is only supported through AsyncLLMEngine "
"as performance will be severely degraded otherwise.")

if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False
outputs = self.model_executor.execute_model(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=[], is_dummy_batch=True))
if not self.has_unfinished_requests():
# Stop the execute model loop in parallel workers until there
# are more requests to process. This avoids waiting indefinitely
# in torch.distributed ops which may otherwise timeout, and
# unblocks the RPC thread in the workers so that they can
# process any other queued control plane messages, such as
# add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
return []

# For llm_engine, there is no pipeline parallel support, so the engine
# used is always 0.
virtual_engine = 0
Expand Down
3 changes: 1 addition & 2 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def __init__(self, *args, **kwargs):
def execute_model(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
) -> Optional[List[SamplerOutput]]:
# TODO: unify into collective_rpc
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
Expand All @@ -297,7 +297,6 @@ def execute_model(

# Only the driver worker returns the sampling results.
driver_outputs = self._driver_execute_model(execute_model_req)
assert driver_outputs is not None
return driver_outputs

def stop_remote_worker_execution_loop(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions vllm/executor/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,8 @@ def _driver_execute_model(
execute_model_req)

def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[SamplerOutput]]:
if not self.use_ray_spmd_worker:
return super().execute_model(execute_model_req)

Expand Down
66 changes: 57 additions & 9 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
is_v1_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.logger import init_logger
from vllm.platforms import current_platform

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
Expand All @@ -32,6 +33,8 @@
@dataclass
class DPMetadata:
cu_tokens_across_dp_cpu: torch.Tensor
hidden_states_across_dp: Optional[torch.Tensor] = None
router_logits_across_dp: Optional[torch.Tensor] = None


@dataclass
Expand Down Expand Up @@ -61,7 +64,8 @@ def get_forward_context() -> ForwardContext:
def set_forward_context(attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: int = 0):
num_tokens: int = 0,
dp_awared_padding: bool = False):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
Expand All @@ -79,18 +83,60 @@ def set_forward_context(attn_metadata: Any,
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
elif attn_metadata is not None and hasattr(attn_metadata,
"slot_mapping"):
batchsize = attn_metadata.slot_mapping.numel()
else:
# for v1 attention backends or no attn_metadata
batchsize = num_tokens
num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = batchsize
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
device="cpu",
dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
if dp_awared_padding:
num_tokens_across_dp = [batchsize] * dp_size
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
device="cpu",
dtype=torch.int32)
else:
num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = batchsize
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
device="cpu",
dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
dp_metadata = DPMetadata(cu_tokens_across_dp_cpu)

if current_platform.is_hpu():
num_expert_names = [
"moe_num_experts", # Dbrx
"num_experts", # Jamba
"n_routed_experts", # DeepSeek
"num_local_experts", # Mixtral
]
num_experts = 0
for name in num_expert_names:
num_experts = getattr(vllm_config.model_config.hf_text_config,
name, 0)
if num_experts > 0:
break
assert num_experts > 0, \
"No expert found in the model config.\
Please check the model config."

request_batch_size = attn_metadata.slot_mapping.size(0)
padded_seq_length = attn_metadata.slot_mapping.size(1)
hidden_size = vllm_config.model_config.get_hidden_size()
device = attn_metadata.slot_mapping.device
dtype = vllm_config.model_config.dtype
hidden_states_across_dp = torch.empty(
(request_batch_size * dp_size, padded_seq_length, hidden_size),
device=device,
dtype=dtype)
router_logits_across_dp = torch.empty(
(batchsize * dp_size, num_experts), device=device, dtype=dtype)
dp_metadata = DPMetadata(cu_tokens_across_dp_cpu,
hidden_states_across_dp,
router_logits_across_dp)
else:
dp_metadata = DPMetadata(cu_tokens_across_dp_cpu)

global _forward_context
prev_context = _forward_context
Expand Down Expand Up @@ -120,6 +166,8 @@ def set_forward_context(attn_metadata: Any,
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
elif hasattr(attn_metadata, "slot_mapping"):
batchsize = attn_metadata.slot_mapping.numel()
else:
# for v1 attention backends
batchsize = num_tokens
Expand Down
52 changes: 44 additions & 8 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,8 @@ def __init__(
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.activation = activation
self.multicast_fn = self.hpu_multicast if is_hpu\
else self.naive_multicast

if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
Expand Down Expand Up @@ -879,9 +881,34 @@ def select_experts(hidden_states: torch.Tensor,

return topk_weights, topk_ids

def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
assert (len(x.shape) == 2)
def hpu_multicast(self,
x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor,
output_tensor: Optional[torch.Tensor] = None):
if output_tensor is None:
world_size = get_dp_group().world_size
input_size = x.size()
# Allocate output tensor.
output_size = list(input_size)
output_size[0] *= world_size
output_tensor = torch.empty(output_size,
dtype=x.dtype,
device=x.device)
else:
if output_tensor.ndim == 3 and x.ndim == 2:
output_tensor.view(-1, x.size(1))
# All-gather.
torch.distributed.all_gather_into_tensor(
output_tensor, x, group=get_dp_group().device_group)
return output_tensor

def naive_multicast(self,
x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor,
output_tensor: Optional[torch.Tensor] = None):
assert (len(x.shape) in [2, 3])
if len(x.shape) == 3:
x = x.view(-1, x.size(2))
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)
Expand Down Expand Up @@ -912,11 +939,17 @@ def forward_impl(self, hidden_states: torch.Tensor,
if self.dp_size > 1:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu

hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_dp_cpu)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_dp_cpu)
hidden_states_across_dp = get_forward_context(
).dp_metadata.hidden_states_across_dp
router_logits_across_dp = get_forward_context(
).dp_metadata.router_logits_across_dp

hidden_states = self.multicast_fn(hidden_states,
cu_tokens_across_dp_cpu,
hidden_states_across_dp)
router_logits = self.multicast_fn(router_logits,
cu_tokens_across_dp_cpu,
router_logits_across_dp)

# Matrix multiply.
final_hidden_states = self.quant_method.apply(
Expand All @@ -938,6 +971,9 @@ def forward_impl(self, hidden_states: torch.Tensor,
)

if self.dp_size > 1:
if final_hidden_states.ndim == 3:
final_hidden_states = final_hidden_states.view(
-1, final_hidden_states.size(2))
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
Expand Down
9 changes: 7 additions & 2 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,16 @@ def forward_hpu(
return self.forward_native(x, residual)
if residual is not None:
orig_shape = x.shape
residual = residual + x.view(residual.shape)
if orig_shape != residual.shape:
residual = residual + x.view(residual.shape)
else:
residual = residual + x
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
x = HPUFusedRMSNorm.apply(residual, self.weight,
self.variance_epsilon)
return x.view(orig_shape), residual
if x.shape != orig_shape:
x = x.view(orig_shape)
return x, residual

x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
return x
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.split_qkv = cache_config.split_qkv

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

Expand Down
2 changes: 2 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,6 +1404,8 @@ class ExecuteModelRequest(
last_sampled_token_ids: Optional[torch.Tensor] = None
# Async callback
async_callback: Optional[Callable] = None
# Dummy batch
is_dummy_batch: bool = False

@property
def is_first_multi_step(self) -> bool:
Expand Down
4 changes: 3 additions & 1 deletion vllm/worker/hpu_enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ def warmup_scenario( # type: ignore[override]
kv_caches,
is_pt_profiler_run=False,
temperature=0,
num_iters=3,
align_worker=False,
) -> None:
phase = 'prompt' if is_prompt else 'decode'
use_graphs = self._use_graphs()
Expand All @@ -269,7 +271,7 @@ def warmup_scenario( # type: ignore[override]
f"ctx{ctx}_"
f"graphs{'T' if use_graphs else 'F'}")
self.profiler.start('internal', scenario_name)
times = 3 if use_graphs or is_pt_profiler_run else 1
times = num_iters if use_graphs or is_pt_profiler_run else 1
if is_prompt:
seqs = [
self.create_dummy_seq_group_metadata(i,
Expand Down
Loading