Skip to content

[Attention] Make local attention backend agnostic #21093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 10 additions & 74 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
make_local_attention_virtual_batches)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
get_kv_cache_layout)
from vllm.v1.kv_cache_interface import AttentionSpec

logger = init_logger(__name__)
Expand Down Expand Up @@ -130,18 +130,6 @@ class FlashAttentionMetadata:
prefix_scheduler_metadata: Optional[torch.Tensor] = None
max_num_splits: int = 0

# for local attention
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor
local_seqused_k: torch.Tensor
local_block_table: torch.Tensor
local_max_query_len: int
local_max_seq_len: int
local_scheduler_metadata: Optional[torch.Tensor]

local_attn_metadata: Optional[LocalAttentionMetadata] = None


def _get_sliding_window_configs(
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
Expand Down Expand Up @@ -221,7 +209,6 @@ def build(self,
max_query_len = common_attn_metadata.max_query_len
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
Expand Down Expand Up @@ -266,40 +253,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
)
return None

# for local attention
local_attn_metadata = None
if self.model_config.attention_chunk_size is not None:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table_tensor = make_local_attention_virtual_batches(
self.model_config.attention_chunk_size,
query_start_loc_cpu.numpy(),
seq_lens_cpu.numpy(),
block_table_tensor,
self.block_size,
)
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
self.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max()
local_max_seq_len = virt_k_seqlens_np.max()
local_scheduler_metadata = schedule(
batch_size=local_query_start_loc.shape[0] - 1,
cu_query_lens=local_query_start_loc,
max_query_len=local_max_query_len,
seqlens=local_seqused_k,
max_seq_len=local_max_seq_len,
causal=True)

local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=local_query_start_loc,
local_seqused_k=local_seqused_k,
local_block_table=virt_block_table_tensor,
local_max_query_len=local_max_query_len,
local_max_seq_len=local_max_seq_len,
local_scheduler_metadata=local_scheduler_metadata,
)

use_cascade = common_prefix_len > 0

if use_cascade:
Expand Down Expand Up @@ -371,7 +324,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
local_attn_metadata=local_attn_metadata,
prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits,
)
Expand Down Expand Up @@ -517,27 +469,13 @@ def forward(
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))

# Compute attention and update output up to `num_actual_tokens`.
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)

if not attn_metadata.use_cascade or use_local_attn:
if use_local_attn:
assert attn_metadata.local_attn_metadata is not None
local_metadata = attn_metadata.local_attn_metadata
cu_seqlens_q = local_metadata.local_query_start_loc
seqused_k = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
scheduler_metadata = local_metadata.local_scheduler_metadata
else:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata

descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])

Expand Down Expand Up @@ -565,8 +503,6 @@ def forward(
)
return output

assert not use_local_attn, (
"Cascade attention does not support local attention.")
# Cascade attention (rare case).
cascade_attention(
output[:num_actual_tokens],
Expand Down
5 changes: 1 addition & 4 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,10 +495,6 @@ def __init__(
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False,
) -> None:
if use_irope:
logger.warning_once(
"Using irope in FlashInfer is not supported yet, it will fall"
" back to global attention for long context.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand All @@ -513,6 +509,7 @@ def __init__(
self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.use_irope = use_irope

self.num_queries_per_kv = self.num_heads // self.num_kv_heads

Expand Down
97 changes: 7 additions & 90 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (
make_local_attention_virtual_batches)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import AttentionSpec

Expand Down Expand Up @@ -201,9 +199,7 @@ def build(self,
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
total_tokens = int(common_attn_metadata.seq_lens_cpu.sum())
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping

Expand All @@ -215,56 +211,6 @@ def build(self,
dtype=cu_seq_lens.dtype,
out=cu_seq_lens[1:])

def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
return None

# for local attention
local_attn_metadata = None
if self.model_config.attention_chunk_size is not None:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table_tensor = make_local_attention_virtual_batches(
self.model_config.attention_chunk_size,
query_start_loc_cpu.numpy(),
seq_lens_cpu.numpy(),
block_table_tensor,
self.block_size,
)
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
self.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max().item()
local_max_seq_len = virt_k_seqlens_np.max().item()
local_scheduler_metadata = schedule(
batch_size=local_query_start_loc.shape[0] - 1,
cu_query_lens=local_query_start_loc,
max_query_len=local_max_query_len,
seqlens=local_seqused_k,
max_seq_len=local_max_seq_len,
causal=True)

local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
dtype=torch.int32,
device=self.device)
local_cu_seq_lens[1:] = torch.cumsum(
torch.from_numpy(virt_k_seqlens_np).to(device=self.device,
dtype=torch.int32,
non_blocking=True),
dim=0)


local_attn_metadata = \
AiterFlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=local_query_start_loc,
local_seqused_k=local_seqused_k,
local_block_table=virt_block_table_tensor,
local_max_query_len=local_max_query_len,
local_max_seq_len=local_max_seq_len,
local_cu_seq_lens=local_cu_seq_lens,
local_scheduler_metadata=local_scheduler_metadata,
)

use_cascade = common_prefix_len > 0

cu_prefix_query_lens = None
Expand All @@ -286,7 +232,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
local_attn_metadata=local_attn_metadata,
)
return attn_metadata

Expand Down Expand Up @@ -377,19 +322,6 @@ class AiterFlashAttentionMetadata:
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]

# for local attention
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor
local_seqused_k: torch.Tensor
local_block_table: torch.Tensor
local_max_query_len: int
local_max_seq_len: int
local_cu_seq_lens: torch.Tensor
local_scheduler_metadata: Optional[torch.Tensor]

local_attn_metadata: Optional[LocalAttentionMetadata] = None


class AiterFlashAttentionImpl(AttentionImpl):

Expand Down Expand Up @@ -521,25 +453,12 @@ def forward(
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))

# Compute attention and update output up to `num_actual_tokens`.
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)

if not attn_metadata.use_cascade or use_local_attn:
if use_local_attn:
assert attn_metadata.local_attn_metadata is not None
local_metadata = attn_metadata.local_attn_metadata
cu_seqlens_q = local_metadata.local_query_start_loc
seqused_k = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
else:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table

if max_seqlen_q > 1:
cu_seq_lens = attn_metadata.cu_seq_lens
Expand All @@ -557,9 +476,7 @@ def forward(
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
cu_seqlens_k=(cu_seq_lens if not use_local_attn else
local_metadata.local_cu_seq_lens),
)
cu_seqlens_k=cu_seq_lens)

_, num_heads, head_size = query.shape
_PARTITION_SIZE_ROCM = 256
Expand Down
Loading