Skip to content

Commit 89cc061

Browse files
noooophuydhn
authored andcommitted
[Model][2/N] Automatic conversion of CrossEncoding model (vllm-project#19978)
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent 1407e7d commit 89cc061

File tree

16 files changed

+200
-93
lines changed

16 files changed

+200
-93
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ Specified using `--task classify`.
471471
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |
472472
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ |
473473
If your model is not in the above list, we will try to automatically convert the model using
474-
[as_classification_model][vllm.model_executor.models.adapters.as_classification_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
474+
[as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
475475

476476
#### Sentence Pair Scoring
477477

docs/serving/openai_compatible_server.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ Code example: <gh-file:examples/online_serving/openai_pooling_client.py>
426426

427427
Our Classification API directly supports Hugging Face sequence-classification models such as [ai21labs/Jamba-tiny-reward-dev](https://huggingface.co/ai21labs/Jamba-tiny-reward-dev) and [jason9693/Qwen2.5-1.5B-apeach](https://huggingface.co/jason9693/Qwen2.5-1.5B-apeach).
428428

429-
We automatically wrap any other transformer via `as_classification_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
429+
We automatically wrap any other transformer via `as_seq_cls_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
430430

431431
Code example: <gh-file:examples/online_serving/openai_classification_client.py>
432432

tests/entrypoints/openai/correctness/test_mteb_score.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,16 @@
66

77
# yapf conflicts with isort for this block
88
# yapf: disable
9-
from tests.models.language.pooling.mteb_utils import (MTEB_RERANK_LANGS,
10-
MTEB_RERANK_TASKS,
11-
MTEB_RERANK_TOL,
12-
RerankClientMtebEncoder,
13-
ScoreClientMtebEncoder,
14-
run_mteb_rerank)
9+
from tests.models.language.pooling.mteb_utils import (
10+
MTEB_RERANK_LANGS, MTEB_RERANK_TASKS, MTEB_RERANK_TOL,
11+
RerankClientMtebEncoder, ScoreClientMtebEncoder,
12+
mteb_test_rerank_models_hf, run_mteb_rerank)
1513
# yapf: enable
1614
from tests.utils import RemoteOpenAIServer
1715

1816
os.environ["VLLM_LOGGING_LEVEL"] = "WARNING"
1917

2018
MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
21-
MAIN_SCORE = 0.33437
2219

2320

2421
@pytest.fixture(scope="module")
@@ -31,12 +28,19 @@ def server():
3128
yield remote_server
3229

3330

34-
def test_mteb_score(server):
31+
@pytest.fixture(scope="module")
32+
def st_main_score(hf_runner):
33+
# The main score related to the version of the dependency.
34+
# So we need to recalculate every time.
35+
main_score, st_dtype = mteb_test_rerank_models_hf(hf_runner, MODEL_NAME)
36+
return main_score
37+
38+
39+
def test_mteb_score(server, st_main_score):
3540
url = server.url_for("score")
3641
encoder = ScoreClientMtebEncoder(MODEL_NAME, url)
3742
vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS,
3843
MTEB_RERANK_LANGS)
39-
st_main_score = MAIN_SCORE
4044

4145
print("VLLM main score: ", vllm_main_score)
4246
print("SentenceTransformer main score: ", st_main_score)
@@ -45,12 +49,11 @@ def test_mteb_score(server):
4549
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL)
4650

4751

48-
def test_mteb_rerank(server):
52+
def test_mteb_rerank(server, st_main_score):
4953
url = server.url_for("rerank")
5054
encoder = RerankClientMtebEncoder(MODEL_NAME, url)
5155
vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS,
5256
MTEB_RERANK_LANGS)
53-
st_main_score = MAIN_SCORE
5457

5558
print("VLLM main score: ", vllm_main_score)
5659
print("SentenceTransformer main score: ", st_main_score)

tests/models/language/pooling/mteb_utils.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,35 @@ def run_mteb_rerank(cross_encoder, tasks, languages):
234234
return main_score
235235

236236

237+
def mteb_test_rerank_models_hf(hf_runner, model_name, hf_model_callback=None):
238+
with hf_runner(model_name, is_cross_encoder=True,
239+
dtype="float32") as hf_model:
240+
241+
original_predict = hf_model.predict
242+
243+
def _predict(
244+
sentences: list[tuple[str, str,
245+
Optional[str]]], # query, corpus, prompt
246+
*args,
247+
**kwargs,
248+
):
249+
# vllm and st both remove the prompt, fair comparison.
250+
prompts = [(s[0], s[1]) for s in sentences]
251+
return original_predict(prompts, *args, **kwargs, batch_size=8)
252+
253+
hf_model.predict = _predict
254+
hf_model.original_predict = original_predict
255+
256+
if hf_model_callback is not None:
257+
hf_model_callback(hf_model)
258+
259+
st_main_score = run_mteb_rerank(hf_model,
260+
tasks=MTEB_RERANK_TASKS,
261+
languages=MTEB_RERANK_LANGS)
262+
st_dtype = next(hf_model.model.model.parameters()).dtype
263+
return st_main_score, st_dtype
264+
265+
237266
def mteb_test_rerank_models(hf_runner,
238267
vllm_runner,
239268
model_info: RerankModelInfo,
@@ -264,31 +293,8 @@ def mteb_test_rerank_models(hf_runner,
264293
languages=MTEB_RERANK_LANGS)
265294
vllm_dtype = model_config.dtype
266295

267-
with hf_runner(model_info.name, is_cross_encoder=True,
268-
dtype="float32") as hf_model:
269-
270-
original_predict = hf_model.predict
271-
272-
def _predict(
273-
sentences: list[tuple[str, str,
274-
Optional[str]]], # query, corpus, prompt
275-
*args,
276-
**kwargs,
277-
):
278-
# vllm and st both remove the prompt, fair comparison.
279-
prompts = [(s[0], s[1]) for s in sentences]
280-
return original_predict(prompts, *args, **kwargs, batch_size=8)
281-
282-
hf_model.predict = _predict
283-
hf_model.original_predict = original_predict
284-
285-
if hf_model_callback is not None:
286-
hf_model_callback(hf_model)
287-
288-
st_main_score = run_mteb_rerank(hf_model,
289-
tasks=MTEB_RERANK_TASKS,
290-
languages=MTEB_RERANK_LANGS)
291-
st_dtype = next(hf_model.model.model.parameters()).dtype
296+
st_main_score, st_dtype = mteb_test_rerank_models_hf(
297+
hf_runner, model_info.name, hf_model_callback)
292298

293299
print("VLLM:", vllm_dtype, vllm_main_score)
294300
print("SentenceTransformers:", st_dtype, st_main_score)

tests/models/test_registry.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from vllm.model_executor.models import (is_pooling_model,
1010
is_text_generation_model,
1111
supports_multimodal)
12-
from vllm.model_executor.models.adapters import (as_classification_model,
13-
as_embedding_model,
14-
as_reward_model)
12+
from vllm.model_executor.models.adapters import (as_embedding_model,
13+
as_reward_model,
14+
as_seq_cls_model)
1515
from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
1616
_SPECULATIVE_DECODING_MODELS,
1717
_TEXT_GENERATION_MODELS,
@@ -38,7 +38,7 @@ def test_registry_imports(model_arch):
3838
assert is_text_generation_model(model_cls)
3939

4040
# All vLLM models should be convertible to a pooling model
41-
assert is_pooling_model(as_classification_model(model_cls))
41+
assert is_pooling_model(as_seq_cls_model(model_cls))
4242
assert is_pooling_model(as_embedding_model(model_cls))
4343
assert is_pooling_model(as_reward_model(model_cls))
4444

tests/test_config.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_get_field():
5252
("distilbert/distilgpt2", "generate", "generate"),
5353
("intfloat/multilingual-e5-small", "pooling", "embed"),
5454
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
55-
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"),
55+
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
5656
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
5757
("openai/whisper-small", "transcription", "transcription"),
5858
],
@@ -72,6 +72,32 @@ def test_auto_task(model_id, expected_runner_type, expected_task):
7272
assert config.task == expected_task
7373

7474

75+
@pytest.mark.parametrize(
76+
("model_id", "expected_runner_type", "expected_task"),
77+
[
78+
("distilbert/distilgpt2", "pooling", "embed"),
79+
("intfloat/multilingual-e5-small", "pooling", "embed"),
80+
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
81+
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
82+
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "embed"),
83+
("openai/whisper-small", "pooling", "embed"),
84+
],
85+
)
86+
def test_score_task(model_id, expected_runner_type, expected_task):
87+
config = ModelConfig(
88+
model_id,
89+
task="score",
90+
tokenizer=model_id,
91+
tokenizer_mode="auto",
92+
trust_remote_code=False,
93+
seed=0,
94+
dtype="float16",
95+
)
96+
97+
assert config.runner_type == expected_runner_type
98+
assert config.task == expected_task
99+
100+
75101
@pytest.mark.parametrize(("model_id", "bad_task"), [
76102
("Qwen/Qwen2.5-Math-RM-72B", "generate"),
77103
])

