From 3034e5a47376aec281d6a363532750051c5bd176 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Fri, 4 Jul 2025 17:05:32 +0800 Subject: [PATCH 1/2] [Misc] Code clean up Signed-off-by: wangxiyuan --- tests/ut/attention/test_attention_mask.py | 107 ++++++ vllm_ascend/attention/attention.py | 105 +----- vllm_ascend/attention/attention_mask.py | 103 ++++++ vllm_ascend/worker/eagle_proposer_v1.py | 49 +-- vllm_ascend/worker/model_runner_v1.py | 383 ++++++++-------------- vllm_ascend/worker/mtp_proposer_v1.py | 39 +-- 6 files changed, 349 insertions(+), 437 deletions(-) create mode 100644 tests/ut/attention/test_attention_mask.py create mode 100644 vllm_ascend/attention/attention_mask.py diff --git a/tests/ut/attention/test_attention_mask.py b/tests/ut/attention/test_attention_mask.py new file mode 100644 index 0000000000..200c2a34c2 --- /dev/null +++ b/tests/ut/attention/test_attention_mask.py @@ -0,0 +1,107 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.attention.attention_mask import AttentionMaskBuilder + + +class TestAttentionMaskBuilder(TestBase): + + def test_init_attention_mask_builder(self): + # generate attention_mask_builder with float16 + attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, + dtype=torch.float16) + self.assertEqual(attention_mask_builder._seq_len_cached, 1024) + self.assertEqual(attention_mask_builder.attn_mask_cache.dtype, + torch.float16) + self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000) + self.assertEqual(attention_mask_builder.attn_mask_cache.shape, + (1024, 1024)) + self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], + torch.tensor(float("-inf"), dtype=torch.float16)) + + # generate attention_mask_builder with int8 + attention_mask_builder = AttentionMaskBuilder(max_seq_len=512, + dtype=torch.int8) + self.assertEqual(attention_mask_builder._seq_len_cached, 512) + self.assertEqual(attention_mask_builder.attn_mask_cache.dtype, + torch.int8) + self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000) + self.assertEqual(attention_mask_builder.attn_mask_cache.shape, + (512, 512)) + self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], + torch.tensor(1, dtype=torch.int8)) + + def test_get_attn_mask(self): + # if the len is less than max_seq_len, the attn_mask_cache will not be updated + attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, + dtype=torch.float16) + attn_mask = attention_mask_builder.get_attn_mask( + max_seq_len=512, dtype=torch.float16, device=torch.device("cpu")) + self.assertEqual(attn_mask.shape, (512, 512)) + self.assertEqual(attn_mask[0][-1], + torch.tensor(float("-inf"), dtype=torch.float16)) + self.assertEqual(attention_mask_builder._seq_len_cached, 1024) + self.assertEqual(attention_mask_builder.attn_mask_cache.shape, + (1024, 1024)) + self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], + torch.tensor(float("-inf"), dtype=torch.float16)) + + # if the len is greater than max_seq_len, the attn_mask_cache will be updated + attn_mask = attention_mask_builder.get_attn_mask( + max_seq_len=2048, dtype=torch.float16, device=torch.device("cpu")) + self.assertEqual(attn_mask.shape, (2048, 2048)) + self.assertEqual(attn_mask[0][-1], + torch.tensor(float("-inf"), dtype=torch.float16)) + self.assertEqual(attention_mask_builder._seq_len_cached, 2048) + self.assertEqual(attention_mask_builder.attn_mask_cache.shape, + (2048, 2048)) + self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], + torch.tensor(float("-inf"), dtype=torch.float16)) + + def test_get_splitfuse_attn_mask(self): + attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, + dtype=torch.float16) + attn_mask = attention_mask_builder.get_splitfuse_attn_mask( + seq_lens=[512], + query_lens=[512], + position=torch.tensor([0]), + dtype=torch.float16, + device=torch.device("cpu"), + ) + self.assertEqual(attn_mask.shape, (1, 512)) + self.assertEqual(attention_mask_builder._seq_len_cached, 1024) + + attn_mask = attention_mask_builder.get_splitfuse_attn_mask( + seq_lens=[2048], + query_lens=[1024], + position=torch.tensor([0]), + dtype=torch.float16, + device=torch.device("cpu"), + ) + self.assertEqual(attn_mask.shape, (1024, 2048)) + + attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, + dtype=torch.int8) + attn_mask = attention_mask_builder.get_splitfuse_attn_mask( + seq_lens=[512], + query_lens=[512], + position=torch.tensor([0]), + dtype=torch.int8, + device=torch.device("cpu"), + ) + self.assertEqual(attn_mask.shape, (1, 512)) diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index 4b545a1242..35cb624e19 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -35,6 +35,7 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.ops.cache import concat_and_cache_mla from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, enable_custom_op, is_310p, nd_to_nz_2d) @@ -44,108 +45,6 @@ _ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128] -def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None): - # Construct lower triangle matrix. - mask_flag = torch.tril( - torch.ones((max_seq_len, max_seq_len), - dtype=torch.bool)).view(max_seq_len, max_seq_len) - # Create upper triangle matrix used to mark mask positions. - mask_flag = ~mask_flag - # Currently for fp16 dtype, the mask value should be set to -inf. - # TODO: Eliminate this part in the future. - if mask_value is None: - if dtype == torch.float16: - mask_value = torch.finfo(torch.float32).min - else: - mask_value = 1 - attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)), - mask_flag, mask_value).to(dtype) - return attn_mask - - -class AttentionMaskBuilder: - - def __init__(self, attn_mask: torch.Tensor): - self._seq_len_cached = attn_mask.shape[0] - self.attn_mask_cache = attn_mask - self.splitfuse_mask_value = -10000 - - @classmethod - def initialize_from_len(cls, - max_seq_len: int, - dtype: torch.dtype = torch.float16, - mask_value: Optional[int] = None): - return cls(generate_attn_mask(max_seq_len, dtype, mask_value)) - - def update_attn_cache(self, seqlen: int, dtype: torch.dtype, - device: torch.device): - if seqlen > self._seq_len_cached or self.attn_mask_cache.dtype != dtype: - self._seq_len_cached = seqlen - self.attn_mask_cache = generate_attn_mask(seqlen, dtype) - if self.attn_mask_cache.device != device: - self.attn_mask_cache = self.attn_mask_cache.to(device) - - def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, - device: torch.device): - self.update_attn_cache(max_seq_len, dtype, device) - return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous() - - def get_decode_attn_mask( - self, - input_lengths: torch.tensor, - max_s: int, - dtype: torch.dtype, - device: torch.device, - ): - self.update_attn_cache(max_s, dtype, device) - return (self.attn_mask_cache.index_select( - 0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous()) - - def get_splitfuse_attn_mask( - self, - seq_lens, - query_lens, - position, - dtype, - device, - ) -> torch.Tensor: - max_seq_len = max(seq_lens, default=0) - if max_seq_len <= self._seq_len_cached: - self.update_attn_cache(max_seq_len, dtype, device) - # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation - # is not the same. Fix this in the future when kernel is ready. - if self.attn_mask_cache.numel( - ) > 1 and self.attn_mask_cache[0][1] > 0: - attn_mask = self.get_attn_mask( # type: ignore - max_seq_len, dtype, device) - attn_mask *= -10000 - else: - attn_mask = self.attn_mask_cache - return torch.index_select(attn_mask, dim=0, - index=position)[:, :max_seq_len] - total_q_len = sum(query_lens) - attn_mask = torch.zeros((total_q_len, max_seq_len), - dtype=dtype, - device="cpu") - - current_row = 0 - for i in range(len(query_lens)): - seq_len = seq_lens[i] - q_len = query_lens[i] - context_len = seq_len - q_len - - assert context_len >= 0 - attn_mask[current_row:current_row + q_len, - context_len:] = self.splitfuse_mask_value - right_tensor = attn_mask[current_row:current_row + q_len, - context_len:seq_len] - right_tensor.masked_fill_( - right_tensor.tril() == self.splitfuse_mask_value, 0) - current_row += q_len - - return attn_mask.to(device, non_blocking=True) - - class AscendAttentionBackend(AttentionBackend): @staticmethod @@ -524,7 +423,7 @@ def __init__(self, input_builder: "ModelInputForNPUBuilder"): self.compress_mask = None self.chunk_mask = None if AscendMetadataBuilder._attn_mask_builder is None: - AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len( + AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder( 128, self.input_builder.runner.model_config.dtype) def _add_seq_group( diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py new file mode 100644 index 0000000000..66ab4140fa --- /dev/null +++ b/vllm_ascend/attention/attention_mask.py @@ -0,0 +1,103 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def _generate_attn_mask(max_seq_len, dtype): + # Construct lower triangle matrix. + mask_flag = torch.tril( + torch.ones((max_seq_len, max_seq_len), + dtype=torch.bool)).view(max_seq_len, max_seq_len) + # Create upper triangle matrix used to mark mask positions. + mask_flag = ~mask_flag + # Currently for fp16 dtype, the mask value should be set to -inf. + # TODO: Eliminate this part in the future. + if dtype == torch.float16: + mask_value = torch.finfo(torch.float32).min + else: + mask_value = 1 + attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)), + mask_flag, mask_value).to(dtype) + return attn_mask + + +class AttentionMaskBuilder: + + def __init__( + self, + max_seq_len: int, + dtype: torch.dtype, + ): + attn_mask = _generate_attn_mask(max_seq_len, dtype) + + self._seq_len_cached = attn_mask.shape[0] + self.attn_mask_cache = attn_mask + self.splitfuse_mask_value = -10000 + + def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, + device: torch.device): + self._update_attn_cache(max_seq_len, dtype, device) + return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous() + + def get_splitfuse_attn_mask( + self, + seq_lens, + query_lens, + position, + dtype, + device, + ) -> torch.Tensor: + max_seq_len = max(seq_lens, default=0) + if max_seq_len <= self._seq_len_cached: + self._update_attn_cache(max_seq_len, dtype, device) + # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation + # is not the same. Fix this in the future when kernel is ready. + if self.attn_mask_cache.numel( + ) > 1 and self.attn_mask_cache[0][1] > 0: + attn_mask = self.get_attn_mask( # type: ignore + max_seq_len, dtype, device) + attn_mask *= -10000 + else: + attn_mask = self.attn_mask_cache + return torch.index_select(attn_mask, dim=0, + index=position)[:, :max_seq_len] + total_q_len = sum(query_lens) + attn_mask = torch.zeros((total_q_len, max_seq_len), + dtype=dtype, + device="cpu") + current_row = 0 + for i in range(len(query_lens)): + seq_len = seq_lens[i] + q_len = query_lens[i] + context_len = seq_len - q_len + + assert context_len >= 0 + attn_mask[current_row:current_row + q_len, + context_len:] = self.splitfuse_mask_value + right_tensor = attn_mask[current_row:current_row + q_len, + context_len:seq_len] + right_tensor.masked_fill_( + right_tensor.tril() == self.splitfuse_mask_value, 0) + current_row += q_len + + return attn_mask.to(device, non_blocking=True) + + def _update_attn_cache(self, seqlen: int, dtype: torch.dtype, + device: torch.device): + if seqlen > self._seq_len_cached: + self._seq_len_cached = seqlen + self.attn_mask_cache = _generate_attn_mask(seqlen, dtype) + if self.attn_mask_cache.device != device: + self.attn_mask_cache = self.attn_mask_cache.to(device) diff --git a/vllm_ascend/worker/eagle_proposer_v1.py b/vllm_ascend/worker/eagle_proposer_v1.py index 3a82018402..fc074d58c1 100644 --- a/vllm_ascend/worker/eagle_proposer_v1.py +++ b/vllm_ascend/worker/eagle_proposer_v1.py @@ -14,7 +14,7 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.v1.sample.metadata import SamplingMetadata -from vllm_ascend.attention.attention import AttentionMaskBuilder +from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState logger = init_logger(__name__) @@ -74,8 +74,8 @@ def __init__(self, mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000) self.attn_mask_len = min(self.model_config.max_model_len, int(mask_len)) - self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len( - self.attn_mask_len, self.dtype) + self.attn_mask_builder = AttentionMaskBuilder(self.attn_mask_len, + self.dtype) def _make_attention_mask( self, @@ -384,46 +384,3 @@ def prepare_eagle_input_sequential(out_tensor: torch.Tensor, (target_indices < end_pos) & \ (offset_tensor < num_tokens) out_tensor[target_indices[mask]] = values_to_store[mask] - - -# NOTE(woosuk): Currently, the below code is not used and we always use argmax -# to sample the draft tokens. We will use this after we find a way to manage -# the draft prob tensor. -# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details. -# FIXME(woosuk): The logic here is duplicated with the main sampling code. -# We should refactor this to reuse the same sampling implementation. -def compute_probs_and_sample_next_token( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> tuple[torch.Tensor, torch.Tensor]: - if sampling_metadata.all_greedy: - # For greedy requests, draft_probs is not used in rejection sampling. - # Therefore, we can just return the logits. - probs = logits - next_token_ids = logits.argmax(dim=-1) - return next_token_ids, probs - - is_greedy = sampling_metadata.temperature == -1 - temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) - logits.div_(temperature.view(-1, 1)) - probs = logits.softmax(dim=-1, dtype=torch.float32) - - # NOTE(woosuk): Currently, we ignore most of the sampling parameters in - # generating the draft tokens. We only use the temperature. While this - # could degrade the acceptance rate, it does not affect the distribution - # of the generated tokens after rejection sampling. - - # TODO(woosuk): Consider seeds. - q = torch.empty_like(probs) - q.exponential_() - # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs - # will be used later for rejection sampling. - next_token_ids = probs.div(q).argmax(dim=-1).view(-1) - if not sampling_metadata.all_random: - greedy_token_ids = probs.argmax(dim=-1) - next_token_ids = torch.where( - is_greedy, - greedy_token_ids, - next_token_ids, - ) - return next_token_ids, probs diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e74ece3aea..e5e0ce4c63 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -68,10 +68,12 @@ scatter_mm_placeholders) from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.attention.attention import AttentionMaskBuilder +from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import (AscendAttentionState, AscendMetadata) -from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata +from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata +from vllm_ascend.attention.mla_v1 import (AscendMLAMetadata, + CommonAttentionMetadata) from vllm_ascend.platform import NPUPlatform from vllm_ascend.pool.metadata import PoolingMetadata from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler @@ -195,10 +197,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): device=self.device) self.graph_block_tables = np.zeros( - (self.max_num_reqs, - (self.model_config.max_model_len + self.block_size - 1) // - self.block_size), - dtype=np.int32) + (self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32) # Set up Attention self.attn_backend = get_attn_backend( @@ -211,13 +210,17 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): ) self.attn_metadata_builder = self.attn_backend.get_builder_cls()( weakref.proxy(self)) + self.attn_mask_builder = AttentionMaskBuilder( + min(self.model_config.max_model_len, + int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype) # Set up speculative decoding. self.use_aux_hidden_state_outputs = False self.use_spec_decode = False self.spec_attn_mask = None self.use_eagle = False - self.drafter = None + self.drafter: Optional[Union[NgramProposer, NgramProposer, + MtpProposer]] = None if self.speculative_config: self.use_spec_decode = True self.spec_attn_mask = torch.triu(torch.ones(2048, @@ -317,19 +320,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): reversed( self.vllm_config.compilation_config.cudagraph_capture_sizes)) - # NOTE: Pre-construct a mask matrix to improve the efficiency of - # attention mask construction during inference. - # Note that the length of the matrix needs to be carefully balanced: a - # matrix that is too large will consume excessive VRAM, while a matrix - # that is too small will require dynamic concatenation during inference, - # leading to performance degradation. - # Therefore, an environment variable is added here to dynamically set - # the size of the pre-constructed mask matrix based on requirements. - mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000) - attn_mask_len = min(self.model_config.max_model_len, int(mask_len)) - self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len( - attn_mask_len, self.dtype) - self.new_kv_cache_bytes = -1 self.torchair_compiled_model = None # type: ignore self.torchair_compiled_models = {} # type: ignore @@ -643,7 +633,8 @@ def _get_forward_metadata_across_dp( def get_eagle_atten_dict( self, scheduler_output: "SchedulerOutput", - ) -> dict[str, AscendMetadata]: + ) -> dict[str, Union[AscendMetadata, AscendMLAMetadata, + AscendTorchairMetadata]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -754,7 +745,8 @@ def get_eagle_atten_dict( self.seq_lens[num_reqs:].fill_(0) self.query_start_loc[num_reqs + 1:].fill_(-1) - attn_metadata: dict[str, AscendMetadata] = {} + attn_metadata: dict[str, Union[AscendMetadata, AscendMLAMetadata, + AscendTorchairMetadata]] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -969,7 +961,8 @@ def _process_reqs( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> tuple[SpecDecodeMetadata, torch.Tensor, SpecDecodeMetadata, + ) -> tuple[Union[AscendMetadata, AscendMLAMetadata, + AscendTorchairMetadata], torch.Tensor, SpecDecodeMetadata, torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray]: # Check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -1079,11 +1072,11 @@ def _process_reqs( else: attn_state = AscendAttentionState.PrefillCacheHit - attn_mask = self._make_attention_mask(seq_lens=seq_lens, - query_lens=num_scheduled_tokens, - position=positions, - attn_state=attn_state) - self.attn_mask = attn_mask + self.attn_mask = self._make_attention_mask( + seq_lens=seq_lens, + query_lens=num_scheduled_tokens, + position=positions, + attn_state=attn_state) self.attn_state = attn_state # type: ignore extra_builder_kwargs = {} @@ -1099,10 +1092,6 @@ def _process_reqs( self.seq_lens[num_reqs:].fill_(0) self.query_start_loc[num_reqs + 1:].fill_(-1) - query_start_loc = self.query_start_loc[:num_reqs + 1] - seq_lens = self.seq_lens[:num_reqs] - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, seq_lens=seq_lens) with_prefill = attn_state not in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] @@ -1126,6 +1115,10 @@ def _process_reqs( extra_builder_kwargs['graph_pad_size'] = graph_pad_size if self.vllm_config.model_config.use_mla: + query_start_loc = self.query_start_loc[:num_reqs + 1] + seq_lens = self.seq_lens[:num_reqs] + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, seq_lens=seq_lens) attn_metadata = self.attn_metadata_builder.build( # type: ignore num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, @@ -1415,98 +1408,24 @@ def _get_spec_token_ids( positions: torch.Tensor, num_scheduled_tokens: int, hidden_states: torch.Tensor, - attn_metadata: SpecDecodeMetadata, + attn_metadata: Union[AscendMetadata, AscendMLAMetadata, + AscendTorchairMetadata], aux_hidden_states: torch.Tensor = None, ) -> Optional[list[list[int]]]: if not self.use_spec_decode: # Speculative decoding is not enabled. spec_token_ids = None elif self.speculative_config.method == "ngram": - assert isinstance(self.drafter, NgramProposer) - spec_token_ids = self._generate_draft_token_ids( - valid_sampled_token_ids, sampling_metadata) + spec_token_ids = self._generate_ngram_token_ids( + valid_sampled_token_ids) elif self.speculative_config.method == "eagle": raise NotImplementedError("Eagle Is Not Supported Yet.") elif self.speculative_config.method == "eagle3": - assert isinstance(self.drafter, EagleProposer) - if self.speculative_config.use_eagle(): - next_token_ids: list[int] = [] - for i, token_ids in enumerate(valid_sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = self.input_batch.req_ids[i] - req_state = self.requests[req_id] - seq_len = ( - req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) - eagle_attn_metadata = attn_metadata[ - self.drafter.attn_layer_name] - num_input_tokens = scheduler_output.total_num_scheduled_tokens - if spec_decode_metadata is None: - # input_ids can be None for multimodal models. - target_token_ids = self.input_ids[:num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] - if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat([ - h[:num_scheduled_tokens] for h in aux_hidden_states - ], - dim=-1) - else: - target_hidden_states = hidden_states[: - num_scheduled_tokens] - target_slot_mapping = eagle_attn_metadata.slot_mapping - cu_num_tokens = eagle_attn_metadata.query_start_loc - else: - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens = torch.tensor( - num_rejected_tokens, - dtype=torch.int32, - device=self.device, - ) - num_tokens = num_scheduled_tokens - sum( - num_rejected_tokens) - cu_num_tokens, token_indices = self.drafter.prepare_inputs( - eagle_attn_metadata.query_start_loc, - num_rejected_tokens, num_tokens) - target_token_ids = self.input_ids[token_indices] - target_positions = positions[token_indices] - if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat( - [h[token_indices] for h in aux_hidden_states], - dim=-1) - else: - target_hidden_states = hidden_states[token_indices] - target_slot_mapping = eagle_attn_metadata.slot_mapping[ - token_indices] - - positions = self.positions[:num_input_tokens] - draft_token_ids = self.drafter.propose( - target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - target_slot_mapping=target_slot_mapping, - next_token_ids=next_token_ids, - cu_num_tokens=cu_num_tokens, - block_table=eagle_attn_metadata.block_tables, - sampling_metadata=sampling_metadata, - ) - spec_token_ids = draft_token_ids.tolist() + spec_token_ids = self._generate_eagle3_token_ids( + valid_sampled_token_ids, sampling_metadata, scheduler_output, + spec_decode_metadata, positions, num_scheduled_tokens, + hidden_states, aux_hidden_states) elif self.speculative_config.method == 'deepseek_mtp': - assert isinstance(self.drafter, MtpProposer) spec_token_ids = self._generate_mtp_token_ids( valid_sampled_token_ids, sampling_metadata, scheduler_output, spec_decode_metadata, positions, num_scheduled_tokens, @@ -1572,14 +1491,11 @@ def execute_model( scheduler_output, intermediate_tensors)) with ProfileExecuteDuration().capture_async("post process"): - if self.input_batch.pooling_params: return self._pool(hidden_states, num_scheduled_tokens, num_scheduled_tokens_np) logits = self.model.compute_logits(hidden_states[sample_indices], None) - if self.use_eagle: - attn_metadata = self.get_eagle_atten_dict(scheduler_output) # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: logits = self.apply_grammar_bitmask(scheduler_output, logits) @@ -1729,96 +1645,12 @@ def execute_model( return model_runner_output - def _profile_multimodal(self) -> None: - # TODO: handle encoder-decoder models once we support them. - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - - if (not self.is_multimodal_model - or self.max_num_encoder_input_tokens <= 0 - or self.encoder_cache_size <= 0): - return - - max_tokens_by_modality_dict = ( - MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( - self.model_config)) - dummy_data_modality, max_tokens_per_mm_item = max( - max_tokens_by_modality_dict.items(), key=lambda item: item[1]) - - # Check how many items of this modality can be supported by - # the encoder budget. - encoder_budget = min(self.max_num_encoder_input_tokens, - self.encoder_cache_size) - - max_num_mm_items_encoder_budget = cdiv(encoder_budget, - max_tokens_per_mm_item) - - # Check how many items of this modality can be supported by - # the decoder budget. - max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt( - self.model_config)[dummy_data_modality] - - # NOTE: We do not consider max_num_batched_tokens on purpose - # because the multimodal embeddings can be generated in advance - # and chunked prefilled. - max_num_mm_items_decoder_budget = self.max_num_reqs * \ - max_mm_items_per_req - - max_num_mm_items = min(max_num_mm_items_encoder_budget, - max_num_mm_items_decoder_budget) - - logger.info( - "Encoder cache will be initialized with a budget of %s tokens," - " and profiled with %s %s items of the maximum feature size.", - encoder_budget, max_num_mm_items, dummy_data_modality) - - # Create dummy batch of multimodal inputs. - dummy_request_data = self.input_registry.dummy_data_for_profiling( - model_config=self.model_config, - seq_len=self.max_num_tokens, - mm_registry=self.mm_registry, - ) - dummy_mm_data = dummy_request_data.multi_modal_data - - if not isinstance(dummy_mm_data, MultiModalKwargs): - # TODO: Delete this check once input mapper is fully removed. - raise RuntimeError("Legacy input mapper is not supported in V1") - - # Dummy data definition in V0 may contain multiple multimodal items - # (e.g, multiple images) for a single request, therefore here we - # always replicate first item by max_num_mm_items times since in V1 - # they are scheduled to be processed separately. - - dummy_mm_item = dummy_mm_data.get_item(modality=dummy_data_modality, - item_index=0) - dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) - - batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * - max_num_mm_items) - batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( - batched_dummy_mm_inputs, device=self.device) - - # Run multimodal encoder. - dummy_encoder_outputs = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) - assert len(dummy_encoder_outputs) == max_num_mm_items, ( - "Expected dimension 0 of encoder outputs to match the number " - f"of multimodal data items: {max_num_mm_items}, got " - f"{len(dummy_encoder_outputs)=} instead. This is most likely " - "due to the 'get_multimodal_embeddings' method of the model " - "not implemented correctly.") - - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) - @torch.inference_mode() def _dummy_run( self, num_tokens: int, is_compile: bool = False, with_prefill: bool = True, - skip_attn: bool = True, ) -> torch.Tensor: # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -1828,19 +1660,8 @@ def _dummy_run( min_tokens_per_req = num_tokens // num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs - assert sum(num_scheduled_tokens_list) == num_tokens - assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) - if skip_attn: - attn_metadata = None - else: - attn_metadata = self.attn_metadata_builder.build( - num_reqs=num_tokens, - num_actual_tokens=num_tokens, - max_query_len=num_tokens, - common_prefix_len=0, - ) with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): @@ -1918,48 +1739,32 @@ def _dummy_run( hidden_states, _ = hidden_states else: hidden_states = hidden_states - if self.use_spec_decode and \ - self.speculative_config.method in ('eagle', 'eagle3'): - assert isinstance(self.drafter, EagleProposer) + if self.use_spec_decode and isinstance( + self.drafter, EagleProposer): self.drafter.dummy_run(num_tokens) return hidden_states def profile_run(self) -> None: - # FIXME Profile with multimodal encoder & encoder cache. - # current _profile_multimodal() using PyTorch SDPA backend method not - # support for window/full attn to reduce Memcpy operations, so will cause - # Out Of Memory problem, so we currently don't use self._profile_multimodal() - # self._profile_multimodal() - - # For profile, have maximum num_reqs and that collectively have - # maximum num_tokens. - min_tokens_per_req = self.max_num_tokens // self.max_num_reqs - - num_scheduled_tokens_list = [min_tokens_per_req] * self.max_num_reqs - num_scheduled_tokens_list[ - -1] += self.max_num_tokens % self.max_num_reqs - assert sum(num_scheduled_tokens_list) == self.max_num_tokens - assert len(num_scheduled_tokens_list) == self.max_num_reqs - - num_scheduled_tokens = np.array(num_scheduled_tokens_list, - dtype=np.int32) - logit_indices = np.cumsum(num_scheduled_tokens) - 1 - - # assert self.lora_manager is not None, "LoRA is not enabled" - # TODO: call maybe_profile_with_lora() - # Trigger compilation for general shape. hidden_states = self._dummy_run(self.max_num_tokens) - + output = None if get_pp_group().is_last_rank: if self.is_pooling_model: output = self._dummy_pooler_run(hidden_states) else: + # For profile, have maximum num_reqs and that collectively have + # maximum num_tokens. + min_tokens_per_req = self.max_num_tokens // self.max_num_reqs + num_scheduled_tokens_list = [min_tokens_per_req + ] * self.max_num_reqs + num_scheduled_tokens_list[ + -1] += self.max_num_tokens % self.max_num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, + dtype=np.int32) + logit_indices = np.cumsum(num_scheduled_tokens) - 1 # TODO: need to rum a dummy sampler for generate task hidden_states = hidden_states[logit_indices] output = self.model.compute_logits(hidden_states, None) - else: - output = None NPUPlatform.synchronize() del hidden_states, output @@ -1978,8 +1783,6 @@ def _dummy_pooler_run( min_tokens_per_req = num_tokens // num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs - assert sum(num_scheduled_tokens_list) == num_tokens - assert len(num_scheduled_tokens_list) == num_reqs hidden_states_list = list( torch.split(hidden_states, num_scheduled_tokens_list)) @@ -2035,7 +1838,7 @@ def load_model(self) -> None: pass if self.drafter: logger.info("Loading drafter model...") - if self.use_aux_hidden_state_outputs: + if isinstance(self.drafter, EagleProposer): self.drafter.load_model(self.model) self.model.set_aux_hidden_state_layers( self.model.get_eagle3_aux_hidden_state_layers()) @@ -2346,10 +2149,9 @@ def capture_model(self) -> None: logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, npu_graph_size / (1 << 30)) - def _generate_draft_token_ids( + def _generate_ngram_token_ids( self, sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata, ) -> list[list[int]]: # TODO(woosuk): Optimize. draft_token_ids: list[list[int]] = [] @@ -2375,7 +2177,7 @@ def _generate_draft_token_ids( start_idx = self.input_batch.num_tokens_no_spec[i] end_idx = start_idx + num_sampled_ids self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids - assert self.drafter is not None + assert isinstance(self.drafter, NgramProposer) drafter_output = self.drafter.propose( self.input_batch.token_ids_cpu[i, :end_idx]) if drafter_output is None or len(drafter_output) == 0: @@ -2384,6 +2186,86 @@ def _generate_draft_token_ids( draft_token_ids.append(drafter_output.tolist()) return draft_token_ids + def _generate_eagle3_token_ids(self, + valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata, + scheduler_output: "SchedulerOutput", + spec_decode_metadata: SpecDecodeMetadata, + positions: torch.Tensor, + num_scheduled_tokens: int, + hidden_states: torch.Tensor, + aux_hidden_states: torch.Tensor = None): + assert isinstance(self.drafter, EagleProposer) + attn_metadata = self.get_eagle_atten_dict(scheduler_output) + next_token_ids: list[int] = [] + for i, token_ids in enumerate(valid_sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = self.input_batch.req_ids[i] + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.device) + eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] + if spec_decode_metadata is None: + # input_ids can be None for multimodal models. + target_token_ids = self.input_ids[:num_scheduled_tokens] + target_positions = positions[:num_scheduled_tokens] + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat( + [h[:num_scheduled_tokens] for h in aux_hidden_states], + dim=-1) + else: + target_hidden_states = hidden_states[:num_scheduled_tokens] + target_slot_mapping = eagle_attn_metadata.slot_mapping + cu_num_tokens = eagle_attn_metadata.query_start_loc + else: + num_draft_tokens = spec_decode_metadata.num_draft_tokens + num_rejected_tokens = [ + n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor( + num_rejected_tokens, + dtype=torch.int32, + device=self.device, + ) + num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) + cu_num_tokens, token_indices = self.drafter.prepare_inputs( + eagle_attn_metadata.query_start_loc, num_rejected_tokens, + num_tokens) + target_token_ids = self.input_ids[token_indices] + target_positions = positions[token_indices] + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat( + [h[token_indices] for h in aux_hidden_states], dim=-1) + else: + target_hidden_states = hidden_states[token_indices] + target_slot_mapping = eagle_attn_metadata.slot_mapping[ + token_indices] + + draft_token_ids = self.drafter.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + target_slot_mapping=target_slot_mapping, + next_token_ids=next_token_ids, + cu_num_tokens=cu_num_tokens, + block_table=eagle_attn_metadata.block_tables, + sampling_metadata=sampling_metadata, + ) + spec_token_ids = draft_token_ids.tolist() + return spec_token_ids + def _generate_mtp_token_ids( self, valid_sampled_token_ids: list[list[int]], @@ -2393,8 +2275,10 @@ def _generate_mtp_token_ids( positions: torch.Tensor, num_scheduled_tokens: int, hidden_states: torch.Tensor, - attn_metadata: SpecDecodeMetadata, + attn_metadata: Union[AscendMetadata, AscendMLAMetadata, + AscendTorchairMetadata], ): + assert isinstance(self.drafter, MtpProposer) next_token_ids: list[int] = [] for i, token_ids in enumerate(valid_sampled_token_ids): if token_ids: @@ -2432,7 +2316,6 @@ def _generate_mtp_token_ids( dtype=torch.int32, device=self.device, ) - assert self.drafter is not None cu_num_tokens, token_indices = self.drafter.prepare_inputs( attn_metadata.query_start_loc, num_rejected_tokens, @@ -2441,7 +2324,7 @@ def _generate_mtp_token_ids( target_positions = positions[token_indices] target_hidden_states = hidden_states[token_indices] target_slot_mapping = attn_metadata.slot_mapping[token_indices] - assert self.drafter is not None + draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index ba8406fa0a..5b88e7e2c1 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -12,43 +12,6 @@ from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP -# FIXME(woosuk): The logic here is duplicated with the main sampling code. -# We should refactor this to reuse the same sampling implementation. -def compute_probs_and_sample_next_token( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> tuple[torch.Tensor, torch.Tensor]: - if sampling_metadata.all_greedy: - # For greedy requests, draft_probs is not used in rejection sampling. - # Therefore, we can just return the logits. - probs = logits - next_token_ids = logits.argmax(dim=-1) - return next_token_ids, probs - - is_greedy = sampling_metadata.temperature == -1 - temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) - logits.div_(temperature.view(-1, 1)) - probs = logits.softmax(dim=-1, dtype=torch.float32) - - # NOTE(woosuk): Currently, we ignore most of the sampling parameters in - # generating the draft tokens. We only use the temperature. While this - # could degrade the acceptance rate, it does not affect the distribution - # of the generated tokens after rejection sampling. - - # TODO(woosuk): Consider seeds. - q = torch.empty_like(probs) - q.exponential_() - next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) - if not sampling_metadata.all_random: - greedy_token_ids = probs.argmax(dim=-1) - next_token_ids = torch.where( - is_greedy, - greedy_token_ids, - next_token_ids, - ) - return next_token_ids, probs - - class MtpProposer: def __init__( @@ -121,7 +84,7 @@ def propose( # [batch_size, max_num_blocks_per_req] block_table: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] last_token_indices = cu_num_tokens[1:] - 1 From 00cd6542a074d4d1023f169eba1933948633762b Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Tue, 8 Jul 2025 14:41:38 +0800 Subject: [PATCH 2/2] Clean up v0.9.1 code Signed-off-by: wangxiyuan --- .../developer_guide/feature_guide/patch.md | 12 +- .../test_offline_inference_distributed.py | 22 -- .../ascend_scheduler/test_ascend_scheduler.py | 82 ++--- .../sample/test_rejection_sampler.py | 59 ++-- tests/e2e/singlecard/test_embedding.py | 4 - .../e2e/singlecard/test_offline_inference.py | 24 -- tests/e2e/singlecard/test_sampler.py | 152 ---------- tests/e2e/singlecard/test_scheduler.py | 26 +- .../worker/patch_common/test_patch_sampler.py | 31 -- vllm_ascend/core/scheduler.py | 34 +-- vllm_ascend/envs.py | 4 - vllm_ascend/models/deepseek_dbo.py | 21 +- vllm_ascend/models/deepseek_v2.py | 21 +- vllm_ascend/ops/fused_moe.py | 45 +-- vllm_ascend/patch/__init__.py | 32 +- vllm_ascend/patch/platform/__init__.py | 4 +- .../{patch_0_9_1 => patch_0_9_2}/__init__.py | 0 vllm_ascend/patch/worker/__init__.py | 4 +- .../patch/worker/patch_0_9_1/patch_sampler.py | 106 ------- .../{patch_0_9_1 => patch_0_9_2}/__init__.py | 1 - vllm_ascend/worker/model_runner_v1.py | 283 ++++++------------ vllm_ascend/worker/npu_input_batch.py | 88 ++---- 22 files changed, 208 insertions(+), 847 deletions(-) delete mode 100644 tests/e2e/singlecard/test_sampler.py delete mode 100644 tests/ut/patch/worker/patch_common/test_patch_sampler.py rename vllm_ascend/patch/platform/{patch_0_9_1 => patch_0_9_2}/__init__.py (100%) delete mode 100644 vllm_ascend/patch/worker/patch_0_9_1/patch_sampler.py rename vllm_ascend/patch/worker/{patch_0_9_1 => patch_0_9_2}/__init__.py (91%) diff --git a/docs/source/developer_guide/feature_guide/patch.md b/docs/source/developer_guide/feature_guide/patch.md index 7a422b836c..df4f40357f 100644 --- a/docs/source/developer_guide/feature_guide/patch.md +++ b/docs/source/developer_guide/feature_guide/patch.md @@ -20,11 +20,11 @@ In `vllm_ascend/patch`, you can see the code structure as follows: vllm_ascend ├── patch │ ├── platform -│ │ ├── patch_0_9_1 +│ │ ├── patch_0_9_2 │ │ ├── patch_common │ │ ├── patch_main │ ├── worker -│ │ ├── patch_0_9_1 +│ │ ├── patch_0_9_2 │ │ ├── patch_common │ │ ├── patch_main └─────────── @@ -38,15 +38,15 @@ vllm_ascend In both **platform** and **worker** folder, there are several patch modules. They are used for patching different version of vLLM. -- `patch_0_9_1`: This module is used for patching vLLM 0.9.1. The version is always the nearest version of vLLM. Once vLLM is released, we will drop this patch module and bump to a new version. For example, `patch_0_9_2` is used for patching vLLM 0.9.2. +- `patch_0_9_2`: This module is used for patching vLLM 0.9.2. The version is always the nearest version of vLLM. Once vLLM is released, we will drop this patch module and bump to a new version. For example, `patch_0_9_2` is used for patching vLLM 0.9.2. - `patch_main`: This module is used for patching the code in vLLM main branch. -- `patch_common`: This module is used for patching both vLLM 0.9.1 and vLLM main branch. +- `patch_common`: This module is used for patching both vLLM 0.9.2 and vLLM main branch. ## How to write a patch Before writing a patch, following the principle above, we should patch the least code. If it's necessary, we can patch the code in either **platform** and **worker** folder. Here is an example to patch `distributed` module in vLLM. -1. Decide which version of vLLM we should patch. For example, after analysis, here we want to patch both 0.9.1 and main of vLLM. +1. Decide which version of vLLM we should patch. For example, after analysis, here we want to patch both 0.9.2 and main of vLLM. 2. Decide which process we should patch. For example, here `distributed` belongs to the vLLM main process, so we should patch `platform`. 3. Create the patch file in the right folder. The file should be named as `patch_{module_name}.py`. The example here is `vllm_ascend/patch/platform/patch_common/patch_distributed.py`. 4. Write your patch code in the new file. Here is an example: @@ -79,4 +79,4 @@ Before writing a patch, following the principle above, we should patch the least ## Limitation 1. In V1 Engine, vLLM starts three kinds of process: Main process, EngineCore process and Worker process. Now vLLM Ascend only support patch the code in Main process and Worker process by default. If you want to patch the code runs in EngineCore process, you should patch EngineCore process entirely during setup, the entry code is here `vllm.v1.engine.core`. Please override `EngineCoreProc` and `DPEngineCoreProc` entirely. -2. If you are running an edited vLLM code, the version of the vLLM may be changed automatically. For example, if you runs an edited vLLM based on v0.9.1, the version of vLLM may be change to v0.9.2xxx, in this case, the patch for v0.9.1 in vLLM Ascend would not work as expect, because that vLLM Ascend can't distinguish the version of vLLM you're using. In this case, you can set the environment variable `VLLM_VERSION` to specify the version of vLLM you're using, then the patch for v0.9.1 should work. +2. If you are running an edited vLLM code, the version of the vLLM may be changed automatically. For example, if you runs an edited vLLM based on v0.9.2, the version of vLLM may be change to v0.9.2xxx, in this case, the patch for v0.9.2 in vLLM Ascend would not work as expect, because that vLLM Ascend can't distinguish the version of vLLM you're using. In this case, you can set the environment variable `VLLM_VERSION` to specify the version of vLLM you're using, then the patch for v0.9.2 should work. diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index 47ff47eddd..6ccdbd7fdb 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -73,28 +73,6 @@ def test_models_distributed_DeepSeek_multistream_moe(): vllm_model.generate_greedy(example_prompts, max_tokens) -@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": "1"}) -def test_models_distributed_topk() -> None: - example_prompts = [ - "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", - "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.", - "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", - ] - dtype = "half" - sampling_params = SamplingParams(max_tokens=5, - temperature=0.0, - top_k=50, - top_p=0.9) - - with VllmRunner( - "deepseek-ai/DeepSeek-V2-Lite", - dtype=dtype, - tensor_parallel_size=4, - distributed_executor_backend="mp", - ) as vllm_model: - vllm_model.generate(example_prompts, sampling_params) - - @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"}) def test_models_distributed_DeepSeek_dbo(): example_prompts = ["The president of the United States is"] * 41 diff --git a/tests/e2e/singlecard/core/ascend_scheduler/test_ascend_scheduler.py b/tests/e2e/singlecard/core/ascend_scheduler/test_ascend_scheduler.py index e1fd16bda9..5669c3fae5 100644 --- a/tests/e2e/singlecard/core/ascend_scheduler/test_ascend_scheduler.py +++ b/tests/e2e/singlecard/core/ascend_scheduler/test_ascend_scheduler.py @@ -16,7 +16,6 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm_ascend.core.scheduler import AscendScheduler -from vllm_ascend.utils import vllm_version_is EOS_TOKEN_ID = 50256 @@ -140,9 +139,7 @@ def create_requests(num_requests: int, multi_modal_placeholders=mm_position, multi_modal_hashes=None, eos_token_id=EOS_TOKEN_ID, - **({ - "pooling_params": None - } if not vllm_version_is("0.9.1") else {}), + pooling_params={}, ) requests.append(request) return requests @@ -201,10 +198,7 @@ def test_schedule(enable_prefix_caching: Optional[bool], # Test initial scheduling output = scheduler.schedule() assert len(output.scheduled_new_reqs) == len(requests) - if vllm_version_is("0.9.1"): - assert len(output.scheduled_cached_reqs) == 0 - else: - assert output.scheduled_cached_reqs.num_reqs == 0 + assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 # Verify all requests are scheduled. for req_id, num_tokens in output.num_scheduled_tokens.items(): @@ -241,10 +235,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 3 - if vllm_version_is("0.9.1"): - assert len(output.scheduled_cached_reqs) == 0 - else: - assert output.scheduled_cached_reqs.num_reqs == 0 + assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 # The first request is scheduled partially - 400. @@ -264,9 +255,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + pooler_output=[]) scheduler.update_from_output(output, model_runner_output) # Schedule the next step. All three requests are running. @@ -274,10 +263,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): output1 = scheduler.schedule() assert len(scheduler.running) == 3 assert len(output1.scheduled_new_reqs) == 0 - if vllm_version_is("0.9.1"): - assert len(output1.scheduled_cached_reqs) == 3 - else: - assert output1.scheduled_cached_reqs.num_reqs == 3 + assert output1.scheduled_cached_reqs.num_reqs == 3 assert len(output1.finished_req_ids) == 0 assert output1.num_scheduled_tokens[requests[0].request_id] == 400 assert output1.num_scheduled_tokens[requests[1].request_id] == 400 @@ -293,18 +279,13 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + pooler_output=[]) scheduler.update_from_output(output1, model_runner_output) output2 = scheduler.schedule() assert len(scheduler.running) == 3 assert len(output2.scheduled_new_reqs) == 0 - if vllm_version_is("0.9.1"): - assert len(output2.scheduled_cached_reqs) == 3 - else: - assert output2.scheduled_cached_reqs.num_reqs == 3 + assert output2.scheduled_cached_reqs.num_reqs == 3 assert len(output2.finished_req_ids) == 0 assert output2.num_scheduled_tokens[requests[0].request_id] == 1 assert output2.num_scheduled_tokens[requests[1].request_id] == 1 @@ -351,9 +332,7 @@ def test_stop_via_update_from_output(): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) @@ -402,9 +381,7 @@ def test_stop_via_update_from_output(): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) @@ -452,9 +429,7 @@ def test_stop_via_update_from_output(): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) @@ -497,9 +472,7 @@ def test_stop_via_update_from_output(): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) @@ -549,9 +522,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + pooler_output=[]) scheduler.update_from_output(scheduler_output0, model_runner_output) @@ -569,9 +540,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + pooler_output=[]) scheduler.update_from_output(scheduler_output1, model_runner_output) @@ -622,9 +591,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): spec_token_ids=spec_tokens, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + pooler_output=[]) engine_core_outputs = scheduler.update_from_output(output, model_runner_output) @@ -657,16 +624,13 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): else: assert req_id not in output.scheduled_spec_decode_tokens - model_runner_output = ModelRunnerOutput( - req_ids=req_ids, - req_id_to_index=req_to_index, - sampled_token_ids=output_tokens, - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + model_runner_output = ModelRunnerOutput(req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=output_tokens, + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) engine_core_outputs = scheduler.update_from_output(output, model_runner_output) @@ -695,9 +659,7 @@ def make_output(scheduler: AscendScheduler): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + pooler_output=[]) def assert_scheduler_empty(scheduler: AscendScheduler): diff --git a/tests/e2e/singlecard/sample/test_rejection_sampler.py b/tests/e2e/singlecard/sample/test_rejection_sampler.py index 3b48864cea..1b92aca19c 100644 --- a/tests/e2e/singlecard/sample/test_rejection_sampler.py +++ b/tests/e2e/singlecard/sample/test_rejection_sampler.py @@ -4,12 +4,12 @@ import pytest import torch import torch.nn.functional as F +from vllm.v1.sample.logits_processor import LogitsProcessorManager from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm_ascend.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, AscendRejectionSampler) -from vllm_ascend.utils import vllm_version_is DEVICE = "npu" @@ -50,46 +50,23 @@ def create_sampling_metadata( temperature = None else: assert temperature is not None - if vllm_version_is("0.9.1"): - return SamplingMetadata( - temperature=temperature, - all_greedy=all_greedy, - all_random=not all_greedy, - top_p=top_p, - top_k=top_k, - min_p=torch.empty(1, ), - generators=generators, - max_num_logprobs=0, - no_penalties=False, - prompt_token_ids=None, - frequency_penalties=torch.tensor([]), - presence_penalties=torch.tensor([]), - repetition_penalties=torch.tensor([]), - output_token_ids=[], - min_tokens={}, - logit_bias=[None], - allowed_token_ids_mask=None, - bad_words_token_ids={}, - ) - else: - from vllm.v1.sample.logits_processor import LogitsProcessorManager - - return SamplingMetadata(temperature=temperature, - all_greedy=all_greedy, - all_random=not all_greedy, - top_p=top_p, - top_k=top_k, - generators=generators, - max_num_logprobs=0, - no_penalties=False, - prompt_token_ids=None, - frequency_penalties=torch.tensor([]), - presence_penalties=torch.tensor([]), - repetition_penalties=torch.tensor([]), - output_token_ids=[], - allowed_token_ids_mask=None, - bad_words_token_ids={}, - logitsprocs=LogitsProcessorManager()) + + return SamplingMetadata(temperature=temperature, + all_greedy=all_greedy, + all_random=not all_greedy, + top_p=top_p, + top_k=top_k, + generators=generators, + max_num_logprobs=0, + no_penalties=False, + prompt_token_ids=None, + frequency_penalties=torch.tensor([]), + presence_penalties=torch.tensor([]), + repetition_penalties=torch.tensor([]), + output_token_ids=[], + allowed_token_ids_mask=None, + bad_words_token_ids={}, + logitsprocs=LogitsProcessorManager()) ########################### Tests for Greedy Sampling ################### diff --git a/tests/e2e/singlecard/test_embedding.py b/tests/e2e/singlecard/test_embedding.py index 0ca07a017e..938f7cc3a6 100644 --- a/tests/e2e/singlecard/test_embedding.py +++ b/tests/e2e/singlecard/test_embedding.py @@ -19,12 +19,10 @@ from collections.abc import Sequence from typing import Optional -import pytest from modelscope import snapshot_download # type: ignore[import-untyped] from tests.conftest import HfRunner from tests.utils import check_embeddings_close, matryoshka_fy -from vllm_ascend.utils import vllm_version_is def run_embedding_correctness_test( @@ -51,8 +49,6 @@ def test_dummy(): assert True -@pytest.mark.skipif(vllm_version_is("0.9.1"), - reason="vLLM 0.9.1 does not support embed task for v1") def test_embed_models_correctness(hf_runner, vllm_runner): queries = ['What is the capital of China?', 'Explain gravity'] diff --git a/tests/e2e/singlecard/test_offline_inference.py b/tests/e2e/singlecard/test_offline_inference.py index de69612279..f6f9b04728 100644 --- a/tests/e2e/singlecard/test_offline_inference.py +++ b/tests/e2e/singlecard/test_offline_inference.py @@ -21,12 +21,9 @@ Run `pytest tests/test_offline_inference.py`. """ import os -from unittest.mock import patch import pytest -import vllm # noqa: F401 from modelscope import snapshot_download # type: ignore[import-untyped] -from vllm import SamplingParams from vllm.assets.image import ImageAsset import vllm_ascend # noqa: F401 @@ -106,24 +103,3 @@ def test_multimodal(model, prompt_template, vllm_runner): vllm_model.generate_greedy(prompts=prompts, images=images, max_tokens=64) - - -@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": "1"}) -def test_models_topk() -> None: - example_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - sampling_params = SamplingParams(max_tokens=5, - temperature=0.0, - top_k=50, - top_p=0.9) - - with VllmRunner("Qwen/Qwen2.5-0.5B-Instruct", - max_model_len=8192, - dtype="float16", - enforce_eager=True, - gpu_memory_utilization=0.7) as vllm_model: - vllm_model.generate(example_prompts, sampling_params) diff --git a/tests/e2e/singlecard/test_sampler.py b/tests/e2e/singlecard/test_sampler.py deleted file mode 100644 index d9584daeec..0000000000 --- a/tests/e2e/singlecard/test_sampler.py +++ /dev/null @@ -1,152 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# Adapted from vllm/tests/entrypoints/llm/test_guided_generate.py -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from typing import Optional - -import pytest -import torch -from vllm.v1.sample.sampler import Sampler # noqa: F401 - -from vllm_ascend.utils import vllm_version_is - -# Set tolerance to 1 for quant ops -DEFAULT_ATOL = 1e-3 -DEFAULT_RTOL = 1e-3 - - -def apply_min_p_new( - logits: torch.Tensor, - min_p: torch.Tensor, -) -> torch.Tensor: - """ - Filters logits using adaptive probability thresholding. - """ - if min_p == 0: - return logits - # Convert logits to probability distribution - probability_values = torch.nn.functional.softmax(logits, dim=-1) - # Calculate maximum probabilities per sequence - max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) - # Reshape min_p for broadcasting - adjusted_min_p = min_p.unsqueeze(1) * max_probabilities - # Identify valid tokens using threshold comparison - # Apply mask using boolean indexing - logits = logits.masked_fill(probability_values < adjusted_min_p, - -float('inf')) - return logits - - -def apply_top_k_top_p( - logits: torch.Tensor, - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], -) -> torch.Tensor: - """Apply top-k and top-p masks to the logits. - - If a top-p is used, this function will sort the logits tensor, - which can be slow for large batches. - - The logits tensor may be updated in-place. - """ - logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - - if k is not None: - # Apply top-k. - top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B - # Get all the top_k values. - top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) - top_k_mask = logits_sort < top_k_mask - logits_sort.masked_fill_(top_k_mask, -float("inf")) - - if p is not None: - # Apply top-p. - probs_sort = logits_sort.softmax(dim=-1) - probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) - top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) - # at least one - top_p_mask[:, -1] = False - logits_sort.masked_fill_(top_p_mask, -float("inf")) - - # Re-sort the probabilities. - logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) - return logits - - -def apply_top_k_top_p_new( - logits: torch.Tensor, - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], -) -> torch.Tensor: - batch_size, vocab_size = logits.shape - logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - - # Apply top-k. - boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1)) - top_k_mask = logits_sort < boundary - logits_sort.masked_fill_(top_k_mask, -float("inf")) - - if p is not None: - # Apply top-p. - cutoff = top_k_mask.sum(dim=-1).min() - probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:] - probs_sum = probs_sort.cumsum(dim=-1) - top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1) - top_p_mask[:, -1] = True - strides = torch.arange(0, - batch_size * vocab_size, - vocab_size, - device=logits.device) - flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1) - valid_idx = torch.masked_select(flatten_idx, top_p_mask) - logits_flatten = logits.flatten() - valid_logits = torch.index_select(logits_flatten, 0, valid_idx) - logits = torch.empty_like(logits_flatten).fill_(-float("inf")) - logits[valid_idx] = valid_logits - return logits.reshape(batch_size, vocab_size) - - -# test with leading dimension and merge seqlen and batch_size as num_tokens -@pytest.mark.skipif(not vllm_version_is("0.9.1"), - reason="apply_min_p has been removed after vllm 0.9.1") -@torch.inference_mode() -def test_apply_min_p() -> None: - logits = torch.randn((128, 7168)).npu() - min_p = torch.Tensor([0.01]).npu() - logits_new = apply_min_p_new(logits, min_p) - sampler = Sampler() - logits_old = sampler.apply_min_p(logits, min_p) - # Compare the results. - torch.testing.assert_close(logits_new, - logits_old, - atol=DEFAULT_ATOL, - rtol=DEFAULT_RTOL) - - -# test with leading dimension and merge seqlen and batch_size as num_tokens -@torch.inference_mode() -def test_apply_top_k_top_p() -> None: - logits = torch.randn((128, 7168)).npu() - k = torch.Tensor([-1]).int().npu() - p = torch.Tensor([1]).int().npu() - logits_new = apply_top_k_top_p_new(logits, k, p) - logits_old = apply_top_k_top_p(logits, k, p) - # Compare the results. - torch.testing.assert_close(logits_new, - logits_old, - atol=DEFAULT_ATOL, - rtol=DEFAULT_RTOL) diff --git a/tests/e2e/singlecard/test_scheduler.py b/tests/e2e/singlecard/test_scheduler.py index fba344afb4..c4c9b7f554 100644 --- a/tests/e2e/singlecard/test_scheduler.py +++ b/tests/e2e/singlecard/test_scheduler.py @@ -31,7 +31,6 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm_ascend.core.scheduler import AscendScheduler -from vllm_ascend.utils import vllm_version_is EOS_TOKEN_ID = 50256 @@ -131,9 +130,7 @@ def create_requests(num_requests: int, multi_modal_placeholders=mm_position, multi_modal_hashes=None, eos_token_id=EOS_TOKEN_ID, - **({ - "pooling_params": None - } if not vllm_version_is("0.9.1") else {}), + pooling_params=None, ) requests.append(request) return requests @@ -192,10 +189,7 @@ def test_schedule(enable_prefix_caching: Optional[bool], # Test initial scheduling output = scheduler.schedule() assert len(output.scheduled_new_reqs) == len(requests) - if vllm_version_is("0.9.1"): - assert len(output.scheduled_cached_reqs) == 0 - else: - assert output.scheduled_cached_reqs.num_reqs == 0 + assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 # Verify all requests are scheduled. for req_id, num_tokens in output.num_scheduled_tokens.items(): @@ -245,9 +239,7 @@ def test_stop_via_update_from_output(): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) @@ -294,9 +286,7 @@ def test_stop_via_update_from_output(): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) @@ -342,9 +332,7 @@ def test_stop_via_update_from_output(): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) @@ -386,9 +374,7 @@ def test_stop_via_update_from_output(): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) diff --git a/tests/ut/patch/worker/patch_common/test_patch_sampler.py b/tests/ut/patch/worker/patch_common/test_patch_sampler.py deleted file mode 100644 index b87175d662..0000000000 --- a/tests/ut/patch/worker/patch_common/test_patch_sampler.py +++ /dev/null @@ -1,31 +0,0 @@ -import importlib -import os -from unittest import mock - -import torch -from vllm.v1.sample.ops import topk_topp_sampler - -from tests.ut.base import TestBase - - -class TestTopKTopPSamplerOptimize(TestBase): - - @mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": "1"}) - @mock.patch("torch_npu.npu_top_k_top_p") - def test_npu_topk_topp_called_when_optimized(self, mock_npu_op): - # We have to patch and reload because the patch will take effect - # only after VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE is set. - import vllm_ascend.patch.worker.patch_0_9_1.patch_sampler - importlib.reload(vllm_ascend.patch.worker.patch_0_9_1.patch_sampler) - - mock_npu_op.return_value = (torch.randn(1, 3)) - sampler = topk_topp_sampler.TopKTopPSampler() - - logits = torch.tensor([[1.0, 2.0, 3.0]]) - k = torch.tensor([2]) - p = torch.tensor([0.9]) - generators = {0: torch.Generator()} - generators[0].manual_seed(42) - - sampler.forward_native(logits, generators, k, p) - mock_npu_op.assert_called_once_with(logits, p, k) diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 76021a2173..00a17ddfc3 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -32,8 +32,6 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager -from vllm_ascend.utils import vllm_version_is - class AscendScheduler(Scheduler): """This Scheduler extends vllm's original v1 scheduler @@ -366,32 +364,12 @@ def skip_cur_request(): req_to_new_block_ids[req.request_id]) for req in scheduled_new_reqs ] - if vllm_version_is("0.9.1"): - resumed_reqs_data = [ - self._make_cached_request_data( - req, - num_scheduled_tokens[req.request_id], - len(scheduled_spec_decode_tokens.get(req.request_id, ())), - req_to_new_block_ids[req.request_id], - resumed_from_preemption=True, - ) for req in scheduled_resumed_reqs - ] - running_reqs_data = [ - self._make_cached_request_data( - req, - num_scheduled_tokens[req.request_id], - len(scheduled_spec_decode_tokens.get(req.request_id, ())), - req_to_new_block_ids[req.request_id], - resumed_from_preemption=False, - ) for req in scheduled_running_reqs - ] - scheduled_cached_reqs = resumed_reqs_data + running_reqs_data - else: - cached_reqs_data = self._make_cached_request_data( - scheduled_running_reqs, scheduled_resumed_reqs, - num_scheduled_tokens, scheduled_spec_decode_tokens, - req_to_new_block_ids) - scheduled_cached_reqs = cached_reqs_data + + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, scheduled_resumed_reqs, + num_scheduled_tokens, scheduled_spec_decode_tokens, + req_to_new_block_ids) + scheduled_cached_reqs = cached_reqs_data scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 7bded5747a..ea43868606 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -50,10 +50,6 @@ # value is None, which means the system default C compiler will be used. "C_COMPILER": lambda: os.getenv("C_COMPILER", None), - # Whether to enable the topk optimization. It's disabled by default for experimental support - # We'll make it enabled by default in the future. - "VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": - lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE", '0'))), # The version of the Ascend chip. If not set, the default value is # ASCEND910B1. It's used for package building. Please make sure that the # version is correct. diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index b49b4e4dce..679bbc2c0e 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -78,7 +78,7 @@ make_multistream_metadata_ds) from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.ops.fused_moe import AscendFusedMoE -from vllm_ascend.utils import dispose_tensor, vllm_version_is +from vllm_ascend.utils import dispose_tensor VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO @@ -1032,19 +1032,12 @@ def load_weights(self, weights: Iterable[tuple[str, param = params_dict[name] weight_loader = param.weight_loader - if vllm_version_is("0.9.1"): - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - else: - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id, - return_success=False) + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + return_success=False) break else: # Skip loading extra bias for GPTQ models. diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index d7f68a12c7..1ff580f524 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -75,7 +75,7 @@ from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import (dispose_tensor, npu_stream_switch, - npu_wait_tensor, vllm_version_is) + npu_wait_tensor) class CustomDeepseekV2SiluAndMul(SiluAndMul): @@ -936,19 +936,12 @@ def load_weights(self, weights: Iterable[tuple[str, param = params_dict[name] weight_loader = param.weight_loader - if vllm_version_is("0.9.1"): - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - else: - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id, - return_success=False) + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + return_success=False) break else: # Skip loading extra bias for GPTQ models. diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index aa189428a3..0197cb30bf 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -28,6 +28,10 @@ tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import get_dp_group, get_tp_group from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe.config import \ + FusedMoEConfig # isort: skip +from vllm.model_executor.layers.fused_moe.config import \ + FusedMoEParallelConfig # isort: skip from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map) from vllm.model_executor.layers.quantization.base_config import \ @@ -39,16 +43,7 @@ from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.utils import (FusedMoEState, dispose_tensor, get_fused_moe_state, is_310p, npu_stream_switch, - npu_wait_tensor, vllm_version_is) - -if vllm_version_is("0.9.1"): - from vllm.model_executor.layers.fused_moe.layer import \ - FusedMoEParallelConfig - from vllm.model_executor.layers.fused_moe.layer import \ - MoEConfig as FusedMoEConfig -else: - from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, FusedMoEParallelConfig) + npu_wait_tensor) MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER @@ -1177,27 +1172,15 @@ def __init__( if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") - - if vllm_version_is("0.9.1"): - moe = FusedMoEConfig( - num_experts=self.global_num_experts, - experts_per_token=top_k, - hidden_dim=hidden_size, - num_local_experts=self.local_num_experts, - moe_parallel_config=self.moe_parallel_config, - # TODO (bnell): this needs to be fixed for quantized types. - in_dtype=params_dtype, - ) - else: - moe = FusedMoEConfig.make( - num_experts=self.global_num_experts, - experts_per_token=top_k, - hidden_dim=hidden_size, - num_local_experts=self.local_num_experts, - moe_parallel_config=self.moe_parallel_config, - # TODO (bnell): this needs to be fixed for quantized types. - in_dtype=params_dtype, - quant_config=quant_config) + moe = FusedMoEConfig.make( + num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + # TODO (bnell): this needs to be fixed for quantized types. + in_dtype=params_dtype, + quant_config=quant_config) if quant_config is None: self.quant_method = AscendUnquantizedFusedMoEMethod(moe) diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 63b56fd7c0..b054fc6cb6 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -24,9 +24,9 @@ # each worker's `__init__` function. # # Then in each kind of patch, there are three folders: -# - patch_0_9_1: contains the patches applied when vllm version is 0.9.1. +# - patch_0_9_2: contains the patches applied when vllm version is 0.9.2. # - patch_main: contains the patches applied when vllm version is main branch. -# - patch_common: contains the patches applied in both 0.9.1 and main branch. +# - patch_common: contains the patches applied in both 0.9.2 and main branch. # # Once a new patch is added in vllm-ascend, please add the patch description into this file as well. # ---------------------------------------------------------------------------------- @@ -105,32 +105,6 @@ # Future Plan: # Revert it when the related pr is merged in vllm and vllm-ascend. # -# ** File: worker/patch_common/patch_sampler.py ** -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# 1. `vllm.v1.sample.sampler.Sampler.apply_top_k_top_p` -# Why: -# We need to use the patched `apply_top_k_top_p` in `sample`. -# The mainly reason to overwrite `apply_top_k_top_p` is -# to improve performance. -# How: -# Re-implementation the `apply_top_k_top_p` function by pytorch -# Related PR (if no, explain why): -# - https://github.com/vllm-project/vllm-ascend/pull/970 -# Future Plan: -# Revert it when the ascend scatter performance improves. -# -# 2. `vllm.v1.sample.sampler.Sampler.apply_min_p` -# Why: -# We need to use the patched `apply_min_p` in `sample`. -# The mainly reason to overwrite `apply_min_p` is -# to improve performance. -# How: -# Re-implementation the `apply_min_p` function by pytorch -# Related PR (if no, explain why): -# - https://github.com/vllm-project/vllm-ascend/pull/970 -# Future Plan: -# Revert it when the ascend indexput performance improves. -# # ** File: worker/patch_common/patch_distributed.py ** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.distributed.parallel_state.GroupCoordinator` @@ -154,4 +128,4 @@ # Related PR (if no, explain why): # This is the problem in vllm-ascend # Future Plan: -# Remove this patch once pytorch 2.7.0 is supported for vllm ascend. \ No newline at end of file +# Remove this patch once pytorch 2.7.0 is supported for vllm ascend. diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index 4ec38e38d0..4e0fed578b 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -17,8 +17,8 @@ from vllm_ascend.utils import vllm_version_is # Import specific patches for different versions -if vllm_version_is("0.9.1"): - from vllm_ascend.patch.platform import patch_0_9_1 # noqa: F401 +if vllm_version_is("0.9.2"): + from vllm_ascend.patch.platform import patch_0_9_2 # noqa: F401 from vllm_ascend.patch.platform import patch_common # noqa: F401 else: from vllm_ascend.patch.platform import patch_common # noqa: F401 diff --git a/vllm_ascend/patch/platform/patch_0_9_1/__init__.py b/vllm_ascend/patch/platform/patch_0_9_2/__init__.py similarity index 100% rename from vllm_ascend/patch/platform/patch_0_9_1/__init__.py rename to vllm_ascend/patch/platform/patch_0_9_2/__init__.py diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 3b29856d22..de7219ad2e 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -18,8 +18,8 @@ from vllm_ascend.utils import vllm_version_is # Import specific patches for different versions -if vllm_version_is("0.9.1"): - from vllm_ascend.patch.worker import patch_0_9_1 # noqa: F401 +if vllm_version_is("0.9.2"): + from vllm_ascend.patch.worker import patch_0_9_2 # noqa: F401 from vllm_ascend.patch.worker import patch_common # noqa: F401 else: from vllm_ascend.patch.worker import patch_common # noqa: F401 diff --git a/vllm_ascend/patch/worker/patch_0_9_1/patch_sampler.py b/vllm_ascend/patch/worker/patch_0_9_1/patch_sampler.py deleted file mode 100644 index 69fcd691b2..0000000000 --- a/vllm_ascend/patch/worker/patch_0_9_1/patch_sampler.py +++ /dev/null @@ -1,106 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing import Optional - -import torch -import torch_npu -from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample -from vllm.v1.sample.sampler import Sampler - -from vllm_ascend import envs - - -def apply_min_p( - self, - logits: torch.Tensor, - min_p: torch.Tensor, -) -> torch.Tensor: - """ - Filters logits using adaptive probability thresholding. - """ - # Convert logits to probability distribution - probability_values = torch.nn.functional.softmax(logits, dim=-1) - # Calculate maximum probabilities per sequence - max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) - # Reshape min_p for broadcasting - adjusted_min_p = min_p.unsqueeze(1) * max_probabilities - # Identify valid tokens using threshold comparison - # Apply mask using boolean indexing - logits = logits.masked_fill(probability_values < adjusted_min_p, - -float('inf')) - return logits - - -def _apply_top_k_top_p( - logits: torch.Tensor, - k: torch.Tensor, - p: torch.Tensor, -) -> torch.Tensor: - if p is not None and k is not None: - # npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p) - return torch_npu.npu_top_k_top_p(logits, p, k) - - probs = logits.softmax(dim=-1) - probs_sort, _ = probs.sort(dim=-1, descending=False) - - if k is not None: - top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) - top_k_count = top_k_count.unsqueeze(dim=1) - top_k_cutoff = probs_sort.gather(-1, top_k_count) - - # Make sure the no top-k rows are no-op. - no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) - top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) - - elements_to_discard = probs < top_k_cutoff - logits.masked_fill_(elements_to_discard, -float("inf")) - - if p is not None: - cumprob = torch.cumsum(probs_sort, dim=-1) - top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) - top_p_mask[:, -1] = False # at least one - - top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) - top_p_cutoff = probs_sort.gather(-1, top_p_count) - elements_to_discard = probs < top_p_cutoff - logits.masked_fill_(elements_to_discard, -float("inf")) - - return logits - - -def topk_topp_forward_native( - self, - logits: torch.Tensor, - generators: dict[int, torch.Generator], - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], -) -> torch.Tensor: - """ - PyTorch-native implementation of top-k and top-p sampling. - - The logits tensor may be updated in-place. - """ - logits = _apply_top_k_top_p(logits, k, p) - probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) - - -Sampler.apply_min_p = apply_min_p -if envs.VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE: - TopKTopPSampler.forward_native = topk_topp_forward_native diff --git a/vllm_ascend/patch/worker/patch_0_9_1/__init__.py b/vllm_ascend/patch/worker/patch_0_9_2/__init__.py similarity index 91% rename from vllm_ascend/patch/worker/patch_0_9_1/__init__.py rename to vllm_ascend/patch/worker/patch_0_9_2/__init__.py index 6b08ae9863..116c73c06c 100644 --- a/vllm_ascend/patch/worker/patch_0_9_1/__init__.py +++ b/vllm_ascend/patch/worker/patch_0_9_2/__init__.py @@ -14,4 +14,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import vllm_ascend.patch.worker.patch_0_9_1.patch_sampler # noqa diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e5e0ce4c63..4f7b58d17a 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -44,6 +44,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models.interfaces import has_step_pooler from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality @@ -81,7 +82,7 @@ ProfileExecuteDuration, check_torchair_cache_exist, is_310p, maybe_converting_weight_acl_format, - vllm_version_is, write_kv_cache_bytes_to_file) + write_kv_cache_bytes_to_file) from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch @@ -97,9 +98,6 @@ import vllm_ascend.envs as envs_ascend -if vllm_version_is("0.9.1"): - from vllm.v1.spec_decode.utils import is_spec_decode_supported - if is_310p(): torch_npu.npu.set_compile_mode(jit_compile=False) @@ -398,16 +396,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: else: generator = None - # For vllm v0.9.1 version compatibility, we check if - # `pooling_params` is present in the new request data. - pooling_params = getattr(new_req_data, "pooling_params", None) self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, - pooling_params=pooling_params, + pooling_params=new_req_data.pooling_params, generator=generator, block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, @@ -455,62 +450,59 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_ids_to_add.append(req_id) # Update the states of the running/resumed requests. - if vllm_version_is("0.9.1"): - for req_data in scheduler_output.scheduled_cached_reqs: - req_id = req_data.req_id - req_state = self.requests[req_id] + req_data = scheduler_output.scheduled_cached_reqs + is_last_rank = get_pp_group().is_last_rank + for i, req_id in enumerate(req_data.req_ids): + req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_data.resumed_from_preemption[i] - # Update the cached states. - num_computed_tokens = req_data.num_computed_tokens - req_state.num_computed_tokens = num_computed_tokens + req_state.num_computed_tokens = num_computed_tokens + if not is_last_rank: + new_token_ids = req_data.new_token_ids[i] # Add the sampled token(s) from the previous step (if any). # This doesn't include "unverified" tokens like spec decode tokens. - num_new_tokens = (num_computed_tokens + - len(req_data.new_token_ids) - + num_new_tokens = (num_computed_tokens + len(new_token_ids) - req_state.num_tokens) if num_new_tokens == 1: # Avoid slicing list in most common case. - req_state.output_token_ids.append( - req_data.new_token_ids[-1]) + req_state.output_token_ids.append(new_token_ids[-1]) elif num_new_tokens > 0: req_state.output_token_ids.extend( - req_data.new_token_ids[-num_new_tokens:]) - # Update the block IDs. - if not req_data.resumed_from_preemption: - # Append the new blocks to the existing block IDs. - for block_ids, new_block_ids in zip( # type: ignore[call-overload] - req_state.block_ids, - req_data.new_block_ids, - strict=True): - block_ids.extend(new_block_ids) - else: - # The request is resumed from preemption. - # Replace the existing block IDs with the new ones. - req_state.block_ids = req_data.new_block_ids - - req_index = self.input_batch.req_id_to_index.get(req_id) - if req_index is None: - # The request is not in the persistent batch. - # The request was either preempted and resumed later, or was not - # scheduled in the previous step and needs to be added again. - req_ids_to_add.append(req_id) - continue + new_token_ids[-num_new_tokens:]) + # Update the block IDs. + if not resumed_from_preemption: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip( # type: ignore[call-overload] + req_state.block_ids, new_block_ids): + block_ids.extend(new_ids) + else: + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + req_ids_to_add.append(req_id) + continue - # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = ( - num_computed_tokens) + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = ( + num_computed_tokens) - start_index = (len(req_state.block_ids) - - len(req_data.new_block_ids)) - self.input_batch.block_table.append_row( - req_data.new_block_ids, req_index) + self.input_batch.block_table.append_row(new_block_ids, req_index) + + if not is_last_rank: # Add new_token_ids to token_ids_cpu. start_token_index = num_computed_tokens - end_token_index = num_computed_tokens + len( - req_data.new_token_ids) + end_token_index = num_computed_tokens + len(new_token_ids) self.input_batch.token_ids_cpu[ req_index, - start_token_index:end_token_index] = req_data.new_token_ids + start_token_index:end_token_index] = new_token_ids self.input_batch.num_tokens_no_spec[ req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. @@ -524,75 +516,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: start_index:end_token_index] = spec_token_ids # NOTE(woosuk): `num_tokens` here may include spec decode tokens. self.input_batch.num_tokens[req_index] = end_token_index - else: - req_data = scheduler_output.scheduled_cached_reqs - is_last_rank = get_pp_group().is_last_rank - for i, req_id in enumerate(req_data.req_ids): - req_state = self.requests[req_id] - num_computed_tokens = req_data.num_computed_tokens[i] - new_block_ids = req_data.new_block_ids[i] - resumed_from_preemption = req_data.resumed_from_preemption[i] - - req_state.num_computed_tokens = num_computed_tokens - if not is_last_rank: - new_token_ids = req_data.new_token_ids[i] - # Add the sampled token(s) from the previous step (if any). - # This doesn't include "unverified" tokens like spec decode tokens. - num_new_tokens = (num_computed_tokens + - len(new_token_ids) - - req_state.num_tokens) - if num_new_tokens == 1: - # Avoid slicing list in most common case. - req_state.output_token_ids.append(new_token_ids[-1]) - elif num_new_tokens > 0: - req_state.output_token_ids.extend( - new_token_ids[-num_new_tokens:]) - # Update the block IDs. - if not resumed_from_preemption: - # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip( # type: ignore[call-overload] - req_state.block_ids, new_block_ids): - block_ids.extend(new_ids) - else: - # The request is resumed from preemption. - # Replace the existing block IDs with the new ones. - req_state.block_ids = new_block_ids - - req_index = self.input_batch.req_id_to_index.get(req_id) - if req_index is None: - # The request is not in the persistent batch. - # The request was either preempted and resumed later, or was not - # scheduled in the previous step and needs to be added again. - req_ids_to_add.append(req_id) - continue - - # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = ( - num_computed_tokens) - - self.input_batch.block_table.append_row( - new_block_ids, req_index) - - if not is_last_rank: - # Add new_token_ids to token_ids_cpu. - start_token_index = num_computed_tokens - end_token_index = num_computed_tokens + len(new_token_ids) - self.input_batch.token_ids_cpu[ - req_index, - start_token_index:end_token_index] = new_token_ids - self.input_batch.num_tokens_no_spec[ - req_index] = end_token_index - # Add spec_token_ids to token_ids_cpu. - spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( - req_id, ()) - if spec_token_ids: - start_index = end_token_index - end_token_index += len(spec_token_ids) - self.input_batch.token_ids_cpu[ - req_index, - start_index:end_token_index] = spec_token_ids - # NOTE(woosuk): `num_tokens` here may include spec decode tokens. - self.input_batch.num_tokens[req_index] = end_token_index # Check if the batch has changed. If not, we can skip copying the # sampling metadata from CPU to GPU. @@ -827,25 +750,13 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): # compute completion's mrope_positions on-the-fly dst_start = mrope_pos_ptr dst_end = mrope_pos_ptr + completion_part_len - - if vllm_version_is("0.9.1"): - self.mrope_positions_cpu[:, dst_start:dst_end] = \ - MRotaryEmbedding.get_next_input_positions_tensor( - req.mrope_position_delta, - context_len=num_computed_tokens + - prompt_part_len, - seq_len=num_computed_tokens + - prompt_part_len + - completion_part_len, - ) - else: - MRotaryEmbedding.get_next_input_positions_tensor( - out=self.mrope_positions_np, - out_offset=dst_start, - mrope_position_delta=req.mrope_position_delta, - context_len=num_computed_tokens + prompt_part_len, - num_new_tokens=completion_part_len, - ) + MRotaryEmbedding.get_next_input_positions_tensor( + out=self.mrope_positions_np, + out_offset=dst_start, + mrope_position_delta=req.mrope_position_delta, + context_len=num_computed_tokens + prompt_part_len, + num_new_tokens=completion_part_len, + ) mrope_pos_ptr += completion_part_len @@ -1577,30 +1488,29 @@ def execute_model( for i in discard_sampled_tokens_req_indices: valid_sampled_token_ids[i].clear() - if not vllm_version_is("0.9.1"): - # Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions - # doesn't need to send them back. - # NOTE(woosuk): As an exception, when using PP, the scheduler sends - # the sampled tokens back, because there's no direct communication - # between the first-stage worker and the last-stage worker. - for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): - if not sampled_ids: - continue - - start_idx = self.input_batch.num_tokens_no_spec[req_idx] - end_idx = start_idx + len(sampled_ids) - assert end_idx <= self.model_config.max_model_len, ( - "Sampled token IDs exceed the max model length. " - f"Total number of tokens: {end_idx} > max_model_len: " - f"{self.model_config.max_model_len}") + # Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions + # doesn't need to send them back. + # NOTE(woosuk): As an exception, when using PP, the scheduler sends + # the sampled tokens back, because there's no direct communication + # between the first-stage worker and the last-stage worker. + for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): + if not sampled_ids: + continue - self.input_batch.token_ids_cpu[ - req_idx, start_idx:end_idx] = sampled_ids - self.input_batch.num_tokens_no_spec[req_idx] = end_idx - self.input_batch.num_tokens[req_idx] = end_idx - req_id = self.input_batch.req_ids[req_idx] - req_state = self.requests[req_id] - req_state.output_token_ids.extend(sampled_ids) + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + len(sampled_ids) + assert end_idx <= self.model_config.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.model_config.max_model_len}") + + self.input_batch.token_ids_cpu[req_idx, + start_idx:end_idx] = sampled_ids + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + req_id = self.input_batch.req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) spec_token_ids = self._get_spec_token_ids( valid_sampled_token_ids, @@ -1613,25 +1523,16 @@ def execute_model( attn_metadata, aux_hidden_states, ) - if vllm_version_is("0.9.1"): - model_runner_output = ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=spec_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict=prompt_logprobs_dict, - ) - else: - model_runner_output = ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=spec_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], - ) + + model_runner_output = ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=spec_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + ) durations = ProfileExecuteDuration().pop_captured_sync() if durations: @@ -1827,15 +1728,8 @@ def load_model(self) -> None: QKVParallelLinear, RowParallelLinear)): module.weight.data = torch_npu.npu_format_cast( module.weight.data, ACL_FORMAT_FRACTAL_NZ) - - try: - # For version compatibility, remove this after we abort vllm v0.9.1 support - from vllm.model_executor.models.interfaces import \ - has_step_pooler # type: ignore - if has_step_pooler(self.model): - self.input_batch.logits_processing_needs_token_ids = True - except ImportError: - pass + if has_step_pooler(self.model): + self.input_batch.logits_processing_needs_token_ids = True if self.drafter: logger.info("Loading drafter model...") if isinstance(self.drafter, EagleProposer): @@ -2164,14 +2058,9 @@ def _generate_ngram_token_ids( # Skip requests that require top-p, top-k, etc. req_id = self.input_batch.req_ids[i] - if vllm_version_is("0.9.1"): - if not is_spec_decode_supported(req_id, self.input_batch): - draft_token_ids.append([]) - continue - else: - if req_id in self.input_batch.spec_decode_unsupported_reqs: - draft_token_ids.append([]) - continue + if req_id in self.input_batch.spec_decode_unsupported_reqs: + draft_token_ids.append([]) + continue # Add sampled_token_ids to token_ids_cpu. start_idx = self.input_batch.num_tokens_no_spec[i] diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index 5e7b2c0e6f..cb5b264d7e 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -28,15 +28,13 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values from vllm.v1.outputs import LogprobsTensors +from vllm.v1.sample.logits_processor import init_builtin_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.utils import copy_slice from vllm.v1.worker.block_table import MultiGroupBlockTable from vllm_ascend.pool.metadata import PoolingMetadata -from vllm_ascend.utils import vllm_version_is - -if not vllm_version_is("0.9.1"): - from vllm.v1.spec_decode.utils import is_spec_decode_unsupported _SAMPLING_EPS = 1e-5 @@ -253,17 +251,13 @@ def __init__( self.req_output_token_ids: list[Optional[list[int]]] = [] - if not vllm_version_is("0.9.1"): - from vllm.v1.sample.logits_processor import \ - init_builtin_logitsprocs - - # Define logits processors. - # TODO(andy): logits processor list should be extensible via engine - # constructor argument; for now the list is fixed. - self.logitsprocs = init_builtin_logitsprocs( - pin_memory_available=pin_memory, - max_num_reqs=max_num_reqs + 1, - device=device) + # Define logits processors. + # TODO(andy): logits processor list should be extensible via engine + # constructor argument; for now the list is fixed. + self.logitsprocs = init_builtin_logitsprocs( + pin_memory_available=pin_memory, + max_num_reqs=max_num_reqs + 1, + device=device) # This is updated each time the batch constituents change. self.sampling_metadata = self._make_sampling_metadata() @@ -314,8 +308,8 @@ def add_request( self.block_table.add_row(request.block_ids, req_index) if sampling_params := request.sampling_params: - if ((not vllm_version_is("0.9.1")) and self.is_spec_decode - and is_spec_decode_unsupported(sampling_params)): + if self.is_spec_decode and is_spec_decode_unsupported( + sampling_params): self.spec_decode_unsupported_reqs.add(req_id) if sampling_params.sampling_type == SamplingType.GREEDY: # Avoid later division by zero. @@ -641,48 +635,24 @@ def _make_sampling_metadata(self) -> SamplingMetadata: self.allowed_token_ids_mask, num_reqs) allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] - if vllm_version_is("0.9.1"): - return SamplingMetadata( - temperature=temperature, - all_greedy=self.all_greedy, - all_random=self.all_random, - top_p=None if self.no_top_p else self.top_p[:num_reqs], - top_k=None if self.no_top_k else self.top_k[:num_reqs], - min_p=None if self.no_min_p else self.min_p[:num_reqs], - generators=self.generators, - max_num_logprobs=self.max_num_logprobs, - prompt_token_ids=prompt_token_ids, - frequency_penalties=self.frequency_penalties[:num_reqs], - presence_penalties=self.presence_penalties[:num_reqs], - repetition_penalties=self.repetition_penalties[:num_reqs], - output_token_ids=cast(list[list[int]], - self.req_output_token_ids), - min_tokens=self.min_tokens, - no_penalties=self.no_penalties, - logit_bias=self.logit_bias[:num_reqs], - allowed_token_ids_mask=allowed_token_ids_mask, - bad_words_token_ids=self.bad_words_token_ids, - ) - else: - return SamplingMetadata( - temperature=temperature, - all_greedy=self.all_greedy, - all_random=self.all_random, - top_p=None if self.no_top_p else self.top_p[:num_reqs], - top_k=None if self.no_top_k else self.top_k[:num_reqs], - generators=self.generators, - max_num_logprobs=self.max_num_logprobs, - prompt_token_ids=prompt_token_ids, - frequency_penalties=self.frequency_penalties[:num_reqs], - presence_penalties=self.presence_penalties[:num_reqs], - repetition_penalties=self.repetition_penalties[:num_reqs], - output_token_ids=cast(list[list[int]], - self.req_output_token_ids), - no_penalties=self.no_penalties, - allowed_token_ids_mask=allowed_token_ids_mask, - bad_words_token_ids=self.bad_words_token_ids, - logitsprocs=self.logitsprocs, - ) + return SamplingMetadata( + temperature=temperature, + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=None if self.no_top_p else self.top_p[:num_reqs], + top_k=None if self.no_top_k else self.top_k[:num_reqs], + generators=self.generators, + max_num_logprobs=self.max_num_logprobs, + prompt_token_ids=prompt_token_ids, + frequency_penalties=self.frequency_penalties[:num_reqs], + presence_penalties=self.presence_penalties[:num_reqs], + repetition_penalties=self.repetition_penalties[:num_reqs], + output_token_ids=cast(list[list[int]], self.req_output_token_ids), + no_penalties=self.no_penalties, + allowed_token_ids_mask=allowed_token_ids_mask, + bad_words_token_ids=self.bad_words_token_ids, + logitsprocs=self.logitsprocs, + ) @property def pooling_metadata(self) -> PoolingMetadata: