Skip to content

Commit 58e4120

Browse files
cleanup
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 0e357ee commit 58e4120

File tree

3 files changed

+42
-60
lines changed

3 files changed

+42
-60
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import vllm.envs as envs
1515
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1616
AttentionType)
17+
from vllm.config import VllmConfig
1718
from vllm.logger import init_logger
1819
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
1920
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
@@ -27,7 +28,6 @@
2728
from vllm.v1.kv_cache_interface import AttentionSpec
2829

2930
if TYPE_CHECKING:
30-
from vllm.config import VllmConfig
3131
from vllm.v1.core.sched.output import SchedulerOutput
3232
from vllm.v1.worker.gpu_input_batch import InputBatch
3333

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@
1111
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
1212
CommonAttentionMetadata)
1313
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
14-
from vllm.v1.worker.block_table import BlockTable
1514

1615
if TYPE_CHECKING:
1716
from vllm.v1.core.sched.output import SchedulerOutput
1817
from vllm.v1.worker.gpu_input_batch import InputBatch
19-
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
2018

2119

2220
def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
@@ -97,13 +95,11 @@ class Mamba2AttentionMetadata:
9795
class Mamba2AttentionMetadataBuilder(
9896
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
9997

100-
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
101-
block_table: BlockTable):
98+
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
99+
device: torch.device):
102100
assert isinstance(kv_cache_spec, MambaSpec)
103-
self.runner = runner
104101
self.kv_cache_spec = kv_cache_spec
105-
self.block_table = block_table
106-
self.chunk_size = get_mamba2_chunk_size(runner.vllm_config)
102+
self.chunk_size = get_mamba2_chunk_size(vllm_config)
107103

