diff --git a/examples/offline_inference/logits_processor.py b/examples/offline_inference/logits_processor.py new file mode 100644 index 00000000000..84de136c69d --- /dev/null +++ b/examples/offline_inference/logits_processor.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""This example demonstrates instantiating vLLM with a custom logits processor +class object. + +For testing purposes, a dummy logits processor is employed which, if +`target_token` is passed as a keyword argument to `SamplingParams.extra_args`, +will mask out all tokens except `target_token`. + +A batch is constructed with `temperature=0.0` and 50% of requests specifying +`target_token`, and for these requests - and *only* these requests - we +expect the `target_token` to be decoded in each step, yielding an output +similar to that shown below: + +Generated Outputs: +------------------------------------------------------------ +Prompt: 'Hello, my name is' +Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '" +------------------------------------------------------------ +Prompt: 'The president of the United States is' +Output: " not a racist. He is a racist.\nHe's a racist because he" +------------------------------------------------------------ +Prompt: 'The capital of France is' +Output: ' also also also also also also also also also also also also also + also also also' +------------------------------------------------------------ +Prompt: 'The future of AI is' +Output: ' in the hands of the people.\n\nThe future of AI is in the' +------------------------------------------------------------ +""" + +from vllm import LLM, SamplingParams +from vllm.test_utils import DummyLogitsProcessor + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=0.0, extra_args={"target_token": 128}), + SamplingParams(temperature=0.0), + SamplingParams(temperature=0.0, extra_args={"target_token": 67}), + SamplingParams(temperature=0.0), +] + + +def main(): + # Create an LLM. + llm = LLM( + model="facebook/opt-125m", + logits_processors=[DummyLogitsProcessor], + ) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params_list) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_completion_client_logits_processor.py b/examples/online_serving/openai_completion_client_logits_processor.py new file mode 100644 index 00000000000..df6e4e94296 --- /dev/null +++ b/examples/online_serving/openai_completion_client_logits_processor.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse + +from openai import OpenAI + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + + +def parse_args(): + parser = argparse.ArgumentParser(description="Client for vLLM API server") + parser.add_argument( + "--stream", action="store_true", help="Enable streaming response" + ) + return parser.parse_args() + + +def main(args): + client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, + ) + + models = client.models.list() + model = models.data[0].id + + # Completion API + completion = client.completions.create( + model=model, + prompt="A robot may not injure a human being", + echo=False, + n=2, + stream=args.stream, + logprobs=3, + ) + + print("-" * 50) + print("Completion results:") + if args.stream: + for c in completion: + print(c) + else: + print(completion) + print("-" * 50) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/tests/v1/entrypoints/openai/test_multi_api_servers.py b/tests/v1/entrypoints/openai/test_multi_api_servers.py index e84b5e3095d..d2da53c6856 100644 --- a/tests/v1/entrypoints/openai/test_multi_api_servers.py +++ b/tests/v1/entrypoints/openai/test_multi_api_servers.py @@ -2,11 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import os -import re import openai # use the official client for correctness check import pytest import pytest_asyncio +import regex as re import requests from tests.utils import RemoteOpenAIServer diff --git a/tests/v1/sample/test_logits_processors.py b/tests/v1/sample/logits_processors/test_correctness.py similarity index 97% rename from tests/v1/sample/test_logits_processors.py rename to tests/v1/sample/logits_processors/test_correctness.py index 84ee3b0392b..71e4ac74f5f 100644 --- a/tests/v1/sample/test_logits_processors.py +++ b/tests/v1/sample/logits_processors/test_correctness.py @@ -14,18 +14,20 @@ create_prompt_tokens_tensor, fake_apply_logitsprocs, fake_update_logitsprocs_state) +from vllm.config import VllmConfig from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available # yapf: disable -from vllm.v1.sample.logits_processor import (BatchUpdate, BatchUpdateBuilder, +from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, LogitBiasLogitsProcessor, LogitsProcessor, MinPLogitsProcessor, MinTokensLogitsProcessor, - MoveDirectionality, - init_builtin_logitsprocs) + build_logitsprocs) # yapf: enable +from vllm.v1.sample.logits_processor.interface import (BatchUpdate, + MoveDirectionality) from vllm.v1.sample.metadata import SamplingMetadata PIN_MEMORY_AVAILABLE = is_pin_memory_available() @@ -53,6 +55,7 @@ class LogitsProcsRequestParams: workload_index: int logitproc_type: LogitprocType # Logitproc enabled, specified by str id out_tokens: list[int] # Output tokens required for min tokens test + prompt_tokens: list[int] # Dummy prompt tokens placeholder params: SamplingParams # Settings customized for logitproc def __init__(self, workload_index: int, logitproc_type: LogitprocType): @@ -63,6 +66,7 @@ def __init__(self, workload_index: int, logitproc_type: LogitprocType): # don't matter *for these tests* so use 0 as a dummy value self.out_tokens = ([0] * (MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2))) + self.prompt_tokens = [] self.params = _sampling_params_from_logitproc(logitproc_type) def __str__(self): @@ -88,11 +92,11 @@ def _generate_fake_sampling_metadata( vocab_size, size=np.random.randint( 1, MAX_NUM_PROMPT_TOKENS)).tolist()) - logitsprocs = init_builtin_logitsprocs( - pin_memory_available=PIN_MEMORY_AVAILABLE, - max_num_reqs=MAX_NUM_REQS + 1, - device=device) - + logitsprocs = build_logitsprocs( + vllm_config=VllmConfig(), + device=device, + is_pin_memory=PIN_MEMORY_AVAILABLE, + ) fake_sampling_metadata = SamplingMetadata( temperature=torch.full((batch_size, ), 0.0), all_greedy=True, @@ -462,7 +466,8 @@ def _generate_fake_step_update( # Replace as many removed requests as possible with added requests add_remove_idx = batch_update_builder.pop_removed() batch_update_builder.added.append( - (add_remove_idx, add_req_params.params, add_req_params.out_tokens)) + (add_remove_idx, add_req_params.params, add_req_params.out_tokens, + add_req_params.prompt_tokens)) persistent_batch[add_remove_idx] = add_req_params # Append remaining added requests to end of batch @@ -470,7 +475,8 @@ def _generate_fake_step_update( num_step_add_replace):(wdx + num_step_add)] batch_update_builder.added.extend([ - (adx + batch_size, add_req_params.params, add_req_params.out_tokens) + (adx + batch_size, add_req_params.params, add_req_params.out_tokens, + add_req_params.prompt_tokens) for adx, add_req_params in enumerate(add_reqs_append) ]) persistent_batch.extend(add_reqs_append) diff --git a/tests/v1/sample/logits_processors/test_custom_cli.py b/tests/v1/sample/logits_processors/test_custom_cli.py new file mode 100644 index 00000000000..a58815f7455 --- /dev/null +++ b/tests/v1/sample/logits_processors/test_custom_cli.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import random + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer +from tests.v1.sample.logits_processors.utils import (DUMMY_LOGITPROC_ARG, + MAX_TOKENS, MODEL_NAME, + TEMP_GREEDY, prompts) +from vllm.test_utils import DUMMY_LOGITPROC_FQCN + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager" + ] + + +@pytest.fixture(scope="function", + params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]]) +def server(default_server_args, request, monkeypatch): + """Server cli arg list is parameterized by logitproc source: either fully- + qualified class name (FQCN) specified by `--logits-processors`, or + entrypoint. + + Entrypoint requires no cli argument, but for testing purposes an + environment variable must be set to mock a dummy logit processor entrypoint + """ + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1") + if request.param: + # Append FQCN argument + default_server_args = default_server_args + request.param + else: + # Enable mock logit processor entrypoint + monkeypatch.setenv("VLLM_MOCK_LP_ENTRYPOINT", "1") + + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +api_kwargs = { + "temperature": TEMP_GREEDY, + "max_tokens": MAX_TOKENS, + "logprobs": 0, +} + +extra_body_kwargs = {"vllm_xargs": {DUMMY_LOGITPROC_ARG: 128}} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_custom_logitsprocs_cli(client: openai.AsyncOpenAI, + model_name: str): + """Test CLI interface for passing custom logitsprocs + + Launch vLLM OpenAI-compatible server with CLI argument to loads a custom + logitproc that has a well-defined behavior (mask out all tokens except one + `target_token`). Logitproc is specified by fully-qualified class name (FQCN) + + Pass in requests, 50% of which pass a `target_token` value + in through `extra_body["vllm_xargs"]`, 50% of which do not. + + Validate that requests which activate the custom logitproc, only output + `target_token` + """ + use_dummy_logitproc = True + for prompt in prompts: + # Send vLLM API request; for some requests, activate dummy logitproc + kwargs = { + **api_kwargs, + } + if use_dummy_logitproc: + target_token = random.choice([128, 67]) + # For requests which activate the dummy logitproc, choose one of + # two `target_token` values which are known not to be EOS tokens + kwargs["extra_body"] = { + "vllm_xargs": { + DUMMY_LOGITPROC_ARG: target_token + } + } + batch = await client.completions.create( + model=model_name, + prompt=prompt, + **kwargs, + ) + + if use_dummy_logitproc: + # Only for requests which activate dummy logitproc - validate that + # only `target_token` is generated + choices: openai.types.CompletionChoice = batch.choices + toks = choices[0].logprobs.tokens + if not all([x == toks[0] for x in toks]): + raise AssertionError( + f"Generated {toks} should all be {toks[0]}") + + # Alternate whether to activate dummy logitproc for each request + use_dummy_logitproc = not use_dummy_logitproc diff --git a/tests/v1/sample/logits_processors/test_custom_py.py b/tests/v1/sample/logits_processors/test_custom_py.py new file mode 100644 index 00000000000..b059f23b6a7 --- /dev/null +++ b/tests/v1/sample/logits_processors/test_custom_py.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random + +import pytest + +from tests.v1.sample.logits_processors.utils import (DUMMY_LOGITPROC_ARG, + MAX_TOKENS, MODEL_NAME, + TEMP_GREEDY, + CustomLogitprocSource, + prompts) +from vllm import LLM, SamplingParams +from vllm.test_utils import DUMMY_LOGITPROC_FQCN, DummyLogitsProcessor + +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=TEMP_GREEDY, + max_tokens=MAX_TOKENS, + extra_args={DUMMY_LOGITPROC_ARG: 128}), + SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS), + SamplingParams(temperature=TEMP_GREEDY, + max_tokens=MAX_TOKENS, + extra_args={DUMMY_LOGITPROC_ARG: 67}), + SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS), +] + + +def _run_test(kwargs: dict, logitproc_loaded: bool): + # Create a vLLM instance and load custom logitproc + llm_logitproc = LLM( + model=MODEL_NAME, + gpu_memory_utilization=0.1, + **kwargs, + ) + + # Create a reference vLLM instance without custom logitproc + llm_ref = LLM(model=MODEL_NAME, gpu_memory_utilization=0.1) + + # Run inference with logitproc loaded + outputs_logitproc = llm_logitproc.generate(prompts, sampling_params_list) + + # Reference run + outputs_ref = llm_ref.generate(prompts, sampling_params_list) + + # Validate outputs + for bdx, (out_lp, out_ref, params) in enumerate( + zip(outputs_logitproc, outputs_ref, sampling_params_list)): + lp_toks = out_lp.outputs[0].token_ids + if logitproc_loaded and params.extra_args: + # This request exercises custom logitproc; validate that logitproc + # forces `target_token` to be decoded in each step + target_token = params.extra_args[DUMMY_LOGITPROC_ARG] + if not all(x == target_token for x in lp_toks): + raise AssertionError( + f"Request {bdx} generated {lp_toks}, shoud all be " + f"{target_token}") + else: + # This request does not exercise custom logitproc (or custom + # logitproc is not enabled on this server); validate against + # reference result + ref_toks = out_ref.outputs[0].token_ids + if lp_toks != ref_toks: + raise AssertionError( + f"Request {bdx} generated {lp_toks}, should match " + f"{ref_toks}") + + +@pytest.mark.parametrize("logitproc_source", list(CustomLogitprocSource)) +def test_custom_logitsprocs_py(monkeypatch, + logitproc_source: CustomLogitprocSource): + """Test Python interface for passing custom logitsprocs + + Construct an `LLM` instance which loads a custom logitproc that has a + well-defined behavior (mask out all tokens except one `target_token`) + + Construct a reference `LLM` instance with no custom logitproc + + Pass in a batch of requests, 50% of which pass a `target_token` value + in through `SamplingParams.extra_args`, 50% of which do not. + + Validate that + * Requests which do not activate the custom logitproc, yield the same + results for both `LLM` instances + * Requests which activate the custom logitproc, only output `target_token` + + Args: + logitproc_source: what source (entrypoint, fully-qualified class name + (FQCN), or class object) the user pulls the + logitproc from + """ + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1") + random.seed(40) + + # Choose LLM args based on logitproc source + kwargs = {} + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_NONE: + # Scenario: the server does not load any custom logitproc + # Every other scenario is a different way of loading a custom logitproc + _run_test(kwargs, logitproc_loaded=False) + return + elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT: + # Scenario: vLLM loads a logitproc from a preconfigured entrypoint + # To that end, mock a dummy logitproc entrypoint + monkeypatch.setenv("VLLM_MOCK_LP_ENTRYPOINT", "1") + elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN: + # Scenario: load logitproc based on fully-qualified class name (FQCN) + kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN] + elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS: + # Scenario: load logitproc from provided class object + kwargs["logits_processors"] = [DummyLogitsProcessor] + + # Test one of the above scenarios where the server loads a custom logitproc + _run_test(kwargs, logitproc_loaded=True) diff --git a/tests/v1/sample/logits_processors/utils.py b/tests/v1/sample/logits_processors/utils.py new file mode 100644 index 00000000000..4d0ce1a2e07 --- /dev/null +++ b/tests/v1/sample/logits_processors/utils.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from enum import Enum, auto + +MODEL_NAME = "facebook/opt-125m" +DUMMY_LOGITPROC_ARG = "target_token" +TEMP_GREEDY = 0.0 +MAX_TOKENS = 20 + + +class CustomLogitprocSource(Enum): + """How to source a logitproc for testing purposes""" + LOGITPROC_SOURCE_NONE = auto() # No custom logitproc + LOGITPROC_SOURCE_ENTRYPOINT = auto() # Via entrypoint + LOGITPROC_SOURCE_FQCN = auto() # Via fully-qualified class name (FQCN) + LOGITPROC_SOURCE_CLASS = auto() # Via provided class object + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 3a4d48afc9d..4e912f98f37 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from vllm.platforms import current_platform -from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, RejectionSampler) @@ -69,7 +69,7 @@ def create_sampling_metadata( output_token_ids=[], allowed_token_ids_mask=None, bad_words_token_ids={}, - logitsprocs=LogitsProcessorManager(), + logitsprocs=LogitsProcessors(), ) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index ea10661ea11..d584ee98cb5 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -9,7 +9,7 @@ from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available, make_tensor_with_pad -from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler @@ -147,7 +147,7 @@ def _create_default_sampling_metadata( no_penalties=True, allowed_token_ids_mask=None, bad_words_token_ids={}, - logitsprocs=LogitsProcessorManager(), + logitsprocs=LogitsProcessors(), ) return fake_sampling_metadata diff --git a/tests/v1/sample/utils.py b/tests/v1/sample/utils.py index e33efb413d0..1db4bb1a591 100644 --- a/tests/v1/sample/utils.py +++ b/tests/v1/sample/utils.py @@ -10,7 +10,8 @@ from vllm import CompletionOutput from vllm.utils import make_tensor_with_pad -from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor +from vllm.v1.sample.logits_processor import LogitsProcessor +from vllm.v1.sample.logits_processor.interface import BatchUpdate from vllm.v1.sample.metadata import SamplingMetadata diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 943a13debad..d4814beb9b7 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -13,7 +13,7 @@ from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.pool.metadata import PoolingMetadata -from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -169,7 +169,7 @@ def _construct_expected_sampling_metadata( and all(x == 1 for x in repetition_penalties)), allowed_token_ids_mask=allowed_token_ids_mask, bad_words_token_ids=bad_words_token_ids, - logitsprocs=LogitsProcessorManager(), + logitsprocs=LogitsProcessors(), ) diff --git a/vllm/config.py b/vllm/config.py index 22f74017136..5d3afdb9815 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -51,6 +51,7 @@ cuda_device_count_stateless, get_cpu_memory, get_open_port, is_torch_equal_or_newer, random_uuid, resolve_obj_by_qualname) +from vllm.v1.sample.logits_processor.interface import LogitsProcessor # yapf: enable @@ -4418,6 +4419,8 @@ class VllmConfig: you are using. Contents must be hashable.""" instance_id: str = "" """The ID of the vLLM instance.""" + logits_processors: Optional[list[type[LogitsProcessor]]] = None + """A list of logitproc types to construct for this vLLM instance""" def compute_hash(self) -> str: """ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ae5eb46fa96..71236f7fea4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -41,6 +41,8 @@ from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor) +from vllm.v1.sample.logits_processor import (LogitsProcessor, + load_custom_logitsprocs) # yapf: enable @@ -439,6 +441,9 @@ class EngineArgs: enable_multimodal_encoder_data_parallel: bool = \ ParallelConfig.enable_multimodal_encoder_data_parallel + logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None + """Custom logitproc types""" + async_scheduling: bool = SchedulerConfig.async_scheduling def __post_init__(self): @@ -451,6 +456,14 @@ def __post_init__(self): # Setup plugins from vllm.plugins import load_general_plugins load_general_plugins() + if envs.VLLM_USE_V1: + # Setup V1 custom logitsprocs. Load plugins & any logitsprocs + # specified by FQCN + self.logits_processors = load_custom_logitsprocs( + self.logits_processors) + elif self.logits_processors is not None: + raise ValueError( + "vLLM V0 does not support logits_processors engine args") @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -542,6 +555,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **model_kwargs["model_impl"]) model_group.add_argument("--override-attention-dtype", **model_kwargs["override_attention_dtype"]) + model_group.add_argument( + "--logits-processors", + nargs='+', + help="One or more logits processors' fully-qualified class names.") # Model loading arguments load_kwargs = get_kwargs(LoadConfig) @@ -1293,6 +1310,7 @@ def create_engine_config( kv_transfer_config=self.kv_transfer_config, kv_events_config=self.kv_events_config, additional_config=self.additional_config, + logits_processors=self.logits_processors, ) return config diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e7398ecc23c..db947a0846e 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -52,6 +52,7 @@ get_cached_tokenizer) from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of +from vllm.v1.sample.logits_processor import LogitsProcessor if TYPE_CHECKING: from vllm.v1.metrics.reader import Metric @@ -194,6 +195,8 @@ def __init__( override_pooler_config: Optional[PoolerConfig] = None, compilation_config: Optional[Union[int, dict[str, Any], CompilationConfig]] = None, + logits_processors: Optional[list[Union[str, + type[LogitsProcessor]]]] = None, **kwargs, ) -> None: """LLM constructor.""" @@ -267,6 +270,7 @@ def __init__( mm_processor_kwargs=mm_processor_kwargs, override_pooler_config=override_pooler_config, compilation_config=compilation_config_instance, + logits_processors=logits_processors, **kwargs, ) diff --git a/vllm/envs.py b/vllm/envs.py index 502978c7685..ebd157b9744 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -141,6 +141,7 @@ VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120 VLLM_USE_CUDNN_PREFILL: bool = False VLLM_LOOPBACK_IP: str = "" + VLLM_MOCK_LP_ENTRYPOINT: bool = False def get_default_cache_root(): @@ -974,6 +975,10 @@ def get_vllm_port() -> Optional[int]: # Used to force set up loopback IP "VLLM_LOOPBACK_IP": lambda: os.getenv("VLLM_LOOPBACK_IP", ""), + + # Controls whether or not to use cudnn prefill + "VLLM_MOCK_LP_ENTRYPOINT": + lambda: bool(int(os.getenv("VLLM_MOCK_LP_ENTRYPOINT", "0"))), } # --8<-- [end:env-vars-definition] diff --git a/vllm/test_utils.py b/vllm/test_utils.py index c6b126d002b..c50aa9a69fe 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -1,5 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm.config import VllmConfig +from vllm.sampling_params import SamplingParams +from vllm.v1.sample.logits_processor import LogitsProcessor +from vllm.v1.sample.logits_processor.interface import (BatchUpdate, + MoveDirectionality) + MODELS_ON_S3 = [ "adept/fuyu-8b", "ai21labs/AI21-Jamba-1.5-Mini", @@ -128,3 +138,68 @@ ] MODEL_WEIGHTS_S3_BUCKET = "s3://vllm-ci-model-weights" + +DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc" + +DUMMY_LOGITPROC_FQCN = "vllm.test_utils:DummyLogitsProcessor" + + +class DummyLogitsProcessor(LogitsProcessor): + """Fake logit processor to support unit testing and examples""" + + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + self.req_info = {} + + def is_argmax_invariant(self) -> bool: + """Never impacts greedy sampling""" + return False + + def update_state(self, batch_update: Optional[BatchUpdate]): + if not batch_update: + return + + # Process added requests. + for index, params, _, _ in batch_update.added: + if isinstance(params, SamplingParams) and params.extra_args: + target_token = params.extra_args.get("target_token", None) + else: + target_token = None + self.req_info[index] = target_token + + if self.req_info: + # Process removed requests. + for index in batch_update.removed: + self.req_info.pop(index, None) + + # Process moved requests, unidirectional (a->b) and swap (a<->b) + for adx, bdx, direct in batch_update.moved: + if direct == MoveDirectionality.SWAP: + (self.req_info[adx], + self.req_info[bdx]) = (self.req_info[bdx], + self.req_info[adx]) + else: + self.req_info[bdx] = self.req_info[adx] + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + for bdx in range(logits.shape[0]): + if (target_token := self.req_info[bdx]) is not None: + mask = torch.ones_like(logits[bdx, :], dtype=torch.bool) + mask[target_token] = False + logits[bdx, mask] = float('-inf') + + return logits + + +class EntryPoint: + """Fake entrypoint class""" + + def __init__(self): + self.name = DUMMY_LOGITPROC_ENTRYPOINT + self.value = DUMMY_LOGITPROC_FQCN + + def load(self): + return DummyLogitsProcessor + + +entry_points = lambda group: [EntryPoint()] diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index bbcc2a523dc..c7004c5c945 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2539,7 +2539,7 @@ def direct_register_custom_op( def resolve_obj_by_qualname(qualname: str) -> Any: """ - Resolve an object by its fully qualified name. + Resolve an object by its fully-qualified class name. """ module_name, obj_name = qualname.rsplit(".", 1) module = importlib.import_module(module_name) diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py new file mode 100644 index 00000000000..ff144ee0a84 --- /dev/null +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib +import itertools +from typing import TYPE_CHECKING, Optional, Union + +import torch + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor, + MinPLogitsProcessor, + MinTokensLogitsProcessor) +from vllm.v1.sample.logits_processor.interface import (BatchUpdate, + LogitsProcessor, + MoveDirectionality) +from vllm.v1.sample.logits_processor.state import (BatchUpdateBuilder, + LogitsProcessors) + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + +LOGITSPROCS_GROUP = 'vllm.logits_processors' + +BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ + MinTokensLogitsProcessor, + LogitBiasLogitsProcessor, + MinPLogitsProcessor, +] + + +def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]: + """Load all installed logit processor plugins""" + + import sys + + if envs.VLLM_MOCK_LP_ENTRYPOINT: + from vllm.test_utils import entry_points + elif sys.version_info < (3, 10): + from importlib_metadata import entry_points + else: + from importlib.metadata import entry_points + + installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP) + if len(installed_logitsprocs_plugins) == 0: + logger.debug("No logitsprocs plugins installed (group %s).", + LOGITSPROCS_GROUP) + return [] + + # Load logitsprocs plugins + logger.debug("Loading installed logitsprocs plugins (group %s):", + LOGITSPROCS_GROUP) + classes: list[type[LogitsProcessor]] = [] + for entrypoint in installed_logitsprocs_plugins: + try: + logger.debug("- Loading logitproc plugin entrypoint=%s target=%s", + entrypoint.name, entrypoint.value) + classes.append(entrypoint.load()) + except Exception as e: + raise RuntimeError( + f"Failed to load LogitsProcessor plugin {entrypoint}") from e + return classes + + +def _load_logitsprocs_by_fqcns( + logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] +) -> list[type[LogitsProcessor]]: + """Load logit processor types, identifying them by fully-qualified class + names (FQCNs). + + Effectively, a mixed list of logitproc types and FQCN strings is converted + into a list of entirely logitproc types, by loading from the FQCNs. + + FQCN syntax is : i.e. x.y.z:CustomLogitProc + + Already-loaded logitproc types must be subclasses of LogitsProcessor + + Args: + logits_processors: Potentially mixed list of logitsprocs types and FQCN + strings for logitproc types + + Returns: + List of logitproc types + + """ + if not logits_processors: + return [] + + logger.debug( + "%s additional custom logits processors specified, checking whether " + "they need to be loaded.", len(logits_processors)) + + classes: list[type[LogitsProcessor]] = [] + for ldx, logitproc in enumerate(logits_processors): + if isinstance(logitproc, type): + logger.debug(" - Already-loaded logit processor: %s", + logitproc.__name__) + if not issubclass(logitproc, LogitsProcessor): + raise ValueError( + f"{logitproc.__name__} is not a subclass of LogitsProcessor" + ) + classes.append(logitproc) + continue + + logger.debug("- Loading logits processor %s", logitproc) + module_path, qualname = logitproc.split(":") + + try: + # Load module + module = importlib.import_module(module_path) + except Exception as e: + raise RuntimeError( + f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}" + ) from e + + # Walk down dotted name to get logitproc class + obj = module + for attr in qualname.split("."): + obj = getattr(obj, attr) + if not isinstance(obj, type): + raise ValueError("Loaded logit processor must be a type.") + if not issubclass(obj, LogitsProcessor): + raise ValueError( + f"{obj.__name__} must be a subclass of LogitsProcessor") + classes.append(obj) + + return classes + + +def load_custom_logitsprocs( + logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]], +) -> list[type[LogitsProcessor]]: + """Load all custom logits processors. + + * First load all installed logitproc plugins + * Second load custom logitsprocs pass by the user at initialization time + + Args: + logits_processors: potentially mixed list of logitproc types and + logitproc type fully-qualified names (FQCNs) + which need to be loaded + + Returns: + A list of all loaded logitproc types + """ + from vllm.platforms import current_platform + if current_platform.is_tpu(): + # No logitsprocs specified by caller + # TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs + return [] + + return (_load_logitsprocs_plugins() + + _load_logitsprocs_by_fqcns(logits_processors)) + + +def build_logitsprocs(vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool) -> LogitsProcessors: + custom_logitsprocs_classes = vllm_config.logits_processors or [] + return LogitsProcessors( + ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain( + BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes)) + + +__all__ = [ + "LogitsProcessor", + "LogitBiasLogitsProcessor", + "MinPLogitsProcessor", + "MinTokensLogitsProcessor", + "BatchUpdate", + "BatchUpdateBuilder", + "MoveDirectionality", + "LogitsProcessors", + "build_logitsprocs", + "load_custom_logitsprocs", +] diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor/builtin.py similarity index 55% rename from vllm/v1/sample/logits_processor.py rename to vllm/v1/sample/logits_processor/builtin.py index 3a4c25964e7..5bbfc63320a 100644 --- a/vllm/v1/sample/logits_processor.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -1,238 +1,30 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import dataclasses -from abc import ABC, abstractmethod -from collections.abc import Iterator, Sequence -from dataclasses import dataclass, field -from enum import Enum -from itertools import chain -from typing import Optional, Union +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional import torch -from torch._prims_common import DeviceLikeType - -from vllm import PoolingParams, SamplingParams -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -class MoveDirectionality(Enum): - # One-way i1->i2 req move within batch - UNIDIRECTIONAL = 0 - # Two-way i1<->i2 req swap within batch - SWAP = 1 - - -# (index, params, output_tok_ids) tuples for new -# requests added to the batch. -AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]] -# (index 1, index 2, directionality) tuples representing -# one-way moves or two-way swaps of requests in batch -MovedRequest = tuple[int, int, MoveDirectionality] -# Batch indices of any removed requests. -RemovedRequest = int - - -@dataclasses.dataclass(frozen=True) -class BatchUpdate: - """Persistent batch state change info for logitsprocs""" - batch_size: int # Current num reqs in batch - - # Metadata for requests added to, removed from, and moved - # within the persistent batch. - # - # Note: each added request is represented as - # (index, params, output_tok_ids) - # Key assumption: output_tok_ids is a reference to the - # request's running output tokens list; in this way - # the logits processors always see the latest list of - # generated tokens - removed: Sequence[RemovedRequest] - moved: Sequence[MovedRequest] - added: Sequence[AddedRequest] - - -class BatchUpdateBuilder: - """Helps track persistent batch state changes and build - a batch update data structure for logitsprocs - - Assumptions: - * All information about requests removed from persistent batch - during a step is aggregated in self._removed through calls to - self.removed_append() at the beginning of a step. This must happen - before the first time that self.removed, self.pop_removed() - or self.peek_removed() are invoked in a given step - * After the first time that self.removed, self.pop_removed() - or self.peek_removed() are read in a step, no new removals - are registered using self.removed_append() - * Elements of self._removed are never directly modified, added or - removed (i.e. modification is only via self.removed_append() and - self.pop_removed()) - - Guarantees under above assumptions: - * self.removed is always sorted in descending order - * self.pop_removed() and self.peek_removed() both return - the lowest removed request index in the current step - """ - - _removed: list[RemovedRequest] - _is_removed_sorted: bool - moved: list[MovedRequest] - added: list[AddedRequest] - - def __init__( - self, - removed: Optional[list[RemovedRequest]] = None, - moved: Optional[list[MovedRequest]] = None, - added: Optional[list[AddedRequest]] = None, - ) -> None: - self._removed = removed or [] - self.moved = moved or [] - self.added = added or [] - self._is_removed_sorted = False - - def _ensure_removed_sorted(self) -> None: - """Sort removed request indices in - descending order. - - Idempotent after first call in a - given step, until reset. - """ - if not self._is_removed_sorted: - self._removed.sort(reverse=True) - self._is_removed_sorted = True - - @property - def removed(self) -> list[RemovedRequest]: - """Removed request indices sorted in - descending order""" - self._ensure_removed_sorted() - return self._removed - - def removed_append(self, index: int) -> None: - """Register the removal of a request from - the persistent batch. - - Must not be called after the first time - self.removed, self.pop_removed() or - self.peek_removed() are invoked. - - Args: - index: request index - """ - if self._is_removed_sorted: - raise RuntimeError("Cannot register new removed request after" - " self.removed has been read.") - self._removed.append(index) - - def has_removed(self) -> bool: - return bool(self._removed) - - def peek_removed(self) -> Optional[int]: - """Return lowest removed request index""" - if self.has_removed(): - self._ensure_removed_sorted() - return self._removed[-1] - return None - - def pop_removed(self) -> Optional[int]: - """Pop lowest removed request index""" - if self.has_removed(): - self._ensure_removed_sorted() - return self._removed.pop() - return None - - def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: - """Generate a logitsprocs batch update data structure - and reset internal batch update builder state. - - Args: - batch_size: current persistent batch size - - Returns: - Frozen logitsprocs batch update instance; `None` if no updates - """ - # Reset removal-sorting logic - self._is_removed_sorted = False - if not any((self._removed, self.moved, self.added)): - # No update; short-circuit - return None - # Build batch state update - batch_update = BatchUpdate( - batch_size=batch_size, - removed=self._removed, - moved=self.moved, - added=self.added, - ) - # Reset removed/moved/added update lists - self._removed = [] - self.moved = [] - self.added = [] - return batch_update - - -class LogitsProcessor(ABC): - - @abstractmethod - def apply(self, logits: torch.Tensor) -> torch.Tensor: - raise NotImplementedError - @abstractmethod - def is_argmax_invariant(self) -> bool: - """True if logits processor has no impact on the - argmax computation in greedy sampling. - NOTE: may or may not have the same value for all - instances of a given LogitsProcessor subclass, - depending on subclass implementation. - TODO(andy): won't be utilized until logits - processors are user-extensible - """ - raise NotImplementedError - - @abstractmethod - def update_state( - self, - batch_update: Optional[BatchUpdate], - ) -> None: - """Called when there are new output tokens, prior - to each forward pass. - - Args: - batch_update is non-None iff there have been - changes to the batch makeup. - """ - raise NotImplementedError - - -@dataclass -class LogitsProcessorManager: - """Encapsulates initialized logitsproc objects.""" - argmax_invariant: list[LogitsProcessor] = field( - default_factory=list) # argmax-invariant logitsprocs - non_argmax_invariant: list[LogitsProcessor] = field( - default_factory=list) # non-argmax-invariant logitsprocs - - @property - def all(self) -> Iterator[LogitsProcessor]: - """Iterator over all logits processors.""" - return chain(self.argmax_invariant, self.non_argmax_invariant) - - -###### ----- Built-in LogitsProcessor impls below here +from vllm import SamplingParams +from vllm.v1.sample.logits_processor.interface import (BatchUpdate, + LogitsProcessor, + MoveDirectionality) + +if TYPE_CHECKING: + from vllm.config import VllmConfig class MinPLogitsProcessor(LogitsProcessor): - def __init__(self, max_num_reqs: int, pin_memory: bool, - device: DeviceLikeType): - super().__init__() + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + max_num_reqs = vllm_config.scheduler_config.max_num_seqs self.min_p_count: int = 0 self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), dtype=torch.float32, device="cpu", - pin_memory=pin_memory) + pin_memory=is_pin_memory) self.min_p_cpu = self.min_p_cpu_tensor.numpy() self.use_double_tensor = torch.device("cpu") != torch.device(device) @@ -260,7 +52,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]): needs_update = False # Process added requests. - for index, params, _ in batch_update.added: + for index, params, _, _ in batch_update.added: min_p = params.min_p if isinstance(params, SamplingParams) else 0.0 if self.min_p_cpu[index] != min_p: needs_update = True @@ -316,11 +108,10 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: class LogitBiasLogitsProcessor(LogitsProcessor): - def __init__(self, pin_memory: bool, device: torch.device): - super().__init__() - self.biases: dict[int, dict[int, float]] = {} + def __init__(self, _, device: torch.device, is_pin_memory: bool): self.device = device - self.pin_memory = pin_memory + self.pin_memory = is_pin_memory + self.biases: dict[int, dict[int, float]] = {} self.bias_tensor: torch.Tensor = torch.tensor(()) self.logits_slice = (self._device_tensor([], torch.int32), @@ -337,7 +128,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]): # Process added requests. needs_update = bool(batch_update.added) - for index, params, _ in batch_update.added: + for index, params, _, _ in batch_update.added: if isinstance(params, SamplingParams) and (lb := params.logit_bias): self.biases[index] = lb @@ -395,12 +186,12 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: class MinTokensLogitsProcessor(LogitsProcessor): - def __init__(self, pin_memory: bool, device: torch.device): + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): # index -> (min_toks, output_token_ids, stop_token_ids) - super().__init__() - self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {} self.device = device - self.pin_memory = pin_memory + self.pin_memory = is_pin_memory + self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {} # (req_idx_tensor,eos_tok_id_tensor) self.logits_slice: tuple[torch.Tensor, @@ -420,7 +211,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]): if batch_update: # Process added requests. needs_update |= bool(batch_update.added) - for index, params, output_tok_ids in batch_update.added: + for index, params, output_tok_ids, _ in batch_update.added: if (isinstance(params, SamplingParams) and (min_tokens := params.min_tokens) and len(output_tok_ids) < min_tokens): @@ -491,35 +282,3 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: # Inhibit EOS token for requests which have not reached min length logits[self.logits_slice] = -float("inf") return logits - - -def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int, - device: torch.device) -> LogitsProcessorManager: - """Construct 'builtin' vLLM logitsprocs which the engine - loads by default. - - Args: - pin_memory_available: pinned memory is available for use - for use by logitsproc - max_num_reqs: ceiling on request count in persistent batch - device: inference device - - Returns: - Data structure encapsulating loaded logitsprocs - """ - min_tokens_logitproc = MinTokensLogitsProcessor( - pin_memory=pin_memory_available, device=device) - logit_bias_logitproc = LogitBiasLogitsProcessor( - pin_memory=pin_memory_available, device=device) - min_p_logitproc = MinPLogitsProcessor( - pin_memory=pin_memory_available, - device=device, - # +1 for temporary swap space - max_num_reqs=max_num_reqs + 1) - return LogitsProcessorManager( - non_argmax_invariant=[ - min_tokens_logitproc, - logit_bias_logitproc, - ], - argmax_invariant=[min_p_logitproc], - ) diff --git a/vllm/v1/sample/logits_processor/interface.py b/vllm/v1/sample/logits_processor/interface.py new file mode 100644 index 00000000000..528304b37b7 --- /dev/null +++ b/vllm/v1/sample/logits_processor/interface.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Optional, Union + +import torch + +from vllm import PoolingParams, SamplingParams + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +class MoveDirectionality(Enum): + # One-way i1->i2 req move within batch + UNIDIRECTIONAL = 0 + # Two-way i1<->i2 req swap within batch + SWAP = 1 + + +# (index, params, output_tok_ids, prompt_tok_ids) tuples for new +# requests added to the batch. +AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int], + list[int]] + +# (index 1, index 2, directionality) tuples representing +# one-way moves or two-way swaps of requests in batch +MovedRequest = tuple[int, int, MoveDirectionality] + +# Batch indices of any removed requests. +RemovedRequest = int + + +@dataclass(frozen=True) +class BatchUpdate: + """Persistent batch state change info for logitsprocs""" + batch_size: int # Current num reqs in batch + + # Metadata for requests added to, removed from, and moved + # within the persistent batch. + # + # Note: each added request is represented as + # (index, params, output_tok_ids) + # Key assumption: output_tok_ids is a reference to the + # request's running output tokens list; in this way + # the logits processors always see the latest list of + # generated tokens + removed: Sequence[RemovedRequest] + moved: Sequence[MovedRequest] + added: Sequence[AddedRequest] + + +class LogitsProcessor(ABC): + + @abstractmethod + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool) -> None: + raise NotImplementedError + + @abstractmethod + def apply(self, logits: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def is_argmax_invariant(self) -> bool: + """True if logits processor has no impact on the + argmax computation in greedy sampling. + NOTE: may or may not have the same value for all + instances of a given LogitsProcessor subclass, + depending on subclass implementation. + """ + raise NotImplementedError + + @abstractmethod + def update_state( + self, + batch_update: Optional["BatchUpdate"], + ) -> None: + """Called when there are new output tokens, prior + to each forward pass. + + Args: + batch_update is non-None iff there have been + changes to the batch makeup. + """ + raise NotImplementedError diff --git a/vllm/v1/sample/logits_processor/state.py b/vllm/v1/sample/logits_processor/state.py new file mode 100644 index 00000000000..abb4d296b6e --- /dev/null +++ b/vllm/v1/sample/logits_processor/state.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterator +from dataclasses import field +from itertools import chain +from typing import TYPE_CHECKING, Optional + +from vllm.v1.sample.logits_processor.interface import (AddedRequest, + BatchUpdate, + MovedRequest, + RemovedRequest) + +if TYPE_CHECKING: + from vllm.v1.sample.logits_processor.interface import LogitsProcessor + + +class BatchUpdateBuilder: + """Helps track persistent batch state changes and build + a batch update data structure for logitsprocs + Assumptions: + * All information about requests removed from persistent batch + during a step is aggregated in self._removed through calls to + self.removed_append() at the beginning of a step. This must happen + before the first time that self.removed, self.pop_removed() + or self.peek_removed() are invoked in a given step + * After the first time that self.removed, self.pop_removed() + or self.peek_removed() are read in a step, no new removals + are registered using self.removed_append() + * Elements of self._removed are never directly modified, added or + removed (i.e. modification is only via self.removed_append() and + self.pop_removed()) + Guarantees under above assumptions: + * self.removed is always sorted in descending order + * self.pop_removed() and self.peek_removed() both return + the lowest removed request index in the current step + """ + + _removed: list[RemovedRequest] + _is_removed_sorted: bool + moved: list[MovedRequest] + added: list[AddedRequest] + + def __init__( + self, + removed: Optional[list[RemovedRequest]] = None, + moved: Optional[list[MovedRequest]] = None, + added: Optional[list[AddedRequest]] = None, + ) -> None: + self._removed = removed or [] + self.moved = moved or [] + self.added = added or [] + self._is_removed_sorted = False + + def _ensure_removed_sorted(self) -> None: + """Sort removed request indices in + descending order. + Idempotent after first call in a + given step, until reset. + """ + if not self._is_removed_sorted: + self._removed.sort(reverse=True) + self._is_removed_sorted = True + + @property + def removed(self) -> list[RemovedRequest]: + """Removed request indices sorted in + descending order""" + self._ensure_removed_sorted() + return self._removed + + def removed_append(self, index: int) -> None: + """Register the removal of a request from + the persistent batch. + + Must not be called after the first time + self.removed, self.pop_removed() or + self.peek_removed() are invoked. + Args: + index: request index + """ + if self._is_removed_sorted: + raise RuntimeError("Cannot register new removed request after" + " self.removed has been read.") + self._removed.append(index) + + def has_removed(self) -> bool: + return bool(self._removed) + + def peek_removed(self) -> Optional[int]: + """Return lowest removed request index""" + if self.has_removed(): + self._ensure_removed_sorted() + return self._removed[-1] + return None + + def pop_removed(self) -> Optional[int]: + """Pop lowest removed request index""" + if self.has_removed(): + self._ensure_removed_sorted() + return self._removed.pop() + return None + + def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: + """Generate a logitsprocs batch update data structure + and reset internal batch update builder state. + Args: + batch_size: current persistent batch size + + Returns: + Frozen logitsprocs batch update instance; `None` if no updates + """ + # Reset removal-sorting logic + self._is_removed_sorted = False + if not any((self._removed, self.moved, self.added)): + # No update; short-circuit + return None + # Build batch state update + batch_update = BatchUpdate( + batch_size=batch_size, + removed=self._removed, + moved=self.moved, + added=self.added, + ) + # Reset removed/moved/added update lists + self._removed = [] + self.moved = [] + self.added = [] + return batch_update + + +class LogitsProcessors: + """Encapsulates initialized logitsproc objects.""" + argmax_invariant: list["LogitsProcessor"] = field( + default_factory=list, init=False) # argmax-invariant logitsprocs + non_argmax_invariant: list["LogitsProcessor"] = field( + default_factory=list, init=False) # non-argmax-invariant logitsprocs + + def __init__( + self, + logitsprocs: Optional[Iterator["LogitsProcessor"]] = None) -> None: + self.argmax_invariant = [] + self.non_argmax_invariant = [] + if logitsprocs: + for logitproc in logitsprocs: + (self.argmax_invariant if logitproc.is_argmax_invariant() else + self.non_argmax_invariant).append(logitproc) + + @property + def all(self) -> Iterator["LogitsProcessor"]: + """Iterator over all logits processors.""" + return chain(self.argmax_invariant, self.non_argmax_invariant) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 1189b12f307..9d6a87cea3d 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -6,7 +6,7 @@ import torch -from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.logits_processor import LogitsProcessors @dataclass @@ -40,4 +40,4 @@ class SamplingMetadata: bad_words_token_ids: dict[int, list[list[int]]] # Loaded logits processors - logitsprocs: LogitsProcessorManager + logitsprocs: LogitsProcessors diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1a79d72be0a..3f9cbe89243 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -16,8 +16,8 @@ from vllm.v1.outputs import LogprobsTensors from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, - MoveDirectionality, - init_builtin_logitsprocs) + LogitsProcessors) +from vllm.v1.sample.logits_processor.interface import MoveDirectionality from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.utils import copy_slice @@ -69,6 +69,7 @@ def __init__( pin_memory: bool, vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group + logitsprocs: Optional[LogitsProcessors] = None, is_spec_decode: bool = False, logits_processing_needs_token_ids: bool = False, ): @@ -215,14 +216,6 @@ def __init__( # updates. Should reset each step. self.batch_update_builder = BatchUpdateBuilder() - # Define logits processors. - # TODO(andy): logits processor list should be extensible via engine - # constructor argument; for now the list is fixed. - self.logitsprocs = init_builtin_logitsprocs( - pin_memory_available=pin_memory, - max_num_reqs=max_num_reqs + 1, - device=device) - # TODO convert this to LogitsProcessor self.has_allowed_token_ids: set[str] = set() # NOTE(lufang): In the mask tensor, if the corresponding token allowed, @@ -235,6 +228,10 @@ def __init__( self.req_output_token_ids: list[Optional[list[int]]] = [] + # Build logits processors. If specified by user, load custom + # logitsprocs constructors. + self.logitsprocs = logitsprocs or LogitsProcessors(None) + # This is updated each time the batch constituents change. self.sampling_metadata = self._make_sampling_metadata() @@ -260,7 +257,8 @@ def _register_add_request(self, request: "CachedRequestState") -> int: params = (request.sampling_params if request.sampling_params else request.pooling_params) self.batch_update_builder.added.append( - (req_index, params, request.output_token_ids)) + (req_index, params, request.output_token_ids, + request.prompt_token_ids)) return req_index def add_request( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 29f519393e4..2a1500009f1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -53,6 +53,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler @@ -63,7 +64,6 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from ..sample.logits_processor import LogitsProcessorManager from .utils import (bind_kv_cache, gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) @@ -209,6 +209,8 @@ def __init__( vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=build_logitsprocs(self.vllm_config, self.device, + self.pin_memory), ) self.use_cuda_graph = ( @@ -2067,7 +2069,7 @@ def _dummy_sampler_run( output_token_ids=[[] for _ in range(num_reqs)], allowed_token_ids_mask=None, bad_words_token_ids={}, - logitsprocs=LogitsProcessorManager(), + logitsprocs=LogitsProcessors(), ) try: sampler_output = self.sampler(logits=logits, @@ -2363,6 +2365,7 @@ def may_reinitialize_input_batch(self, vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=self.input_batch.logitsprocs, ) def _allocate_kv_cache_tensors(