Skip to content

Commit 418d2f8

Browse files
ekagra-ranjanrootWoosukKwon
authored
[V1][Spec Decode] Share input embedding of target model with EAGLE draft model to free ~1GB for llama 3 model (#17326)
Co-authored-by: root <root@ekagra-8xh100.us-east5-a.c.serving-efficiency-poc.internal> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 964472b commit 418d2f8

File tree

4 files changed

+59
-19
lines changed

4 files changed

+59
-19
lines changed

examples/offline_inference/eagle.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ def main():
105105
outputs = llm.generate(prompt_token_ids=prompt_ids,
106106
sampling_params=sampling_params)
107107

108+
# print the generated text
109+
for output in outputs:
110+
print("-" * 50)
111+
print(f"prompt: {output.prompt}")
112+
print(f"generated text: {output.outputs[0].text}")
113+
print("-" * 50)
114+
108115
if not hasattr(outputs, "metrics") or outputs.metrics is None:
109116
return
110117

vllm/model_executor/models/llama_eagle.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from vllm.compilation.decorators import support_torch_compile
1010
from vllm.config import VllmConfig
11+
from vllm.distributed.parallel_state import get_pp_group
1112
from vllm.logger import init_logger
1213
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1314
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -52,11 +53,15 @@ def __init__(
5253
self.config = vllm_config. \
5354
speculative_config.draft_model_config.hf_config
5455
self.vocab_size = self.config.vocab_size
55-
self.embed_tokens = VocabParallelEmbedding(
56-
self.config.vocab_size,
57-
self.config.hidden_size,
58-
prefix=maybe_prefix(prefix, "embed_tokens"),
59-
)
56+
57+
# if PP disabled then draft will share embed with target
58+
if get_pp_group().world_size > 1:
59+
self.embed_tokens = VocabParallelEmbedding(
60+
self.config.vocab_size,
61+
self.config.hidden_size,
62+
prefix=maybe_prefix(prefix, "embed_tokens"),
63+
)
64+
6065
self.layers = nn.ModuleList([
6166
LlamaDecoderLayer(
6267
self.config,
@@ -109,6 +114,12 @@ def load_weights(self, weights: Iterable[Tuple[str,
109114
weight_loader(param, loaded_weight, shard_id)
110115
break
111116
else:
117+
118+
# if PP disabled then draft will share embed with target
119+
if get_pp_group().world_size == 1 and \
120+
"embed_tokens." in name:
121+
continue
122+
112123
param = params_dict[name]
113124
weight_loader = getattr(param, "weight_loader",
114125
default_weight_loader)
@@ -142,14 +153,12 @@ def forward(
142153
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
143154
loader = AutoWeightsLoader(
144155
self,
145-
skip_prefixes=(["lm_head."]
146-
if self.config.tie_word_embeddings else None),
156+
skip_prefixes=None,
147157
)
148158

149159
model_weights = {}
150160
for name, loaded_weight in weights:
151161
if "lm_head" not in name:
152162
name = "model." + name
153163
model_weights[name] = loaded_weight
154-
155-
loader.load_weights(model_weights.items())
164+
return loader.load_weights(model_weights.items())

vllm/model_executor/models/llama_eagle3.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from vllm.compilation.decorators import support_torch_compile
1010
from vllm.config import VllmConfig
11+
from vllm.distributed.parallel_state import get_pp_group
1112
from vllm.logger import init_logger
1213
from vllm.model_executor.layers.layernorm import RMSNorm
1314
from vllm.model_executor.layers.linear import QKVParallelLinear
@@ -91,11 +92,15 @@ def __init__(
9192
self.config = vllm_config. \
9293
speculative_config.draft_model_config.hf_config
9394
self.vocab_size = self.config.vocab_size
94-
self.embed_tokens = VocabParallelEmbedding(
95-
self.config.vocab_size,
96-
self.config.hidden_size,
97-
prefix=maybe_prefix(prefix, "embed_tokens"),
98-
)
95+
96+
# if PP disabled then draft will share embed with target
97+
if get_pp_group().world_size > 1:
98+
self.embed_tokens = VocabParallelEmbedding(
99+
self.config.vocab_size,
100+
self.config.hidden_size,
101+
prefix=maybe_prefix(prefix, "embed_tokens"),
102+
)
103+
99104
self.layers = nn.ModuleList([
100105
LlamaDecoderLayer(
101106
self.config,

vllm/v1/spec_decode/eagle.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from vllm.attention.layer import Attention
66
from vllm.config import (CompilationLevel, VllmConfig,
77
get_layers_from_vllm_config, set_current_vllm_config)
8+
from vllm.distributed.parallel_state import get_pp_group
89
from vllm.forward_context import set_forward_context
910
from vllm.logger import init_logger
1011
from vllm.model_executor.model_loader import get_model_loader
@@ -306,12 +307,30 @@ def load_model(self, target_model: nn.Module) -> None:
306307
self.attn_layer_name = next(iter(draft_attn_layer_names))
307308
loaded_weights = self.model.load_weights(
308309
loader.get_all_weights(draft_model_config, self.model))
309-
if self.vllm_config.speculative_config.method == "eagle3":
310-
if "model.embed_tokens.weight" not in loaded_weights:
311-
logger.info(
312-
"Loading EAGLE embedding weights from the target model.")
313-
self.model.model.embed_tokens = target_model.model.embed_tokens
310+
311+
# share embed_tokens with the target model if needed
312+
if get_pp_group().world_size == 1:
313+
assert "model.embed_tokens.weight" not in loaded_weights, \
314+
"For PP = 1, Eagle draft should share embed with target model"
315+
logger.info(
316+
"The EAGLE head shares the same vocab embedding" \
317+
" with the target model."
318+
)
319+
self.model.model.embed_tokens = target_model.model.embed_tokens
314320
else:
321+
assert "model.embed_tokens.weight" in loaded_weights, \
322+
"For PP > 1, Eagle draft checkpoint should its own copy of "
323+
" the model.embed_tokens.weight"
324+
logger.info(
325+
"Since PP > 1, the EAGLE head loaded its own vocab embedding" \
326+
" weights instead of sharing them with the target model."
327+
)
328+
329+
# share lm_head with the target model if needed
330+
# some model definition do not define lm_head explicitly
331+
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
332+
if self.vllm_config.speculative_config.method != "eagle3" and \
333+
hasattr(target_model, "lm_head"):
315334
logger.info("Loading EAGLE LM head weights from the target model.")
316335
self.model.lm_head = target_model.lm_head
317336

0 commit comments

Comments
 (0)