Skip to content

Commit 5975e98

Browse files
Integrate upstream logits processors (#290)
At first it wasn't obvious if it would be easy to integrate the changes of PR vllm-project/vllm#16728 so initially I added PR that copies the sampler files previous to that PR in vllm-spyre. But actually it's easier than I thought because the sampler code is not compiled to the AIU, only the model forward is. Currently in the MinP processor there is a tensor for the cpu and for the device. Since only the model forward runs on the AIU, both tensors end up on the CPU, which means that there is an unnecessary copy from one to the other, but the result is still correct. There is a future upstream PR that will generalize the Logits processor to other sampling parameters: vllm-project/vllm#19912 Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Joe Runde <joe@joerun.de>
1 parent 24c503a commit 5975e98

File tree

11 files changed

+76
-806
lines changed

11 files changed

+76
-806
lines changed

tests/v1/worker/test_spyre_input_batch.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
import torch
88
from vllm.sampling_params import SamplingParams
99
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
10+
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
11+
init_builtin_logitsprocs)
12+
from vllm.v1.sample.metadata import SamplingMetadata
1013

11-
from vllm_spyre.v1.sample.metadata import SamplingMetadata
1214
from vllm_spyre.v1.worker.spyre_input_batch import (CachedRequestState,
1315
InputBatch)
1416

@@ -56,20 +58,27 @@ def _construct_expected_sampling_metadata(
5658
repetition_penalties = [1.0 for _ in range(num_reqs)]
5759
top_k = [0 for _ in range(num_reqs)]
5860
top_p = [0.0 for _ in range(num_reqs)]
59-
min_p = [0.0 for _ in range(num_reqs)]
6061
temperature = [0.0 for _ in range(num_reqs)]
61-
min_tokens = {}
62-
logit_bias = [None] * num_reqs
6362
allowed_token_ids_mask = torch.zeros(num_reqs,
6463
VOCAB_SIZE,
6564
dtype=torch.bool,
6665
device=device)
6766

67+
batch_update_builder = BatchUpdateBuilder()
68+
logitsprocs = init_builtin_logitsprocs(pin_memory_available=False,
69+
max_num_reqs=len(reqs) + 1,
70+
device=device)
71+
6872
bad_words_token_ids = {}
6973
for req in reqs:
7074
if req.req_id not in req_ids_retained:
7175
continue
7276
index_in_input_batch = input_batch.req_id_to_dense_index(req.req_id)
77+
78+
params = req.sampling_params
79+
batch_update_builder.added.append(
80+
(index_in_input_batch, params, req.output_token_ids))
81+
7382
output_token_ids[index_in_input_batch] = req.output_token_ids
7483
prompt_token_ids[index_in_input_batch] = req.prompt_token_ids
7584
presence_penalties[
@@ -80,19 +89,18 @@ def _construct_expected_sampling_metadata(
8089
req.sampling_params.repetition_penalty)
8190
top_k[index_in_input_batch] = req.sampling_params.top_k
8291
top_p[index_in_input_batch] = req.sampling_params.top_p
83-
min_p[index_in_input_batch] = req.sampling_params.min_p
8492
temperature[index_in_input_batch] = req.sampling_params.temperature
85-
min_tokens[index_in_input_batch] = (
86-
req.sampling_params.min_tokens,
87-
req.sampling_params.all_stop_token_ids)
88-
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
8993
if req.sampling_params.allowed_token_ids:
9094
allowed_token_ids_mask[index_in_input_batch][
9195
req.sampling_params.allowed_token_ids] = True
9296
if req.sampling_params.bad_words_token_ids:
9397
bad_words_token_ids[
9498
index_in_input_batch] = req.sampling_params.bad_words_token_ids
9599

100+
batch_update = batch_update_builder.get_and_reset(num_reqs)
101+
for logit_proc in logitsprocs.all:
102+
logit_proc.update_state(batch_update)
103+
96104
return SamplingMetadata(
97105
temperature=torch.tensor(temperature, dtype=torch.float,
98106
device=device),
@@ -102,8 +110,6 @@ def _construct_expected_sampling_metadata(
102110
top_p, dtype=torch.float, device=device),
103111
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
104112
top_k, dtype=torch.int, device=device),
105-
min_p=None if all(x == 0.0 for x in min_p) else torch.tensor(
106-
min_p, dtype=torch.float, device=device),
107113
generators={},
108114
max_num_logprobs=0,
109115
prompt_token_ids=make_tensor_with_pad(
@@ -122,13 +128,12 @@ def _construct_expected_sampling_metadata(
122128
dtype=torch.float,
123129
device=device),
124130
output_token_ids=output_token_ids,
125-
min_tokens=min_tokens,
126131
no_penalties=(all(x == 0 for x in presence_penalties)
127132
and all(x == 0 for x in frequency_penalties)
128133
and all(x == 1 for x in repetition_penalties)),
129-
logit_bias=logit_bias,
130134
allowed_token_ids_mask=allowed_token_ids_mask,
131135
bad_words_token_ids=bad_words_token_ids,
136+
logitsprocs=logitsprocs,
132137
)
133138

134139

@@ -196,10 +201,8 @@ def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
196201
sampling_metadata.prompt_token_ids)
197202
assert (expected_sampling_metadata.output_token_ids ==
198203
sampling_metadata.output_token_ids)
199-
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
200204
assert expected_sampling_metadata.no_penalties == \
201205
sampling_metadata.no_penalties
202-
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
203206
if sampling_metadata.allowed_token_ids_mask:
204207
assert torch.allclose(
205208
expected_sampling_metadata.allowed_token_ids_mask,

vllm_spyre/model_executor/model_loader/spyre.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77
import torch._inductor.config
88
import torch.distributed as dist
99
import torch.nn as nn
10-
import vllm.envs as envs
1110
from fms.models import get_model
1211
from transformers import PretrainedConfig
1312
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
1413
from vllm.forward_context import get_forward_context
1514
from vllm.logger import init_logger
1615
from vllm.model_executor.layers.logits_processor import LogitsProcessor
17-
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
16+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
1817
from vllm.model_executor.model_loader.weight_utils import (
1918
download_weights_from_hf)
2019
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -32,14 +31,6 @@
3231
logger = init_logger(__name__)
3332

3433

35-
def get_sampler() -> torch.nn.Module:
36-
if envs.VLLM_USE_V1:
37-
# Lazy import: the v1 package isn't distributed
38-
from vllm_spyre.v1.sample.sampler import Sampler as V1Sampler
39-
return V1Sampler()
40-
return Sampler()
41-
42-
4334
@dataclass
4435
class SpyreAttentionMetadata:
4536
slot_mapping: torch.Tensor = None

vllm_spyre/v1/sample/__init__.py

Whitespace-only changes.

vllm_spyre/v1/sample/metadata.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

vllm_spyre/v1/sample/ops/__init__.py

Whitespace-only changes.

vllm_spyre/v1/sample/ops/bad_words.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

vllm_spyre/v1/sample/ops/penalties.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

0 commit comments

Comments
 (0)