Skip to content

Commit 61f4fc5

Browse files
authored
[Bugfix][v1] Fix step pooler implementation and step pooling usage in v1 (#19956)
Signed-off-by: Isotr0py <2037008807@qq.com>
1 parent 68aaeb3 commit 61f4fc5

File tree

14 files changed

+164
-40
lines changed

14 files changed

+164
-40
lines changed

tests/conftest.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,13 +1027,13 @@ def classify(self, prompts: list[str]) -> list[list[float]]:
10271027
req_outputs = self.model.classify(prompts)
10281028
return [req_output.outputs.probs for req_output in req_outputs]
10291029

1030-
def encode(self,
1031-
prompts: list[str],
1032-
images: Optional[PromptImageInput] = None,
1033-
videos: Optional[PromptVideoInput] = None,
1034-
audios: Optional[PromptAudioInput] = None,
1035-
*args,
1036-
**kwargs) -> list[list[float]]:
1030+
def embed(self,
1031+
prompts: list[str],
1032+
images: Optional[PromptImageInput] = None,
1033+
videos: Optional[PromptVideoInput] = None,
1034+
audios: Optional[PromptAudioInput] = None,
1035+
*args,
1036+
**kwargs) -> list[list[float]]:
10371037
inputs = self.get_inputs(prompts,
10381038
images=images,
10391039
videos=videos,
@@ -1042,6 +1042,10 @@ def encode(self,
10421042
req_outputs = self.model.embed(inputs, *args, **kwargs)
10431043
return [req_output.outputs.embedding for req_output in req_outputs]
10441044

1045+
def encode(self, prompts: list[str]) -> list[list[float]]:
1046+
req_outputs = self.model.encode(prompts)
1047+
return [req_output.outputs.data for req_output in req_outputs]
1048+
10451049
def score(
10461050
self,
10471051
text_1: Union[str, list[str]],

tests/model_executor/test_model_load_with_params.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def test_model_loading_with_params(vllm_runner):
2929
revision=REVISION,
3030
dtype="float16",
3131
max_model_len=MAX_MODEL_LEN) as vllm_model:
32-
output = vllm_model.encode("Write a short story about a robot that"
33-
" dreams for the first time.\n")
32+
output = vllm_model.embed("Write a short story about a robot that"
33+
" dreams for the first time.\n")
3434

3535
model_config = vllm_model.model.llm_engine.model_config
3636
model_tokenizer = vllm_model.model.llm_engine.tokenizer
@@ -67,8 +67,8 @@ def test_roberta_model_loading_with_params(vllm_runner):
6767
revision=REVISION_ROBERTA,
6868
dtype="float16",
6969
max_model_len=MAX_MODEL_LEN) as vllm_model:
70-
output = vllm_model.encode("Write a short story about a robot that"
71-
" dreams for the first time.\n")
70+
output = vllm_model.embed("Write a short story about a robot that"
71+
" dreams for the first time.\n")
7272

7373
model_config = vllm_model.model.llm_engine.model_config
7474
model_tokenizer = vllm_model.model.llm_engine.tokenizer
@@ -105,8 +105,8 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner):
105105
with vllm_runner(model_name=model_name,
106106
dtype="float16",
107107
max_model_len=MAX_MODEL_LEN) as vllm_model:
108-
output = vllm_model.encode("Write a short story about a robot that"
109-
" dreams for the first time.\n")
108+
output = vllm_model.embed("Write a short story about a robot that"
109+
" dreams for the first time.\n")
110110

111111
model_tokenizer = vllm_model.model.llm_engine.tokenizer
112112
assert model_tokenizer.tokenizer_id == model_name

tests/models/language/pooling/embed_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def correctness_test_embed_models(hf_runner,
5555
task="embed",
5656
max_model_len=None,
5757
**vllm_extra_kwargs) as vllm_model:
58-
vllm_outputs = vllm_model.encode(example_prompts)
58+
vllm_outputs = vllm_model.embed(example_prompts)
5959

6060
with hf_runner(
6161
model_info.name,

tests/models/language/pooling/test_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_models(
8989
task="embed",
9090
max_model_len=512,
9191
**vllm_extra_kwargs) as vllm_model:
92-
vllm_outputs = vllm_model.encode(example_prompts)
92+
vllm_outputs = vllm_model.embed(example_prompts)
9393

9494
check_embeddings_close(
9595
embeddings_0_lst=hf_outputs,

tests/models/language/pooling/test_jina.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,11 @@ def test_matryoshka(
9898

9999
if dimensions not in matryoshka_dimensions:
100100
with pytest.raises(ValueError):
101-
vllm_model.encode(
101+
vllm_model.embed(
102102
example_prompts,
103103
pooling_params=PoolingParams(dimensions=dimensions))
104104
else:
105-
vllm_outputs = vllm_model.encode(
105+
vllm_outputs = vllm_model.embed(
106106
example_prompts,
107107
pooling_params=PoolingParams(dimensions=dimensions))
108108

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
import torch.nn.functional as F
6+
from transformers import AutoModel
7+
8+
from vllm.platforms import current_platform
9+
10+
from ....conftest import HfRunner
11+
12+
13+
@pytest.fixture(autouse=True)
14+
def v1(run_with_both_engines):
15+
# Simple autouse wrapper to run both engines for each test
16+
# This can be promoted up to conftest.py to run for every
17+
# test in a package
18+
pass
19+
20+
21+
@pytest.fixture
22+
def math_step_prompts():
23+
# ruff: noqa: E501
24+
data = {
25+
"system":
26+
"Please reason step by step, and put your final answer within \\boxed{}. ",
27+
"query":
28+
"Sue lives in a fun neighborhood. One weekend, the neighbors decided to play a prank on Sue. On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard. On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and put these newly painted white flamingos back out on Sue's front yard. Then, on Sunday morning, they added another 18 pink plastic flamingos to the collection. At noon on Sunday, how many more pink plastic flamingos were out than white plastic flamingos?",
29+
"response": [
30+
"To find out how many more pink plastic flamingos were out than white plastic flamingos at noon on Sunday, we can break down the problem into steps. First, on Friday, the neighbors start with 18 pink plastic flamingos.",
31+
"On Saturday, they take back one third of the flamingos. Since there were 18 flamingos, (1/3 \\times 18 = 6) flamingos are taken back. So, they have (18 - 6 = 12) flamingos left in their possession. Then, they paint these 6 flamingos white and put them back out on Sue's front yard. Now, Sue has the original 12 pink flamingos plus the 6 new white ones. Thus, by the end of Saturday, Sue has (12 + 6 = 18) pink flamingos and 6 white flamingos.",
32+
"On Sunday, the neighbors add another 18 pink plastic flamingos to Sue's front yard. By the end of Sunday morning, Sue has (18 + 18 = 36) pink flamingos and still 6 white flamingos.",
33+
"To find the difference, subtract the number of white flamingos from the number of pink flamingos: (36 - 6 = 30). Therefore, at noon on Sunday, there were 30 more pink plastic flamingos out than white plastic flamingos. The answer is (\\boxed{30}).",
34+
],
35+
}
36+
answer = "<extra_0>".join(data['response']) + "<extra_0>"
37+
prompt = f"<im_start>system\n{data['system']}<im_end>\n<im_start>user\n{data['query']}<im_end>\n<im_start>assistant\n{answer}<im_end><|endoftext|>"
38+
return [prompt]
39+
40+
41+
def step_reward_patch_hf_model(hf_model: HfRunner):
42+
43+
# Patch the hf_runner to use the step reward function
44+
def make_step_rewards(logits: torch.Tensor,
45+
token_masks: torch.Tensor) -> list[list[float]]:
46+
probabilities = F.softmax(logits, dim=-1)
47+
probabilities = probabilities * token_masks.unsqueeze(-1)
48+
49+
all_scores_res: list[list[float]] = []
50+
for i in range(probabilities.size(0)):
51+
sample = probabilities[i] # seq_len, num_labels
52+
positive_probs = sample[sample != 0].view(-1, 2)
53+
non_zero_elements_list = positive_probs.cpu().tolist()
54+
all_scores_res.append(non_zero_elements_list)
55+
return all_scores_res
56+
57+
def reward(prompts: list[str]) -> list[list[float]]:
58+
input_ids = hf_model.tokenizer(prompts, return_tensors="pt").input_ids
59+
input_ids = hf_model.wrap_device(input_ids)
60+
outputs = hf_model.model(input_ids=input_ids)
61+
62+
step_sep_id = hf_model.tokenizer.encode("<extra_0>")[0]
63+
token_masks = (input_ids == step_sep_id)
64+
return make_step_rewards(outputs[0], token_masks)
65+
66+
hf_model.reward = reward # type: ignore[attr-defined]
67+
68+
return hf_model
69+
70+
71+
@pytest.mark.parametrize(
72+
"model",
73+
[
74+
pytest.param("Qwen/Qwen2.5-Math-PRM-7B",
75+
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
76+
],
77+
)
78+
@pytest.mark.parametrize("dtype", ["half"])
79+
def test_prm_models(
80+
hf_runner,
81+
vllm_runner,
82+
math_step_prompts,
83+
model: str,
84+
dtype: str,
85+
monkeypatch,
86+
) -> None:
87+
if current_platform.is_rocm():
88+
# ROCm Triton FA does not currently support sliding window attention
89+
# switch to use ROCm CK FA backend
90+
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
91+
92+
with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model:
93+
vllm_outputs = vllm_model.encode(math_step_prompts)
94+
95+
with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model:
96+
hf_model = step_reward_patch_hf_model(hf_model)
97+
hf_outputs = hf_model.reward(math_step_prompts)
98+
99+
# check logits difference
100+
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
101+
hf_output = torch.tensor(hf_output)
102+
vllm_output = torch.tensor(vllm_output)
103+
104+
assert torch.allclose(hf_output, vllm_output, 1e-2)

tests/models/multimodal/pooling/test_dse_qwen2_vl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _run_test(
9898
max_model_len=8192) as vllm_model:
9999
tokenizer = vllm_model.model.get_tokenizer()
100100
texts = [
101-
# this is necessary because vllm_model.encode will not apply any
101+
# this is necessary because vllm_model.embed will not apply any
102102
# templating to the prompt, and therefore lacks an image_pad
103103
# token unless one is inserted beforehand (the (28,28) image
104104
# above is converted to an image pad token by the chat template).
@@ -109,7 +109,7 @@ def _run_test(
109109
# vllm will replace the pad token with the actual image,
110110
# which may be a placeholder image, later.
111111
]
112-
vllm_outputs = vllm_model.encode(texts, images=input_images)
112+
vllm_outputs = vllm_model.embed(texts, images=input_images)
113113

114114
hf_outputs = []
115115
with hf_runner(model,

tests/models/multimodal/pooling/test_llava_next.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _run_test(
6868
dtype=dtype,
6969
max_model_len=4096,
7070
enforce_eager=True) as vllm_model:
71-
vllm_outputs = vllm_model.encode(input_texts, images=input_images)
71+
vllm_outputs = vllm_model.embed(input_texts, images=input_images)
7272

7373
with hf_runner(model, dtype=dtype,
7474
auto_cls=AutoModelForImageTextToText) as hf_model:

tests/models/multimodal/pooling/test_phi3v.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _run_test(
4646
# will hurt multiprocessing backend with fork method (the default method).
4747
with vllm_runner(model, task="embed", dtype=dtype,
4848
enforce_eager=True) as vllm_model:
49-
vllm_outputs = vllm_model.encode(input_texts, images=input_images)
49+
vllm_outputs = vllm_model.embed(input_texts, images=input_images)
5050

5151
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
5252
hf_model_kwargs = {"_attn_implementation": "eager"}

tests/quantization/test_bitsandbytes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def test_4bit_bnb_embedding_model(
161161
dtype=dtype,
162162
gpu_memory_utilization=0.5,
163163
quantization="bitsandbytes") as vllm_model:
164-
vllm_outputs = vllm_model.encode(example_prompts)
164+
vllm_outputs = vllm_model.embed(example_prompts)
165165
check_embeddings_close(
166166
embeddings_0_lst=hf_outputs,
167167
embeddings_1_lst=vllm_outputs,

0 commit comments

Comments
 (0)