Skip to content

Commit ea1b17d

Browse files
LucasWilkinsonWorldExplored
authored andcommitted
[Attention] Make local attention backend agnostic (vllm-project#21093)
1 parent 85cac3a commit ea1b17d

File tree

8 files changed

+94
-242
lines changed

8 files changed

+94
-242
lines changed

vllm/v1/attention/backends/flash_attn.py

Lines changed: 10 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
from vllm.config import VllmConfig, get_layers_from_vllm_config
2626
from vllm.logger import init_logger
2727
from vllm.utils import cdiv
28-
from vllm.v1.attention.backends.utils import (
29-
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
30-
make_local_attention_virtual_batches)
28+
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
29+
CommonAttentionMetadata,
30+
get_kv_cache_layout)
3131
from vllm.v1.kv_cache_interface import AttentionSpec
3232

3333
logger = init_logger(__name__)
@@ -130,18 +130,6 @@ class FlashAttentionMetadata:
130130
prefix_scheduler_metadata: Optional[torch.Tensor] = None
131131
max_num_splits: int = 0
132132

133-
# for local attention
134-
@dataclass
135-
class LocalAttentionMetadata:
136-
local_query_start_loc: torch.Tensor
137-
local_seqused_k: torch.Tensor
138-
local_block_table: torch.Tensor
139-
local_max_query_len: int
140-
local_max_seq_len: int
141-
local_scheduler_metadata: Optional[torch.Tensor]
142-
143-
local_attn_metadata: Optional[LocalAttentionMetadata] = None
144-
145133

146134
def _get_sliding_window_configs(
147135
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
@@ -221,7 +209,6 @@ def build(self,
221209
max_query_len = common_attn_metadata.max_query_len
222210
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
223211
query_start_loc = common_attn_metadata.query_start_loc
224-
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
225212
seq_lens = common_attn_metadata.seq_lens
226213
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
227214
block_table_tensor = common_attn_metadata.block_table_tensor
@@ -266,40 +253,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
266253
)
267254
return None
268255

269-
# for local attention
270-
local_attn_metadata = None
271-
if self.model_config.attention_chunk_size is not None:
272-
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
273-
virt_block_table_tensor = make_local_attention_virtual_batches(
274-
self.model_config.attention_chunk_size,
275-
query_start_loc_cpu.numpy(),
276-
seq_lens_cpu.numpy(),
277-
block_table_tensor,
278-
self.block_size,
279-
)
280-
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
281-
self.device, non_blocking=True)
282-
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
283-
self.device, non_blocking=True)
284-
local_max_query_len = seqlens_q_local_np.max()
285-
local_max_seq_len = virt_k_seqlens_np.max()
286-
local_scheduler_metadata = schedule(
287-
batch_size=local_query_start_loc.shape[0] - 1,
288-
cu_query_lens=local_query_start_loc,
289-
max_query_len=local_max_query_len,
290-
seqlens=local_seqused_k,
291-
max_seq_len=local_max_seq_len,
292-
causal=True)
293-
294-
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
295-
local_query_start_loc=local_query_start_loc,
296-
local_seqused_k=local_seqused_k,
297-
local_block_table=virt_block_table_tensor,
298-
local_max_query_len=local_max_query_len,
299-
local_max_seq_len=local_max_seq_len,
300-
local_scheduler_metadata=local_scheduler_metadata,
301-
)
302-
303256
use_cascade = common_prefix_len > 0
304257

305258
if use_cascade:
@@ -371,7 +324,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
371324
cu_prefix_query_lens=cu_prefix_query_lens,
372325
prefix_kv_lens=prefix_kv_lens,
373326
suffix_kv_lens=suffix_kv_lens,
374-
local_attn_metadata=local_attn_metadata,
375327
prefix_scheduler_metadata=prefix_scheduler_metadata,
376328
max_num_splits=max_num_splits,
377329
)
@@ -517,27 +469,13 @@ def forward(
517469
layer._q_scale)
518470
query = query.reshape((num_tokens, num_heads, head_size))
519471

