diff --git a/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py b/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py index 3b5e1986f2..68736c0928 100644 --- a/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py +++ b/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py @@ -114,7 +114,8 @@ def test_mtp_torchair_correctness( enforce_eager=False, additional_config={ "torchair_graph_config": { - "enabled": True + "enabled": True, + "graph_batch_size": [256] }, "ascend_scheduler_config": { "enabled": True @@ -132,7 +133,8 @@ def test_mtp_torchair_correctness( }, additional_config={ "torchair_graph_config": { - "enabled": True + "enabled": True, + "graph_batch_size": [256] }, "ascend_scheduler_config": { "enabled": True diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 98f0a3389c..50826f1493 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -329,7 +329,7 @@ def build_torchair_graph_dummy( num_reqs, block_table) num_tokens = num_reqs * self.runner.decode_token_per_req seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device) - seq_lens_list = seq_lens.tolist() + seq_lens_list = [0] * num_reqs input_positions = torch.zeros(num_tokens, dtype=torch.int32, device=device).long() @@ -471,7 +471,7 @@ def build( decode_metadata = None use_torchair_graph = num_token_pad_size != -1 if self._num_decodes > 0: - actual_seq_q_lens = None + actual_seq_q_lens = query_start_loc[1:].tolist() max_seq_lens = seq_lens[:self._num_decodes].max().item() seq_lens = seq_lens[:self._num_decode_tokens] input_positions = input_positions[:self._num_decode_tokens] @@ -980,11 +980,13 @@ def _forward_decode( self.qk_rope_head_dim) input_layout = "BNSD" - # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim] if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: assert num_tokens % self.spec_token_num == 0 + if self.enable_kv_nz: + input_layout = "TND_NTD" + else: + input_layout = "TND" # [bs * q_seq_len, num_heads_per_rank, dim] - input_layout = "TND" q_nope = q_nope.view(num_tokens, self.num_heads, -1) q_pe = q_pe.view(num_tokens, self.num_heads, -1) sparse_mode = 3 diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index 384787be01..c738410d0c 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -432,7 +432,7 @@ def sample_recovered_tokens_pytorch( if IS_NGRAM: draft_token_id = draft_token_ids[token_idx] - orig_prob = target_probs[token_idx, draft_token_id] + orig_prob = target_probs[token_idx, draft_token_id].item() target_probs[token_idx, draft_token_id] = 0 prob = target_probs[token_idx].clone() else: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e0ab79be45..1d4bb09310 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1677,7 +1677,7 @@ def _dummy_run( **model_kwargs) if self.speculative_config and self.speculative_config.method == "deepseek_mtp": assert isinstance(self.drafter, MtpProposer) - self.drafter.dummy_run(num_reqs, with_prefill=with_prefill) + self.drafter.dummy_run(num_reqs) return hidden_states @contextmanager @@ -2140,6 +2140,12 @@ def check_torchair_graph_batch_sizes(self): if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs: self.torchair_graph_batch_sizes.append(self.max_num_reqs) + # we need to make sure that we can deal with max_num_req when `self.decode_token_per_req` is not 1 + self.torchair_graph_batch_sizes = [ + graph_batch_size * self.decode_token_per_req + for graph_batch_size in self.torchair_graph_batch_sizes + ] + # NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size` tp_size = self.parallel_config.tensor_parallel_size if self.parallel_config.enable_expert_parallel: @@ -2147,9 +2153,13 @@ def check_torchair_graph_batch_sizes(self): for graph_batch_size in self.torchair_graph_batch_sizes: cur_graph_batch_size = (graph_batch_size + tp_size - 1) // tp_size * tp_size - # `graph_batch_size` need to be divisible by `self.decode_token_per_req` - cur_graph_batch_size = cur_graph_batch_size * self.decode_token_per_req if cur_graph_batch_size not in new_graph_batch_sizes and \ cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens: new_graph_batch_sizes.append(cur_graph_batch_size) + elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \ + and self.decode_token_per_req > 1: + logger.warning( + f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens", + f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size." + ) self.torchair_graph_batch_sizes = new_graph_batch_sizes diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 1f0f75fda7..370c65ffa9 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -307,14 +307,9 @@ def load_model(self) -> None: def dummy_run( self, num_tokens: int, - with_prefill: bool = False, ) -> None: - if self.runner.torchair_graph_enabled and not with_prefill: - attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( - num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True) - else: - attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( - num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True) + attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( + num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True) with set_ascend_forward_context(None, self.vllm_config, num_tokens=num_tokens):