Skip to content

Commit cbd14ed

Browse files
[Bugfix] Refactor /invocations to be task-agnostic (vllm-project#20764)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 7bd4c37 commit cbd14ed

File tree

9 files changed

+352
-75
lines changed

9 files changed

+352
-75
lines changed

tests/entrypoints/openai/test_chat.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,10 +1113,7 @@ async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer):
11131113

11141114

11151115
@pytest.mark.asyncio
1116-
@pytest.mark.parametrize("model_name", [MODEL_NAME, ""])
1117-
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer,
1118-
model_name: str):
1119-
1116+
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer):
11201117
openai_api_key = "EMPTY"
11211118
openai_api_base = f"http://localhost:{server.port}/v1"
11221119

@@ -1135,3 +1132,35 @@ async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer,
11351132
messages=messages,
11361133
)
11371134
assert response.model == MODEL_NAME
1135+
1136+
1137+
@pytest.mark.asyncio
1138+
async def test_invocations(server: RemoteOpenAIServer,
1139+
client: openai.AsyncOpenAI):
1140+
messages = [{
1141+
"role": "system",
1142+
"content": "you are a helpful assistant"
1143+
}, {
1144+
"role": "user",
1145+
"content": "what is 1+1?"
1146+
}]
1147+
1148+
request_args = {
1149+
"model": MODEL_NAME,
1150+
"messages": messages,
1151+
"max_completion_tokens": 5,
1152+
"temperature": 0.0,
1153+
"logprobs": False,
1154+
}
1155+
1156+
chat_completion = await client.chat.completions.create(**request_args)
1157+
1158+
invocation_response = requests.post(server.url_for("invocations"),
1159+
json=request_args)
1160+
invocation_response.raise_for_status()
1161+
1162+
chat_output = chat_completion.model_dump()
1163+
invocation_output = invocation_response.json()
1164+
1165+
assert chat_output.keys() == invocation_output.keys()
1166+
assert chat_output["choices"] == invocation_output["choices"]

