Skip to content

Commit 3ff7ebe

Browse files
[spec decoding] add tree attention backend selection
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
1 parent 3bf0a2e commit 3ff7ebe

File tree

9 files changed

+42
-18
lines changed

9 files changed

+42
-18
lines changed

vllm/attention/layer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(
5252
prefix: str = "",
5353
attn_type: str = AttentionType.DECODER,
5454
kv_sharing_target_layer_name: Optional[str] = None,
55+
is_draft: bool = False,
5556
**extra_impl_args,
5657
) -> None:
5758
"""
@@ -135,7 +136,8 @@ def __init__(
135136
block_size,
136137
is_attention_free,
137138
blocksparse_params is not None,
138-
use_mla=use_mla)
139+
use_mla=use_mla,
140+
is_draft=is_draft)
139141
impl_cls = attn_backend.get_impl_cls()
140142
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
141143
alibi_slopes, sliding_window, kv_cache_dtype,

vllm/attention/selector.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,22 @@ def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
2727
loaded.
2828
"""
2929
assert backend_name is not None
30-
return _Backend[backend_name] if backend_name in _Backend.__members__ else \
31-
None
30+
return _Backend[
31+
backend_name] if backend_name in _Backend.__members__ else None
3232

3333

3434
def get_env_variable_attn_backend() -> Optional[_Backend]:
35-
'''
35+
"""
3636
Get the backend override specified by the vLLM attention
3737
backend environment variable, if one is specified.
3838
3939
Returns:
4040
4141
* _Backend enum value if an override is specified
4242
* None otherwise
43-
'''
43+
"""
4444
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
45-
return (None
46-
if backend_name is None else backend_name_to_enum(backend_name))
45+
return None if backend_name is None else backend_name_to_enum(backend_name)
4746

4847

4948
# Global state allows a particular choice of backend
@@ -57,7 +56,7 @@ def get_env_variable_attn_backend() -> Optional[_Backend]:
5756

5857

5958
def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
60-
'''
59+
"""
6160
Force all attention operations to use a specified backend.
6261
6362
Passing `None` for the argument re-enables automatic
@@ -66,16 +65,16 @@ def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
6665
Arguments:
6766
6867
* attn_backend: backend selection (None to revert to auto)
69-
'''
68+
"""
7069
global forced_attn_backend
7170
forced_attn_backend = attn_backend
7271

7372

7473
def get_global_forced_attn_backend() -> Optional[_Backend]:
75-
'''
74+
"""
7675
Get the currently-forced choice of attention backend,
7776
or None if auto-selection is currently enabled.
78-
'''
77+
"""
7978
return forced_attn_backend
8079

8180

@@ -87,6 +86,7 @@ def get_attn_backend(
8786
is_attention_free: bool,
8887
is_blocksparse: bool = False,
8988
use_mla: bool = False,
89+
is_draft: bool = False,
9090
) -> Type[AttentionBackend]:
9191
"""Selects which attention backend to use and lazily imports it."""
9292
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
@@ -102,6 +102,7 @@ def get_attn_backend(
102102
is_blocksparse=is_blocksparse,
103103
use_v1=envs.VLLM_USE_V1,
104104
use_mla=use_mla,
105+
is_draft=is_draft,
105106
)
106107

107108

@@ -115,18 +116,28 @@ def _cached_get_attn_backend(
115116
is_blocksparse: bool = False,
116117
use_v1: bool = False,
117118
use_mla: bool = False,
119+
is_draft: bool = False,
118120
) -> Type[AttentionBackend]:
121+
# Draft model backend is currently forced to FlashAttentionBackend for
122+
# consistency with EagleProposer using FlashAttentionMetadata.
123+
if use_v1 and is_draft:
124+
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
125+
126+
return FlashAttentionBackend
127+
119128
if is_blocksparse:
120129
logger.info("Using BlocksparseFlashAttention backend.")
121130
from vllm.attention.backends.blocksparse_attn import (
122131
BlocksparseFlashAttentionBackend)
132+
123133
return BlocksparseFlashAttentionBackend
124134

125135
# If there are no attention layers (e.g. we are running Mamba),
126136
# use the placeholder NO_ATTENTION
127137
if is_attention_free:
128138
from vllm.attention.backends.placeholder_attn import (
129139
PlaceholderAttentionBackend)
140+
130141
return PlaceholderAttentionBackend
131142

