|
| 1 | +# |
| 2 | +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. |
| 3 | +# This file is a part of the vllm-ascend project. |
| 4 | +# Adapted from vllm-project/vllm/tests/spec_decode/e2e/conftest.py |
| 5 | +# Copyright 2023 The vLLM team. |
| 6 | +# |
| 7 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 8 | +# you may not use this file except in compliance with the License. |
| 9 | +# You may obtain a copy of the License at |
| 10 | +# |
| 11 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 12 | +# |
| 13 | +# Unless required by applicable law or agreed to in writing, software |
| 14 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 15 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 16 | +# See the License for the specific language governing permissions and |
| 17 | +# limitations under the License. |
| 18 | +# |
| 19 | + |
| 20 | +from itertools import cycle |
| 21 | +from typing import List, Optional, Sequence, Tuple, Union |
| 22 | + |
| 23 | +import pytest |
| 24 | +import torch |
| 25 | +from vllm import LLM, SamplingParams |
| 26 | +from vllm.distributed import cleanup_dist_env_and_memory |
| 27 | +from vllm.model_executor.utils import set_random_seed |
| 28 | +from vllm.sequence import PromptLogprobs, SampleLogprobs |
| 29 | + |
| 30 | +from ...model_utils import (TokensTextLogprobs, |
| 31 | + TokensTextLogprobsPromptLogprobs, |
| 32 | + check_logprobs_close, check_outputs_equal) |
| 33 | +from ...utils import RemoteOpenAIServer |
| 34 | + |
| 35 | +PROMPTS = [ |
| 36 | + "Hello, my name is", |
| 37 | + "The president of the United States is", |
| 38 | + "The capital of France is", |
| 39 | + "The future of AI is", |
| 40 | + "San Francisco is know for its", |
| 41 | + "Facebook was created in 2004 by", |
| 42 | + "Curious George is a", |
| 43 | + "Python 3.11 brings improvements to its", |
| 44 | +] |
| 45 | + |
| 46 | + |
| 47 | +@pytest.fixture |
| 48 | +def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, |
| 49 | + test_llm_kwargs, seed): |
| 50 | + |
| 51 | + def generate(): |
| 52 | + kwargs = { |
| 53 | + **common_llm_kwargs, |
| 54 | + **per_test_common_llm_kwargs, |
| 55 | + **test_llm_kwargs, |
| 56 | + } |
| 57 | + |
| 58 | + llm = LLM(**kwargs) |
| 59 | + |
| 60 | + if seed is not None: |
| 61 | + set_random_seed(seed) |
| 62 | + |
| 63 | + yield llm |
| 64 | + |
| 65 | + del llm |
| 66 | + cleanup_dist_env_and_memory() |
| 67 | + |
| 68 | + return generate |
| 69 | + |
| 70 | + |
| 71 | +def maybe_assert_ngram_worker(llm): |
| 72 | + # Verify the proposer worker is ngram if ngram is specified. |
| 73 | + if (llm.llm_engine.speculative_config is not None |
| 74 | + and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0): |
| 75 | + from vllm.spec_decode.ngram_worker import NGramWorker |
| 76 | + assert isinstance( |
| 77 | + llm.llm_engine.model_executor.driver_worker.proposer_worker, |
| 78 | + NGramWorker) |
| 79 | + |
| 80 | + |
| 81 | +def get_output_from_llm_generator( |
| 82 | + llm_generator, prompts, |
| 83 | + sampling_params) -> Tuple[List[str], List[List[int]], float]: |
| 84 | + tokens: List[str] = [] |
| 85 | + token_ids: List[List[int]] = [] |
| 86 | + acceptance_rate: float = -1.0 |
| 87 | + for llm in llm_generator(): |
| 88 | + maybe_assert_ngram_worker(llm) |
| 89 | + |
| 90 | + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) |
| 91 | + |
| 92 | + token_ids = [output.outputs[0].token_ids for output in outputs] |
| 93 | + tokens = [output.outputs[0].text for output in outputs] |
| 94 | + |
| 95 | + # Fetch acceptance rate if logging is enabled. |
| 96 | + if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None): |
| 97 | + stat_logger = stat_loggers["prometheus"] |
| 98 | + acceptance_rate = (stat_logger.metrics. |
| 99 | + gauge_spec_decode_draft_acceptance_rate.labels( |
| 100 | + **stat_logger.labels)._value.get()) |
| 101 | + del llm |
| 102 | + |
| 103 | + return tokens, token_ids, acceptance_rate |
| 104 | + |
| 105 | + |
| 106 | +def check_logprobs_correctness( |
| 107 | + spec_outputs: Sequence[Union[TokensTextLogprobs, |
| 108 | + TokensTextLogprobsPromptLogprobs]], |
| 109 | + baseline_outputs: Sequence[Union[TokensTextLogprobs, |
| 110 | + TokensTextLogprobsPromptLogprobs]], |
| 111 | + disable_logprobs: bool = False, |
| 112 | +): |
| 113 | + """Compare sampled and prompt logprobs between baseline and spec decoding |
| 114 | + """ |
| 115 | + if not disable_logprobs: |
| 116 | + return check_logprobs_close( |
| 117 | + outputs_0_lst=baseline_outputs, |
| 118 | + outputs_1_lst=spec_outputs, |
| 119 | + name_0="org", |
| 120 | + name_1="sd", |
| 121 | + ) |
| 122 | + |
| 123 | + # Check correctness when disable_logprobs == True |
| 124 | + for spec_output, baseline_output in zip(spec_outputs, baseline_outputs): |
| 125 | + # Check generated token logprobs. |
| 126 | + spec_logprobs = spec_output[2] |
| 127 | + baseline_logprobs = baseline_output[2] |
| 128 | + _check_logprobs_when_output_disabled(spec_logprobs, |
| 129 | + baseline_logprobs, |
| 130 | + is_prompt_logprobs=False) |
| 131 | + |
| 132 | + # Check prompt logprobs too, if they exist |
| 133 | + if len(baseline_output) == 4: |
| 134 | + assert len(spec_output) == 4 |
| 135 | + spec_prompt_logprobs = spec_output[3] |
| 136 | + baseline_prompt_logprobs = baseline_output[3] |
| 137 | + _check_logprobs_when_output_disabled(spec_prompt_logprobs, |
| 138 | + baseline_prompt_logprobs, |
| 139 | + is_prompt_logprobs=True) |
| 140 | + |
| 141 | + |
| 142 | +def _check_logprobs_when_output_disabled( |
| 143 | + spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs], |
| 144 | + baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs], |
| 145 | + is_prompt_logprobs: bool = False, |
| 146 | +): |
| 147 | + # Prompt logprobs are optional |
| 148 | + if is_prompt_logprobs and baseline_logprobs is None: |
| 149 | + assert spec_logprobs is None |
| 150 | + return |
| 151 | + |
| 152 | + assert spec_logprobs is not None |
| 153 | + assert baseline_logprobs is not None |
| 154 | + assert len(spec_logprobs) == len(baseline_logprobs) |
| 155 | + |
| 156 | + # For each generated position of the sequence. |
| 157 | + for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate( |
| 158 | + zip(spec_logprobs, baseline_logprobs)): |
| 159 | + |
| 160 | + # First prompt logprob is expected to be None |
| 161 | + if is_prompt_logprobs and baseline_pos_logprobs is None: |
| 162 | + assert spec_pos_logprobs is None |
| 163 | + assert pos == 0 |
| 164 | + continue |
| 165 | + |
| 166 | + assert spec_pos_logprobs is not None |
| 167 | + assert baseline_pos_logprobs is not None |
| 168 | + |
| 169 | + # When disabled, the 1 logprob is returned with dummy values for the |
| 170 | + # score and rank, but the token id should match the baseline model |
| 171 | + assert len(spec_pos_logprobs) == 1 |
| 172 | + (spec_pos_logprob_token_id, |
| 173 | + spec_pos_logprob) = next(iter(spec_pos_logprobs.items())) |
| 174 | + assert spec_pos_logprob.rank == -1 |
| 175 | + assert spec_pos_logprob.logprob == 0.0 |
| 176 | + if isinstance(spec_pos_logprob_token_id, torch.Tensor): |
| 177 | + spec_pos_logprob_token_id = spec_pos_logprob_token_id.item() |
| 178 | + assert spec_pos_logprob_token_id in baseline_pos_logprobs |
| 179 | + |
| 180 | + |
| 181 | +def run_equality_correctness_test( |
| 182 | + vllm_runner, |
| 183 | + common_llm_kwargs, |
| 184 | + per_test_common_llm_kwargs, |
| 185 | + baseline_llm_kwargs, |
| 186 | + test_llm_kwargs, |
| 187 | + batch_size: int, |
| 188 | + max_output_len: int, |
| 189 | + seed: Optional[int] = 0, |
| 190 | + temperature: float = 0.0, |
| 191 | + disable_seed: bool = False, |
| 192 | + ignore_eos: bool = True, |
| 193 | + ensure_all_accepted: bool = False, |
| 194 | + expected_acceptance_rate: Optional[float] = None, |
| 195 | + logprobs: Optional[int] = None, |
| 196 | + prompt_logprobs: Optional[int] = None, |
| 197 | + disable_logprobs: bool = False): |
| 198 | + |
| 199 | + org_args = { |
| 200 | + **common_llm_kwargs, |
| 201 | + **per_test_common_llm_kwargs, |
| 202 | + **baseline_llm_kwargs, |
| 203 | + } |
| 204 | + |
| 205 | + sd_args = { |
| 206 | + **common_llm_kwargs, |
| 207 | + **per_test_common_llm_kwargs, |
| 208 | + **test_llm_kwargs, |
| 209 | + } |
| 210 | + |
| 211 | + prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))] |
| 212 | + |
| 213 | + if disable_seed: |
| 214 | + seed = None |
| 215 | + |
| 216 | + sampling_params = SamplingParams(temperature=temperature, |
| 217 | + max_tokens=max_output_len, |
| 218 | + seed=seed, |
| 219 | + ignore_eos=ignore_eos, |
| 220 | + logprobs=logprobs, |
| 221 | + prompt_logprobs=prompt_logprobs) |
| 222 | + |
| 223 | + with vllm_runner(**org_args) as vllm_model: |
| 224 | + org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) |
| 225 | + |
| 226 | + with vllm_runner(**sd_args) as vllm_model: |
| 227 | + if ensure_all_accepted or expected_acceptance_rate is not None: |
| 228 | + # Force log interval to be 0 to catch all metrics. |
| 229 | + stat_logger = vllm_model.model.llm_engine.stat_loggers[ |
| 230 | + 'prometheus'] |
| 231 | + stat_logger.local_interval = -100 |
| 232 | + |
| 233 | + sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) |
| 234 | + |
| 235 | + if ensure_all_accepted or expected_acceptance_rate is not None: |
| 236 | + acceptance_rate = (stat_logger.metrics. |
| 237 | + gauge_spec_decode_draft_acceptance_rate.labels( |
| 238 | + **stat_logger.labels)._value.get()) |
| 239 | + |
| 240 | + if ensure_all_accepted: |
| 241 | + assert True |
| 242 | + # FIXME: ci fails to log acceptance rate. |
| 243 | + # It works locally. |
| 244 | + # assert acceptance_rate == 1.0 |
| 245 | + |
| 246 | + if expected_acceptance_rate is not None: |
| 247 | + assert acceptance_rate >= expected_acceptance_rate - 1e-2 |
| 248 | + |
| 249 | + # Only pass token entries, not the logprobs |
| 250 | + check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs], |
| 251 | + outputs_1_lst=[out[0:2] for out in sd_outputs], |
| 252 | + name_0="org", |
| 253 | + name_1="sd") |
| 254 | + |
| 255 | + # Check logprobs if requested |
| 256 | + if logprobs is not None or prompt_logprobs is not None: |
| 257 | + check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs) |
| 258 | + |
| 259 | + |
| 260 | +def run_equality_correctness_test_tp(model, |
| 261 | + common_llm_kwargs, |
| 262 | + per_test_common_llm_kwargs, |
| 263 | + baseline_llm_kwargs, |
| 264 | + test_llm_kwargs, |
| 265 | + batch_size: int, |
| 266 | + max_output_len: int, |
| 267 | + seed: int = 0, |
| 268 | + temperature: float = 0.0, |
| 269 | + logprobs: Optional[int] = None): |
| 270 | + """Helper method that compares the outputs of both the baseline LLM and |
| 271 | + the test LLM. It asserts greedy equality, e.g. that the outputs are exactly |
| 272 | + the same when temperature is zero. |
| 273 | + """ |
| 274 | + arg1 = common_llm_kwargs + per_test_common_llm_kwargs + baseline_llm_kwargs |
| 275 | + arg2 = common_llm_kwargs + per_test_common_llm_kwargs + test_llm_kwargs |
| 276 | + env1 = env2 = None |
| 277 | + |
| 278 | + max_wait_seconds = 240 |
| 279 | + results = [] |
| 280 | + |
| 281 | + prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))] |
| 282 | + for args, env in ((arg1, env1), (arg2, env2)): |
| 283 | + with RemoteOpenAIServer(model, |
| 284 | + args, |
| 285 | + env_dict=env, |
| 286 | + max_wait_seconds=max_wait_seconds) as server: |
| 287 | + client = server.get_client() |
| 288 | + |
| 289 | + completion = client.completions.create(model=model, |
| 290 | + prompt=prompts, |
| 291 | + max_tokens=max_output_len, |
| 292 | + seed=seed, |
| 293 | + temperature=temperature, |
| 294 | + logprobs=logprobs) |
| 295 | + |
| 296 | + results.append({ |
| 297 | + "test": |
| 298 | + "seeded_sampling", |
| 299 | + "text": [choice.text for choice in completion.choices], |
| 300 | + "logprobs": [choice.logprobs for choice in completion.choices], |
| 301 | + "finish_reason": |
| 302 | + [choice.finish_reason for choice in completion.choices], |
| 303 | + "usage": |
| 304 | + completion.usage, |
| 305 | + }) |
| 306 | + |
| 307 | + n = len(results) // 2 |
| 308 | + arg1_results = results[:n] |
| 309 | + arg2_results = results[n:] |
| 310 | + # Separate logprobs to avoid asserting exact equality. |
| 311 | + arg1_logprobs = [r.pop("logprobs") for r in arg1_results] |
| 312 | + arg2_logprobs = [r.pop("logprobs") for r in arg2_results] |
| 313 | + |
| 314 | + for arg1_result, arg2_result in zip(arg1_results, arg2_results): |
| 315 | + assert arg1_result == arg2_result, ( |
| 316 | + f"Results for {model=} are not the same with {arg1=} and {arg2=}. " |
| 317 | + f"{arg1_result=} != {arg2_result=}") |
| 318 | + if logprobs: |
| 319 | + for logs1, logs2 in zip(arg1_logprobs, arg2_logprobs): |
| 320 | + for l1, l2 in zip(logs1, logs2): |
| 321 | + assert l1.tokens == l2.tokens |
0 commit comments