diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 029478ccece..77281a80434 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1147,6 +1147,14 @@ def _process_model_outputs(self, self.seq_id_to_seq_group, use_cache=self.use_cached_outputs) if request_output: + #wa for MTP opt + #todo , sync the final text. + token_ids=request_output.outputs[0].token_ids + if token_ids[-2:] == [10, 12]: + token_ids = token_ids[:-2] + elif token_ids[-1] == 10: + token_ids = token_ids[:-1] + request_output.outputs[0].text=self.tokenizer.tokenizer.decode(token_ids, skip_special_tokens=True) ctx.request_outputs.append(request_output) # When we process a single request, we skip it for the next time, @@ -1500,7 +1508,7 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: # queued control plane messages, such as add/remove lora adapters. logger.debug("Stopping remote worker execution loop.") self.model_executor.stop_remote_worker_execution_loop() - + return ctx.request_outputs def _has_remaining_steps( diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 3bca0bee35a..f1cd7a71369 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -47,7 +47,7 @@ def maybe_stop_sequence( # Check if the sequence has generated the EOS token. if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == seq.eos_token_id): + and ( seq.get_last_token_id() == seq.eos_token_id or seq.get_last_n_token_id(3)==seq.eos_token_id)) : # Remove the last EOS token unless explicitly specified # This prevents unintended exposure of the EOS token if new_char_count and ( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index ec805c934e4..7860141794f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -7,8 +7,10 @@ from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy, QuantizationType) -from compressed_tensors.utils import combine_shards - +try: + from compressed_tensors.utils import combine_shards +except: + pass from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index c826fb8aa8e..9cff66efcb7 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -17,6 +17,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +import habana_frameworks.torch as htorch from .deepseek_v2 import (DeepseekV2DecoderLayer, get_spec_layer_idx_from_weight_name) @@ -83,9 +84,10 @@ def forward( inputs_embeds = self.enorm(inputs_embeds) previous_hidden_states = self.hnorm(previous_hidden_states) + hidden_states = self.eh_proj( torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) - + hidden_states, residual = self.mtp_block(positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index fa291f78ffc..dc12dfc5c8a 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -27,6 +27,7 @@ import torch from torch import nn from transformers import PretrainedConfig +import habana_frameworks.torch as htorch from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile @@ -264,6 +265,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.kv_b_proj") # O projection. + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, bias=False, @@ -300,6 +302,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: + if is_hpu: # need reshape from tensor(x0, y0) to tensor(x1) for hpu _batch_size = positions.shape[0] @@ -402,7 +405,6 @@ def __init__( self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim self.v_head_dim = v_head_dim - self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank @@ -410,11 +412,9 @@ def __init__( tp_size = get_tensor_model_parallel_world_size() assert num_heads % tp_size == 0 self.num_local_heads = num_heads // tp_size - self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - if self.q_lora_rank is not None: self.q_a_proj = ReplicatedLinear(self.hidden_size, self.q_lora_rank, @@ -503,6 +503,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: + if self.q_lora_rank is not None: ckq = self.q_a_proj(hidden_states)[0] hidden_states_or_q_c = self.q_a_layernorm(ckq) @@ -592,6 +593,7 @@ def forward( else: hidden_states, residual = self.input_layernorm( hidden_states, residual) + hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -630,7 +632,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.embed_tokens") else: self.embed_tokens = PPMissingLayer() - self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: DeepseekV2DecoderLayer( @@ -667,15 +668,16 @@ def forward( hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) + residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): layer = self.layers[i] kvcaches = None if kv_caches is None else kv_caches[i - self.start_layer] + hidden_states, residual = layer(positions, hidden_states, kvcaches, attn_metadata, residual) diff --git a/vllm/sequence.py b/vllm/sequence.py index a06fe333cca..7f76f89b265 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -316,8 +316,12 @@ def get_num_computed_tokens(self) -> int: def update_num_computed_tokens(self, num_new_computed_tokens: int): """Update number of tokens computed so far.""" self._num_computed_tokens += num_new_computed_tokens - assert self._num_computed_tokens <= self.get_len(), ( - self._num_computed_tokens, self.get_len()) + # assert self._num_computed_tokens <= self.get_len(), ( + # self._num_computed_tokens, self.get_len()) + + # if self._num_computed_tokens > self.get_len(): + # self._num_computed_tokens-=1 + # c=0 # If all tokens are computed, it means it is in decoding phase. if self.get_num_uncomputed_tokens() == 0: self._stage = SequenceStage.DECODE @@ -351,6 +355,13 @@ def get_last_token_id(self) -> int: return self._prompt_token_ids[-1] return self._output_token_ids[-1] + def get_last_n_token_id(self,n) -> int: + if self.get_output_len() Tuple[int, ...]: return self.prompt_token_ids @@ -572,6 +583,8 @@ def get_prompt_token_ids(self) -> Tuple[int, ...]: def get_last_token_id(self) -> int: return self.data.get_last_token_id() + def get_last_n_token_id(self,n): + return self.data.get_last_n_token_id(n) def get_output_token_ids(self) -> Tuple[int, ...]: return self.data.get_output_token_ids() @@ -1220,7 +1233,7 @@ class HiddenStates(msgspec.Struct, array_like=True, def __post_init__(self): if self.seq_group_metadata_list is not None: - assert len(self.seq_group_metadata_list) == len(self.hidden_states) + # assert len(self.seq_group_metadata_list) == len(self.hidden_states) self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) @property @@ -1322,6 +1335,10 @@ class ExecuteModelRequest( # Dummy batch is_dummy_batch: bool = False + + expand: Optional[Callable[[], Tuple[Any, Any]]] = None + hack_indices_of_seq_with_bonus_tokens: Optional[List[int]] = None + expand_req:Optional["ExecuteModelRequest"] =None @property def is_first_multi_step(self) -> bool: # TODO(will) make this be able to handle batches with variable number of diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 6561c38e78e..89fca5924e7 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -43,6 +43,7 @@ def score_proposals( self, execute_model_req: ExecuteModelRequest, proposals: SpeculativeProposals, + accepted_token_id: Optional[torch.Tensor] = None, ) -> SpeculativeScores: """Score the proposed tokens via the scorer model. @@ -80,7 +81,7 @@ def score_proposals( target_sampler_output = self._scorer_worker.execute_model( execute_model_req=execute_model_req.clone( - seq_group_metadata_list=target_seq_group_metadata_list)) + seq_group_metadata_list=target_seq_group_metadata_list), accepted_token_id=accepted_token_id) assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index dd085ad7763..89a811688a0 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -75,6 +75,7 @@ def get_spec_proposals( # If set, this contains all sequence IDs that were assigned # bonus tokens in their last forward pass. seq_ids_with_bonus_token_in_last_step: Set[int], + accepted_token_id: Optional[torch.Tensor] = None, ) -> SpeculativeProposals: raise NotImplementedError @@ -94,5 +95,7 @@ def score_proposals( self, execute_model_req: ExecuteModelRequest, proposals: SpeculativeProposals, + accepted_token_id: Optional[torch.Tensor] = None, + ) -> SpeculativeScores: raise NotImplementedError diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 755ac60a929..1b8c5c729ca 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -2,7 +2,7 @@ import copy import weakref -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Set, Tuple, Optional import torch @@ -61,6 +61,7 @@ def sampler_output( execute_model_req: ExecuteModelRequest, sample_len: int, seq_ids_with_bonus_token_in_last_step: Set[int], + accepted_token_id: Optional[torch.Tensor] = None, ) -> Tuple[List[SamplerOutput], bool]: """Run the model forward pass sample_len times. Returns the list of sampler output, one per model forward pass, along with indicator of @@ -69,14 +70,51 @@ def sampler_output( For multi step worker, this indicator shall be True. """ + rank=torch.distributed.get_rank() + self._raise_if_unsupported(execute_model_req) - # Expand the batch for sequences with a bonus token. - # Perform a forward pass on the expanded batch and filter the - # response to retain only the original sequences' responses. - expanded_request, indices_of_seq_with_bonus_tokens =\ - self._expand_execute_model_request( - execute_model_req, seq_ids_with_bonus_token_in_last_step) - + + if accepted_token_id is not None: + def bind_expand_fn_to_request(execute_model_req, accepted_token_id, seq_ids_with_bonus_token_in_last_step, expand_fn): + def expand(): + if accepted_token_id is not None: + accepted_token_id_=accepted_token_id.cpu() + for seq_index, sg in enumerate(execute_model_req.seq_group_metadata_list): + seq_data_iter = sg.seq_data.values() + last_token_id = accepted_token_id_[seq_index][-1] + token1 = accepted_token_id_[seq_index][0] + if last_token_id == -1: + for seq_id in sg.seq_data: + seq_ids_with_bonus_token_in_last_step.discard(seq_id) + token1 = accepted_token_id_[seq_index][0] + for seq in seq_data_iter: + seq.output_token_ids = seq.output_token_ids[:-2] + (token1,) + #seq._new_appended_tokens = seq._new_appended_tokens[:-3] + [token1] + seq._num_computed_tokens -= 1 + else: + token2 = accepted_token_id_[seq_index][1] + for seq in seq_data_iter: + seq.output_token_ids = seq.output_token_ids[:-2] + (token1, token2) + #seq._new_appended_tokens = seq._new_appended_tokens[:-3] + [token1, token2] + return expand_fn(execute_model_req, seq_ids_with_bonus_token_in_last_step) + + execute_model_req.expand = expand + bind_expand_fn_to_request( + execute_model_req, + accepted_token_id, + seq_ids_with_bonus_token_in_last_step, + self._expand_execute_model_request, + ) + expanded_request=execute_model_req + + else: + expanded_request, indices_of_seq_with_bonus_tokens =\ + self._expand_execute_model_request( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + + + + # Run model sample_len times. model_outputs: List[SamplerOutput] = [] if current_platform.is_cuda_alike() and isinstance( @@ -99,21 +137,27 @@ def sampler_output( self.worker.model_runner.return_hidden_states = True for _ in range(sample_len): model_output: List[SamplerOutput] = self.worker.execute_model( - execute_model_req=expanded_request) + execute_model_req=expanded_request, accepted_token_id=accepted_token_id) assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] self._maybe_update_previous_hidden_states( model_output, expanded_request) - + if execute_model_req.hack_indices_of_seq_with_bonus_tokens is not None: + indices_of_seq_with_bonus_tokens=execute_model_req.hack_indices_of_seq_with_bonus_tokens + expanded_request=execute_model_req.expand_req + execute_model_req.hack_indices_of_seq_with_bonus_tokens=None + execute_model_req.expand_req=None self._append_new_tokens( model_output, expanded_request.seq_group_metadata_list, indices_of_seq_with_bonus_tokens) model_outputs.append(model_output) - + # move indices to device to avoid stream sync indices_of_seq_with_bonus_tokens = torch.tensor( indices_of_seq_with_bonus_tokens, device=self.device) + # if model_outputs[0].sampled_token_ids[0][0].item()==2501: + filtered_model_outputs = self._filter_model_output( model_outputs, indices_of_seq_with_bonus_tokens) return filtered_model_outputs, True @@ -136,7 +180,7 @@ def _expand_execute_model_request( execute_model_req: ExecuteModelRequest, seq_with_bonus_token_in_last_step: set, ) -> Tuple[ExecuteModelRequest, List[int]]: - """ + """` Expands the execute model request based on sequences with bonus tokens. @@ -238,12 +282,13 @@ def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, seq_ids_with_bonus_token_in_last_step: set, + accepted_token_id: Optional[torch.Tensor] = None, ) -> SpeculativeProposals: """Produce speculations given an input batch of sequences. The number of speculative tokens per sequence is determined by max_proposal_len. """ return self._proposer.get_spec_proposals( - execute_model_req, seq_ids_with_bonus_token_in_last_step) + execute_model_req, seq_ids_with_bonus_token_in_last_step, accepted_token_id) @staticmethod def _append_new_tokens( diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 7bf3dfa97c9..c661155250b 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -339,6 +339,22 @@ def __init__( self._disable_log_stats = disable_log_stats self._num_spec_prefill_steps = num_spec_prefill_steps + self._pending_data = None + self._pending_step = 0 + + #self.cached_step_outputs: List[torch.Tensor] = [] + self.cached_step_accepted_tokens: List[torch.Tensor] = [] + self.cached_step_target_logprobs: List[torch.Tensor] = [] + self.cached_step_prompt_logprobs: List[torch.Tensor] = [] + + self.accepted_token_ids_ = None + self.target_logprobs_ = None + self.prompt_logprobs_ = None + self.hpu_opt=True + + + + def init_device(self) -> None: """Initialize both scorer and proposer models. """ @@ -760,11 +776,29 @@ def _run_speculative_decoding_step( # Pass last hidden states from target model to proposer execute_model_req.previous_hidden_states = self.previous_hidden_states self.previous_hidden_states = None + if self.hpu_opt: + if 1:#self.rank == self._driver_rank: + self._pending_step = self._pending_step + 1 + if self._pending_step > 1: + #self.accepted_token_ids_=self.cached_step_accepted_tokens.pop(0).cpu() + #print(f" cache before pop{self.cached_step_accepted_tokens=}") + self.accepted_token_ids_=self.cached_step_accepted_tokens.pop(0) + + self.target_logprobs_=self.cached_step_target_logprobs[0] + self.prompt_logprobs_=self.cached_step_prompt_logprobs[0] if not self._disable_logprobs else None + + with Timer() as proposal_timer: + #print(f"==============put to spec {self.accepted_token_ids_=}===") # Generate proposals using draft worker. - proposals = self.proposer_worker.get_spec_proposals( - execute_model_req, self._seq_with_bonus_token_in_last_step) + if self.hpu_opt: + proposals = self.proposer_worker.get_spec_proposals( + execute_model_req, self._seq_with_bonus_token_in_last_step, self.accepted_token_ids_) + else: + proposals = self.proposer_worker.get_spec_proposals( + execute_model_req, self._seq_with_bonus_token_in_last_step) + if not self._allow_zero_draft_token_step and proposals.no_proposals: #TODO: Fix it #5814 @@ -774,11 +808,20 @@ def _run_speculative_decoding_step( execute_model_req.previous_hidden_states = None with Timer() as scoring_timer: - proposal_scores = self.scorer.score_proposals( - execute_model_req, - proposals, - ) - + #print(f"==============put to score {self.accepted_token_ids_=}===") + + if self.hpu_opt: + proposal_scores = self.scorer.score_proposals( + execute_model_req, + proposals, + self.accepted_token_ids_, + ) + else: + proposal_scores = self.scorer.score_proposals( + execute_model_req, + proposals, + ) + _, (non_spec_seqs, non_spec_indices) = split_batch_by_proposal_len( execute_model_req.seq_group_metadata_list, proposals.proposal_lens) # With prefill chunking enabled, `non_spec_seqs` contains prefills too: @@ -802,12 +845,25 @@ def _run_speculative_decoding_step( accepted_token_ids, target_logprobs = self._verify_tokens( execute_model_req.seq_group_metadata_list, proposal_scores, proposals, execute_model_req.num_lookahead_slots) - + stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots, scoring_timer.elapsed_time_ms, verification_timer.elapsed_time_ms) - - return self._create_output_sampler_list( + #print("!!!accept token id",accepted_token_ids) + + self._pending_data = { + "seq_group_metadata_list": execute_model_req.seq_group_metadata_list, + "k": execute_model_req.num_lookahead_slots, + "stage_times": stage_times, + } + + self.cached_step_accepted_tokens.append(accepted_token_ids) + #print(f" cache after append{self.cached_step_accepted_tokens=}") + self.cached_step_target_logprobs.append(target_logprobs) + self.cached_step_prompt_logprobs.append(proposal_scores.prompt_logprobs) + + print(f"!!!_create_output_sampler_list {accepted_token_ids}") + tmp = self._create_output_sampler_list( execute_model_req.seq_group_metadata_list, accepted_token_ids, target_logprobs=target_logprobs, @@ -815,6 +871,8 @@ def _run_speculative_decoding_step( if not self._disable_logprobs else None, k=execute_model_req.num_lookahead_slots, stage_times=stage_times) + return tmp + @nvtx_range("spec_decode_worker._verify_tokens") def _verify_tokens( @@ -897,9 +955,10 @@ def _verify_tokens( accepted_index = accepted_token_ids + 1 # Convert -1 to 0 accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b # Drop non-terminal prefill chunks hidden states. - hidden_states = hidden_states[accepted_index != - VLLM_INVALID_TOKEN_ID] - accepted_index = accepted_index[accepted_index != + if not self.hpu_opt: + hidden_states = hidden_states[accepted_index != + VLLM_INVALID_TOKEN_ID] + accepted_index = accepted_index[accepted_index != VLLM_INVALID_TOKEN_ID] assert len(accepted_index) == hidden_states.shape[0] == len( terminal_metadata) @@ -957,9 +1016,23 @@ def _create_output_sampler_list( seq_group_metadata_list) num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list) - + self.sync_last=False # Serialize tensor to CPU Python list. - accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + #dummy token=2 + n=2 + #todo batch>1 compatible + for seq_index, sg in enumerate(seq_group_metadata_list): + for seq in sg.seq_data.values(): + if seq.get_output_len()+n>= seq_group_metadata_list[seq_index].sampling_params.max_tokens: + self.sync_last=True + break + + if not self.hpu_opt or self.sync_last: + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + else: + #hpu_opt + pading_tokens = [[10], [12]] + accepted_token_ids_by_step=[[item[0] for _ in range(batch_size)] for item in pading_tokens] # Construct the output on a per-step, per-sequence basis. # Non-terminal prefill chunks will end up here as rows with just -1s @@ -1030,7 +1103,7 @@ def _create_output_sampler_list( SamplerOutput( outputs=[create_sequence_group_output( **seq_kwargs)])) # type: ignore - + # Decodes, create one SamplerOutput per-step (at most K+1). for step_index in range(num_steps): if all(token_id == -1 for sg, token_id in zip( @@ -1305,3 +1378,4 @@ def prepare_prefill_hidden_states( # align n-1th hidden state with nth token. return HiddenStates(prefill_hidden_states.roll( shifts=1, dims=0)) if prefill_hidden_states is not None else None + diff --git a/vllm/spec_decode/target_model_runner.py b/vllm/spec_decode/target_model_runner.py index 08e773c562b..57a384e15ce 100644 --- a/vllm/spec_decode/target_model_runner.py +++ b/vllm/spec_decode/target_model_runner.py @@ -6,7 +6,7 @@ from vllm.worker.model_runner_base import (ModelRunnerBase, ModelRunnerInputBase, ModelRunnerWrapperBase) - +import torch class TargetModelRunner(ModelRunnerWrapperBase): """Specialized model runner for speculative decoding target model. @@ -31,10 +31,12 @@ def prepare_model_input( seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None, + accepted_token_id: Optional[torch.Tensor] = None, + execute_model_req=None, ) -> ModelRunnerInputBase: model_input: ModelRunnerInputBase =\ self.model_runner.prepare_model_input( - seq_group_metadata_list, virtual_engine, finished_requests_ids) + seq_group_metadata_list, virtual_engine, finished_requests_ids,accepted_token_id=accepted_token_id,execute_model_req=execute_model_req) # If token log probabilities is disabled then skip generating sampler # CPU output. We directly serialize the GPU sampled_token_id tensors # as needed. If log probabilities is enabled then synchronize all the diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index b538923c03e..3d8962fc2fa 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -45,6 +45,7 @@ def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, seq_ids_with_bonus_token_in_last_step: Set[int], + accepted_token_id: Optional[torch.Tensor] = None, ) -> SpeculativeProposals: """Get speculative proposals given the input batch. @@ -76,11 +77,14 @@ def get_spec_proposals( num_lookahead_slots=proposal_len, previous_hidden_states=hidden_states, ) + #print(f"a22!!!{accepted_token_id=}") + maybe_sampler_output, transposed = self._worker.sampler_output( execute_model_req=nonzero_execute_model_req, sample_len=proposal_len, seq_ids_with_bonus_token_in_last_step=\ seq_ids_with_bonus_token_in_last_step, + accepted_token_id=accepted_token_id, ) ( proposal_lens, diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 264b06fe51c..7e69ce50c26 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -55,8 +55,10 @@ MultiModalKwargs, MultiModalPlaceholderMap, MultiModalRegistry) from vllm.sampling_params import SamplingParams + + from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - Logprob, SequenceData, SequenceGroupMetadata, + Logprob, SequenceData, SequenceGroupMetadata,ExecuteModelRequest, SequenceOutput) from vllm.utils import (bind_kv_cache, is_fake_hpu, is_pin_memory_available, make_tensor_with_pad) @@ -70,6 +72,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend + logger = init_logger(__name__) _TYPE_CACHE = {} @@ -1546,6 +1549,7 @@ def _prepare_decode( self.device, non_blocking=True) input_positions = input_positions.to( # type: ignore self.device, non_blocking=True) + block_list = block_list.to( # type: ignore self.device, non_blocking=True) block_groups = block_groups.to( # type: ignore @@ -1603,7 +1607,17 @@ def prepare_input_tensors( seq_group_metadata_list: List[SequenceGroupMetadata], finished_requests_ids: Optional[List[str]] = None, align_worker=False, + accepted_token_id: Optional[torch.Tensor] = None, + execute_model_req=None, ) -> Tuple[TModelInputForHPU, SamplingMetadata]: + + if execute_model_req is not None and execute_model_req.expand is not None: + expanded_request, indices_of_seq_with_bonus_tokens = execute_model_req.expand() + seq_group_metadata_list = expanded_request.seq_group_metadata_list + finished_requests_ids=expanded_request.finished_requests_ids + execute_model_req.hack_indices_of_seq_with_bonus_tokens=indices_of_seq_with_bonus_tokens + execute_model_req.expand_req=expanded_request + if len(seq_group_metadata_list) == 0: return self._model_input_cls(), None @@ -1618,6 +1632,8 @@ def prepare_input_tensors( real_batch_size = None batch_size_padded = None + + self.event_start = self.profiler.get_timestamp_us() is_prompt = seq_group_metadata_list[0].is_prompt base_event_name = 'prompt' if is_prompt else 'decode' @@ -1629,6 +1645,13 @@ def prepare_input_tensors( prefill_reqs = [] decode_reqs = [] + + if accepted_token_id is not None: + valid_tokens = accepted_token_id[accepted_token_id != -1] + + if accepted_token_id.numel()-valid_tokens.numel()==1: + pass + for seq_group_meta in seq_group_metadata_list: if seq_group_meta.is_prompt: prefill_reqs.append(seq_group_meta) @@ -1648,7 +1671,8 @@ def prepare_input_tensors( multi_modal_kwargs, slot_mapping, lora_ids, - ) = self._prepare_prompt(prefill_reqs, align_worker=align_worker) + ) = self._prepare_prompt(prefill_reqs, align_worker=align_worker) + ( decode_input_tokens, decode_input_positions, @@ -1659,7 +1683,8 @@ def prepare_input_tensors( decode_slot_mapping, decode_lora_ids, ) = self._prepare_decode(decode_reqs, align_worker=align_worker) - + + if not self.is_pooler: generators = self.get_generators(finished_requests_ids) sampling_metadata = SamplingMetadata.prepare( @@ -1772,6 +1797,27 @@ def prepare_input_tensors( attn_metadata = prefill_attn_metadata if \ prefill_attn_metadata is not None else decode_attn_metadata + rank = torch.distributed.get_rank() + + + + # print(f"{input_tokens=}") + # print(f"{query_lens=}") + # print(f"{input_positions=}") + # print(f"{lora_requests=}") + # print(f"{sampling_metadata=}") + # print(f"{lora_ids=}") + # print(f"{lora_mapping=}") + # print(f"{real_batch_size=}") + # print(f"{batch_size_padded=}") + # print(f"{seq_lens=}") + # print(f"{attn_metadata.num_decode_tokens=}") + # print(f"{attn_metadata.slot_mapping=}") + # print(f"{attn_metadata.input_positions=}") + # print(f"{attn_metadata.block_usage.shape=}") + # print(f"{attn_metadata.block_groups.shape=}") + + return self._model_input_cls(input_tokens=input_tokens, seq_lens=seq_lens, query_lens=query_lens, @@ -2499,7 +2545,9 @@ def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None + finished_requests_ids: Optional[List[str]] = None, + accepted_token_id: Optional[torch.Tensor] = None, + execute_model_req:Optional[ExecuteModelRequest]=None ) -> ModelInputForHPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. @@ -2513,7 +2561,9 @@ def prepare_model_input( return self.prepare_model_input_align_worker(seq_group_metadata_list, virtual_engine, finished_requests_ids, - False) + False, + accepted_token_id, + execute_model_req,) @torch.inference_mode() def prepare_model_input_align_worker( @@ -2522,6 +2572,8 @@ def prepare_model_input_align_worker( virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None, align_worker: bool = False, + accepted_token_id: Optional[torch.Tensor] = None, + execute_model_req:Optional[ExecuteModelRequest]=None, ) -> ModelInputForHPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. @@ -2537,8 +2589,9 @@ def prepare_model_input_align_worker( if self.profiler.enabled: self.profiler_counter_helper.capture_seq_group_metadata_stats( seq_group_metadata_list=seq_group_metadata_list) + model_input, sampling_metadata = self.prepare_input_tensors( - seq_group_metadata_list, finished_requests_ids, align_worker) + seq_group_metadata_list, finished_requests_ids, align_worker, accepted_token_id,execute_model_req) assert model_input.attn_metadata is not None is_prompt = model_input.attn_metadata.is_prompt @@ -2645,12 +2698,13 @@ def execute_model( num_steps: int = 1, profile_run_mode=False, seqs=None, + accepted_token_id: Optional[torch.Tensor] = None, + execute_model_req=None, is_dummy_run=False, **kwargs, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: warmup_mode = kwargs.get('warmup_mode', False) previous_hidden_states = kwargs.get('previous_hidden_states') - self.has_patched_prev_output = False use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode assert not (use_delayed_sampling and num_steps != 1), \ @@ -2685,6 +2739,21 @@ def execute_model( 0, target_indices, self.cached_step_outputs[i]) htorch.core.mark_step() + + ''' + # if False: # !self.hpu_opt + if self.is_driver_worker: + model_kwargs_broadcast_data = { + "input_tokens": model_input.input_tokens + } + broadcast_tensor_dict(model_kwargs_broadcast_data, src=0) + input_tokens = model_input.input_tokens + + else: + model_kwargs_broadcast_data = broadcast_tensor_dict(src=0) + input_tokens = model_kwargs_broadcast_data["input_tokens"] + + ''' if not model_input.is_first_multi_step: if not model_input.is_last_step: # not first or last multi-step diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 5cf23c0a90b..a652e5689dc 100755 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -243,6 +243,7 @@ def load_model(self): def execute_model( self, execute_model_req: Optional[ExecuteModelRequest] = None, + accepted_token_id: Optional[torch.Tensor] = None, ) -> Optional[List[SamplerOutput]]: # VLLM_HPU_LOG_STEP_GRAPH_COMPILATION - will log graph compilations per engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS! # noqa:E501 # VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL - will log graph compilations per engine step, always, even if there were none # noqa:E501 @@ -305,7 +306,7 @@ def execute_model( return output output = LocalOrDistributedWorkerBase.execute_model( - self, execute_model_req) + self, execute_model_req, accepted_token_id=accepted_token_id) return output @torch.inference_mode() diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 7ae72625540..3d5e17bf053 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -11,8 +11,9 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata + +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata, ExecuteModelRequest if TYPE_CHECKING: from vllm.attention import AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend @@ -216,6 +217,8 @@ def prepare_model_input( seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None, + accepted_token_id: Optional[torch.Tensor] = None, + execute_model_req:Optional[ExecuteModelRequest]=None, ) -> T: """ Prepare the inputs to ModelRunnerBase.execute_model from an execution diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 9d2ddb4615e..e6d4bc40328 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -335,7 +335,8 @@ def _get_worker_input_from_broadcast( return model_input, worker_input, kwargs def _get_driver_input_and_broadcast( - self, execute_model_req: ExecuteModelRequest + self, execute_model_req: ExecuteModelRequest, + accepted_token_id: Optional[torch.Tensor] = None ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: """ Get the driver input and broadcast it to other workers. """ assert self.is_driver_worker @@ -346,7 +347,10 @@ def _get_driver_input_and_broadcast( self.model_runner.prepare_model_input( execute_model_req.seq_group_metadata_list, execute_model_req.virtual_engine, - execute_model_req.finished_requests_ids)) + execute_model_req.finished_requests_ids, + accepted_token_id=accepted_token_id, + execute_model_req=execute_model_req + )) kwargs = extract_previous_hidden_states(execute_model_req) @@ -365,7 +369,8 @@ def _get_driver_input_and_broadcast( def prepare_input( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] = None, + accepted_token_id: Optional[torch.Tensor] = None ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ str, torch.Tensor]]]: """ @@ -386,7 +391,8 @@ def prepare_input( broadcast_tensor_dict({"is_dummy_batch": True}, src=0) self.model_runner._dummy_run(1) return None - return self._get_driver_input_and_broadcast(execute_model_req) + + return self._get_driver_input_and_broadcast(execute_model_req, accepted_token_id) else: return self._get_worker_input_from_broadcast() @@ -396,13 +402,12 @@ def get_model(self) -> nn.Module: def execute_model( self, execute_model_req: Optional[ExecuteModelRequest] = None, + accepted_token_id: Optional[torch.Tensor] = None, ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" start_time = time.perf_counter() - - inputs = self.prepare_input(execute_model_req) - + inputs = self.prepare_input(execute_model_req, accepted_token_id) # Need to keep worker running when executing dummy batch under DP # scenario if self.is_driver_worker: @@ -448,6 +453,8 @@ def execute_model( if self.kv_cache is not None else None, intermediate_tensors=intermediate_tensors, num_steps=num_steps, + accepted_token_id=accepted_token_id, + execute_model_req=execute_model_req, **kwargs, )