132143
# Check whether a particular choice of backend was
@@ -135,8 +146,8 @@ def _cached_get_attn_backend(
135146
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
136147
# ENVIRONMENT VARIABLE.
137148
selected_backend = None
138-
backend_by_global_setting: Optional[_Backend] = (
139-
get_global_forced_attn_backend())
149+
backend_by_global_setting: Optional[
150+
_Backend] = get_global_forced_attn_backend()
140151
if backend_by_global_setting is not None:
141152
selected_backend = backend_by_global_setting
142153
else:
@@ -157,8 +168,8 @@ def _cached_get_attn_backend(
157168

158169
@contextmanager
159170
def global_force_attn_backend_context_manager(
160-
attn_backend: _Backend) -> Generator[None, None, None]:
161-
'''
171+
attn_backend: _Backend, ) -> Generator[None, None, None]:
172+
"""
162173
Globally force a vLLM attention backend override within a
163174
context manager, reverting the global attention backend
164175
override to its prior state upon exiting the context
@@ -171,7 +182,7 @@ def global_force_attn_backend_context_manager(
171182
Returns:
172183
173184
* Generator
174-
'''
185+
"""
175186

176187
# Save the current state of the global backend override (if any)
177188
original_value = get_global_forced_attn_backend()

vllm/engine/arg_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,6 +1414,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14141414
"ROCM_AITER_MLA",
14151415
"TORCH_SDPA_VLLM_V1",
14161416
"FLEX_ATTENTION",
1417+
"TREE_ATTN",
14171418
]
14181419
if (envs.is_set("VLLM_ATTENTION_BACKEND")
14191420
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):

vllm/model_executor/models/llama.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(
113113
cache_config: Optional[CacheConfig] = None,
114114
prefix: str = "",
115115
attn_type: str = AttentionType.DECODER,
116+
is_draft: bool = False,
116117
) -> None:
117118
super().__init__()
118119
layer_idx = extract_layer_index(prefix)
@@ -190,6 +191,7 @@ def __init__(
190191
per_layer_sliding_window=sliding_window,
191192
attn_type=attn_type,
192193
prefix=f"{prefix}.attn",
194+
is_draft=is_draft,
193195
)
194196

195197
def forward(
@@ -231,6 +233,7 @@ def __init__(
231233
cache_config: Optional[CacheConfig] = None,
232234
quant_config: Optional[QuantizationConfig] = None,
233235
prefix: str = "",
236+
is_draft: bool = False,
234237
) -> None:
235238
super().__init__()
236239
self.hidden_size = config.hidden_size
@@ -275,6 +278,7 @@ def __init__(
275278
cache_config=cache_config,
276279
prefix=f"{prefix}.self_attn",
277280
attn_type=attn_type,
281+
is_draft=is_draft,
278282
)
279283
self.mlp = LlamaMLP(
280284
hidden_size=self.hidden_size,

vllm/model_executor/models/llama_eagle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
disable_input_layernorm: bool,
3232
prefix: str = "",
3333
) -> None:
34-
super().__init__(config, prefix=prefix)
34+
super().__init__(config, prefix=prefix, is_draft=True)
3535

3636
# Skip the input_layernorm
3737
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427

vllm/platforms/cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
248248
logger.info_once("Using Flash Attention backend on V1 engine.")
249249
return ("vllm.v1.attention.backends."
250250
"flash_attn.FlashAttentionBackend")
251+
elif selected_backend == _Backend.TREE_ATTN:
252+
logger.info_once("Using Tree Attention backend on V1 engine.")
253+
return ("vllm.v1.attention.backends."
254+
"tree_attn.TreeAttentionBackend")
251255

252256
# Default backends for V1 engine
253257
# Prefer FlashInfer for Blackwell GPUs if installed

vllm/platforms/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class _Backend(enum.Enum):
6161
DUAL_CHUNK_FLASH_ATTN = enum.auto()
6262
NO_ATTENTION = enum.auto()
6363
FLEX_ATTENTION = enum.auto()
64+
TREE_ATTN = enum.auto()
6465

6566

6667
class PlatformEnum(enum.Enum):

vllm/v1/attention/backends/flash_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _get_sliding_window_configs(
132132
sliding_window_configs: set[Optional[tuple[int, int]]] = set()
133133
layers = get_layers_from_vllm_config(vllm_config, Attention)
134134
for layer in layers.values():
135-
assert isinstance(layer.impl, FlashAttentionImpl)
135+
assert hasattr(layer.impl, "sliding_window")
136136
sliding_window_configs.add(layer.impl.sliding_window)
137137
return sliding_window_configs
138138

vllm/v1/attention/backends/tree_attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def __init__(
380380
None, # Skip KV reshape and cache. This class handles it.
381381
use_irope=use_irope,
382382
)
383+
self.sliding_window = self.prefill_attention_impl.sliding_window
383384

384385
def forward(
385386
self,

0 commit comments

Comments
 (0)