|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +from collections import deque |
| 4 | + |
| 5 | +import pytest |
| 6 | + |
| 7 | +from vllm.v1.core.sched.output import SchedulerOutput |
| 8 | +from vllm.v1.outputs import ModelRunnerOutput |
| 9 | +from vllm.v1.request import RequestStatus |
| 10 | + |
| 11 | +from .utils import create_requests, create_scheduler |
| 12 | + |
| 13 | + |
| 14 | +def _make_model_runner_output( |
| 15 | + scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput: |
| 16 | + req_ids = list(scheduler_output.num_scheduled_tokens.keys()) |
| 17 | + return ModelRunnerOutput( |
| 18 | + req_ids=req_ids, |
| 19 | + req_id_to_index={ |
| 20 | + req_id: i |
| 21 | + for i, req_id in enumerate(req_ids) |
| 22 | + }, |
| 23 | + sampled_token_ids=[[i] for i in range(len(req_ids))], |
| 24 | + spec_token_ids=None, |
| 25 | + logprobs=None, |
| 26 | + prompt_logprobs_dict={}, |
| 27 | + pooler_output=[], |
| 28 | + ) |
| 29 | + |
| 30 | + |
| 31 | +@pytest.mark.parametrize("max_tokens", [1, 2, 3, 5]) |
| 32 | +def test_stop_by_max_tokens(max_tokens: int): |
| 33 | + scheduler = create_scheduler(async_scheduling=True) |
| 34 | + requests = create_requests(num_requests=2, max_tokens=max_tokens) |
| 35 | + req0, req1 = requests |
| 36 | + |
| 37 | + sched_outputs: deque[SchedulerOutput] = deque() |
| 38 | + scheduler.add_request(req0) |
| 39 | + sched_outputs.append(scheduler.schedule()) |
| 40 | + |
| 41 | + scheduler.add_request(req1) |
| 42 | + sched_outputs.append(scheduler.schedule()) |
| 43 | + |
| 44 | + while sched_outputs: |
| 45 | + sched_output = sched_outputs.popleft() |
| 46 | + model_runner_output = _make_model_runner_output(sched_output) |
| 47 | + scheduler.update_from_output(sched_output, model_runner_output) |
| 48 | + |
| 49 | + sched_output = scheduler.schedule() |
| 50 | + if sched_output.num_scheduled_tokens: |
| 51 | + sched_outputs.append(sched_output) |
| 52 | + |
| 53 | + assert scheduler.get_num_unfinished_requests() == 0 |
| 54 | + assert req0.num_output_tokens == max_tokens |
| 55 | + assert req1.num_output_tokens == max_tokens |
| 56 | + |
| 57 | + |
| 58 | +def test_abort(): |
| 59 | + scheduler = create_scheduler(async_scheduling=True) |
| 60 | + requests = create_requests(num_requests=10, max_tokens=20) |
| 61 | + |
| 62 | + for req in requests: |
| 63 | + scheduler.add_request(req) |
| 64 | + |
| 65 | + sched_outputs: deque[SchedulerOutput] = deque() |
| 66 | + sched_outputs.append(scheduler.schedule()) |
| 67 | + sched_outputs.append(scheduler.schedule()) |
| 68 | + |
| 69 | + abort_order = [0, 8, 3, 1, 6, 4, 2, 5, 7, 9] |
| 70 | + abort_order_copy = abort_order.copy() |
| 71 | + |
| 72 | + def abort_request(): |
| 73 | + if not abort_order: |
| 74 | + return |
| 75 | + req = requests[abort_order.pop(0)] |
| 76 | + scheduler.finish_requests(req.request_id, |
| 77 | + RequestStatus.FINISHED_ABORTED) |
| 78 | + |
| 79 | + while sched_outputs: |
| 80 | + # Abort a scheduled request. |
| 81 | + abort_request() |
| 82 | + sched_output = sched_outputs.popleft() |
| 83 | + model_runner_output = _make_model_runner_output(sched_output) |
| 84 | + scheduler.update_from_output(sched_output, model_runner_output) |
| 85 | + |
| 86 | + sched_output = scheduler.schedule() |
| 87 | + if sched_output.num_scheduled_tokens: |
| 88 | + sched_outputs.append(sched_output) |
| 89 | + |
| 90 | + for i, req in enumerate(requests): |
| 91 | + assert req.status == RequestStatus.FINISHED_ABORTED |
| 92 | + assert req.num_output_tokens == abort_order_copy.index(i) |
| 93 | + |
| 94 | + |
| 95 | +def test_preempt(): |
| 96 | + scheduler = create_scheduler(async_scheduling=True) |
| 97 | + requests = create_requests(num_requests=10, max_tokens=20) |
| 98 | + |
| 99 | + for req in requests: |
| 100 | + scheduler.add_request(req) |
| 101 | + |
| 102 | + sched_outputs: deque[SchedulerOutput] = deque() |
| 103 | + sched_outputs.append(scheduler.schedule()) |
| 104 | + sched_outputs.append(scheduler.schedule()) |
| 105 | + |
| 106 | + abort_order = [0, 8, 3, 1, 6, 4, 2, 5, 7, 9] |
| 107 | + abort_order_copy = abort_order.copy() |
| 108 | + |
| 109 | + def abort_request(): |
| 110 | + if not abort_order: |
| 111 | + return |
| 112 | + req = requests[abort_order.pop(0)] |
| 113 | + scheduler.finish_requests(req.request_id, |
| 114 | + RequestStatus.FINISHED_ABORTED) |
| 115 | + |
| 116 | + while sched_outputs: |
| 117 | + # Abort a scheduled request. |
| 118 | + abort_request() |
| 119 | + sched_output = sched_outputs.popleft() |
| 120 | + model_runner_output = _make_model_runner_output(sched_output) |
| 121 | + scheduler.update_from_output(sched_output, model_runner_output) |
| 122 | + |
| 123 | + sched_output = scheduler.schedule() |
| 124 | + if sched_output.num_scheduled_tokens: |
| 125 | + sched_outputs.append(sched_output) |
| 126 | + |
| 127 | + for i, req in enumerate(requests): |
| 128 | + assert req.status == RequestStatus.FINISHED_ABORTED |
| 129 | + assert req.num_output_tokens == abort_order_copy.index(i) |
| 130 | + |
| 131 | + |
| 132 | +def test_prefix_caching_for_prefill_dedup(): |
| 133 | + CHUNK_SIZE = 1000 |
| 134 | + BLOCK_SIZE = 16 |
| 135 | + num_prompt_tokens = 100 |
| 136 | + scheduler = create_scheduler(async_scheduling=True, |
| 137 | + max_num_batched_tokens=CHUNK_SIZE, |
| 138 | + enable_prefix_caching=True, |
| 139 | + block_size=BLOCK_SIZE) |
| 140 | + requests = create_requests(num_requests=5, |
| 141 | + num_tokens=num_prompt_tokens, |
| 142 | + max_tokens=3, |
| 143 | + same_prompt=True) |
| 144 | + requests_copy = requests.copy() |
| 145 | + |
| 146 | + # Two requests with the same prompt. |
| 147 | + req0 = requests.pop(0) |
| 148 | + req1 = requests.pop(0) |
| 149 | + scheduler.add_request(req0) |
| 150 | + scheduler.add_request(req1) |
| 151 | + |
| 152 | + sched_outputs: deque[SchedulerOutput] = deque() |
| 153 | + sched_output = scheduler.schedule() |
| 154 | + sched_outputs.append(sched_output) |
| 155 | + # Make sure prefix caching de-duplicates the prompts in the same step, |
| 156 | + # so all the blocks except the last are shared between the two requests. |
| 157 | + assert len(sched_output.num_scheduled_tokens) == 2 |
| 158 | + num_blocks = num_prompt_tokens // BLOCK_SIZE |
| 159 | + assert req0.num_cached_tokens == 0 |
| 160 | + assert req1.num_cached_tokens >= num_blocks * BLOCK_SIZE |
| 161 | + |
| 162 | + sched_outputs.append(scheduler.schedule()) |
| 163 | + while sched_outputs: |
| 164 | + if requests: |
| 165 | + scheduler.add_request(requests.pop(0)) |
| 166 | + sched_output = sched_outputs.popleft() |
| 167 | + model_runner_output = _make_model_runner_output(sched_output) |
| 168 | + scheduler.update_from_output(sched_output, model_runner_output) |
| 169 | + sched_output = scheduler.schedule() |
| 170 | + if sched_output.num_scheduled_tokens: |
| 171 | + sched_outputs.append(sched_output) |
| 172 | + |
| 173 | + # Other requests scheduled after the two requests should also get |
| 174 | + # prefix cache hit. |
| 175 | + assert scheduler.get_num_unfinished_requests() == 0 |
| 176 | + for req in requests_copy[1:]: |
| 177 | + assert req.num_cached_tokens >= num_blocks * BLOCK_SIZE |
| 178 | + |
| 179 | + |
| 180 | +def test_prefix_caching_for_multi_turn(): |
| 181 | + CHUNK_SIZE = 1000 |
| 182 | + BLOCK_SIZE = 16 |
| 183 | + num_prompt_tokens = 100 |
| 184 | + num_output_tokens = 200 |
| 185 | + scheduler = create_scheduler(async_scheduling=True, |
| 186 | + max_num_batched_tokens=CHUNK_SIZE, |
| 187 | + enable_prefix_caching=True, |
| 188 | + block_size=BLOCK_SIZE) |
| 189 | + requests = create_requests(num_requests=5, |
| 190 | + num_tokens=num_prompt_tokens, |
| 191 | + max_tokens=num_output_tokens) |
| 192 | + |
| 193 | + for req in requests: |
| 194 | + scheduler.add_request(req) |
| 195 | + sched_outputs: deque[SchedulerOutput] = deque() |
| 196 | + sched_outputs.append(scheduler.schedule()) |
| 197 | + sched_outputs.append(scheduler.schedule()) |
| 198 | + |
| 199 | + # Process the requests. |
| 200 | + while sched_outputs: |
| 201 | + sched_output = sched_outputs.popleft() |
| 202 | + model_runner_output = _make_model_runner_output(sched_output) |
| 203 | + scheduler.update_from_output(sched_output, model_runner_output) |
| 204 | + sched_output = scheduler.schedule() |
| 205 | + if sched_output.num_scheduled_tokens: |
| 206 | + sched_outputs.append(sched_output) |
| 207 | + assert scheduler.get_num_unfinished_requests() == 0 |
| 208 | + |
| 209 | + # Create next-turn requests whose prompts are the full output of the |
| 210 | + # previous turn. |
| 211 | + next_turn_requests = create_requests( |
| 212 | + num_requests=5, |
| 213 | + num_tokens=num_prompt_tokens + num_output_tokens, |
| 214 | + max_tokens=num_output_tokens, |
| 215 | + ) |
| 216 | + for i, req in enumerate(next_turn_requests): |
| 217 | + req.prompt_token_ids = (requests[i].prompt_token_ids + |
| 218 | + list(requests[i].output_token_ids)) |
| 219 | + # Schedule the next-turn requests. |
| 220 | + for req in next_turn_requests: |
| 221 | + scheduler.add_request(req) |
| 222 | + sched_outputs.append(scheduler.schedule()) |
| 223 | + |
| 224 | + # Make sure the next-turn requests get prefix cache hit by the previous |
| 225 | + # requests. |
| 226 | + for req in next_turn_requests: |
| 227 | + assert (req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE * |
| 228 | + BLOCK_SIZE) |
0 commit comments