Skip to content

Commit 30bf701

Browse files
authored
[Bugfix] Add func swap_states to fix MLA attention (#1580)
### What this PR does / why we need it? mla attention still using the gpu_input_batch's attr:`swap_states`, which will lead to an error `AttributeError: 'InputBatch' object has no attribute 'swap_states'` This PR fixed the mla input patch error ### How was this patch tested? will be tested by #1136 --------- Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent 59237ea commit 30bf701

File tree

3 files changed

+76
-1
lines changed

3 files changed

+76
-1
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
2323
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
2424
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
25+
from vllm_ascend.worker.npu_input_batch import InputBatch
2526

2627
if TYPE_CHECKING:
2728
from vllm.v1.core.sched.output import SchedulerOutput
28-
from vllm.v1.worker.gpu_input_batch import InputBatch
2929

3030

3131
@dataclass

vllm_ascend/pool/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# This file is a part of the vllm-ascend project.
16+
#

vllm_ascend/worker/npu_input_batch.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
2727
from vllm.pooling_params import PoolingParams
2828
from vllm.sampling_params import SamplingParams, SamplingType
29+
from vllm.utils import swap_dict_values
2930
from vllm.v1.outputs import LogprobsTensors
3031
from vllm.v1.sample.metadata import SamplingMetadata
3132
from vllm.v1.utils import copy_slice
@@ -423,6 +424,64 @@ def remove_request(self, req_id: str) -> Optional[int]:
423424
self.pooling_params.pop(req_id, None)
424425
return req_index
425426

427+
def swap_states(self, i1: int, i2: int) -> None:
428+
old_id_i1 = self._req_ids[i1]
429+
old_id_i2 = self._req_ids[i2]
430+
self._req_ids[i1], self._req_ids[i2] =\
431+
self._req_ids[i2], self._req_ids[i1] # noqa
432+
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
433+
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
434+
assert old_id_i1 is not None and old_id_i2 is not None
435+
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
436+
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
437+
self.num_tokens[i1], self.num_tokens[i2] =\
438+
self.num_tokens[i2], self.num_tokens[i1]
439+
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
440+
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
441+
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
442+
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
443+
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
444+
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
445+
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
446+
self.temperature_cpu[i2], self.temperature_cpu[i1]
447+
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
448+
self.top_p_cpu[i2], self.top_p_cpu[i1]
449+
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
450+
self.top_k_cpu[i2], self.top_k_cpu[i1]
451+
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
452+
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
453+
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
454+
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
455+
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
456+
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
457+
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
458+
self.min_p_cpu[i2], self.min_p_cpu[i1]
459+
460+
# NOTE: the following is unsafe
461+
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
462+
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
463+
# instead, we need to temporiarily copy the data for one of the indices
464+
# TODO(lucas): optimize this by only copying valid indices
465+
tmp = self.token_ids_cpu[i1, ...].copy()
466+
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
467+
self.token_ids_cpu[i2, ...] = tmp
468+
469+
swap_dict_values(self.generators, i1, i2)
470+
swap_dict_values(self.min_tokens, i1, i2)
471+
swap_dict_values(self.bad_words_token_ids, i1, i2)
472+
473+
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
474+
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
475+
self.logit_bias[i1], self.logit_bias[i2] =\
476+
self.logit_bias[i2], self.logit_bias[i1]
477+
478+
if self.allowed_token_ids_mask_cpu_tensor is not None:
479+
self.allowed_token_ids_mask_cpu_tensor[i1], \
480+
self.allowed_token_ids_mask_cpu_tensor[i2] =\
481+
self.allowed_token_ids_mask_cpu_tensor[i2], \
482+
self.allowed_token_ids_mask_cpu_tensor[i1]
483+
self.block_table.swap_row(i1, i2)
484+
426485
def condense(self, empty_req_indices: list[int]) -> None:
427486
"""Move non-empty requests down into lower, empty indices.
428487

0 commit comments

Comments
 (0)