Skip to content

Commit 643fb3f

Browse files
cleanup
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 56ea170 commit 643fb3f

File tree

3 files changed

+56
-80
lines changed

3 files changed

+56
-80
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
import vllm.envs as envs
1515
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1616
AttentionType)
17-
from vllm.config import VllmConfig
17+
from vllm.attention.layer import Attention
18+
from vllm.config import VllmConfig, get_layers_from_vllm_config
1819
from vllm.logger import init_logger
1920
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
2021
from vllm.v1.attention.backends.utils import (
@@ -23,7 +24,6 @@
2324
from vllm.v1.kv_cache_interface import AttentionSpec
2425

2526
if TYPE_CHECKING:
26-
from vllm.config import VllmConfig
2727
from vllm.v1.core.sched.output import SchedulerOutput
2828
from vllm.v1.worker.gpu_input_batch import InputBatch
2929

@@ -97,27 +97,21 @@ def get_per_layer_parameters(
9797
Scan all attention layers and determine some hyperparameters
9898
to use during `plan`.
9999
"""
100-
model_config = vllm_config.model_config
100+
layers = get_layers_from_vllm_config(vllm_config, Attention)
101+
per_layer_params: dict[str, PerLayerParameters] = {}
101102

102-
# This is a workaround for the fact that the attention backend
103-
# in this standalone test does not have access to the full model,
104-
# so we mock the layer access.
105-
if not hasattr(model_config, 'get_num_layers'):
106-
raise RuntimeError(
107-
"The model config does not have a get_num_layers method. "
108-
"This is required for the FlashInfer backend.")
103+
for key, layer in layers.items():
104+
impl = layer.impl
105+
assert isinstance(impl, FlashInferImpl)
109106

110-
per_layer_params: dict[str, PerLayerParameters] = {}
111-
for i in range(model_config.get_num_layers()):
112-
sliding_window = model_config.get_sliding_window_for_layer(i)
113-
logits_soft_cap = model_config.get_logits_soft_cap_for_layer(i)
114-
sm_scale = model_config.get_sm_scale_for_layer(i)
115-
116-
per_layer_params[str(i)] = PerLayerParameters(
117-
sm_scale=sm_scale,
118-
logits_soft_cap=logits_soft_cap,
119-
window_left=sliding_window if sliding_window is not None else -1,
120-
)
107+
# Infer hyperparameters from the attention layer
108+
window_size = impl.sliding_window
109+
window_left = window_size[0] if window_size is not None else -1
110+
logits_soft_cap = impl.logits_soft_cap
111+
sm_scale = impl.scale
112+
113+
per_layer_params[key] = PerLayerParameters(window_left,
114+
logits_soft_cap, sm_scale)
121115

122116
return per_layer_params
123117

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,10 @@
1212
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
1313
CommonAttentionMetadata)
1414
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
15-
from vllm.v1.worker.block_table import BlockTable
1615

1716
if TYPE_CHECKING:
1817
from vllm.v1.core.sched.output import SchedulerOutput
1918
from vllm.v1.worker.gpu_input_batch import InputBatch
20-
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
2119

2220

2321
def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
@@ -58,13 +56,11 @@ class Mamba2AttentionMetadata:
5856
class Mamba2AttentionMetadataBuilder(
5957
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
6058

61-
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
62-
block_table: BlockTable):
59+
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
60+
device: torch.device):
6361
assert isinstance(kv_cache_spec, MambaSpec)
64-
self.runner = runner
6562
self.kv_cache_spec = kv_cache_spec
66-
self.block_table = block_table
67-
self.chunk_size = get_mamba2_chunk_size(runner.vllm_config)
63+
self.chunk_size = get_mamba2_chunk_size(vllm_config)
6864

6965
def reorder_batch(self, input_batch: "InputBatch",
7066
scheduler_output: "SchedulerOutput") -> bool:
@@ -143,15 +139,14 @@ def build(self,
143139
has_initial_states = None
144140
prep_initial_states = False
145141

146-
state_indices_tensor = self.block_table.block_table[:num_reqs, 0]
142+
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
147143

148144
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
149145
if self._num_prefills > 0:
150146
#[batch,]
151147
has_initial_states_cpu = (
152-
self.runner.input_batch.
153-
num_computed_tokens_cpu_tensor[num_reqs -
154-
self._num_prefills:num_reqs]
148+
common_attn_metadata.
149+
num_computed_tokens_cpu[num_reqs - self._num_prefills:num_reqs]
155150
> 0)
156151
prep_initial_states = torch.any(has_initial_states_cpu).item()
157152
has_initial_states = has_initial_states_cpu.to(

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 35 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,21 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""Attention layer with AiterFlashAttention."""
33
from dataclasses import dataclass
4-
from typing import TYPE_CHECKING, Any, Optional
4+
from typing import Any, Optional
55

66
import torch
77

88
from vllm import _custom_ops as ops
99
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1010
AttentionMetadata, AttentionType,
1111
is_quantized_kv_cache)
12+
from vllm.config import VllmConfig
1213
from vllm.logger import init_logger
1314
from vllm.platforms import current_platform
1415
from vllm.v1.attention.backends.flash_attn import (
1516
make_local_attention_virtual_batches)
1617
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
1718
from vllm.v1.kv_cache_interface import AttentionSpec
18-
from vllm.v1.worker.block_table import BlockTable
19-
20-
if TYPE_CHECKING:
21-
from vllm.v1.core.sched.output import SchedulerOutput
22-
from vllm.v1.worker.gpu_input_batch import InputBatch
23-
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
2419

2520
if current_platform.is_rocm():
2621
import aiter
@@ -171,56 +166,49 @@ def flash_attn_varlen_func_fake(
171166

172167
class AiterFlashAttentionMetadataBuilder:
173168

174-
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
175-
block_table: BlockTable):
176-
model_config = runner.model_config
177-
178-
self.runner = runner
179-
self.num_heads_q = model_config.get_num_attention_heads(
180-
runner.parallel_config)
181-
self.num_heads_kv = model_config.get_num_kv_heads(
182-
runner.parallel_config)
183-
self.headdim = model_config.get_head_size()
169+
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
170+
device: torch.device):
171+
self.vllm_config = vllm_config
172+
self.model_config = vllm_config.model_config
173+
self.parallel_config = vllm_config.parallel_config
174+
self.cache_config = vllm_config.cache_config
175+
self.device = device
176+
177+
self.num_heads_q = self.model_config.get_num_attention_heads(
178+
self.parallel_config)
179+
self.num_heads_kv = self.model_config.get_num_kv_heads(
180+
self.parallel_config)
181+
self.headdim = self.model_config.get_head_size()
184182
self.block_size = kv_cache_spec.block_size
185183
self.kv_cache_spec = kv_cache_spec
186-
self.block_table = block_table
187184

