|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +from __future__ import annotations |
| 3 | + |
| 4 | +import random |
| 5 | +from typing import Any |
| 6 | + |
| 7 | +import pytest |
| 8 | + |
| 9 | +from vllm import LLM, SamplingParams |
| 10 | + |
| 11 | + |
| 12 | +@pytest.fixture |
| 13 | +def test_prompts(): |
| 14 | + prompt_types = ["repeat", "sentence"] |
| 15 | + num_prompts = 100 |
| 16 | + prompts = [] |
| 17 | + |
| 18 | + random.seed(0) |
| 19 | + random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) |
| 20 | + |
| 21 | + # Generate a mixed batch of prompts, some of which can be easily |
| 22 | + # predicted by n-gram matching and some which likely cannot. |
| 23 | + for kind in random_prompt_type_choices: |
| 24 | + word_choices = ["test", "temp", "hello", "where"] |
| 25 | + word = random.choice(word_choices) |
| 26 | + if kind == "repeat": |
| 27 | + prompt = f""" |
| 28 | + please repeat the word '{word}' 10 times. |
| 29 | + give no other output than the word at least ten times in a row, |
| 30 | + in lowercase with spaces between each word and without quotes. |
| 31 | + """ |
| 32 | + elif kind == "sentence": |
| 33 | + prompt = f""" |
| 34 | + please give a ten-word sentence that |
| 35 | + uses the word {word} at least once. |
| 36 | + give no other output than that simple sentence without quotes. |
| 37 | + """ |
| 38 | + else: |
| 39 | + raise ValueError(f"Unknown prompt type: {kind}") |
| 40 | + prompts.append([{"role": "user", "content": prompt}]) |
| 41 | + |
| 42 | + return prompts |
| 43 | + |
| 44 | + |
| 45 | +@pytest.fixture |
| 46 | +def sampling_config(): |
| 47 | + return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) |
| 48 | + |
| 49 | + |
| 50 | +@pytest.fixture |
| 51 | +def model_name(): |
| 52 | + return "meta-llama/Llama-3.1-8B-Instruct" |
| 53 | + |
| 54 | + |
| 55 | +def eagle_model_name(): |
| 56 | + return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" |
| 57 | + |
| 58 | + |
| 59 | +def eagle3_model_name(): |
| 60 | + return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" |
| 61 | + |
| 62 | + |
| 63 | +def test_ngram_correctness( |
| 64 | + monkeypatch: pytest.MonkeyPatch, |
| 65 | + test_prompts: list[list[dict[str, Any]]], |
| 66 | + sampling_config: SamplingParams, |
| 67 | + model_name: str, |
| 68 | +): |
| 69 | + ''' |
| 70 | + Compare the outputs of a original LLM and a speculative LLM |
| 71 | + should be the same when using ngram speculative decoding. |
| 72 | + ''' |
| 73 | + with monkeypatch.context() as m: |
| 74 | + m.setenv("VLLM_USE_V1", "1") |
| 75 | + |
| 76 | + ref_llm = LLM(model=model_name, max_model_len=1024) |
| 77 | + ref_outputs = ref_llm.chat(test_prompts, sampling_config) |
| 78 | + del ref_llm |
| 79 | + |
| 80 | + spec_llm = LLM( |
| 81 | + model=model_name, |
| 82 | + speculative_config={ |
| 83 | + "method": "ngram", |
| 84 | + "prompt_lookup_max": 5, |
| 85 | + "prompt_lookup_min": 3, |
| 86 | + "num_speculative_tokens": 3, |
| 87 | + }, |
| 88 | + max_model_len=1024, |
| 89 | + ) |
| 90 | + spec_outputs = spec_llm.chat(test_prompts, sampling_config) |
| 91 | + matches = 0 |
| 92 | + misses = 0 |
| 93 | + for ref_output, spec_output in zip(ref_outputs, spec_outputs): |
| 94 | + if ref_output.outputs[0].text == spec_output.outputs[0].text: |
| 95 | + matches += 1 |
| 96 | + else: |
| 97 | + misses += 1 |
| 98 | + print(f"ref_output: {ref_output.outputs[0].text}") |
| 99 | + print(f"spec_output: {spec_output.outputs[0].text}") |
| 100 | + |
| 101 | + # Heuristic: expect at least 70% of the prompts to match exactly |
| 102 | + # Upon failure, inspect the outputs to check for inaccuracy. |
| 103 | + assert matches > int(0.7 * len(ref_outputs)) |
| 104 | + del spec_llm |
| 105 | + |
| 106 | + |
| 107 | +@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"]) |
| 108 | +def test_eagle_correctness( |
| 109 | + monkeypatch: pytest.MonkeyPatch, |
| 110 | + test_prompts: list[list[dict[str, Any]]], |
| 111 | + sampling_config: SamplingParams, |
| 112 | + model_name: str, |
| 113 | + use_eagle3: bool, |
| 114 | +): |
| 115 | + ''' |
| 116 | + Compare the outputs of a original LLM and a speculative LLM |
| 117 | + should be the same when using eagle speculative decoding. |
| 118 | + ''' |
| 119 | + with monkeypatch.context() as m: |
| 120 | + m.setenv("VLLM_USE_V1", "1") |
| 121 | + |
| 122 | + ref_llm = LLM(model=model_name, max_model_len=2048) |
| 123 | + ref_outputs = ref_llm.chat(test_prompts, sampling_config) |
| 124 | + del ref_llm |
| 125 | + |
| 126 | + spec_model_name = eagle3_model_name( |
| 127 | + ) if use_eagle3 else eagle_model_name() |
| 128 | + spec_llm = LLM( |
| 129 | + model=model_name, |
| 130 | + trust_remote_code=True, |
| 131 | + speculative_config={ |
| 132 | + "method": "eagle3" if use_eagle3 else "eagle", |
| 133 | + "model": spec_model_name, |
| 134 | + "num_speculative_tokens": 3, |
| 135 | + "max_model_len": 2048, |
| 136 | + }, |
| 137 | + max_model_len=2048, |
| 138 | + ) |
| 139 | + spec_outputs = spec_llm.chat(test_prompts, sampling_config) |
| 140 | + matches = 0 |
| 141 | + misses = 0 |
| 142 | + for ref_output, spec_output in zip(ref_outputs, spec_outputs): |
| 143 | + if ref_output.outputs[0].text == spec_output.outputs[0].text: |
| 144 | + matches += 1 |
| 145 | + else: |
| 146 | + misses += 1 |
| 147 | + print(f"ref_output: {ref_output.outputs[0].text}") |
| 148 | + print(f"spec_output: {spec_output.outputs[0].text}") |
| 149 | + |
| 150 | + # Heuristic: expect at least 66% of the prompts to match exactly |
| 151 | + # Upon failure, inspect the outputs to check for inaccuracy. |
| 152 | + assert matches > int(0.66 * len(ref_outputs)) |
| 153 | + del spec_llm |
0 commit comments