Skip to content

Commit 26f2b7f

Browse files
committed
fix bug
Signed-off-by: skylee-01 <497627264@qq.com>
1 parent 4711051 commit 26f2b7f

File tree

3 files changed

+5
-9
lines changed

3 files changed

+5
-9
lines changed

vllm/engine/arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1477,7 +1477,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14771477
speculative_model = self.speculative_config.get("model")
14781478
if speculative_model in ("ngram", "[ngram]"):
14791479
is_ngram_enabled = True
1480-
if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled):
1480+
if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled or is_mlp_speculator_enabled):
14811481
# Other speculative decoding methods are not supported yet.
14821482
_raise_or_fallback(feature_name="Speculative Decoding",
14831483
recommend_to_remove=False)

vllm/v1/spec_decode/mlp_speculator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def __init__(
3131
self.hidden_size = vllm_config.speculative_config.\
3232
draft_model_config.get_hidden_size(
3333
)
34-
self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
34+
self.num_speculative_tokens = vllm_config.speculative_config.\
35+
num_speculative_tokens
3536
self.dtype = vllm_config.model_config.dtype
3637

3738
def propose(
@@ -43,8 +44,7 @@ def propose(
4344
) -> torch.Tensor:
4445
# Generate blocks and compute logits
4546
draft_tokens = self.model.generate_proposals(input_ids, previous_hidden_states, num_predict_tokens,sampling_metadata)
46-
draft_tokens = list(map(lambda x: x[0], zip(*[i.sampled_token_ids.tolist() for i in draft_tokens])))
47-
return draft_tokens
47+
return list(map(lambda x: x[0], zip(*[i.sampled_token_ids.tolist() for i in draft_tokens])))
4848

4949
def load_model(self, target_model: nn.Module) -> None:
5050
self.model = get_model(vllm_config=self.vllm_config,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,16 +1607,13 @@ def propose_draft_token_ids(
16071607
)
16081608
elif self.speculative_config.method == "mlp_speculator":
16091609
assert isinstance(self.drafter, MLPSpeculatorProposer)
1610-
16111610
is_sample_match = sample_hidden_states.shape[0] == len(
16121611
sampled_token_ids)
16131612
# Get last token from each sequence
16141613
draft_input_ids = torch.tensor(
1615-
sampled_token_ids[0] if is_sample_match else
16161614
[tokens[-1] for tokens in sampled_token_ids],
16171615
device=sample_hidden_states.device)
1618-
1619-
if is_sample_match:
1616+
if not is_sample_match:
16201617
# Calculate indices for hidden states
16211618
indices = []
16221619
offset = 0
@@ -1629,7 +1626,6 @@ def propose_draft_token_ids(
16291626
hidden_states = sample_hidden_states[indices]
16301627
else:
16311628
hidden_states = sample_hidden_states
1632-
16331629
spec_token_ids = self.drafter.propose(
16341630
input_ids=draft_input_ids,
16351631
previous_hidden_states=hidden_states,

0 commit comments

Comments
 (0)