Skip to content

Commit 2e9541f

Browse files
committed
add kv copy logic and offline tests
Signed-off-by: morgendave <morgendave@gmail.com>
1 parent 5546b9f commit 2e9541f

File tree

7 files changed

+288
-55
lines changed

7 files changed

+288
-55
lines changed

examples/offline_inference/spec_decode.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,19 @@ def parse_args():
6868
parser.add_argument("--model-dir", type=str, default=None)
6969
parser.add_argument("--eagle-dir", type=str, default=None)
7070
parser.add_argument("--custom-mm-prompts", action="store_true")
71-
parser.add_argument("--no-prefill-token-shift", dest="prefill_token_shift",
72-
action="store_false", help="Disable prefill token shift (default: enabled)")
71+
parser.add_argument(
72+
"--no-prefill-token-shift",
73+
dest="prefill_token_shift",
74+
action="store_false",
75+
help="Disable prefill token shift (default: enabled)",
76+
)
77+
parser.add_argument("--target_kv_layer_copy_from", type=int, default=-1)
78+
parser.add_argument(
79+
"--draft_kv_layer_copy_to",
80+
type=str,
81+
default="",
82+
help="comma separated list of layer indices to copy to",
83+
)
7384
return parser.parse_args()
7485

7586

@@ -101,11 +112,24 @@ def main():
101112

102113
elif args.method == "eagle3" and eagle_dir is None:
103114
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
115+
target_kv_layer_copy_from = args.target_kv_layer_copy_from
116+
draft_kv_layers_copy_to = (
117+
[int(layer) for layer in args.draft_kv_layer_copy_to.split(",")]
118+
if args.draft_kv_layer_copy_to
119+
else None
120+
)
121+
kv_sharing_mapping = None
122+
if args.target_kv_layer_copy_from >= 0 and draft_kv_layers_copy_to:
123+
kv_sharing_mapping = {
124+
f"{layer}": f"{target_kv_layer_copy_from}"
125+
for layer in draft_kv_layers_copy_to
126+
}
104127
speculative_config = {
105128
"method": args.method,
106129
"model": eagle_dir,
107130
"num_speculative_tokens": args.num_spec_tokens,
108131
"prefill_token_shift": args.prefill_token_shift,
132+
"kv_sharing_mapping": kv_sharing_mapping,
109133
}
110134
elif args.method == "ngram":
111135
speculative_config = {

vllm/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2556,6 +2556,7 @@ class SpeculativeConfig:
25562556
"""Shift tokens during draft prefill or not"""
25572557

25582558
# Config for kv sharing, map from base model layer to draft layer
2559+
# Key is draft layer, value is base layer
25592560
kv_sharing_mapping: SkipValidation[dict[str, str]] = None
25602561
"""KV copy mapping for prefill stage from base to draft"""
25612562

vllm/envs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@
138138
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE"
139139
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
140140
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
141+
VLLM_DECODE_ONLY_ATTN: bool = False
141142

142143

143144
def get_default_cache_root():
@@ -953,7 +954,9 @@ def get_vllm_port() -> Optional[int]:
953954
# generations on machines < 100 for compressed-tensors
954955
# models
955956
"VLLM_USE_NVFP4_CT_EMULATIONS":
956-
lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0")))
957+
lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))),
958+
"VLLM_DECODE_ONLY_ATTN":
959+
lambda: os.environ.get("VLLM_DECODE_ONLY_ATTN", "0") == "1"
957960
}
958961

959962
# --8<-- [end:env-vars-definition]

vllm/model_executor/models/llama4.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,6 @@ def __init__(self,
183183
is_gguf = quant_config and quant_config.get_name() == "gguf"
184184
if is_gguf and config.model_type == "llama":
185185
is_neox_style = False
186-
elif config.model_type == "eagle":
187-
# EAGLE draft model does not use neox style RoPE
188-
is_neox_style = False
189-
else:
190-
is_neox_style = True
191186

192187
self.rotary_emb = get_rope(
193188
self.head_dim,

0 commit comments

Comments
 (0)