Skip to content

Commit c1c7dbb

Browse files
authored
[Bugfix][Core] Prevent token lengths exceeding max_model_len in V0 (#19348)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
1 parent 5cf2dae commit c1c7dbb

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

tests/entrypoints/llm/test_generate.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@
2525
]
2626

2727

28+
@pytest.fixture(autouse=True)
29+
def v1(run_with_both_engines):
30+
"""We can run both engines for this test."""
31+
pass
32+
33+
2834
@pytest.fixture(scope="module")
2935
def llm():
3036
# pytest caches the fixture so we use weakref.proxy to
@@ -104,3 +110,19 @@ def test_multiple_sampling_params(llm: LLM):
104110
# sampling_params is None, default params should be applied
105111
outputs = llm.generate(PROMPTS, sampling_params=None)
106112
assert len(PROMPTS) == len(outputs)
113+
114+
115+
def test_max_model_len():
116+
max_model_len = 20
117+
llm = LLM(
118+
model=MODEL_NAME,
119+
max_model_len=max_model_len,
120+
gpu_memory_utilization=0.10,
121+
enforce_eager=True, # reduce test time
122+
)
123+
sampling_params = SamplingParams(max_tokens=max_model_len + 10)
124+
outputs = llm.generate(PROMPTS, sampling_params)
125+
for output in outputs:
126+
num_total_tokens = len(output.prompt_token_ids) + len(
127+
output.outputs[0].token_ids)
128+
assert num_total_tokens == max_model_len

vllm/engine/output_processor/stop_checker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def maybe_stop_sequence(
8282
return
8383

8484
# Check if the sequence has reached max_model_len.
85-
if seq.get_len() > self._get_max_model_len(lora_req):
85+
if seq.get_len() >= self._get_max_model_len(lora_req):
8686
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
8787
return
8888

0 commit comments

Comments
 (0)