Skip to content

Commit 38692b5

Browse files
henryxuxu0716刘哲续
andauthored
fix torchair execute issue on padding data, and mtp padding logic -- by#1160 (#1214)
### What this PR does / why we need it? …by#1160 - Fixes # -->fix torchair execute issue on padding data, and mtp padding logic ### How was this patch tested? it has been tested and merged in main. Signed-off-by: 刘哲续 <liuzhexu1@huawei.com> Co-authored-by: 刘哲续 <liuzhexu1@huawei.com>
1 parent 75c10ce commit 38692b5

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,10 @@ def build(
376376
seq_lens = seq_lens[:self._num_decode_tokens]
377377
input_positions = input_positions[:self._num_decode_tokens]
378378
block_table = block_table[:self._num_decode_tokens, ...]
379-
if use_torchair_graph and self.runner.attn_state == AscendAttentionState.DecodeOnly:
379+
if use_torchair_graph and self.runner.attn_state in [
380+
AscendAttentionState.DecodeOnly,
381+
AscendAttentionState.SpecDecoding
382+
]:
380383
num_seqs = len(seq_lens)
381384
if graph_pad_size != 0:
382385
pad_value = 1

vllm_ascend/worker/model_runner_v1.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -943,11 +943,6 @@ def _process_reqs(
943943
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
944944
input_ids = self.input_ids[:num_input_tokens]
945945

946-
if (envs_ascend.VLLM_ENABLE_MC2
947-
or self.torchair_graph_enabled) and not with_prefill:
948-
input_ids = self.input_ids[:padded_batch_size]
949-
positions = self.positions[:padded_batch_size]
950-
951946
# prepare the MRoPE for mllm if using multimodal
952947
num_input_tokens = total_num_scheduled_tokens
953948
# _prepare_inputs may reorder the batch, so we must gather multi
@@ -985,6 +980,11 @@ def _process_reqs(
985980
else:
986981
positions = self.positions[:num_input_tokens]
987982

983+
if (envs_ascend.VLLM_ENABLE_MC2
984+
or self.torchair_graph_enabled) and not with_prefill:
985+
input_ids = self.input_ids[:padded_batch_size]
986+
positions = self.positions[:padded_batch_size]
987+
988988
# Run forward pass
989989
with set_forward_context(attn_metadata,
990990
self.vllm_config,

0 commit comments

Comments
 (0)