Skip to content

DP: Optimizations for Data Parallel Attention #1463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: habana_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
class DPMetadata:
cu_tokens_across_dp_cpu: torch.Tensor
hidden_states_across_dp: Optional[torch.Tensor] = None
router_logits_across_dp: Optional[torch.Tensor] = None
topk_ids_across_dp: Optional[torch.Tensor] = None
topk_weights_across_dp: Optional[torch.Tensor] = None
hidden_states: Optional[torch.Tensor] = None


@dataclass
Expand Down Expand Up @@ -105,19 +107,11 @@ def set_forward_context(attn_metadata: Any,
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)

if current_platform.is_hpu():
num_expert_names = [
"moe_num_experts", # Dbrx
"num_experts", # Jamba
"n_routed_experts", # DeepSeek
"num_local_experts", # Mixtral
]
num_experts = 0
for name in num_expert_names:
num_experts = getattr(vllm_config.model_config.hf_text_config,
name, 0)
if num_experts > 0:
break
assert num_experts > 0, \
num_experts_per_tok = 0
num_experts_per_tok = getattr(
vllm_config.model_config.hf_text_config, "num_experts_per_tok",
0)
assert num_experts_per_tok > 0, \
"No expert found in the model config.\
Please check the model config."

Expand All @@ -130,11 +124,16 @@ def set_forward_context(attn_metadata: Any,
(request_batch_size * dp_size, padded_seq_length, hidden_size),
device=device,
dtype=dtype)
router_logits_across_dp = torch.empty(
(batchsize * dp_size, num_experts), device=device, dtype=dtype)
topk_ids_across_dp = torch.empty((batchsize * dp_size,\
num_experts_per_tok), device=device, dtype=torch.int64)
topk_weights_across_dp = torch.empty((batchsize * dp_size,\
num_experts_per_tok), device=device, dtype=dtype)
hidden_states = torch.empty((batchsize, hidden_size),\
device=device, dtype=dtype)
dp_metadata = DPMetadata(cu_tokens_across_dp_cpu,
hidden_states_across_dp,
router_logits_across_dp)
topk_ids_across_dp,
topk_weights_across_dp, hidden_states)
else:
dp_metadata = DPMetadata(cu_tokens_across_dp_cpu)