vllm/config.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,14 @@
9393
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
9494
"score", "reward", "transcription"]
9595

96-
_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
97-
"draft", "transcription"]
96+
_ResolvedTask = Literal["generate", "embed", "classify", "reward", "draft",
97+
"transcription"]
9898

9999
RunnerType = Literal["generate", "pooling", "draft", "transcription"]
100100

101101
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
102102
"generate": ["generate"],
103-
"pooling": ["embed", "classify", "score", "reward"],
103+
"pooling": ["embed", "classify", "reward"],
104104
"draft": ["draft"],
105105
"transcription": ["transcription"],
106106
}
@@ -777,7 +777,7 @@ def _get_preferred_task(
777777
if get_pooling_config(model_id, self.revision):
778778
return "embed"
779779
if self.registry.is_cross_encoder_model(architectures):
780-
return "score"
780+
return "classify"
781781
if self.registry.is_transcription_model(architectures):
782782
return "transcription"
783783

@@ -841,14 +841,24 @@ def _resolve_task(
841841
"This model supports multiple tasks: %s. "
842842
"Defaulting to '%s'.", supported_tasks, selected_task)
843843
else:
844-
# Aliases
845-
if task_option == "embedding":
846-
msg = ("The 'embedding' task has been renamed to "
847-
"'embed', please use the new name. The old name "
848-
"will be removed in v1.0.")
849-
warnings.warn(msg, DeprecationWarning, stacklevel=2)
850-
851-
task_option = "embed"
844+
if task_option == "score":
845+
if not runner_support["pooling"]:
846+
msg = (f"This model does not support the '{task_option}' "
847+
f"task. Supported tasks: {supported_tasks}")
848+
raise ValueError(msg)
849+
if self.registry.is_cross_encoder_model(architectures):
850+
task_option = "classify"
851+
else:
852+
task_option = "embed"
853+
else:
854+
# Aliases
855+
if task_option == "embedding":
856+
msg = ("The 'embedding' task has been renamed to "
857+
"'embed', please use the new name. The old name "
858+
"will be removed in v1.0.")
859+
warnings.warn(msg, DeprecationWarning, stacklevel=2)
860+
861+
task_option = "embed"
852862

853863
if task_option not in supported_tasks:
854864
msg = (

vllm/entrypoints/llm.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,9 +1289,13 @@ def score(
12891289

12901290
raise ValueError(" ".join(messages))
12911291

1292-
if self.llm_engine.model_config.task not in ("embed", "score"):
1293-
raise ValueError(
1294-
"Score API is only enabled for `--task embed or --task score`")
1292+
if self.llm_engine.model_config.task not in ("embed", "classify"):
1293+
raise ValueError("Score API is only enabled for "
1294+
"`--task embed or --task classify`.")
1295+
1296+
if (self.llm_engine.model_config.task == "classify"
1297+
and self.llm_engine.model_config.hf_config.num_labels != 1):
1298+
raise ValueError("Score API is only enabled for num_labels == 1.")
12951299

12961300
# the tokenizer for models such as
12971301
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing

vllm/entrypoints/openai/api_server.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,24 +1311,27 @@ async def init_app_state(
13111311
chat_template=resolved_chat_template,
13121312
chat_template_content_format=args.chat_template_content_format,
13131313
) if model_config.task == "embed" else None
1314-
state.openai_serving_scores = ServingScores(
1315-
engine_client,
1316-
model_config,
1317-
state.openai_serving_models,
1318-
request_logger=request_logger) if model_config.task in (
1319-
"score", "embed", "pooling") else None
13201314
state.openai_serving_classification = ServingClassification(
13211315
engine_client,
13221316
model_config,
13231317
state.openai_serving_models,
13241318
request_logger=request_logger,
13251319
) if model_config.task == "classify" else None
1320+
1321+
enable_serving_reranking = (model_config.task == "classify" and getattr(
1322+
model_config.hf_config, "num_labels", 0) == 1)
13261323
state.jinaai_serving_reranking = ServingScores(
13271324
engine_client,
13281325
model_config,
13291326
state.openai_serving_models,
1330-
request_logger=request_logger
1331-
) if model_config.task == "score" else None
1327+
request_logger=request_logger) if enable_serving_reranking else None
1328+
state.openai_serving_scores = ServingScores(
1329+
engine_client,
1330+
model_config,
1331+
state.openai_serving_models,
1332+
request_logger=request_logger) if (
1333+
model_config.task == "embed" or enable_serving_reranking) else None
1334+
13321335
state.openai_serving_tokenization = OpenAIServingTokenization(
13331336
engine_client,
13341337
model_config,

vllm/entrypoints/openai/run_batch.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,12 +357,16 @@ async def main(args):
357357
chat_template=None,
358358
chat_template_content_format="auto",
359359
) if model_config.task == "embed" else None
360+
361+
enable_serving_reranking = (model_config.task == "classify" and getattr(
362+
model_config.hf_config, "num_labels", 0) == 1)
363+
360364
openai_serving_scores = (ServingScores(
361365
engine,
362366
model_config,
363367
openai_serving_models,
364368
request_logger=request_logger,
365-
) if model_config.task == "score" else None)
369+
) if (model_config.task == "embed" or enable_serving_reranking) else None)
366370

367371
tracker = BatchProgressTracker()
368372
logger.info("Reading batch from %s...", args.input_file)

0 commit comments

Comments
 (0)