2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
"""Attention layer with AiterFlashAttention."""
4
4
from dataclasses import dataclass
5
- from typing import TYPE_CHECKING , Any , Optional
5
+ from typing import Any , Optional
6
6
7
7
import torch
8
8
9
9
from vllm import _custom_ops as ops
10
10
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
11
11
AttentionMetadata , AttentionType ,
12
12
is_quantized_kv_cache )
13
+ from vllm .config import VllmConfig
13
14
from vllm .logger import init_logger
14
15
from vllm .platforms import current_platform
15
16
from vllm .v1 .attention .backends .flash_attn import (
16
17
make_local_attention_virtual_batches )
17
18
from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
18
19
from vllm .v1 .kv_cache_interface import AttentionSpec
19
- from vllm .v1 .worker .block_table import BlockTable
20
-
21
- if TYPE_CHECKING :
22
- from vllm .v1 .core .sched .output import SchedulerOutput
23
- from vllm .v1 .worker .gpu_input_batch import InputBatch
24
- from vllm .v1 .worker .gpu_model_runner import GPUModelRunner
25
20
26
21
if current_platform .is_rocm ():
27
22
import aiter
@@ -172,56 +167,49 @@ def flash_attn_varlen_func_fake(
172
167
173
168
class AiterFlashAttentionMetadataBuilder :
174
169
175
- def __init__ (self , runner : "GPUModelRunner" , kv_cache_spec : AttentionSpec ,
176
- block_table : BlockTable ):
177
- model_config = runner .model_config
178
-
179
- self .runner = runner
180
- self .num_heads_q = model_config .get_num_attention_heads (
181
- runner .parallel_config )
182
- self .num_heads_kv = model_config .get_num_kv_heads (
183
- runner .parallel_config )
184
- self .headdim = model_config .get_head_size ()
170
+ def __init__ (self , kv_cache_spec : AttentionSpec , vllm_config : VllmConfig ,
171
+ device : torch .device ):
172
+ self .vllm_config = vllm_config
173
+ self .model_config = vllm_config .model_config
174
+ self .parallel_config = vllm_config .parallel_config
175
+ self .cache_config = vllm_config .cache_config
176
+ self .device = device
177
+
178
+ self .num_heads_q = self .model_config .get_num_attention_heads (
179
+ self .parallel_config )
180
+ self .num_heads_kv = self .model_config .get_num_kv_heads (
181
+ self .parallel_config )
182
+ self .headdim = self .model_config .get_head_size ()
185
183
self .block_size = kv_cache_spec .block_size
186
184
self .kv_cache_spec = kv_cache_spec
187
- self .block_table = block_table
188
185
189
186
# Sliding window size to be used with the AOT scheduler will be
190
187
# populated on first build() call.
191
188
self .aot_sliding_window : Optional [tuple [int , int ]] = None
192
189
193
- def reorder_batch (self , input_batch : "InputBatch" ,
194
- scheduler_output : "SchedulerOutput" ) -> bool :
190
+ def reorder_batch (self , input_batch , scheduler_output ) -> bool :
195
191
return False
196
192
197
193
def build (self ,
198
194
common_prefix_len : int ,
199
195
common_attn_metadata : CommonAttentionMetadata ,
200
196
fast_build : bool = False ) -> 'AiterFlashAttentionMetadata' :
201
197
202
- num_reqs = common_attn_metadata .num_reqs
203
198
num_actual_tokens = common_attn_metadata .num_actual_tokens
204
199
max_query_len = common_attn_metadata .max_query_len
205
200
206
- max_seq_len = int (self . runner . seq_lens_np [: num_reqs ]. max ())
207
- total_tokens = int (self . runner . seq_lens_np [: num_reqs ]. sum ())
201
+ max_seq_len = int (common_attn_metadata . seq_lens_cpu . max (). item ())
202
+ total_tokens = int (common_attn_metadata . seq_lens_cpu . sum (). item ())
208
203
query_start_loc = common_attn_metadata .query_start_loc
204
+ query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
209
205
seq_lens = common_attn_metadata .seq_lens
210
- block_table = self .block_table
211
- block_table_tensor = block_table .get_device_tensor ()[:num_reqs ]
212
-
213
- block_table .slot_mapping [:num_actual_tokens ].copy_ (
214
- block_table .slot_mapping_cpu [:num_actual_tokens ],
215
- non_blocking = True )
216
- # Fill unused with -1. Needed for reshape_and_cache in full cuda graph
217
- # mode.
218
- block_table .slot_mapping [num_actual_tokens :].fill_ (- 1 )
219
-
220
- slot_mapping = block_table .slot_mapping [:num_actual_tokens ]
206
+ seq_lens_cpu = common_attn_metadata .seq_lens_cpu
207
+ block_table_tensor = common_attn_metadata .block_table_tensor
208
+ slot_mapping = common_attn_metadata .slot_mapping
221
209
222
210
cu_seq_lens = torch .zeros (seq_lens .shape [0 ] + 1 ,
223
211
dtype = torch .int32 ,
224
- device = "cuda" )
212
+ device = self . device )
225
213
torch .cumsum (seq_lens ,
226
214
dim = 0 ,
227
215
dtype = cu_seq_lens .dtype ,
@@ -233,21 +221,21 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
233
221
234
222
# for local attention
235
223
local_attn_metadata = None
236
- if self .runner .attention_chunk_size is not None :
224
+ if self .model_config .attention_chunk_size is not None :
237
225
seqlens_q_local_np , virt_q_cu_seqlens_np , virt_k_seqlens_np , \
238
226
virt_block_table_tensor = make_local_attention_virtual_batches (
239
- self .runner .attention_chunk_size ,
240
- self . runner . query_start_loc_np [: num_reqs + 1 ] ,
241
- self . runner . seq_lens_np [: num_reqs ] ,
227
+ self .model_config .attention_chunk_size ,
228
+ query_start_loc_cpu . numpy () ,
229
+ seq_lens_cpu . numpy () ,
242
230
block_table_tensor ,
243
231
self .block_size ,
244
232
)
245
233
local_query_start_loc = torch .from_numpy (virt_q_cu_seqlens_np ).to (
246
- self .runner . device , non_blocking = True )
234
+ self .device , non_blocking = True )
247
235
local_seqused_k = torch .from_numpy (virt_k_seqlens_np ).to (
248
- self .runner . device , non_blocking = True )
249
- local_max_query_len = int ( seqlens_q_local_np .max ())
250
- local_max_seq_len = int ( virt_k_seqlens_np .max ())
236
+ self .device , non_blocking = True )
237
+ local_max_query_len = seqlens_q_local_np .max (). item ( )
238
+ local_max_seq_len = virt_k_seqlens_np .max (). item ( )
251
239
local_scheduler_metadata = schedule (
252
240
batch_size = local_query_start_loc .shape [0 ] - 1 ,
253
241
cu_query_lens = local_query_start_loc ,
@@ -258,12 +246,11 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
258
246
259
247
local_cu_seq_lens = torch .zeros (virt_k_seqlens_np .shape [0 ] + 1 ,
260
248
dtype = torch .int32 ,
261
- device = self .runner . device )
249
+ device = self .device )
262
250
local_cu_seq_lens [1 :] = torch .cumsum (
263
- torch .from_numpy (virt_k_seqlens_np ).to (
264
- device = self .runner .device ,
265
- dtype = torch .int32 ,
266
- non_blocking = True ),
251
+ torch .from_numpy (virt_k_seqlens_np ).to (device = self .device ,
252
+ dtype = torch .int32 ,
253
+ non_blocking = True ),
267
254
dim = 0 )
268
255
269
256
0 commit comments