1
1
# SPDX-License-Identifier: Apache-2.0
2
2
"""Attention layer with AiterFlashAttention."""
3
3
from dataclasses import dataclass
4
- from typing import TYPE_CHECKING , Any , Optional
4
+ from typing import Any , Optional
5
5
6
6
import torch
7
7
8
8
from vllm import _custom_ops as ops
9
9
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
10
10
AttentionMetadata , AttentionType ,
11
11
is_quantized_kv_cache )
12
+ from vllm .config import VllmConfig
12
13
from vllm .logger import init_logger
13
14
from vllm .platforms import current_platform
14
15
from vllm .v1 .attention .backends .flash_attn import (
15
16
make_local_attention_virtual_batches )
16
17
from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
17
18
from vllm .v1 .kv_cache_interface import AttentionSpec
18
- from vllm .v1 .worker .block_table import BlockTable
19
-
20
- if TYPE_CHECKING :
21
- from vllm .v1 .core .sched .output import SchedulerOutput
22
- from vllm .v1 .worker .gpu_input_batch import InputBatch
23
- from vllm .v1 .worker .gpu_model_runner import GPUModelRunner
24
19
25
20
if current_platform .is_rocm ():
26
21
import aiter
@@ -171,56 +166,49 @@ def flash_attn_varlen_func_fake(
171
166
172
167
class AiterFlashAttentionMetadataBuilder :
173
168
174
- def __init__ (self , runner : "GPUModelRunner" , kv_cache_spec : AttentionSpec ,
175
- block_table : BlockTable ):
176
- model_config = runner .model_config
177
-
178
- self .runner = runner
179
- self .num_heads_q = model_config .get_num_attention_heads (
180
- runner .parallel_config )
181
- self .num_heads_kv = model_config .get_num_kv_heads (
182
- runner .parallel_config )
183
- self .headdim = model_config .get_head_size ()
169
+ def __init__ (self , kv_cache_spec : AttentionSpec , vllm_config : VllmConfig ,
170
+ device : torch .device ):
171
+ self .vllm_config = vllm_config
172
+ self .model_config = vllm_config .model_config
173
+ self .parallel_config = vllm_config .parallel_config
174
+ self .cache_config = vllm_config .cache_config
175
+ self .device = device
176
+
177
+ self .num_heads_q = self .model_config .get_num_attention_heads (
178
+ self .parallel_config )
179
+ self .num_heads_kv = self .model_config .get_num_kv_heads (
180
+ self .parallel_config )
181
+ self .headdim = self .model_config .get_head_size ()
184
182
self .block_size = kv_cache_spec .block_size
185
183
self .kv_cache_spec = kv_cache_spec
186
- self .block_table = block_table
187
184
188
185
# Sliding window size to be used with the AOT scheduler will be
189
186
# populated on first build() call.
190
187
self .aot_sliding_window : Optional [tuple [int , int ]] = None
191
188
192
- def reorder_batch (self , input_batch : "InputBatch" ,
193
- scheduler_output : "SchedulerOutput" ) -> bool :
189
+ def reorder_batch (self , input_batch , scheduler_output ) -> bool :
194
190
return False
195
191
196
192
def build (self ,
197
193
common_prefix_len : int ,
198
194
common_attn_metadata : CommonAttentionMetadata ,
199
195
fast_build : bool = False ) -> 'AiterFlashAttentionMetadata' :
200
196
201
- num_reqs = common_attn_metadata .num_reqs
202
197
num_actual_tokens = common_attn_metadata .num_actual_tokens
203
198
max_query_len = common_attn_metadata .max_query_len
204
199
205
- max_seq_len = int (self . runner . seq_lens_np [: num_reqs ]. max ())
206
- total_tokens = int (self . runner . seq_lens_np [: num_reqs ]. sum ())
200
+ max_seq_len = int (common_attn_metadata . seq_lens_cpu . max (). item ())
201
+ total_tokens = int (common_attn_metadata . seq_lens_cpu . sum (). item ())
207
202
query_start_loc = common_attn_metadata .query_start_loc
203
+ query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
208
204
seq_lens = common_attn_metadata .seq_lens
209
- block_table = self .block_table
210
- block_table_tensor = block_table .get_device_tensor ()[:num_reqs ]
211
-
212
- block_table .slot_mapping [:num_actual_tokens ].copy_ (
213
- block_table .slot_mapping_cpu [:num_actual_tokens ],
214
- non_blocking = True )
215
- # Fill unused with -1. Needed for reshape_and_cache in full cuda graph
216
- # mode.
217
- block_table .slot_mapping [num_actual_tokens :].fill_ (- 1 )
218
-
219
- slot_mapping = block_table .slot_mapping [:num_actual_tokens ]
205
+ seq_lens_cpu = common_attn_metadata .seq_lens_cpu
206
+ block_table_tensor = common_attn_metadata .block_table_tensor
207
+ slot_mapping = common_attn_metadata .slot_mapping
220
208
221
209
cu_seq_lens = torch .zeros (seq_lens .shape [0 ] + 1 ,
222
210
dtype = torch .int32 ,
223
- device = "cuda" )
211
+ device = self . device )
224
212
torch .cumsum (seq_lens ,
225
213
dim = 0 ,
226
214
dtype = cu_seq_lens .dtype ,
@@ -232,21 +220,21 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
232
220
233
221
# for local attention
234
222
local_attn_metadata = None
235
- if self .runner .attention_chunk_size is not None :
223
+ if self .model_config .attention_chunk_size is not None :
236
224
seqlens_q_local_np , virt_q_cu_seqlens_np , virt_k_seqlens_np , \
237
225
virt_block_table_tensor = make_local_attention_virtual_batches (
238
- self .runner .attention_chunk_size ,
239
- self . runner . query_start_loc_np [: num_reqs + 1 ] ,
240
- self . runner . seq_lens_np [: num_reqs ] ,
226
+ self .model_config .attention_chunk_size ,
227
+ query_start_loc_cpu . numpy () ,
228
+ seq_lens_cpu . numpy () ,
241
229
block_table_tensor ,
242
230
self .block_size ,
243
231
)
244
232
local_query_start_loc = torch .from_numpy (virt_q_cu_seqlens_np ).to (
245
- self .runner . device , non_blocking = True )
233
+ self .device , non_blocking = True )
246
234
local_seqused_k = torch .from_numpy (virt_k_seqlens_np ).to (
247
- self .runner . device , non_blocking = True )
248
- local_max_query_len = int ( seqlens_q_local_np .max ())
249
- local_max_seq_len = int ( virt_k_seqlens_np .max ())
235
+ self .device , non_blocking = True )
236
+ local_max_query_len = seqlens_q_local_np .max (). item ( )
237
+ local_max_seq_len = virt_k_seqlens_np .max (). item ( )
250
238
local_scheduler_metadata = schedule (
251
239
batch_size = local_query_start_loc .shape [0 ] - 1 ,
252
240
cu_query_lens = local_query_start_loc ,
@@ -257,12 +245,11 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
257
245
258
246
local_cu_seq_lens = torch .zeros (virt_k_seqlens_np .shape [0 ] + 1 ,
259
247
dtype = torch .int32 ,
260
- device = self .runner . device )
248
+ device = self .device )
261
249
local_cu_seq_lens [1 :] = torch .cumsum (
262
- torch .from_numpy (virt_k_seqlens_np ).to (
263
- device = self .runner .device ,
264
- dtype = torch .int32 ,
265
- non_blocking = True ),
250
+ torch .from_numpy (virt_k_seqlens_np ).to (device = self .device ,
251
+ dtype = torch .int32 ,
252
+ non_blocking = True ),
266
253
dim = 0 )
267
254
268
255
0 commit comments