Skip to content

Commit 2c1c8e3

Browse files
committed
Merge branch 'main' of github.com:character-tech/vllm
2 parents 7e1c150 + da77215 commit 2c1c8e3

File tree

12 files changed

+160
-7
lines changed

12 files changed

+160
-7
lines changed

vllm/model_executor/models/transformers.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,15 @@ def forward(
369369

370370
return hidden_states
371371

372+
def compute_additional_head(
373+
self,
374+
hidden_states: torch.Tensor,
375+
) -> Optional[torch.Tensor]:
376+
if get_pp_group().is_last_rank and hasattr(self.model,
377+
"compute_additional_head"):
378+
return self.model.compute_additional_head(hidden_states)
379+
return None
380+
372381
def load_weights(self, weights: Iterable[tuple[str,
373382
torch.Tensor]]) -> set[str]:
374383
params_dict = dict(self.named_parameters())
@@ -463,6 +472,14 @@ def compute_logits(
463472
sampling_metadata)
464473
return logits
465474

475+
def compute_additional_head(
476+
self,
477+
hidden_states: torch.Tensor,
478+
) -> Optional[torch.Tensor]:
479+
if hasattr(self.model, "compute_additional_head"):
480+
return self.model.compute_additional_head(hidden_states)
481+
return None
482+
466483
def load_weights(self, weights: Iterable[tuple[str,
467484
torch.Tensor]]) -> set[str]:
468485
loader = AutoWeightsLoader(

vllm/outputs.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from vllm.lora.request import LoRARequest
1313
from vllm.multimodal.inputs import MultiModalPlaceholderDict
1414
from vllm.sampling_params import RequestOutputKind
15-
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
16-
SequenceGroup, SequenceGroupBase, SequenceStatus)
15+
from vllm.sequence import (AdditionalHeads, PromptLogprobs, RequestMetrics,
16+
SampleLogprobs, SequenceGroup, SequenceGroupBase,
17+
SequenceStatus)
1718

1819

1920
@dataclass
@@ -28,6 +29,8 @@ class CompletionOutput:
2829
output text.
2930
logprobs: The log probabilities of the top probability words at each
3031
position if the logprobs are requested.
32+
additional_heads: The additional head outputs of the generated output
33+
text.
3134
finish_reason: The reason why the sequence is finished.
3235
stop_reason: The stop string or token id that caused the completion
3336
to stop, None if the completion finished for some other reason
@@ -43,6 +46,7 @@ class CompletionOutput:
4346
finish_reason: Optional[str] = None
4447
stop_reason: Union[int, str, None] = None
4548
lora_request: Optional[LoRARequest] = None
49+
additional_heads: Optional[AdditionalHeads] = None
4650

4751
def finished(self) -> bool:
4852
return self.finish_reason is not None
@@ -53,6 +57,7 @@ def __repr__(self) -> str:
5357
f"token_ids={self.token_ids}, "
5458
f"cumulative_logprob={self.cumulative_logprob}, "
5559
f"logprobs={self.logprobs}, "
60+
f"additional_heads={self.additional_heads}, "
5661
f"finish_reason={self.finish_reason}, "
5762
f"stop_reason={self.stop_reason})")
5863