188185
# Sliding window size to be used with the AOT scheduler will be
189186
# populated on first build() call.
190187
self.aot_sliding_window: Optional[tuple[int, int]] = None
191188

192-
def reorder_batch(self, input_batch: "InputBatch",
193-
scheduler_output: "SchedulerOutput") -> bool:
189+
def reorder_batch(self, input_batch, scheduler_output) -> bool:
194190
return False
195191

196192
def build(self,
197193
common_prefix_len: int,
198194
common_attn_metadata: CommonAttentionMetadata,
199195
fast_build: bool = False) -> 'AiterFlashAttentionMetadata':
200196

201-
num_reqs = common_attn_metadata.num_reqs
202197
num_actual_tokens = common_attn_metadata.num_actual_tokens
203198
max_query_len = common_attn_metadata.max_query_len
204199

205-
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
206-
total_tokens = int(self.runner.seq_lens_np[:num_reqs].sum())
200+
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max().item())
201+
total_tokens = int(common_attn_metadata.seq_lens_cpu.sum().item())
207202
query_start_loc = common_attn_metadata.query_start_loc
203+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
208204
seq_lens = common_attn_metadata.seq_lens
209-
block_table = self.block_table
210-
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
211-
212-
block_table.slot_mapping[:num_actual_tokens].copy_(
213-
block_table.slot_mapping_cpu[:num_actual_tokens],
214-
non_blocking=True)
215-
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
216-
# mode.
217-
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
218-
219-
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
205+
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
206+
block_table_tensor = common_attn_metadata.block_table_tensor
207+
slot_mapping = common_attn_metadata.slot_mapping
220208

221209
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
222210
dtype=torch.int32,
223-
device="cuda")
211+
device=self.device)
224212
torch.cumsum(seq_lens,
225213
dim=0,
226214
dtype=cu_seq_lens.dtype,
@@ -232,21 +220,21 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
232220

233221
# for local attention
234222
local_attn_metadata = None
235-
if self.runner.attention_chunk_size is not None:
223+
if self.model_config.attention_chunk_size is not None:
236224
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
237225
virt_block_table_tensor = make_local_attention_virtual_batches(
238-
self.runner.attention_chunk_size,
239-
self.runner.query_start_loc_np[:num_reqs + 1],
240-
self.runner.seq_lens_np[:num_reqs],
226+
self.model_config.attention_chunk_size,
227+
query_start_loc_cpu.numpy(),
228+
seq_lens_cpu.numpy(),
241229
block_table_tensor,
242230
self.block_size,
243231
)
244232
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
245-
self.runner.device, non_blocking=True)
233+
self.device, non_blocking=True)
246234
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
247-
self.runner.device, non_blocking=True)
248-
local_max_query_len = int(seqlens_q_local_np.max())
249-
local_max_seq_len = int(virt_k_seqlens_np.max())
235+
self.device, non_blocking=True)
236+
local_max_query_len = seqlens_q_local_np.max().item()
237+
local_max_seq_len = virt_k_seqlens_np.max().item()
250238
local_scheduler_metadata = schedule(
251239
batch_size=local_query_start_loc.shape[0] - 1,
252240
cu_query_lens=local_query_start_loc,
@@ -257,12 +245,11 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
257245

258246
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
259247
dtype=torch.int32,
260-
device=self.runner.device)
248+
device=self.device)
261249
local_cu_seq_lens[1:] = torch.cumsum(
262-
torch.from_numpy(virt_k_seqlens_np).to(
263-
device=self.runner.device,
264-
dtype=torch.int32,
265-
non_blocking=True),
250+
torch.from_numpy(virt_k_seqlens_np).to(device=self.device,
251+
dtype=torch.int32,
252+
non_blocking=True),
266253
dim=0)
267254

268255

0 commit comments

Comments
 (0)