Skip to content

Commit 879d615

Browse files
committed
test: add openai request regression tests
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
1 parent 52594bc commit 879d615

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# imports for guided decoding tests
4+
from itertools import chain
5+
6+
import openai # use the official client for correctness check
7+
import pytest
8+
import pytest_asyncio
9+
# downloading lora to test lora requests
10+
from openai.types import Completion
11+
12+
from ..utils import RemoteOpenAIServer
13+
14+
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
15+
16+
17+
@pytest.fixture(scope="module")
18+
def default_server_args():
19+
return [
20+
# use half precision for speed and memory savings in CI environment
21+
"--dtype",
22+
"bfloat16",
23+
"--max-model-len",
24+
"8192",
25+
"--max-num-seqs",
26+
"128",
27+
"--enforce-eager",
28+
]
29+
30+
31+
@pytest.fixture(scope="module")
32+
def server(default_server_args):
33+
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
34+
yield remote_server
35+
36+
37+
@pytest_asyncio.fixture()
38+
async def client(server):
39+
async with server.get_async_client() as async_client:
40+
yield async_client
41+
42+
43+
@pytest.mark.asyncio
44+
async def test_multiseq_logprobs_streaming(client: openai.AsyncOpenAI):
45+
"""Edge case request combining multiple functionalities
46+
47+
https://github.com/vllm-project/vllm/pull/15259
48+
https://github.com/vllm-project/vllm/pull/16805
49+
"""
50+
51+
# completions
52+
stream = await client.completions.create(
53+
model=MODEL_NAME,
54+
prompt="1 2 3 4 5",
55+
max_tokens=3,
56+
# include usage chunk to make sure the stream is complete
57+
stream_options={"include_usage": True},
58+
stream=True,
59+
n=2,
60+
logprobs=0, # include 1-top logprob per generated token
61+
temperature=1.0)
62+
63+
n0_chunks: list[Completion] = []
64+
n1_chunks: list[Completion] = []
65+
usage_chunk: Completion = None
66+
async for chunk in stream:
67+
print(chunk)
68+
if choices := chunk.choices:
69+
assert len(choices) == 1, \
70+
(f"Streamed chunk had {len(choices)} choices, when only 1 was"
71+
" expected")
72+
choice = choices[0]
73+
if choice.index == 0:
74+
n0_chunks.append(chunk)
75+
elif choice.index == 1:
76+
n1_chunks.append(chunk)
77+
else:
78+
raise AssertionError(f"Unexpected choice index {choice.index}")
79+
80+
elif chunk.usage is not None:
81+
usage_chunk = chunk
82+
83+
else:
84+
raise AssertionError(f"Unexpected chunk {chunk}")
85+
86+
# check that we got the requested number of tokens
87+
assert sum(
88+
len(chunk.choices[0].logprobs.tokens) for chunk in n0_chunks
89+
if chunk.choices[0].logprobs
90+
) == 3, "Streamed response did not have the expected number of tokens."
91+
assert sum(
92+
len(chunk.choices[0].logprobs.tokens) for chunk in n1_chunks
93+
if chunk.choices[0].logprobs
94+
) == 3, "Streamed response did not have the expected number of tokens."
95+
96+
# check 1 logprob per token/chunk
97+
for chunk in chain(n0_chunks, n1_chunks):
98+
# a finish chunk may not have any text/logprobs
99+
# V0 does not
100+
# V1 does
101+
choice = chunk.choices[0]
102+
if choice.logprobs is None:
103+
assert choice.finish_reason
104+
assert choice.text == ''
105+
continue
106+
107+
assert choice.logprobs.top_logprobs
108+
for top_logprobs in choice.logprobs.top_logprobs:
109+
assert len(top_logprobs) == 1
110+
111+
# requested usage
112+
assert usage_chunk is not None
113+
assert usage_chunk.usage.completion_tokens == 6
114+
assert usage_chunk.usage.prompt_tokens == 9
115+
assert usage_chunk.usage.total_tokens == 15

0 commit comments

Comments
 (0)