vllm/sampling_params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,9 @@ class SamplingParams(
248248
bad_words: Optional[list[str]] = None
249249
_bad_words_token_ids: Optional[list[list[int]]] = None
250250

251+
# Fields used for additional heads (e.g. classifiers)
252+
additional_heads: Optional[bool] = None
253+
251254
@staticmethod
252255
def from_optional(
253256
n: Optional[int] = 1,

vllm/sequence.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class Logprob:
5353
PromptLogprobs = list[Optional[dict[int, Logprob]]]
5454
# {token_id -> logprob} for each sequence group.
5555
SampleLogprobs = list[dict[int, Logprob]]
56+
AdditionalHeads = list[list[float]]
5657

5758

5859
class SequenceStatus(enum.IntEnum):

vllm/v1/core/sched/scheduler.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,8 @@ def update_from_output(
647647
logprobs = model_runner_output.logprobs
648648
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
649649
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
650+
new_additional_head_outputs = \
651+
model_runner_output.additional_head_outputs
650652

651653
new_running: list[Request] = []
652654
outputs: list[EngineCoreOutput] = []
@@ -665,6 +667,13 @@ def update_from_output(
665667

666668
req_index = model_runner_output.req_id_to_index[req_id]
667669
generated_token_ids = sampled_token_ids[req_index]
670+
if new_additional_head_outputs:
671+
head_outputs_list = \
672+
new_additional_head_outputs.additional_head_outputs
673+
additional_head_outputs_per_request = \
674+
head_outputs_list[req_index]
675+
else:
676+
additional_head_outputs_per_request = None
668677

669678
scheduled_spec_token_ids = (
670679
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
@@ -751,7 +760,10 @@ def update_from_output(
751760
new_logprobs=new_logprobs,
752761
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
753762
stop_reason=request.stop_reason,
754-
events=request.take_events()))
763+
events=request.take_events(),
764+
new_additional_head_outputs=
765+
additional_head_outputs_per_request,
766+
))
755767
else:
756768
# Invariant: EngineCore returns no partial prefill outputs.
757769
assert not prompt_logprobs_tensors

vllm/v1/engine/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.multimodal.inputs import PlaceholderRange
1313
from vllm.sampling_params import SamplingParams
1414
from vllm.v1.metrics.stats import SchedulerStats
15-
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
15+
from vllm.v1.outputs import LogprobsLists, LogprobsTensors, AdditionalHeadOutputsPerRequest
1616

1717
# These are possible values of RequestOutput.finish_reason,
1818
# so form part of the external API.
@@ -101,6 +101,8 @@ class EngineCoreOutput(
101101

102102
new_logprobs: Optional[LogprobsLists] = None
103103
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None
104+
new_additional_head_outputs: Optional[
105+
AdditionalHeadOutputsPerRequest] = None
104106

105107
finish_reason: Optional[FinishReason] = None
106108
stop_reason: Union[int, str, None] = None

vllm/v1/engine/additional_heads.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from dataclasses import dataclass
3+
4+
from vllm.logger import init_logger
5+
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
6+
7+
logger = init_logger(__name__)
8+
9+
10+
@dataclass
11+
class AdditionalHeadsProcessor:
12+
"""Processor for additional head outputs from the model.
13+
14+
This class handles storing and managing additional head outputs
15+
for generated tokens, similar to how LogprobsProcessor handles logprobs.
16+
"""
17+
18+
# Additional head outputs for this request
19+
additional_head_outputs: list[list[float]]
20+
21+
@classmethod
22+
def from_new_request(
23+
cls,
24+
request: EngineCoreRequest,
25+
) -> "AdditionalHeadsProcessor":
26+
"""Create a new AdditionalHeadsProcessor for a request.
27+
28+
Args:
29+
request: The engine core request to process additional heads for.
30+
"""
31+
return cls(additional_head_outputs=[], )
32+
33+
def update_from_output(self, output: EngineCoreOutput) -> None:
34+
"""Update with additional head outputs from EngineCore.
35+
36+
Args:
37+
output: The engine core output containing new additional
38+
head outputs.
39+
"""
40+
if output.new_additional_head_outputs is not None:
41+
self.additional_head_outputs.append(
42+
output.new_additional_head_outputs.additional_head_outputs)

vllm/v1/engine/logprobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,4 +195,4 @@ def update_from_output(self, output: EngineCoreOutput) -> None:
195195
if output.new_logprobs is not None:
196196
self._update_sample_logprobs(output.new_logprobs)
197197
if output.new_prompt_logprobs_tensors is not None:
198-
self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)
198+
self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)

vllm/v1/engine/output_processor.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.transformers_utils.tokenizer import AnyTokenizer
1111
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
1212
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
13+
from vllm.v1.engine.additional_heads import AdditionalHeadsProcessor
1314
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
1415
from vllm.v1.engine.logprobs import LogprobsProcessor
1516
from vllm.v1.engine.parallel_sampling import ParentRequest
@@ -81,6 +82,7 @@ def __init__(
8182
prompt: Optional[str],
8283
prompt_token_ids: list[int],
8384
logprobs_processor: LogprobsProcessor,
85+
additional_heads_processor: AdditionalHeadsProcessor,
8486
detokenizer: IncrementalDetokenizer,
8587
max_tokens_param: Optional[int],
8688
arrival_time: float,
@@ -96,6 +98,7 @@ def __init__(
9698
self.prompt_token_ids = prompt_token_ids
9799
self.prompt_len = len(prompt_token_ids)
98100
self.logprobs_processor = logprobs_processor
101+
self.additional_heads_processor = additional_heads_processor
99102
self.detokenizer = detokenizer
100103
self.max_tokens_param = max_tokens_param
101104
self.is_prefilling = True
@@ -130,6 +133,8 @@ def from_new_request(
130133
tokenizer=tokenizer,
131134
request=request,
132135
),
136+
additional_heads_processor=AdditionalHeadsProcessor.
137+
from_new_request(request=request, ),
133138
detokenizer=IncrementalDetokenizer.from_new_request(
134139
tokenizer=tokenizer,
135140
request=request,
@@ -211,11 +216,18 @@ def _new_completion_output(
211216
if delta and logprobs:
212217
logprobs = logprobs[-len(token_ids):]
213218

219+
# Prepare additional heads, based on delta mode
220+
additional_heads = (
221+
self.additional_heads_processor.additional_head_outputs or None)
222+
if delta and additional_heads:
223+
additional_heads = additional_heads[-len(token_ids):]
224+
214225
return CompletionOutput(
215226
index=self.request_index,
216227
text=text,
217228
token_ids=token_ids,
218229
logprobs=logprobs,
230+
additional_heads=additional_heads,
219231
cumulative_logprob=self.logprobs_processor.cumulative_logprob,
220232
finish_reason=str(finish_reason) if finished else None,
221233
stop_reason=stop_reason if finished else None)
@@ -345,8 +357,11 @@ def process_outputs(
345357
finish_reason = FinishReason.STOP
346358
stop_reason = stop_string
347359

348-
# 3) Compute sample and prompt logprobs for request, if required.
360+
# 3) Compute sample and prompt logprobs as well as additional heads
361+
# for request, if required.
349362
req_state.logprobs_processor.update_from_output(engine_core_output)
363+
req_state.additional_heads_processor.update_from_output(
364+
engine_core_output)
350365

351366
# 4) Create and handle RequestOutput objects.
352367
if request_output := req_state.make_request_output(

vllm/v1/outputs.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@
66
import torch
77

88

9+
class AdditionalHeadOutputsPerRequest(NamedTuple):
10+
# num_additional_head_outputs
11+
additional_head_outputs: list[float]
12+
13+
14+
class AdditionalHeadOutputs(NamedTuple):
15+
# num_generated_tokens x num_additional_head_outputs
16+
additional_head_outputs: list[Optional[AdditionalHeadOutputsPerRequest]]
17+
18+
919
class LogprobsLists(NamedTuple):
1020

1121
# [num_reqs, max_num_logprobs + 1]
@@ -100,6 +110,9 @@ class ModelRunnerOutput:
100110
# [prompt_len]
101111
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
102112

113+
# num_reqs x num_generated_tokens x num_additional_head_outputs
114+
additional_head_outputs: Optional[AdditionalHeadOutputs] = None
115+
103116

104117
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
105118
req_ids=[],

0 commit comments

Comments
 (0)