Skip to content

Commit c072386

Browse files
committed
+test
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent 462b269 commit c072386

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch.nn.functional as F
5+
6+
from vllm.config import PoolerConfig
7+
from vllm.multimodal.inputs import torch
8+
9+
10+
@pytest.mark.parametrize(
11+
"model",
12+
[
13+
"jason9693/Qwen2.5-1.5B-apeach",
14+
"papluca/xlm-roberta-base-language-detection"
15+
],
16+
)
17+
@pytest.mark.parametrize("dtype", ["half"])
18+
def test_models(
19+
hf_runner,
20+
vllm_runner,
21+
example_prompts,
22+
model: str,
23+
dtype: str,
24+
) -> None:
25+
26+
with vllm_runner(
27+
model,
28+
max_model_len=512,
29+
dtype=dtype,
30+
override_pooler_config=PoolerConfig(softmax=False)) as vllm_model:
31+
wo_softmax_out = vllm_model.classify(example_prompts)
32+
33+
with vllm_runner(
34+
model,
35+
max_model_len=512,
36+
dtype=dtype,
37+
override_pooler_config=PoolerConfig(softmax=True)) as vllm_model:
38+
w_softmax_out = vllm_model.classify(example_prompts)
39+
40+
for wo_softmax, w_softmax in zip(wo_softmax_out, w_softmax_out):
41+
wo_softmax = torch.tensor(wo_softmax)
42+
w_softmax = torch.tensor(w_softmax)
43+
44+
assert not torch.allclose(
45+
wo_softmax, w_softmax,
46+
atol=1e-2), "override_pooler_config is not working"
47+
assert torch.allclose(F.softmax(wo_softmax, dim=-1), w_softmax,
48+
1e-3 if dtype == "float" else 1e-2)

0 commit comments

Comments
 (0)