108104
def reorder_batch(self, input_batch: "InputBatch",
109105
scheduler_output: "SchedulerOutput") -> bool:
@@ -182,15 +178,14 @@ def build(self,
182178
has_initial_states = None
183179
prep_initial_states = False
184180

185-
state_indices_tensor = self.block_table.block_table[:num_reqs, 0]
181+
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
186182

187183
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
188184
if self._num_prefills > 0:
189185
#[batch,]
190186
has_initial_states_cpu = (
191-
self.runner.input_batch.
192-
num_computed_tokens_cpu_tensor[num_reqs -
193-
self._num_prefills:num_reqs]
187+
common_attn_metadata.
188+
num_computed_tokens_cpu[num_reqs - self._num_prefills:num_reqs]
194189
> 0)
195190
prep_initial_states = torch.any(has_initial_states_cpu).item()
196191
has_initial_states = has_initial_states_cpu.to(

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 35 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,21 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Attention layer with AiterFlashAttention."""
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Any, Optional
5+
from typing import Any, Optional
66

77
import torch
88

99
from vllm import _custom_ops as ops
1010
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1111
AttentionMetadata, AttentionType,
1212
is_quantized_kv_cache)
13+
from vllm.config import VllmConfig
1314
from vllm.logger import init_logger
1415
from vllm.platforms import current_platform
1516
from vllm.v1.attention.backends.flash_attn import (
1617
make_local_attention_virtual_batches)
1718
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
1819
from vllm.v1.kv_cache_interface import AttentionSpec
19-
from vllm.v1.worker.block_table import BlockTable
20-
21-
if TYPE_CHECKING:
22-
from vllm.v1.core.sched.output import SchedulerOutput
23-
from vllm.v1.worker.gpu_input_batch import InputBatch
24-
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
2520

2621
if current_platform.is_rocm():
2722
import aiter
@@ -172,56 +167,49 @@ def flash_attn_varlen_func_fake(
172167

173168
class AiterFlashAttentionMetadataBuilder:
174169

175-
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
176-
block_table: BlockTable):
177-
model_config = runner.model_config
178-
179-
self.runner = runner
180-
self.num_heads_q = model_config.get_num_attention_heads(
181-
runner.parallel_config)
182-
self.num_heads_kv = model_config.get_num_kv_heads(
183-
runner.parallel_config)
184-
self.headdim = model_config.get_head_size()
170+
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
171+
device: torch.device):
172+
self.vllm_config = vllm_config
173+
self.model_config = vllm_config.model_config
174+
self.parallel_config = vllm_config.parallel_config
175+
self.cache_config = vllm_config.cache_config
176+
self.device = device
177+
178+
self.num_heads_q = self.model_config.get_num_attention_heads(
179+
self.parallel_config)
180+
self.num_heads_kv = self.model_config.get_num_kv_heads(
181+
self.parallel_config)
182+
self.headdim = self.model_config.get_head_size()
185183
self.block_size = kv_cache_spec.block_size
186184
self.kv_cache_spec = kv_cache_spec
187-
self.block_table = block_table
188185

189186
# Sliding window size to be used with the AOT scheduler will be
190187
# populated on first build() call.
191188
self.aot_sliding_window: Optional[tuple[int, int]] = None
192189

193-
def reorder_batch(self, input_batch: "InputBatch",
194-
scheduler_output: "SchedulerOutput") -> bool:
190+
def reorder_batch(self, input_batch, scheduler_output) -> bool:
195191
return False
196192

197193
def build(self,
198194
common_prefix_len: int,
199195
common_attn_metadata: CommonAttentionMetadata,
200196
fast_build: bool = False) -> 'AiterFlashAttentionMetadata':
201197

202-
num_reqs = common_attn_metadata.num_reqs
203198
num_actual_tokens = common_attn_metadata.num_actual_tokens
204199
max_query_len = common_attn_metadata.max_query_len
205200

206-
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
207-
total_tokens = int(self.runner.seq_lens_np[:num_reqs].sum())
201+
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max().item())
202+
total_tokens = int(common_attn_metadata.seq_lens_cpu.sum().item())
208203
query_start_loc = common_attn_metadata.query_start_loc
204+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
209205
seq_lens = common_attn_metadata.seq_lens
210-
block_table = self.block_table
211-
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
212-
213-
block_table.slot_mapping[:num_actual_tokens].copy_(
214-
block_table.slot_mapping_cpu[:num_actual_tokens],
215-
non_blocking=True)
216-
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
217-
# mode.
218-
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
219-
220-
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
206+
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
207+
block_table_tensor = common_attn_metadata.block_table_tensor
208+
slot_mapping = common_attn_metadata.slot_mapping
221209

222210
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
223211
dtype=torch.int32,
224-
device="cuda")
212+
device=self.device)
225213
torch.cumsum(seq_lens,
226214
dim=0,
227215
dtype=cu_seq_lens.dtype,
@@ -233,21 +221,21 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
233221

234222
# for local attention
235223
local_attn_metadata = None
236-
if self.runner.attention_chunk_size is not None:
224+
if self.model_config.attention_chunk_size is not None:
237225
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
238226
virt_block_table_tensor = make_local_attention_virtual_batches(
239-
self.runner.attention_chunk_size,
240-
self.runner.query_start_loc_np[:num_reqs + 1],
241-
self.runner.seq_lens_np[:num_reqs],
227+
self.model_config.attention_chunk_size,
228+
query_start_loc_cpu.numpy(),
229+
seq_lens_cpu.numpy(),
242230
block_table_tensor,
243231
self.block_size,
244232
)
245233
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
246-
self.runner.device, non_blocking=True)
234+
self.device, non_blocking=True)
247235
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
248-
self.runner.device, non_blocking=True)
249-
local_max_query_len = int(seqlens_q_local_np.max())
250-
local_max_seq_len = int(virt_k_seqlens_np.max())
236+
self.device, non_blocking=True)
237+
local_max_query_len = seqlens_q_local_np.max().item()
238+
local_max_seq_len = virt_k_seqlens_np.max().item()
251239
local_scheduler_metadata = schedule(
252240
batch_size=local_query_start_loc.shape[0] - 1,
253241
cu_query_lens=local_query_start_loc,
@@ -258,12 +246,11 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
258246

259247
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
260248
dtype=torch.int32,
261-
device=self.runner.device)
249+
device=self.device)
262250
local_cu_seq_lens[1:] = torch.cumsum(
263-
torch.from_numpy(virt_k_seqlens_np).to(
264-
device=self.runner.device,
265-
dtype=torch.int32,
266-
non_blocking=True),
251+
torch.from_numpy(virt_k_seqlens_np).to(device=self.device,
252+
dtype=torch.int32,
253+
non_blocking=True),
267254
dim=0)
268255

269256

0 commit comments

Comments
 (0)