Skip to content

Commit cdece86

Browse files
PotabkYikun
andauthored
[Bugfix] Add max_num_batched_tokens to InputBatch to make main CI pass (vllm-project#806)
### What this PR does / why we need it? 1. Fix V1 error found by [nightly_ci](https://github.com/vllm-project/vllm-ascend/actions/runs/14950004754/job/41998136610), broken by [[v1] Pass BlockTable and KVCacheSpec to AttentionMetadataBuilders #17483](vllm-project/vllm#17483), make `InputBatch` parameter consistent with vllm. 2. Disable benmark and fix it in upstream. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed --------- Signed-off-by: wangli <wangli858794774@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
1 parent 218f21d commit cdece86

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ norecursedirs =
3939
vllm-empty/tests/neuron
4040
; fastsafetensors not support npu now
4141
vllm-empty/tests/fastsafetensors_loader
42+
; Enable after https://github.com/vllm-project/vllm-ascend/issues/808 resolved
43+
vllm-empty/tests/benchmarks
4244

4345
addopts = --ignore=vllm-empty/tests/test_utils.py
4446
--ignore=vllm-empty/tests/test_config.py

vllm_ascend/worker/model_runner_v1.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from vllm_ascend.attention.attention import AttentionMaskBuilder
5656
from vllm_ascend.attention.attention_v1 import AscendAttentionState
5757
from vllm_ascend.platform import NPUPlatform
58+
from vllm_ascend.utils import vllm_version_is
5859

5960
if TYPE_CHECKING:
6061
import xgrammar as xgr # type: ignore[import-untyped]
@@ -187,14 +188,26 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
187188
# Request states.
188189
self.requests: Dict[str, CachedRequestState] = {}
189190
# Persistent batch.
190-
self.input_batch = InputBatch(
191-
max_num_reqs=self.max_num_reqs,
192-
max_model_len=self.model_config.max_model_len,
193-
max_num_blocks_per_req=self.max_num_blocks_per_req,
194-
device=self.device,
195-
pin_memory=True,
196-
vocab_size=self.model_config.get_vocab_size(),
197-
)
191+
# Remove this after we drop 0.8.5 support
192+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
193+
self.input_batch = InputBatch(
194+
max_num_reqs=self.max_num_reqs,
195+
max_model_len=self.model_config.max_model_len,
196+
max_num_blocks_per_req=self.max_num_blocks_per_req,
197+
device=self.device,
198+
pin_memory=True,
199+
vocab_size=self.model_config.get_vocab_size(),
200+
)
201+
else:
202+
self.input_batch = InputBatch(
203+
max_num_reqs=self.max_num_reqs,
204+
max_model_len=self.model_config.max_model_len,
205+
max_num_blocks_per_req=self.max_num_blocks_per_req,
206+
max_num_batched_tokens=self.max_num_tokens,
207+
device=self.device,
208+
pin_memory=True,
209+
vocab_size=self.model_config.get_vocab_size(),
210+
)
198211

199212
self.input_ids = torch.zeros(self.max_num_tokens,
200213
dtype=torch.int32,

0 commit comments

Comments
 (0)