Skip to content

[v1][attention] Support Hybrid Allocator + FlashInfer #21412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/offline_inference/prithvi_geospatial_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import argparse
import datetime
import os
import re
from typing import Union

import albumentations
import numpy as np
import rasterio
import regex as re
import torch
from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import json
import logging
import os
import re
import subprocess
import sys
from pathlib import Path
from shutil import which

import regex as re
import torch
from packaging.version import Version, parse
from setuptools import Extension, setup
Expand Down
19 changes: 11 additions & 8 deletions tests/v1/attention/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def __init__(self, device: torch.device):


def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
vllm_config, device: torch.device,
layer_names: list[str], vllm_config,
device: torch.device,
common_attn_metadata: CommonAttentionMetadata,
query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor,
Expand All @@ -211,31 +212,33 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
if backend == _Backend.FLASHINFER_VLLM_V1:
import unittest.mock

from vllm.v1.attention.backends.flashinfer import PerLayerParameters
from vllm.v1.attention.backends.utils import PerLayerParameters

def mock_get_per_layer_parameters(vllm_config, impl_cls):
def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
# Return mock parameters for a single layer
head_size = vllm_config.model_config.get_head_size()
return {
"mock_layer":
layer_name:
PerLayerParameters(
window_left=-1, # No sliding window
logits_soft_cap=0.0, # No soft cap
sm_scale=1.0 / (head_size**0.5) # Standard scale
)
for layer_name in layer_names
}

with unittest.mock.patch(
'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters',
mock_get_per_layer_parameters):
builder = builder_cls(kv_cache_spec, vllm_config, device)
builder = builder_cls(kv_cache_spec, layer_names, vllm_config,
device)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
else:
# Build metadata
builder = builder_cls(kv_cache_spec, vllm_config, device)
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
Expand Down Expand Up @@ -427,8 +430,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
set_kv_cache_layout("HND")

backend_output = run_attention_backend(backend_name, kv_cache_spec,
vllm_config, device,
common_attn_metadata,
["placeholder"], vllm_config,
device, common_attn_metadata,
query_vllm, key_vllm,
value_vllm,
kv_cache_for_backend)
Expand Down
3 changes: 1 addition & 2 deletions tests/v1/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import re

import pytest
import regex as re
import requests
import torch

Expand Down
32 changes: 24 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,8 @@ def __post_init__(self) -> None:
isinstance(sliding_window, list))

if not self.disable_sliding_window and has_interleaved_attention:
if (backend :=
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND
) in ("XFORMERS", "FLASHINFER"):
sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)

Expand Down Expand Up @@ -4922,13 +4922,29 @@ def assert_hashable(text):
T = TypeVar("T")


def get_layers_from_vllm_config(vllm_config: VllmConfig,
layer_type: type[T]) -> dict[str, T]:
def get_layers_from_vllm_config(
vllm_config: VllmConfig,
layer_type: type[T],
layer_names: Optional[list[str]] = None) -> dict[str, T]:
"""
Get layers from the vLLM config.

Args:
vllm_config: The vLLM config.
layer_type: The type of the layer to get.
layer_names: The names of the layers to get. If None, return all layers.
"""

if layer_names is None:
layer_names = list(
vllm_config.compilation_config.static_forward_context.keys())

forward_context = vllm_config.compilation_config.static_forward_context

return {
layer_name: layer
for layer_name, layer in
vllm_config.compilation_config.static_forward_context.items()
if isinstance(layer, layer_type)
layer_name: forward_context[layer_name]
for layer_name in layer_names
if isinstance(forward_context[layer_name], layer_type)
}


Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/cpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,8 @@ def get_seq_len_block_table_args(

class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device) -> None:
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device) -> None:
self.kv_cache_spec = kv_cache_spec
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
Expand Down
18 changes: 7 additions & 11 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
from vllm.utils import cdiv
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
get_kv_cache_layout, get_per_layer_parameters,
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills)
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
get_per_layer_parameters, infer_global_hyperparameters,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec

if TYPE_CHECKING:
Expand Down Expand Up @@ -220,16 +219,17 @@ def __post_init__(self):

class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might play nicer with: #21588

if we do:

    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                  vllm_config: VllmConfig, device: torch.device):

instead

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I've fixed it.

self.device = device
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode
self._cascade_wrapper = None # Wrapper for cascade attention

# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None
self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))

self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config
Expand Down Expand Up @@ -284,10 +284,6 @@ def _get_cascade_wrapper(self):

def _plan(self, num_prefills: int, num_decodes: int,
attn_metadata: FlashInferMetadata):
if self.global_hyperparameters is None:
self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config, FlashInferImpl))

if attn_metadata.use_cascade:
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan(
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ def __post_init__(self):
class FlexAttentionMetadataBuilder(
AttentionMetadataBuilder[FlexAttentionMetadata]):

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/mamba_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]):

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):

def __init__(self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[type[M]] = None):
Expand Down Expand Up @@ -471,7 +472,8 @@ def __init__(self,
BatchPrefillWithRaggedKVCacheWrapper] = []

self._global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, MLACommonImpl))
get_per_layer_parameters(vllm_config, layer_names,
MLACommonImpl))

if self._use_cudnn_prefill:
self.cudnn_workspace = torch.empty(
Expand Down
7 changes: 4 additions & 3 deletions vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # Decode-only

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
super().__init__(kv_cache_spec, vllm_config, device, FlashMLAMetadata)
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
FlashMLAMetadata)

self.compilation_config = vllm_config.compilation_config
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
Expand Down
7 changes: 4 additions & 3 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # decode only

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
super().__init__(kv_cache_spec, vllm_config, device, AiterMLAMetadata)
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
AiterMLAMetadata)
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
"only supports block size 1."

Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ class AiterFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[AiterFlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.device = device
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
Expand Down
14 changes: 9 additions & 5 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
full_cudagraph_supported: ClassVar[bool] = False

@abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.kv_cache_spec = kv_cache_spec

@abstractmethod
Expand Down Expand Up @@ -162,14 +162,14 @@ class PerLayerParameters:


def get_per_layer_parameters(
vllm_config: VllmConfig,
vllm_config: VllmConfig, layer_names: list[str],
cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]:
"""
Scan all attention layers and determine some hyperparameters
Scan layers in `layer_names` and determine some hyperparameters
to use during `plan`.
"""

layers = get_layers_from_vllm_config(vllm_config, Attention)
layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names)
per_layer_params: dict[str, PerLayerParameters] = {}

for key, layer in layers.items():
Expand Down Expand Up @@ -206,6 +206,10 @@ def infer_global_hyperparameters(
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
for params in param_sets:
if params.window_left != global_params.window_left:
raise ValueError(
"Window left is not the same for all layers. One potential fix "
"is to set disable_sliding_window=True")
assert params == global_params, (
"FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: "
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2504,6 +2504,7 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:

attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
kv_cache_spec,
kv_cache_group_spec.layer_names,
self.vllm_config,
self.device,
)
Expand Down