Skip to content

Commit 5c5bb07

Browse files
fix estimate
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent ccd6331 commit 5c5bb07

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

vllm/v1/attention/backends/mla/rocm_aiter_mla.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import vllm.envs as envs
1010
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
1111
from vllm.config import VllmConfig
12+
from vllm.utils import cdiv
1213
# yapf conflicts with isort for this docstring
1314
# yapf: disable
1415
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
@@ -72,16 +73,18 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
7273
"only supports block size 1."
7374

7475
self.compilation_config = vllm_config.compilation_config
76+
max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len,
77+
self.kv_cache_spec.block_size)
78+
max_num_req = vllm_config.scheduler_config.max_num_seqs
79+
max_num_pages = max_num_req * max_num_pages_per_req
7580

7681
# Preparing persistent buffers
7782
if vllm_config.compilation_config.full_cuda_graph:
7883
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
7984
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
8085
dtype=torch.int32,
8186
device=device)
82-
# We'll assume a reasonable max number of pages
83-
max_pages = max_num_reqs * 1024 # Rough estimate
84-
self.paged_kv_indices = torch.zeros(max_pages,
87+
self.paged_kv_indices = torch.zeros(max_num_pages,
8588
dtype=torch.int32,
8689
device=device)
8790
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,

0 commit comments

Comments
 (0)