tests/entrypoints/openai/test_classification.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,25 @@ def test_batch_classification_empty_list(server: RemoteOpenAIServer,
155155
assert output.object == "list"
156156
assert isinstance(output.data, list)
157157
assert len(output.data) == 0
158+
159+
160+
@pytest.mark.asyncio
161+
async def test_invocations(server: RemoteOpenAIServer):
162+
request_args = {
163+
"model": MODEL_NAME,
164+
"input": "This product was excellent and exceeded my expectations"
165+
}
166+
167+
classification_response = requests.post(server.url_for("classify"),
168+
json=request_args)
169+
classification_response.raise_for_status()
170+
171+
invocation_response = requests.post(server.url_for("invocations"),
172+
json=request_args)
173+
invocation_response.raise_for_status()
174+
175+
classification_output = classification_response.json()
176+
invocation_output = invocation_response.json()
177+
178+
assert classification_output.keys() == invocation_output.keys()
179+
assert classification_output["data"] == invocation_output["data"]

tests/entrypoints/openai/test_completion.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pytest
1212
import pytest_asyncio
1313
import regex as re
14+
import requests
1415
# downloading lora to test lora requests
1516
from huggingface_hub import snapshot_download
1617
from openai import BadRequestError
@@ -833,3 +834,27 @@ async def test_echo_stream_completion(client: openai.AsyncOpenAI,
833834
assert content is not None and saying in content
834835
else:
835836
assert content is not None and saying not in content
837+
838+
839+
@pytest.mark.asyncio
840+
async def test_invocations(server: RemoteOpenAIServer,
841+
client: openai.AsyncOpenAI):
842+
request_args = {
843+
"model": MODEL_NAME,
844+
"prompt": "Hello, my name is",
845+
"max_tokens": 5,
846+
"temperature": 0.0,
847+
"logprobs": None,
848+
}
849+
850+
completion = await client.completions.create(**request_args)
851+
852+
invocation_response = requests.post(server.url_for("invocations"),
853+
json=request_args)
854+
invocation_response.raise_for_status()
855+
856+
completion_output = completion.model_dump()
857+
invocation_output = invocation_response.json()
858+
859+
assert completion_output.keys() == invocation_output.keys()
860+
assert completion_output["choices"] == invocation_output["choices"]

tests/entrypoints/openai/test_embedding.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,63 @@ async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI,
296296
assert "error" in response.object
297297
assert "truncate_prompt_tokens value is greater than max_model_len. "\
298298
"Please, select a smaller truncation size." in response.message
299+
300+
301+
@pytest.mark.asyncio
302+
async def test_invocations(server: RemoteOpenAIServer,
303+
client: openai.AsyncOpenAI):
304+
input_texts = [
305+
"The chef prepared a delicious meal.",
306+
]
307+
308+
request_args = {
309+
"model": MODEL_NAME,
310+
"input": input_texts,
311+
"encoding_format": "float",
312+
}
313+
314+
completion_response = await client.embeddings.create(**request_args)
315+
316+
invocation_response = requests.post(server.url_for("invocations"),
317+
json=request_args)
318+
invocation_response.raise_for_status()
319+
320+
completion_output = completion_response.model_dump()
321+
invocation_output = invocation_response.json()
322+
323+
assert completion_output.keys() == invocation_output.keys()
324+
assert completion_output["data"] == invocation_output["data"]
325+
326+
327+
@pytest.mark.asyncio
328+
async def test_invocations_conversation(server: RemoteOpenAIServer):
329+
messages = [{
330+
"role": "user",
331+
"content": "The cat sat on the mat.",
332+
}, {
333+
"role": "assistant",
334+
"content": "A feline was resting on a rug.",
335+
}, {
336+
"role": "user",
337+
"content": "Stars twinkle brightly in the night sky.",
338+
}]
339+
340+
request_args = {
341+
"model": MODEL_NAME,
342+
"messages": messages,
343+
"encoding_format": "float",
344+
}
345+
346+
chat_response = requests.post(server.url_for("v1/embeddings"),
347+
json=request_args)
348+
chat_response.raise_for_status()
349+
350+
invocation_response = requests.post(server.url_for("invocations"),
351+
json=request_args)
352+
invocation_response.raise_for_status()
353+
354+
chat_output = chat_response.json()
355+
invocation_output = invocation_response.json()
356+
357+
assert chat_output.keys() == invocation_output.keys()
358+
assert chat_output["data"] == invocation_output["data"]

tests/entrypoints/openai/test_pooling.py

Lines changed: 91 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,24 @@
1313

1414
from ...utils import RemoteOpenAIServer
1515

16-
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
16+
MODEL_NAME = "internlm/internlm2-1_8b-reward"
1717
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
1818

1919

2020
@pytest.fixture(scope="module")
2121
def server():
2222
args = [
2323
"--task",
24-
"classify",
24+
"reward",
2525
# use half precision for speed and memory savings in CI environment
2626
"--dtype",
2727
"bfloat16",
2828
"--enforce-eager",
2929
"--max-model-len",
30-
"8192",
30+
"512",
3131
"--chat-template",
3232
DUMMY_CHAT_TEMPLATE,
33+
"--trust-remote-code",
3334
]
3435

3536
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@@ -57,10 +58,10 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
5758

5859
assert poolings.id is not None
5960
assert len(poolings.data) == 1
60-
assert len(poolings.data[0].data) == 2
61+
assert len(poolings.data[0].data) == 8
6162
assert poolings.usage.completion_tokens == 0
62-
assert poolings.usage.prompt_tokens == 7
63-
assert poolings.usage.total_tokens == 7
63+
assert poolings.usage.prompt_tokens == 8
64+
assert poolings.usage.total_tokens == 8
6465

6566
# test using token IDs
6667
input_tokens = [1, 1, 1, 1, 1]
@@ -77,7 +78,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
7778

7879
assert poolings.id is not None
7980
assert len(poolings.data) == 1
80-
assert len(poolings.data[0].data) == 2
81+
assert len(poolings.data[0].data) == 5
8182
assert poolings.usage.completion_tokens == 0
8283
assert poolings.usage.prompt_tokens == 5
8384
assert poolings.usage.total_tokens == 5
@@ -104,10 +105,10 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
104105

105106
assert poolings.id is not None
106107
assert len(poolings.data) == 3
107-
assert len(poolings.data[0].data) == 2
108+
assert len(poolings.data[0].data) == 8
108109
assert poolings.usage.completion_tokens == 0
109-
assert poolings.usage.prompt_tokens == 25
110-
assert poolings.usage.total_tokens == 25
110+
assert poolings.usage.prompt_tokens == 29
111+
assert poolings.usage.total_tokens == 29
111112

112113
# test list[list[int]]
113114
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
@@ -125,7 +126,7 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
125126

126127
assert poolings.id is not None
127128
assert len(poolings.data) == 4
128-
assert len(poolings.data[0].data) == 2
129+
assert len(poolings.data[0].data) == 5
129130
assert poolings.usage.completion_tokens == 0
130131
assert poolings.usage.prompt_tokens == 17
131132
assert poolings.usage.total_tokens == 17
@@ -157,7 +158,11 @@ async def test_conversation_pooling(server: RemoteOpenAIServer,
157158
chat_response.raise_for_status()
158159
chat_poolings = PoolingResponse.model_validate(chat_response.json())
159160

160-
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
161+
tokenizer = get_tokenizer(
162+
tokenizer_name=model_name,
163+
tokenizer_mode="fast",
164+
trust_remote_code=True,
165+
)
161166
prompt = tokenizer.apply_chat_template(
162167
messages,
163168
chat_template=DUMMY_CHAT_TEMPLATE,
@@ -206,6 +211,9 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
206211
)
207212
float_response.raise_for_status()
208213
responses_float = PoolingResponse.model_validate(float_response.json())
214+
float_data = [
215+
np.array(d.data).squeeze(-1).tolist() for d in responses_float.data
216+
]
209217

210218
base64_response = requests.post(
211219
server.url_for("pooling"),
@@ -224,11 +232,10 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
224232
np.frombuffer(base64.b64decode(data.data),
225233
dtype="float32").tolist())
226234

227-
check_embeddings_close(
228-
embeddings_0_lst=[d.data for d in responses_float.data],
229-
embeddings_1_lst=decoded_responses_base64_data,
230-
name_0="float32",
231-
name_1="base64")
235+
check_embeddings_close(embeddings_0_lst=float_data,
236+
embeddings_1_lst=decoded_responses_base64_data,
237+
name_0="float32",
238+
name_1="base64")
232239

233240
# Default response is float32 decoded from base64 by OpenAI Client
234241
default_response = requests.post(
@@ -240,9 +247,71 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
240247
)
241248
default_response.raise_for_status()
242249
responses_default = PoolingResponse.model_validate(default_response.json())
250+
default_data = [
251+
np.array(d.data).squeeze(-1).tolist() for d in responses_default.data
252+
]
253+
254+
check_embeddings_close(embeddings_0_lst=float_data,
255+
embeddings_1_lst=default_data,
256+
name_0="float32",
257+
name_1="default")
258+
259+
260+
@pytest.mark.asyncio
261+
async def test_invocations(server: RemoteOpenAIServer):
262+
input_texts = [
263+
"The chef prepared a delicious meal.",
264+
]
265+
266+
request_args = {
267+
"model": MODEL_NAME,
268+
"input": input_texts,
269+
"encoding_format": "float",
270+
}
271+
272+
completion_response = requests.post(server.url_for("pooling"),
273+
json=request_args)
274+
completion_response.raise_for_status()
275+
276+
invocation_response = requests.post(server.url_for("invocations"),
277+
json=request_args)
278+
invocation_response.raise_for_status()
279+
280+
completion_output = completion_response.json()
281+
invocation_output = invocation_response.json()
282+
283+
assert completion_output.keys() == invocation_output.keys()
284+
assert completion_output["data"] == invocation_output["data"]
285+
286+
287+
@pytest.mark.asyncio
288+
async def test_invocations_conversation(server: RemoteOpenAIServer):
289+
messages = [{
290+
"role": "user",
291+
"content": "The cat sat on the mat.",
292+
}, {
293+
"role": "assistant",
294+
"content": "A feline was resting on a rug.",
295+
}, {
296+
"role": "user",
297+
"content": "Stars twinkle brightly in the night sky.",
298+
}]
299+
300+
request_args = {
301+
"model": MODEL_NAME,
302+
"messages": messages,
303+
"encoding_format": "float",
304+
}
305+
306+
chat_response = requests.post(server.url_for("pooling"), json=request_args)
307+
chat_response.raise_for_status()
308+
309+
invocation_response = requests.post(server.url_for("invocations"),
310+
json=request_args)
311+
invocation_response.raise_for_status()
312+
313+
chat_output = chat_response.json()
314+
invocation_output = invocation_response.json()
243315

244-
check_embeddings_close(
245-
embeddings_0_lst=[d.data for d in responses_default.data],
246-
embeddings_1_lst=[d.data for d in responses_default.data],
247-
name_0="float32",
248-
name_1="base64")
316+
assert chat_output.keys() == invocation_output.keys()
317+
assert chat_output["data"] == invocation_output["data"]

tests/entrypoints/openai/test_rerank.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,30 @@ def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
9494
# Assert just a small fragments of the response
9595
assert "Please reduce the length of the input." in \
9696
rerank_response.text
97+
98+
99+
def test_invocations(server: RemoteOpenAIServer):
100+
query = "What is the capital of France?"
101+
documents = [
102+
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
103+
]
104+
105+
request_args = {
106+
"model": MODEL_NAME,
107+
"query": query,
108+
"documents": documents,
109+
}
110+
111+
rerank_response = requests.post(server.url_for("rerank"),
112+
json=request_args)
113+
rerank_response.raise_for_status()
114+
115+
invocation_response = requests.post(server.url_for("invocations"),
116+
json=request_args)
117+
invocation_response.raise_for_status()
118+
119+
rerank_output = rerank_response.json()
120+
invocation_output = invocation_response.json()
121+
122+
assert rerank_output.keys() == invocation_output.keys()
123+
assert rerank_output["results"] == invocation_output["results"]

0 commit comments

Comments
 (0)