Skip to content

Commit 0e691d5

Browse files
[spec decoding] implement proposing tree drafts
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
1 parent da6c40b commit 0e691d5

File tree

11 files changed

+455
-366
lines changed

11 files changed

+455
-366
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ ignore_patterns = [
6161

6262
[tool.ruff]
6363
# Allow lines to be as long as 80.
64-
line-length = 80
64+
line-length = 90
6565

6666
[tool.ruff.lint.per-file-ignores]
6767
"vllm/third_party/**" = ["ALL"]

vllm/attention/layer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def __init__(
5252
prefix: str = "",
5353
attn_type: str = AttentionType.DECODER,
5454
kv_sharing_target_layer_name: Optional[str] = None,
55-
is_draft: bool = False,
5655
**extra_impl_args,
5756
) -> None:
5857
"""
@@ -136,8 +135,7 @@ def __init__(
136135
block_size,
137136
is_attention_free,
138137
blocksparse_params is not None,
139-
use_mla=use_mla,
140-
is_draft=is_draft)
138+
use_mla=use_mla)
141139
impl_cls = attn_backend.get_impl_cls()
142140
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
143141
alibi_slopes, sliding_window, kv_cache_dtype,

vllm/attention/selector.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,23 @@ def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
2727
loaded.
2828
"""
2929
assert backend_name is not None
30-
return _Backend[
31-
backend_name] if backend_name in _Backend.__members__ else None
30+
return _Backend[backend_name] if backend_name in _Backend.__members__ else \
31+
None
3232

3333

3434
def get_env_variable_attn_backend() -> Optional[_Backend]:
35-
"""
35+
'''
3636
Get the backend override specified by the vLLM attention
3737
backend environment variable, if one is specified.
3838
3939
Returns:
4040
4141
* _Backend enum value if an override is specified
4242
* None otherwise
43-
"""
43+
'''
4444
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
45-
return None if backend_name is None else backend_name_to_enum(backend_name)
45+
return (None
46+
if backend_name is None else backend_name_to_enum(backend_name))
4647

4748

4849
# Global state allows a particular choice of backend
@@ -56,7 +57,7 @@ def get_env_variable_attn_backend() -> Optional[_Backend]:
5657

5758

5859
def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
59-
"""
60+
'''
6061
Force all attention operations to use a specified backend.
6162
6263
Passing `None` for the argument re-enables automatic
@@ -65,16 +66,16 @@ def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
6566
Arguments:
6667
6768
* attn_backend: backend selection (None to revert to auto)
68-
"""
69+
'''
6970
global forced_attn_backend
7071
forced_attn_backend = attn_backend
7172

7273

7374
def get_global_forced_attn_backend() -> Optional[_Backend]:
74-
"""
75+
'''
7576
Get the currently-forced choice of attention backend,
7677
or None if auto-selection is currently enabled.
77-
"""
78+
'''
7879
return forced_attn_backend
7980

8081

@@ -86,7 +87,6 @@ def get_attn_backend(
8687
is_attention_free: bool,
8788
is_blocksparse: bool = False,
8889
use_mla: bool = False,
89-
is_draft: bool = False,
9090
) -> Type[AttentionBackend]:
9191
"""Selects which attention backend to use and lazily imports it."""
9292
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
@@ -102,7 +102,6 @@ def get_attn_backend(
102102
is_blocksparse=is_blocksparse,
103103
use_v1=envs.VLLM_USE_V1,
104104
use_mla=use_mla,
105-
is_draft=is_draft,
106105
)
107106

108107

@@ -116,28 +115,18 @@ def _cached_get_attn_backend(
116115
is_blocksparse: bool = False,
117116
use_v1: bool = False,
118117
use_mla: bool = False,
119-
is_draft: bool = False,
120118
) -> Type[AttentionBackend]:
121-
# Draft model backend is currently forced to FlashAttentionBackend for
122-
# consistency with EagleProposer using FlashAttentionMetadata.
123-
if use_v1 and is_draft:
124-
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
125-
126-
return FlashAttentionBackend
127-
128119
if is_blocksparse:
129120
logger.info("Using BlocksparseFlashAttention backend.")
130121
from vllm.attention.backends.blocksparse_attn import (
131122
BlocksparseFlashAttentionBackend)
132-
133123
return BlocksparseFlashAttentionBackend
134124

135125
# If there are no attention layers (e.g. we are running Mamba),
136126
# use the placeholder NO_ATTENTION
137127
if is_attention_free:
138128
from vllm.attention.backends.placeholder_attn import (
139129
PlaceholderAttentionBackend)
140-
141130
return PlaceholderAttentionBackend
142131

143132
# Check whether a particular choice of backend was
@@ -146,8 +135,8 @@ def _cached_get_attn_backend(
146135
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
147136
# ENVIRONMENT VARIABLE.
148137
selected_backend = None
149-
backend_by_global_setting: Optional[
150-
_Backend] = get_global_forced_attn_backend()
138+
backend_by_global_setting: Optional[_Backend] = (
139+
get_global_forced_attn_backend())
151140
if backend_by_global_setting is not None:
152141
selected_backend = backend_by_global_setting
153142
else:
@@ -168,8 +157,8 @@ def _cached_get_attn_backend(
168157

169158
@contextmanager
170159
def global_force_attn_backend_context_manager(
171-
attn_backend: _Backend, ) -> Generator[None, None, None]:
172-
"""
160+
attn_backend: _Backend) -> Generator[None, None, None]:
161+
'''
173162
Globally force a vLLM attention backend override within a
174163
context manager, reverting the global attention backend
175164
override to its prior state upon exiting the context
@@ -182,7 +171,7 @@ def global_force_attn_backend_context_manager(
182171
Returns:
183172
184173
* Generator
185-
"""
174+
'''
186175

187176
# Save the current state of the global backend override (if any)
188177
original_value = get_global_forced_attn_backend()

vllm/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2724,6 +2724,13 @@ def __post_init__(self):
27242724
f"num_speculative_tokens:{self.num_speculative_tokens}"
27252725
f" must be divisible by {n_predict=}")
27262726

2727+
if self.speculative_token_tree is None:
2728+
# Generate chain of tokens.
2729+
self.speculative_token_tree = str([[
2730+
(i + 1) * (0, )
2731+
for i in range(self.num_speculative_tokens)
2732+
]])
2733+
27272734
self.draft_tensor_parallel_size = \
27282735
SpeculativeConfig._verify_and_get_draft_tp(
27292736
self.target_parallel_config,

vllm/engine/arg_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1427,7 +1427,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14271427
recommend_to_remove=False)
14281428
return False
14291429

1430-
# No XFormers so far.
14311430
V1_BACKENDS = [
14321431
"FLASH_ATTN_VLLM_V1",
14331432
"FLASH_ATTN",

vllm/model_executor/models/llama.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ def __init__(
113113
cache_config: Optional[CacheConfig] = None,
114114
prefix: str = "",
115115
attn_type: str = AttentionType.DECODER,
116-
is_draft: bool = False,
117116
) -> None:
118117
super().__init__()
119118
layer_idx = extract_layer_index(prefix)
@@ -191,7 +190,6 @@ def __init__(
191190
per_layer_sliding_window=sliding_window,
192191
attn_type=attn_type,
193192
prefix=f"{prefix}.attn",
194-
is_draft=is_draft,
195193
)
196194

197195
def forward(
@@ -233,7 +231,6 @@ def __init__(
233231
cache_config: Optional[CacheConfig] = None,
234232
quant_config: Optional[QuantizationConfig] = None,
235233
prefix: str = "",
236-
is_draft: bool = False,
237234
) -> None:
238235
super().__init__()
239236
self.hidden_size = config.hidden_size
@@ -278,7 +275,6 @@ def __init__(
278275
cache_config=cache_config,
279276
prefix=f"{prefix}.self_attn",
280277
attn_type=attn_type,
281-
is_draft=is_draft,
282278
)
283279
self.mlp = LlamaMLP(
284280
hidden_size=self.hidden_size,

vllm/model_executor/models/llama_eagle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
disable_input_layernorm: bool,
3232
prefix: str = "",
3333
) -> None:
34-
super().__init__(config, prefix=prefix, is_draft=True)
34+
super().__init__(config, prefix=prefix)
3535

3636
# Skip the input_layernorm
3737
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427

vllm/v1/attention/backends/flash_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _get_sliding_window_configs(
134134
sliding_window_configs: set[Optional[tuple[int, int]]] = set()
135135
layers = get_layers_from_vllm_config(vllm_config, Attention)
136136
for layer in layers.values():
137-
assert hasattr(layer.impl, "sliding_window")
137+
assert isinstance(layer.impl, FlashAttentionImpl)
138138
sliding_window_configs.add(layer.impl.sliding_window)
139139
return sliding_window_configs
140140

0 commit comments

Comments
 (0)