Expand Down
11 changes: 7 additions & 4 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed import get_dp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm, deep_gemm_moe_fp8)
Expand Down Expand Up @@ -858,8 +859,9 @@ def fused_topk(
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
if not (get_dp_group().world_size > 1 and current_platform.is_hpu()):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")

M, _ = hidden_states.shape

Expand Down Expand Up @@ -899,8 +901,9 @@ def grouped_topk(
e_score_correction_bias: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:

assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
if not (get_dp_group().world_size > 1 and current_platform.is_hpu()):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")

gating_output = gating_output.float()
if e_score_correction_bias is not None:
Expand Down
40 changes: 27 additions & 13 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,28 @@ def forward_hpu(
topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(x.dtype)
topk_ids = topk_ids.to(torch.int64)
topk_weights = topk_weights.to(x.dtype)
if layer.dp_size > 1:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu

topk_ids_across_dp = get_forward_context(
).dp_metadata.topk_ids_across_dp
topk_ids = layer.multicast_fn(topk_ids, cu_tokens_across_dp_cpu,
topk_ids_across_dp)

topk_weights_across_dp = get_forward_context(
).dp_metadata.topk_weights_across_dp
topk_weights = layer.multicast_fn(topk_weights,
cu_tokens_across_dp_cpu,
topk_weights_across_dp)
topk_ids = topk_ids.view(*x.shape[:-1], -1)
topk_weights = topk_weights.view(*x.shape[:-1], -1)
return layer.moe_op(
x,
topk_ids.to(torch.int64),
topk_weights.to(x.dtype),
topk_ids,
topk_weights,
permuted_weights=True,
activation=activation,
).view(*input_shape)
Expand Down Expand Up @@ -941,15 +956,10 @@ def forward_impl(self, hidden_states: torch.Tensor,
).dp_metadata.cu_tokens_across_dp_cpu
hidden_states_across_dp = get_forward_context(
).dp_metadata.hidden_states_across_dp
router_logits_across_dp = get_forward_context(
).dp_metadata.router_logits_across_dp

hidden_states = self.multicast_fn(hidden_states,
cu_tokens_across_dp_cpu,
hidden_states_across_dp)
router_logits = self.multicast_fn(router_logits,
cu_tokens_across_dp_cpu,
router_logits_across_dp)

# Matrix multiply.
final_hidden_states = self.quant_method.apply(
Expand All @@ -974,12 +984,16 @@ def forward_impl(self, hidden_states: torch.Tensor,
if final_hidden_states.ndim == 3:
final_hidden_states = final_hidden_states.view(
-1, final_hidden_states.size(2))
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]

all_hidden_states = get_dp_group().all_reduce(final_hidden_states)
final_hidden_states = all_hidden_states[start:end, :]
import habana_frameworks.torch as htorch
htorch.core.mark_step()
local_hidden_states = get_forward_context(
).dp_metadata.hidden_states
torch.distributed.reduce_scatter_tensor(
local_hidden_states,
final_hidden_states,
group=get_dp_group().device_group)
final_hidden_states = local_hidden_states

if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.)
Expand Down
22 changes: 19 additions & 3 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
Expand Down Expand Up @@ -951,13 +952,28 @@ def forward_hpu(
topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(x.dtype)
topk_ids = topk_ids.to(torch.int64)
topk_weights = topk_weights.to(x.dtype)
if layer.dp_size > 1:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu

topk_ids_across_dp = get_forward_context(
).dp_metadata.topk_ids_across_dp
topk_ids = layer.multicast_fn(topk_ids, cu_tokens_across_dp_cpu,
topk_ids_across_dp)

topk_weights_across_dp = get_forward_context(
).dp_metadata.topk_weights_across_dp
topk_weights = layer.multicast_fn(topk_weights,
cu_tokens_across_dp_cpu,
topk_weights_across_dp)
topk_ids = topk_ids.view(*x.shape[:-1], -1)
topk_weights = topk_weights.view(*x.shape[:-1], -1)
output = layer.moe_op(
x,
topk_ids.to(torch.int64),
topk_weights.to(x.dtype),
topk_ids,
topk_weights,
permuted_weights=True,
activation=activation,
)
Expand Down
30 changes: 21 additions & 9 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2481,7 +2481,8 @@ def _dummy_run(self, max_num_batched_tokens: int) -> None:
num_patches=UNSET_NUM_PATCHES,
is_lora_profile_run=True,
num_iters=1,
align_worker=True)
align_worker=True,
is_dummy_run=True)
return

def _remove_duplicate_submodules(self):
Expand Down Expand Up @@ -2509,9 +2510,10 @@ def warmup_scenario(self,
temperature=0,
num_patches=None,
num_iters=3,
align_worker=False) -> None:
align_worker=False,
is_dummy_run=False) -> None:
phase = 'prompt' if is_prompt else 'decode'
use_graphs = self._use_graphs(num_patches)
use_graphs = is_dummy_run or self._use_graphs(num_patches)
scenario_name = ("warmup_"
f"{phase}_"
f"bs{batch_size}_"
Expand Down Expand Up @@ -2569,7 +2571,8 @@ def warmup_scenario(self,
temperature=temperature,
ctx=ctx) for i, b in enumerate(blocks)
]
torch.hpu.synchronize()
if not is_dummy_run:
torch.hpu.synchronize()
profiler = None
if is_pt_profiler_run and self.is_driver_worker:
profiler = setup_profiler()
Expand Down Expand Up @@ -2597,7 +2600,8 @@ def warmup_scenario(self,
kv_caches,
intermediate_tensors=intermediate_tensors,
warmup_mode=True,
ctx_blocks=ctx)
ctx_blocks=ctx,
is_dummy_run=is_dummy_run)
else: # decode with multi-step
inputs = dataclasses.replace(inputs,
is_first_multi_step=True,
Expand All @@ -2617,13 +2621,15 @@ def warmup_scenario(self,
num_steps=2,
seqs=seqs,
ctx_blocks=ctx)
torch.hpu.synchronize()
if not is_dummy_run:
torch.hpu.synchronize()
if profiler:
profiler.step()
if profiler:
profiler.stop()
self.profiler.end()
gc.collect()
if not is_dummy_run:
gc.collect()

def remove_all_loras(self):
if not self.lora_manager:
Expand Down Expand Up @@ -3235,7 +3241,8 @@ def execute_model(
warmup_mode=False,
previous_hidden_states: Optional[torch.Tensor] = None,
seqs=None,
ctx_blocks: int = 1
ctx_blocks: int = 1,
is_dummy_run: bool = False,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
use_delayed_sampling = self.use_delayed_sampling and not warmup_mode
assert not (use_delayed_sampling and num_steps != 1), \
Expand Down Expand Up @@ -3489,7 +3496,7 @@ def try_revert_dummy_output_tokens():
**execute_model_kwargs,
selected_token_indices=sampling_metadata.
selected_token_indices)
if warmup_mode:
if warmup_mode and not is_dummy_run:
torch.hpu.synchronize()
import torch.distributed as dist
if dist.is_initialized():
Expand All @@ -3515,6 +3522,11 @@ def try_revert_dummy_output_tokens():
LoraMask.setLoraMask(
lora_logits_mask.index_select(
0, sampling_metadata.selected_token_indices))

if is_dummy_run:
fake_output = self._delayed_sampler_outputs(model_input)
return [fake_output]

if not get_pp_group().is_last_rank:
return hidden_states

Expand Down