Skip to content

Commit 5a37c78

Browse files
[spec decoding] add tree attention backend selection
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
1 parent 677c149 commit 5a37c78

File tree

9 files changed

+26
-3
lines changed

9 files changed

+26
-3
lines changed

vllm/attention/layer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(
5252
prefix: str = "",
5353
attn_type: str = AttentionType.DECODER,
5454
kv_sharing_target_layer_name: Optional[str] = None,
55+
is_draft: bool = False,
5556
**extra_impl_args,
5657
) -> None:
5758
"""
@@ -135,7 +136,8 @@ def __init__(
135136
block_size,
136137
is_attention_free,
137138
blocksparse_params is not None,
138-
use_mla=use_mla)
139+
use_mla=use_mla,
140+
is_draft=is_draft)
139141
impl_cls = attn_backend.get_impl_cls()
140142
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
141143
alibi_slopes, sliding_window, kv_cache_dtype,

vllm/attention/selector.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def get_attn_backend(
8787
is_attention_free: bool,
8888
is_blocksparse: bool = False,
8989
use_mla: bool = False,
90+
is_draft: bool = False,
9091
) -> Type[AttentionBackend]:
9192
"""Selects which attention backend to use and lazily imports it."""
9293
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
@@ -102,6 +103,7 @@ def get_attn_backend(
102103
is_blocksparse=is_blocksparse,
103104
use_v1=envs.VLLM_USE_V1,
104105
use_mla=use_mla,
106+
is_draft=is_draft,
105107
)
106108

107109

@@ -115,7 +117,15 @@ def _cached_get_attn_backend(
115117
is_blocksparse: bool = False,
116118
use_v1: bool = False,
117119
use_mla: bool = False,
120+
is_draft: bool = False,
118121
) -> Type[AttentionBackend]:
122+
# TODO(gdelfin): Allow selection of draft model backend for EAGLE. Currently,
123+
# it is forced to FlashAttentionBackend to be consistent with EagleProposer.
124+
if use_v1 and is_draft:
125+
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
126+
127+
return FlashAttentionBackend
128+
119129
if is_blocksparse:
120130
logger.info("Using BlocksparseFlashAttention backend.")
121131
from vllm.attention.backends.blocksparse_attn import (

vllm/engine/arg_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,6 +1414,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14141414
"ROCM_AITER_MLA",
14151415
"TORCH_SDPA_VLLM_V1",
14161416
"FLEX_ATTENTION",
1417+
"TREE_ATTN",
14171418
]
14181419
if (envs.is_set("VLLM_ATTENTION_BACKEND")
14191420
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):

vllm/model_executor/models/llama.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(
113113
cache_config: Optional[CacheConfig] = None,
114114
prefix: str = "",
115115
attn_type: str = AttentionType.DECODER,
116+
is_draft: bool = False,
116117
) -> None:
117118
super().__init__()
118119
layer_idx = extract_layer_index(prefix)
@@ -190,6 +191,7 @@ def __init__(
190191
per_layer_sliding_window=sliding_window,
191192
attn_type=attn_type,
192193
prefix=f"{prefix}.attn",
194+
is_draft=is_draft,
193195
)
194196

195197
def forward(
@@ -231,6 +233,7 @@ def __init__(
231233
cache_config: Optional[CacheConfig] = None,
232234
quant_config: Optional[QuantizationConfig] = None,
233235
prefix: str = "",
236+
is_draft: bool = False,
234237
) -> None:
235238
super().__init__()
236239
self.hidden_size = config.hidden_size
@@ -275,6 +278,7 @@ def __init__(
275278
cache_config=cache_config,
276279
prefix=f"{prefix}.self_attn",
277280
attn_type=attn_type,
281+
is_draft=is_draft,
278282
)
279283
self.mlp = LlamaMLP(
280284
hidden_size=self.hidden_size,

vllm/model_executor/models/llama_eagle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
disable_input_layernorm: bool,
3232
prefix: str = "",
3333
) -> None:
34-
super().__init__(config, prefix=prefix)
34+
super().__init__(config, prefix=prefix, is_draft=True)
3535

3636
# Skip the input_layernorm
3737
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427

vllm/platforms/cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
248248
logger.info_once("Using Flash Attention backend on V1 engine.")
249249
return ("vllm.v1.attention.backends."
250250
"flash_attn.FlashAttentionBackend")
251+
elif selected_backend == _Backend.TREE_ATTN:
252+
logger.info_once("Using Tree Attention backend on V1 engine.")
253+
return ("vllm.v1.attention.backends."
254+
"tree_attn.TreeAttentionBackend")
251255

252256
# Default backends for V1 engine
253257
# Prefer FlashInfer for Blackwell GPUs if installed

vllm/platforms/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class _Backend(enum.Enum):
6161
DUAL_CHUNK_FLASH_ATTN = enum.auto()
6262
NO_ATTENTION = enum.auto()
6363
FLEX_ATTENTION = enum.auto()
64+
TREE_ATTN = enum.auto()
6465

6566

6667
class PlatformEnum(enum.Enum):

vllm/v1/attention/backends/flash_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _get_sliding_window_configs(
132132
sliding_window_configs: set[Optional[tuple[int, int]]] = set()
133133
layers = get_layers_from_vllm_config(vllm_config, Attention)
134134
for layer in layers.values():
135-
assert isinstance(layer.impl, FlashAttentionImpl)
135+
assert hasattr(layer.impl, "sliding_window")
136136
sliding_window_configs.add(layer.impl.sliding_window)
137137
return sliding_window_configs
138138

vllm/v1/attention/backends/tree_attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def __init__(
380380
None, # Skip KV reshape and cache. This class handles it.
381381
use_irope=use_irope,
382382
)
383+
self.sliding_window = self.prefill_attention_impl.sliding_window
383384

384385
def forward(
385386
self,

0 commit comments

Comments
 (0)