520-
# Compute attention and update output up to `num_actual_tokens`.
521-
use_local_attn = \
522-
(self.use_irope and attn_metadata.local_attn_metadata is not None)
523-
524-
if not attn_metadata.use_cascade or use_local_attn:
525-
if use_local_attn:
526-
assert attn_metadata.local_attn_metadata is not None
527-
local_metadata = attn_metadata.local_attn_metadata
528-
cu_seqlens_q = local_metadata.local_query_start_loc
529-
seqused_k = local_metadata.local_seqused_k
530-
max_seqlen_q = local_metadata.local_max_query_len
531-
max_seqlen_k = local_metadata.local_max_seq_len
532-
block_table = local_metadata.local_block_table
533-
scheduler_metadata = local_metadata.local_scheduler_metadata
534-
else:
535-
cu_seqlens_q = attn_metadata.query_start_loc
536-
seqused_k = attn_metadata.seq_lens
537-
max_seqlen_q = attn_metadata.max_query_len
538-
max_seqlen_k = attn_metadata.max_seq_len
539-
block_table = attn_metadata.block_table
540-
scheduler_metadata = attn_metadata.scheduler_metadata
472+
if not attn_metadata.use_cascade:
473+
cu_seqlens_q = attn_metadata.query_start_loc
474+
seqused_k = attn_metadata.seq_lens
475+
max_seqlen_q = attn_metadata.max_query_len
476+
max_seqlen_k = attn_metadata.max_seq_len
477+
block_table = attn_metadata.block_table
478+
scheduler_metadata = attn_metadata.scheduler_metadata
541479

542480
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
543481

@@ -565,8 +503,6 @@ def forward(
565503
)
566504
return output
567505

568-
assert not use_local_attn, (
569-
"Cascade attention does not support local attention.")
570506
# Cascade attention (rare case).
571507
cascade_attention(
572508
output[:num_actual_tokens],

vllm/v1/attention/backends/flashinfer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -496,10 +496,6 @@ def __init__(
496496
kv_sharing_target_layer_name: Optional[int] = None,
497497
use_irope: bool = False,
498498
) -> None:
499-
if use_irope:
500-
logger.warning_once(
501-
"Using irope in FlashInfer is not supported yet, it will fall"
502-
" back to global attention for long context.")
503499
self.num_heads = num_heads
504500
self.head_size = head_size
505501
self.scale = float(scale)
@@ -514,6 +510,7 @@ def __init__(
514510
self.kv_cache_dtype = kv_cache_dtype
515511
self.logits_soft_cap = logits_soft_cap
516512
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
513+
self.use_irope = use_irope
517514

518515
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
519516

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 7 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
from vllm.config import VllmConfig
1414
from vllm.logger import init_logger
1515
from vllm.platforms import current_platform
16-
from vllm.v1.attention.backends.flash_attn import (
17-
make_local_attention_virtual_batches)
1816
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
1917
from vllm.v1.kv_cache_interface import AttentionSpec
2018

@@ -201,9 +199,7 @@ def build(self,
201199
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
202200
total_tokens = int(common_attn_metadata.seq_lens_cpu.sum())
203201
query_start_loc = common_attn_metadata.query_start_loc
204-
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
205202
seq_lens = common_attn_metadata.seq_lens
206-
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
207203
block_table_tensor = common_attn_metadata.block_table_tensor
208204
slot_mapping = common_attn_metadata.slot_mapping
209205

@@ -215,56 +211,6 @@ def build(self,
215211
dtype=cu_seq_lens.dtype,
216212
out=cu_seq_lens[1:])
217213

218-
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
219-
max_seq_len, causal):
220-
return None
221-
222-
# for local attention
223-
local_attn_metadata = None
224-
if self.model_config.attention_chunk_size is not None:
225-
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
226-
virt_block_table_tensor = make_local_attention_virtual_batches(
227-
self.model_config.attention_chunk_size,
228-
query_start_loc_cpu.numpy(),
229-
seq_lens_cpu.numpy(),
230-
block_table_tensor,
231-
self.block_size,
232-
)
233-
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
234-
self.device, non_blocking=True)
235-
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
236-
self.device, non_blocking=True)
237-
local_max_query_len = seqlens_q_local_np.max().item()
238-
local_max_seq_len = virt_k_seqlens_np.max().item()
239-
local_scheduler_metadata = schedule(
240-
batch_size=local_query_start_loc.shape[0] - 1,
241-
cu_query_lens=local_query_start_loc,
242-
max_query_len=local_max_query_len,
243-
seqlens=local_seqused_k,
244-
max_seq_len=local_max_seq_len,
245-
causal=True)
246-
247-
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
248-
dtype=torch.int32,
249-
device=self.device)
250-
local_cu_seq_lens[1:] = torch.cumsum(
251-
torch.from_numpy(virt_k_seqlens_np).to(device=self.device,
252-
dtype=torch.int32,
253-
non_blocking=True),
254-
dim=0)
255-
256-
257-
local_attn_metadata = \
258-
AiterFlashAttentionMetadata.LocalAttentionMetadata(
259-
local_query_start_loc=local_query_start_loc,
260-
local_seqused_k=local_seqused_k,
261-
local_block_table=virt_block_table_tensor,
262-
local_max_query_len=local_max_query_len,
263-
local_max_seq_len=local_max_seq_len,
264-
local_cu_seq_lens=local_cu_seq_lens,
265-
local_scheduler_metadata=local_scheduler_metadata,
266-
)
267-
268214
use_cascade = common_prefix_len > 0
269215

