Skip to content

Commit 5546b9f

Browse files
committed
initial commit for non-shifting prefill in eagle, prepare for kv sharing
rope change rope change rebase
1 parent d64bf91 commit 5546b9f

File tree

6 files changed

+379
-45
lines changed

6 files changed

+379
-45
lines changed

examples/offline_inference/spec_decode.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def parse_args():
6868
parser.add_argument("--model-dir", type=str, default=None)
6969
parser.add_argument("--eagle-dir", type=str, default=None)
7070
parser.add_argument("--custom-mm-prompts", action="store_true")
71+
parser.add_argument("--no-prefill-token-shift", dest="prefill_token_shift",
72+
action="store_false", help="Disable prefill token shift (default: enabled)")
7173
return parser.parse_args()
7274

7375

@@ -103,6 +105,7 @@ def main():
103105
"method": args.method,
104106
"model": eagle_dir,
105107
"num_speculative_tokens": args.num_spec_tokens,
108+
"prefill_token_shift": args.prefill_token_shift,
106109
}
107110
elif args.method == "ngram":
108111
speculative_config = {

vllm/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2551,6 +2551,14 @@ class SpeculativeConfig:
25512551
ParallelConfig] = None # type: ignore
25522552
"""The parallel configuration for the draft model initialized internal."""
25532553

2554+
# Shift prefill tokens for draft, only used in eagle
2555+
prefill_token_shift: bool = True
2556+
"""Shift tokens during draft prefill or not"""
2557+
2558+
# Config for kv sharing, map from base model layer to draft layer
2559+
kv_sharing_mapping: SkipValidation[dict[str, str]] = None
2560+
"""KV copy mapping for prefill stage from base to draft"""
2561+
25542562
def compute_hash(self) -> str:
25552563
"""
25562564
WARNING: Whenever a new field is added to this config,
@@ -2937,6 +2945,11 @@ def num_lookahead_slots(self) -> int:
29372945
def use_eagle(self) -> bool:
29382946
return self.method in ("eagle", "eagle3", "deepseek_mtp")
29392947

2948+
def eagle_shift_prefill_token(self) -> bool:
2949+
if self.use_eagle():
2950+
return self.prefill_token_shift
2951+
return False
2952+
29402953
def __repr__(self) -> str:
29412954
method = self.method
29422955
model = None if method == "ngram" else self.draft_model_config.model

vllm/model_executor/models/llama4.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,11 @@ def __init__(self,
183183
is_gguf = quant_config and quant_config.get_name() == "gguf"
184184
if is_gguf and config.model_type == "llama":
185185
is_neox_style = False
186+
elif config.model_type == "eagle":
187+
# EAGLE draft model does not use neox style RoPE
188+
is_neox_style = False
189+
else:
190+
is_neox_style = True
186191

187192
self.rotary_emb = get_rope(
188193
self.head_dim,

0 commit comments

Comments
 (0)