diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8020cd31704..21cb93cce7a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -237,6 +237,14 @@ def __init__( self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa ) + # Data Parallel Group + self.need_to_sync_across_dp = self.parallel_config.data_parallel_size > 1 # noqa + if self.need_to_sync_across_dp: + self.dp_group = self.parallel_config.stateless_init_dp_group() + # Data Parallel Ranks should execute the dummy batch if no real batch + # is scheduled. + self.should_execute_dummy_batch = False + logger.info( "Initializing a V0 LLM engine (v%s) with config: %s, " "use_cached_outputs=%s, ", @@ -897,17 +905,31 @@ def get_num_unfinished_requests(self) -> int: return sum(scheduler.get_num_unfinished_seq_groups() for scheduler in self.scheduler) - def has_unfinished_requests(self) -> bool: + def has_unfinished_requests(self, + virtual_engine: Optional[int] = None) -> bool: """Returns True if there are unfinished requests.""" - return any(scheduler.has_unfinished_seqs() - for scheduler in self.scheduler) + if virtual_engine is not None: + schedulers = [self.scheduler[virtual_engine]] + else: + schedulers = self.scheduler + has_unfinished = any(scheduler.has_unfinished_seqs() + for scheduler in schedulers) + if not self.need_to_sync_across_dp: + return has_unfinished + aggregated_has_unfinished = ParallelConfig.\ + has_unfinished_dp(self.dp_group, has_unfinished) + if not has_unfinished and aggregated_has_unfinished: + # current rank has no unfinished seqs, but other ranks do, + # so we should execute a dummy batch to sync across ranks + self.should_execute_dummy_batch = True + return aggregated_has_unfinished def has_unfinished_requests_for_virtual_engine( self, virtual_engine: int) -> bool: """ Returns True if there are unfinished requests for the virtual engine. """ - return self.scheduler[virtual_engine].has_unfinished_seqs() + return self.has_unfinished_requests(virtual_engine) def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: """Reset prefix cache for all devices.""" @@ -1308,6 +1330,22 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: "Pipeline parallelism is only supported through AsyncLLMEngine " "as performance will be severely degraded otherwise.") + if self.should_execute_dummy_batch: + self.should_execute_dummy_batch = False + outputs = self.model_executor.execute_model( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=[], is_dummy_batch=True)) + if not self.has_unfinished_requests(): + # Stop the execute model loop in parallel workers until there + # are more requests to process. This avoids waiting indefinitely + # in torch.distributed ops which may otherwise timeout, and + # unblocks the RPC thread in the workers so that they can + # process any other queued control plane messages, such as + # add/remove lora adapters. + logger.debug("Stopping remote worker execution loop.") + self.model_executor.stop_remote_worker_execution_loop() + return [] + # For llm_engine, there is no pipeline parallel support, so the engine # used is always 0. virtual_engine = 0 diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 522bd940211..7ee836fdef8 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -288,7 +288,7 @@ def __init__(self, *args, **kwargs): def execute_model( self, execute_model_req: ExecuteModelRequest, - ) -> List[SamplerOutput]: + ) -> Optional[List[SamplerOutput]]: # TODO: unify into collective_rpc if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( @@ -297,7 +297,6 @@ def execute_model( # Only the driver worker returns the sampling results. driver_outputs = self._driver_execute_model(execute_model_req) - assert driver_outputs is not None return driver_outputs def stop_remote_worker_execution_loop(self) -> None: diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 9231850fc0c..1b2c9609abe 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -465,8 +465,8 @@ def _driver_execute_model( execute_model_req) def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> Optional[List[SamplerOutput]]: if not self.use_ray_spmd_worker: return super().execute_model(execute_model_req) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c75d8f088c5..20494e86c62 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -16,6 +16,7 @@ is_v1_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.logger import init_logger +from vllm.platforms import current_platform if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -32,6 +33,8 @@ @dataclass class DPMetadata: cu_tokens_across_dp_cpu: torch.Tensor + hidden_states_across_dp: Optional[torch.Tensor] = None + router_logits_across_dp: Optional[torch.Tensor] = None @dataclass @@ -61,7 +64,8 @@ def get_forward_context() -> ForwardContext: def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig, virtual_engine: int = 0, - num_tokens: int = 0): + num_tokens: int = 0, + dp_awared_padding: bool = False): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -79,18 +83,60 @@ def set_forward_context(attn_metadata: Any, # for v0 attention backends batchsize = attn_metadata.num_prefill_tokens + \ attn_metadata.num_decode_tokens + elif attn_metadata is not None and hasattr(attn_metadata, + "slot_mapping"): + batchsize = attn_metadata.slot_mapping.numel() else: # for v1 attention backends or no attn_metadata batchsize = num_tokens - num_tokens_across_dp = [0] * dp_size - num_tokens_across_dp[dp_rank] = batchsize - num_tokens_tensor = torch.tensor(num_tokens_across_dp, - device="cpu", - dtype=torch.int32) - from vllm.distributed.parallel_state import get_dp_group - dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) + if dp_awared_padding: + num_tokens_across_dp = [batchsize] * dp_size + num_tokens_tensor = torch.tensor(num_tokens_across_dp, + device="cpu", + dtype=torch.int32) + else: + num_tokens_across_dp = [0] * dp_size + num_tokens_across_dp[dp_rank] = batchsize + num_tokens_tensor = torch.tensor(num_tokens_across_dp, + device="cpu", + dtype=torch.int32) + from vllm.distributed.parallel_state import get_dp_group + dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) - dp_metadata = DPMetadata(cu_tokens_across_dp_cpu) + + if current_platform.is_hpu(): + num_expert_names = [ + "moe_num_experts", # Dbrx + "num_experts", # Jamba + "n_routed_experts", # DeepSeek + "num_local_experts", # Mixtral + ] + num_experts = 0 + for name in num_expert_names: + num_experts = getattr(vllm_config.model_config.hf_text_config, + name, 0) + if num_experts > 0: + break + assert num_experts > 0, \ + "No expert found in the model config.\ + Please check the model config." + + request_batch_size = attn_metadata.slot_mapping.size(0) + padded_seq_length = attn_metadata.slot_mapping.size(1) + hidden_size = vllm_config.model_config.get_hidden_size() + device = attn_metadata.slot_mapping.device + dtype = vllm_config.model_config.dtype + hidden_states_across_dp = torch.empty( + (request_batch_size * dp_size, padded_seq_length, hidden_size), + device=device, + dtype=dtype) + router_logits_across_dp = torch.empty( + (batchsize * dp_size, num_experts), device=device, dtype=dtype) + dp_metadata = DPMetadata(cu_tokens_across_dp_cpu, + hidden_states_across_dp, + router_logits_across_dp) + else: + dp_metadata = DPMetadata(cu_tokens_across_dp_cpu) global _forward_context prev_context = _forward_context @@ -120,6 +166,8 @@ def set_forward_context(attn_metadata: Any, # for v0 attention backends batchsize = attn_metadata.num_prefill_tokens + \ attn_metadata.num_decode_tokens + elif hasattr(attn_metadata, "slot_mapping"): + batchsize = attn_metadata.slot_mapping.numel() else: # for v1 attention backends batchsize = num_tokens diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 94bb9327f55..a9dc52c7cca 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -506,6 +506,8 @@ def __init__( self.scoring_func = scoring_func self.e_score_correction_bias = e_score_correction_bias self.activation = activation + self.multicast_fn = self.hpu_multicast if is_hpu\ + else self.naive_multicast if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " @@ -879,9 +881,34 @@ def select_experts(hidden_states: torch.Tensor, return topk_weights, topk_ids - def naive_multicast(self, x: torch.Tensor, - cu_tokens_across_dp_cpu: torch.Tensor): - assert (len(x.shape) == 2) + def hpu_multicast(self, + x: torch.Tensor, + cu_tokens_across_dp_cpu: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None): + if output_tensor is None: + world_size = get_dp_group().world_size + input_size = x.size() + # Allocate output tensor. + output_size = list(input_size) + output_size[0] *= world_size + output_tensor = torch.empty(output_size, + dtype=x.dtype, + device=x.device) + else: + if output_tensor.ndim == 3 and x.ndim == 2: + output_tensor.view(-1, x.size(1)) + # All-gather. + torch.distributed.all_gather_into_tensor( + output_tensor, x, group=get_dp_group().device_group) + return output_tensor + + def naive_multicast(self, + x: torch.Tensor, + cu_tokens_across_dp_cpu: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None): + assert (len(x.shape) in [2, 3]) + if len(x.shape) == 3: + x = x.view(-1, x.size(2)) buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), device=x.device, dtype=x.dtype) @@ -912,11 +939,17 @@ def forward_impl(self, hidden_states: torch.Tensor, if self.dp_size > 1: cu_tokens_across_dp_cpu = get_forward_context( ).dp_metadata.cu_tokens_across_dp_cpu - - hidden_states = self.naive_multicast(hidden_states, - cu_tokens_across_dp_cpu) - router_logits = self.naive_multicast(router_logits, - cu_tokens_across_dp_cpu) + hidden_states_across_dp = get_forward_context( + ).dp_metadata.hidden_states_across_dp + router_logits_across_dp = get_forward_context( + ).dp_metadata.router_logits_across_dp + + hidden_states = self.multicast_fn(hidden_states, + cu_tokens_across_dp_cpu, + hidden_states_across_dp) + router_logits = self.multicast_fn(router_logits, + cu_tokens_across_dp_cpu, + router_logits_across_dp) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -938,6 +971,9 @@ def forward_impl(self, hidden_states: torch.Tensor, ) if self.dp_size > 1: + if final_hidden_states.ndim == 3: + final_hidden_states = final_hidden_states.view( + -1, final_hidden_states.size(2)) start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ self.dp_rank - 1] end = cu_tokens_across_dp_cpu[self.dp_rank] diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 6ac54b2e84e..7f95d6336a6 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -177,11 +177,16 @@ def forward_hpu( return self.forward_native(x, residual) if residual is not None: orig_shape = x.shape - residual = residual + x.view(residual.shape) + if orig_shape != residual.shape: + residual = residual + x.view(residual.shape) + else: + residual = residual + x # Note: HPUFusedRMSNorm requires 3D tensors as inputs x = HPUFusedRMSNorm.apply(residual, self.weight, self.variance_epsilon) - return x.view(orig_shape), residual + if x.shape != orig_shape: + x = x.view(orig_shape) + return x, residual x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon) return x diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 7fff14cb9f1..34d16601035 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -274,6 +274,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.split_qkv = cache_config.split_qkv + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) diff --git a/vllm/sequence.py b/vllm/sequence.py index 1259a79319d..dcab8a4422e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1404,6 +1404,8 @@ class ExecuteModelRequest( last_sampled_token_ids: Optional[torch.Tensor] = None # Async callback async_callback: Optional[Callable] = None + # Dummy batch + is_dummy_batch: bool = False @property def is_first_multi_step(self) -> bool: diff --git a/vllm/worker/hpu_enc_dec_model_runner.py b/vllm/worker/hpu_enc_dec_model_runner.py index 9c2a5034905..e68e76ae76e 100644 --- a/vllm/worker/hpu_enc_dec_model_runner.py +++ b/vllm/worker/hpu_enc_dec_model_runner.py @@ -259,6 +259,8 @@ def warmup_scenario( # type: ignore[override] kv_caches, is_pt_profiler_run=False, temperature=0, + num_iters=3, + align_worker=False, ) -> None: phase = 'prompt' if is_prompt else 'decode' use_graphs = self._use_graphs() @@ -269,7 +271,7 @@ def warmup_scenario( # type: ignore[override] f"ctx{ctx}_" f"graphs{'T' if use_graphs else 'F'}") self.profiler.start('internal', scenario_name) - times = 3 if use_graphs or is_pt_profiler_run else 1 + times = num_iters if use_graphs or is_pt_profiler_run else 1 if is_prompt: seqs = [ self.create_dummy_seq_group_metadata(i, diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index d3123df0ec3..b3b6bb4749d 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -36,7 +36,8 @@ from vllm.config import DeviceConfig, VllmConfig from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.distributed.kv_transfer import get_kv_transfer_group -from vllm.distributed.parallel_state import get_world_group +from vllm.distributed.parallel_state import (get_dp_group, get_tp_group, + get_world_group) from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger @@ -175,6 +176,23 @@ def custom_tuple_replace(obj: object, typename: str, **to_override): return cached_type(**values) # type: ignore +def align_dp_groups(value, op): + group = get_dp_group().cpu_group + value_t = torch.tensor(value, device="cpu", dtype=torch.int32) + torch.distributed.all_reduce(value_t, op=op, group=group) + return value_t.item() + + +def align_tp_groups(value, op): + group = get_tp_group().cpu_group + world_size = get_tp_group().world_size + if world_size <= 1: + return value + value_t = torch.tensor(value, device='cpu') + torch.distributed.all_reduce(value_t, op=op, group=group) + return value_t.item() + + def align_workers(value, op): group = get_world_group().cpu_group world_size = torch.distributed.get_world_size() @@ -320,6 +338,8 @@ def __init__(self, model, vllm_config, layer_names, is_causal, sampler): self.is_pooler = hasattr(self.model, "_pooler") self.is_causal = is_causal self.use_merged_prefill = get_config().merged_prefill + self.dp_awared_padding = \ + self.vllm_config.parallel_config.data_parallel_size > 1 model_config = getattr(self.model, "config", None) self.model_is_mrope = uses_mrope(model_config) @@ -536,7 +556,10 @@ def forward(self, *args, **kwargs): attn_meta = kwargs.pop('attn_metadata') if 'kv_caches' in kwargs: kwargs.pop('kv_caches') - with set_forward_context(attn_meta, self.vllm_config, virtual_engine): + with set_forward_context(attn_meta, + self.vllm_config, + virtual_engine, + dp_awared_padding=self.dp_awared_padding): hidden_states = self.model(*args, **kwargs) if not get_pp_group().is_last_rank: return hidden_states @@ -844,6 +867,10 @@ def __init__( self.graphed_multimodal_buckets: Set[Any] = set() self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH + # Data Parallel + self.dp_size = vllm_config.parallel_config.data_parallel_size + self.dp_awared_padding = self.dp_size > 1 + self._set_gc_threshold() self.use_contiguous_pa = get_config().use_contiguous_pa if vllm_config.speculative_config is not None \ @@ -1071,10 +1098,20 @@ def load_model(self) -> None: # need to be warmed up. Current tested for MRoPE models only. self.add_vision_buckets_to_mrope_models() - def _add_dummy_seq(self, seq_group_metadata_list, is_prompt): + def _add_dummy_seq(self, + seq_group_metadata_list, + is_prompt, + align_worker=False): real_batch_size = len(seq_group_metadata_list) batch_size_padded = self.bucketing_ctx.get_padded_batch_size( real_batch_size, is_prompt) + if self.dp_awared_padding: + if self.is_driver_worker: + batch_size_padded = align_dp_groups( + batch_size_padded, torch.distributed.ReduceOp.MAX) + if align_worker: + batch_size_padded = align_tp_groups( + batch_size_padded, torch.distributed.ReduceOp.MAX) batch_size_padding = batch_size_padded - real_batch_size seq_group_metadata_list = seq_group_metadata_list.copy() @@ -1270,6 +1307,7 @@ def add_vision_buckets_to_mrope_models(self): def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], + align_worker=False, ) -> PreparePromptMetadata: input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] @@ -1457,6 +1495,14 @@ def _prepare_prompt( self.bucketing_ctx.get_padded_prompt_seq_len(target_query_len), self.block_size) + if self.dp_awared_padding: + if self.is_driver_worker: + max_prompt_len = align_dp_groups( + max_prompt_len, torch.distributed.ReduceOp.MAX) + if align_worker: + max_prompt_len = align_tp_groups( + max_prompt_len, torch.distributed.ReduceOp.MAX) + lora_ids: List[int] = [] for seq_group_metadata, context_len in zip(seq_group_metadata_list, context_lens): @@ -1617,6 +1663,7 @@ def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], output=None, + align_worker=False, ) -> PrepareDecodeMetadata: input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] @@ -1766,6 +1813,13 @@ def _prepare_decode( block_bucket_size = max(max(block_list) + 1, len(block_list)) block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( block_bucket_size) + if self.dp_awared_padding: + if self.is_driver_worker: + block_bucket_size = align_dp_groups( + block_bucket_size, torch.distributed.ReduceOp.MAX) + if align_worker: + block_bucket_size = align_tp_groups( + block_bucket_size, torch.distributed.ReduceOp.MAX) indices: List[Any] indices = [None] * block_bucket_size for i, bid in enumerate(block_list): @@ -1775,6 +1829,13 @@ def _prepare_decode( else: block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( len(block_list)) + if self.dp_awared_padding: + if self.is_driver_worker: + block_bucket_size = align_dp_groups( + block_bucket_size, torch.distributed.ReduceOp.MAX) + if align_worker: + block_bucket_size = align_tp_groups( + block_bucket_size, torch.distributed.ReduceOp.MAX) padding_fn = lambda tensor, pad_value: pad_list( tensor, block_bucket_size, pad_value) @@ -1805,6 +1866,13 @@ def _prepare_decode( real_batch_size = len(seq_group_metadata_list) batch_size_padded = self.bucketing_ctx.get_padded_batch_size( real_batch_size, False) + if self.dp_awared_padding: + if self.is_driver_worker: + batch_size_padded = align_dp_groups( + batch_size_padded, torch.distributed.ReduceOp.MAX) + if align_worker: + batch_size_padded = align_tp_groups( + batch_size_padded, torch.distributed.ReduceOp.MAX) batch_size_padding = batch_size_padded - real_batch_size if batch_size_padding > 0: encoder_seq_lens.extend(encoder_seq_lens[0] @@ -1966,7 +2034,8 @@ def _compute_alibi_block(self, block_tables, seq_lens, num_blocks): def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - finished_requests_ids: Optional[List[str]] = None + finished_requests_ids: Optional[List[str]] = None, + align_worker=False, ) -> Tuple[TModelInputForHPU, SamplingMetadata]: if len(seq_group_metadata_list) == 0: return self._model_input_cls(), None @@ -1988,7 +2057,8 @@ def prepare_input_tensors( self.profiler.start('internal', base_event_name) seq_group_metadata_list, real_batch_size, batch_size_padded = ( - self._add_dummy_seq(seq_group_metadata_list, is_prompt)) + self._add_dummy_seq(seq_group_metadata_list, is_prompt, + align_worker)) prefill_reqs = [] decode_reqs = [] @@ -2011,7 +2081,7 @@ def prepare_input_tensors( multi_modal_kwargs, slot_mapping, lora_ids, - ) = self._prepare_prompt(prefill_reqs) + ) = self._prepare_prompt(prefill_reqs, align_worker=align_worker) ( decode_input_tokens, decode_input_positions, @@ -2021,7 +2091,7 @@ def prepare_input_tensors( decode_lora_requests, decode_slot_mapping, decode_lora_ids, - ) = self._prepare_decode(decode_reqs) + ) = self._prepare_decode(decode_reqs, align_worker=align_worker) selected_token_indices = None if not self.is_pooler: @@ -2400,6 +2470,20 @@ def profile_run(self) -> None: ) return + def _dummy_run(self, max_num_batched_tokens: int) -> None: + assert max_num_batched_tokens == 1 + self.warmup_scenario(batch_size=max_num_batched_tokens, + seq_len=1, + ctx=1, + is_prompt=False, + kv_caches=None, + is_pt_profiler_run=False, + num_patches=UNSET_NUM_PATCHES, + is_lora_profile_run=True, + num_iters=1, + align_worker=True) + return + def _remove_duplicate_submodules(self): model = self.get_model() if hasattr(model, "model"): @@ -2423,7 +2507,9 @@ def warmup_scenario(self, is_pt_profiler_run=False, is_lora_profile_run=False, temperature=0, - num_patches=None) -> None: + num_patches=None, + num_iters=3, + align_worker=False) -> None: phase = 'prompt' if is_prompt else 'decode' use_graphs = self._use_graphs(num_patches) scenario_name = ("warmup_" @@ -2457,7 +2543,7 @@ def warmup_scenario(self, for idx in range(batch_size) ] self.profiler.start('internal', scenario_name) - times = 3 if use_graphs or is_pt_profiler_run else 1 + times = num_iters if use_graphs or is_pt_profiler_run else 1 if is_prompt: seqs = [ self.create_dummy_seq_group_metadata( @@ -2489,7 +2575,8 @@ def warmup_scenario(self, profiler = setup_profiler() profiler.start() for time_index in range(times): - inputs = self.prepare_model_input(seqs) + inputs = self.prepare_model_input_align_worker( + seqs, align_worker=align_worker) # Chendi: Necessary fix for warmup with TP>1 if time_index == 0: if self.is_driver_worker: @@ -3057,6 +3144,28 @@ def prepare_model_input( seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForHPUWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + The result tensors and data structure also batches input in prefill + -> decode order. For example, + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + If cuda graph is required, this API automatically pads inputs. + """ + return self.prepare_model_input_align_worker(seq_group_metadata_list, + virtual_engine, + finished_requests_ids, + False) + + @torch.inference_mode() + def prepare_model_input_align_worker( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None, + align_worker: bool = False, ) -> ModelInputForHPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. @@ -3073,7 +3182,7 @@ def prepare_model_input( self.profiler_counter_helper.capture_seq_group_metadata_stats( seq_group_metadata_list=seq_group_metadata_list) model_input, sampling_metadata = self.prepare_input_tensors( - seq_group_metadata_list, finished_requests_ids) + seq_group_metadata_list, finished_requests_ids, align_worker) assert model_input.attn_metadata is not None is_prompt = model_input.attn_metadata.is_prompt @@ -3384,7 +3493,7 @@ def try_revert_dummy_output_tokens(): torch.hpu.synchronize() import torch.distributed as dist if dist.is_initialized(): - dist.barrier() + get_tp_group().barrier() else: logger.debug("Bypassing model execution") diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 2100fe9b7d2..543ee3de949 100755 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -536,11 +536,14 @@ def init_worker_distributed_environment( get_pp_group().all_reduce(torch.zeros(1).to('hpu')) if torch.distributed.is_initialized(): torch_world_size = torch.distributed.get_world_size() - if torch_world_size != parallel_config.world_size: + expected_size = parallel_config.world_size *\ + parallel_config.data_parallel_size + if torch_world_size != expected_size: raise RuntimeError( "torch.distributed is already initialized but the torch world " - "size does not match parallel_config.world_size " - f"({torch_world_size} vs. {parallel_config.world_size}).") + "size does not match parallel_config.world_size * " + "parallel_config.data_parallel_size " + f"({torch_world_size} vs. {expected_size}).") elif not distributed_init_method: raise ValueError( "distributed_init_method must be set if torch.distributed " @@ -558,7 +561,8 @@ def init_worker_distributed_environment( device = hpu_device_string() dummy_tensor_hpu = torch.ones(1).to(device) torch.distributed.all_reduce(dummy_tensor_hpu) - assert dummy_tensor_hpu.item() == parallel_config.world_size + assert dummy_tensor_hpu.item( + ) == parallel_config.world_size * parallel_config.data_parallel_size ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 0af44df6f08..f4a794265d3 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -218,6 +218,20 @@ def prepare_model_input( """ raise NotImplementedError + def prepare_model_input_align_worker( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None, + align_worker: bool = False, + ) -> T: + """ + Prepare the inputs to ModelRunnerBase.execute_model from an execution + request. This method may move data to the worker's local device. It is + not allowed to communicate with other workers or devices. + """ + raise NotImplementedError + @abstractmethod def get_model(self) -> nn.Module: raise NotImplementedError diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index e5662e69343..42ecfec65ad 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -320,6 +320,13 @@ def _get_worker_input_from_broadcast( if not broadcast_data: return None + # This is for Data Parallel (DP) workers that run dummy batch + # execution. + if "is_dummy_batch" in broadcast_data and broadcast_data[ + "is_dummy_batch"]: + self.model_runner._dummy_run(1) # type: ignore[attr-defined] + return None + worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) model_input = ( self.model_runner.make_model_input_from_broadcasted_tensor_dict( @@ -376,6 +383,11 @@ def prepare_input( # notify all other workers to stop their execution loop. broadcast_tensor_dict({}, src=0) return None + elif execute_model_req.is_dummy_batch: + if self.do_metadata_broadcast: + broadcast_tensor_dict({"is_dummy_batch": True}, src=0) + self.model_runner._dummy_run(1) # type: ignore[attr-defined] + return None return self._get_driver_input_and_broadcast(execute_model_req) else: return self._get_worker_input_from_broadcast() @@ -392,6 +404,21 @@ def execute_model( start_time = time.perf_counter() inputs = self.prepare_input(execute_model_req) + + # Need to keep worker running when executing dummy batch under DP + # scenario + if self.is_driver_worker: + if self.do_metadata_broadcast: + is_dummy_batch = execute_model_req and\ + execute_model_req.is_dummy_batch + broadcast_tensor_dict({"is_dummy_batch": is_dummy_batch}, + src=0) + else: + broadcast_data = broadcast_tensor_dict(src=0) + if "is_dummy_batch" in broadcast_data and broadcast_data[ + "is_dummy_batch"]: + return SamplerOutput(outputs=[], sampled_token_ids=None) + if inputs is None: return None