270216
cu_prefix_query_lens = None
@@ -286,7 +232,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
286232
cu_prefix_query_lens=cu_prefix_query_lens,
287233
prefix_kv_lens=prefix_kv_lens,
288234
suffix_kv_lens=suffix_kv_lens,
289-
local_attn_metadata=local_attn_metadata,
290235
)
291236
return attn_metadata
292237

@@ -377,19 +322,6 @@ class AiterFlashAttentionMetadata:
377322
prefix_kv_lens: Optional[torch.Tensor]
378323
suffix_kv_lens: Optional[torch.Tensor]
379324

380-
# for local attention
381-
@dataclass
382-
class LocalAttentionMetadata:
383-
local_query_start_loc: torch.Tensor
384-
local_seqused_k: torch.Tensor
385-
local_block_table: torch.Tensor
386-
local_max_query_len: int
387-
local_max_seq_len: int
388-
local_cu_seq_lens: torch.Tensor
389-
local_scheduler_metadata: Optional[torch.Tensor]
390-
391-
local_attn_metadata: Optional[LocalAttentionMetadata] = None
392-
393325

394326
class AiterFlashAttentionImpl(AttentionImpl):
395327

@@ -521,25 +453,12 @@ def forward(
521453
layer._q_scale)
522454
query = query.reshape((num_tokens, num_heads, head_size))
523455

524-
# Compute attention and update output up to `num_actual_tokens`.
525-
use_local_attn = \
526-
(self.use_irope and attn_metadata.local_attn_metadata is not None)
527-
528-
if not attn_metadata.use_cascade or use_local_attn:
529-
if use_local_attn:
530-
assert attn_metadata.local_attn_metadata is not None
531-
local_metadata = attn_metadata.local_attn_metadata
532-
cu_seqlens_q = local_metadata.local_query_start_loc
533-
seqused_k = local_metadata.local_seqused_k
534-
max_seqlen_q = local_metadata.local_max_query_len
535-
max_seqlen_k = local_metadata.local_max_seq_len
536-
block_table = local_metadata.local_block_table
537-
else:
538-
cu_seqlens_q = attn_metadata.query_start_loc
539-
seqused_k = attn_metadata.seq_lens
540-
max_seqlen_q = attn_metadata.max_query_len
541-
max_seqlen_k = attn_metadata.max_seq_len
542-
block_table = attn_metadata.block_table
456+
if not attn_metadata.use_cascade:
457+
cu_seqlens_q = attn_metadata.query_start_loc
458+
seqused_k = attn_metadata.seq_lens
459+
max_seqlen_q = attn_metadata.max_query_len
460+
max_seqlen_k = attn_metadata.max_seq_len
461+
block_table = attn_metadata.block_table
543462

544463
if max_seqlen_q > 1:
545464
cu_seq_lens = attn_metadata.cu_seq_lens
@@ -557,9 +476,7 @@ def forward(
557476
alibi_slopes=self.alibi_slopes,
558477
window_size=self.sliding_window,
559478
block_table=block_table,
560-
cu_seqlens_k=(cu_seq_lens if not use_local_attn else
561-
local_metadata.local_cu_seq_lens),
562-
)
479+
cu_seqlens_k=cu_seq_lens)
563480

564481
_, num_heads, head_size = query.shape
565482
_PARTITION_SIZE_ROCM = 256

0 commit comments

Comments
 (0)