13
13
from vllm .config import VllmConfig
14
14
from vllm .logger import init_logger
15
15
from vllm .platforms import current_platform
16
- from vllm .v1 .attention .backends .flash_attn import (
17
- make_local_attention_virtual_batches )
18
16
from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
19
17
from vllm .v1 .kv_cache_interface import AttentionSpec
20
18
@@ -201,9 +199,7 @@ def build(self,
201
199
max_seq_len = int (common_attn_metadata .seq_lens_cpu .max ())
202
200
total_tokens = int (common_attn_metadata .seq_lens_cpu .sum ())
203
201
query_start_loc = common_attn_metadata .query_start_loc
204
- query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
205
202
seq_lens = common_attn_metadata .seq_lens
206
- seq_lens_cpu = common_attn_metadata .seq_lens_cpu
207
203
block_table_tensor = common_attn_metadata .block_table_tensor
208
204
slot_mapping = common_attn_metadata .slot_mapping
209
205
@@ -215,56 +211,6 @@ def build(self,
215
211
dtype = cu_seq_lens .dtype ,
216
212
out = cu_seq_lens [1 :])
217
213
218
- def schedule (batch_size , cu_query_lens , max_query_len , seqlens ,
219
- max_seq_len , causal ):
220
- return None
221
-
222
- # for local attention
223
- local_attn_metadata = None
224
- if self .model_config .attention_chunk_size is not None :
225
- seqlens_q_local_np , virt_q_cu_seqlens_np , virt_k_seqlens_np , \
226
- virt_block_table_tensor = make_local_attention_virtual_batches (
227
- self .model_config .attention_chunk_size ,
228
- query_start_loc_cpu .numpy (),
229
- seq_lens_cpu .numpy (),
230
- block_table_tensor ,
231
- self .block_size ,
232
- )
233
- local_query_start_loc = torch .from_numpy (virt_q_cu_seqlens_np ).to (
234
- self .device , non_blocking = True )
235
- local_seqused_k = torch .from_numpy (virt_k_seqlens_np ).to (
236
- self .device , non_blocking = True )
237
- local_max_query_len = seqlens_q_local_np .max ().item ()
238
- local_max_seq_len = virt_k_seqlens_np .max ().item ()
239
- local_scheduler_metadata = schedule (
240
- batch_size = local_query_start_loc .shape [0 ] - 1 ,
241
- cu_query_lens = local_query_start_loc ,
242
- max_query_len = local_max_query_len ,
243
- seqlens = local_seqused_k ,
244
- max_seq_len = local_max_seq_len ,
245
- causal = True )
246
-
247
- local_cu_seq_lens = torch .zeros (virt_k_seqlens_np .shape [0 ] + 1 ,
248
- dtype = torch .int32 ,
249
- device = self .device )
250
- local_cu_seq_lens [1 :] = torch .cumsum (
251
- torch .from_numpy (virt_k_seqlens_np ).to (device = self .device ,
252
- dtype = torch .int32 ,
253
- non_blocking = True ),
254
- dim = 0 )
255
-
256
-
257
- local_attn_metadata = \
258
- AiterFlashAttentionMetadata .LocalAttentionMetadata (
259
- local_query_start_loc = local_query_start_loc ,
260
- local_seqused_k = local_seqused_k ,
261
- local_block_table = virt_block_table_tensor ,
262
- local_max_query_len = local_max_query_len ,
263
- local_max_seq_len = local_max_seq_len ,
264
- local_cu_seq_lens = local_cu_seq_lens ,
265
- local_scheduler_metadata = local_scheduler_metadata ,
266
- )
267
-
268
214
use_cascade = common_prefix_len > 0
269
215
270
216
cu_prefix_query_lens = None
@@ -286,7 +232,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
286
232
cu_prefix_query_lens = cu_prefix_query_lens ,
287
233
prefix_kv_lens = prefix_kv_lens ,
288
234
suffix_kv_lens = suffix_kv_lens ,
289
- local_attn_metadata = local_attn_metadata ,
290
235
)
291
236
return attn_metadata
292
237
@@ -377,19 +322,6 @@ class AiterFlashAttentionMetadata:
377
322
prefix_kv_lens : Optional [torch .Tensor ]
378
323
suffix_kv_lens : Optional [torch .Tensor ]
379
324
380
- # for local attention
381
- @dataclass
382
- class LocalAttentionMetadata :
383
- local_query_start_loc : torch .Tensor
384
- local_seqused_k : torch .Tensor
385
- local_block_table : torch .Tensor
386
- local_max_query_len : int
387
- local_max_seq_len : int
388
- local_cu_seq_lens : torch .Tensor
389
- local_scheduler_metadata : Optional [torch .Tensor ]
390
-
391
- local_attn_metadata : Optional [LocalAttentionMetadata ] = None
392
-
393
325
394
326
class AiterFlashAttentionImpl (AttentionImpl ):
395
327
@@ -521,25 +453,12 @@ def forward(
521
453
layer ._q_scale )
522
454
query = query .reshape ((num_tokens , num_heads , head_size ))
523
455
524
- # Compute attention and update output up to `num_actual_tokens`.
525
- use_local_attn = \
526
- (self .use_irope and attn_metadata .local_attn_metadata is not None )
527
-
528
- if not attn_metadata .use_cascade or use_local_attn :
529
- if use_local_attn :
530
- assert attn_metadata .local_attn_metadata is not None
531
- local_metadata = attn_metadata .local_attn_metadata
532
- cu_seqlens_q = local_metadata .local_query_start_loc
533
- seqused_k = local_metadata .local_seqused_k
534
- max_seqlen_q = local_metadata .local_max_query_len
535
- max_seqlen_k = local_metadata .local_max_seq_len
536
- block_table = local_metadata .local_block_table
537
- else :
538
- cu_seqlens_q = attn_metadata .query_start_loc
539
- seqused_k = attn_metadata .seq_lens
540
- max_seqlen_q = attn_metadata .max_query_len
541
- max_seqlen_k = attn_metadata .max_seq_len
542
- block_table = attn_metadata .block_table
456
+ if not attn_metadata .use_cascade :
457
+ cu_seqlens_q = attn_metadata .query_start_loc
458
+ seqused_k = attn_metadata .seq_lens
459
+ max_seqlen_q = attn_metadata .max_query_len
460
+ max_seqlen_k = attn_metadata .max_seq_len
461
+ block_table = attn_metadata .block_table
543
462
544
463
if max_seqlen_q > 1 :
545
464
cu_seq_lens = attn_metadata .cu_seq_lens
@@ -557,9 +476,7 @@ def forward(
557
476
alibi_slopes = self .alibi_slopes ,
558
477
window_size = self .sliding_window ,
559
478
block_table = block_table ,
560
- cu_seqlens_k = (cu_seq_lens if not use_local_attn else
561
- local_metadata .local_cu_seq_lens ),
562
- )
479
+ cu_seqlens_k = cu_seq_lens )
563
480
564
481
_ , num_heads , head_size = query .shape
565
482
_PARTITION_SIZE_ROCM = 256
0 commit comments