Skip to content

Commit 45badd0

Browse files
[Core] Set pooling params based on task and model (#21128)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 4adc66f commit 45badd0

24 files changed

+498
-230
lines changed

tests/models/language/pooling/test_gritlm.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from __future__ import annotations
44

5-
import importlib.util
6-
from array import array
7-
5+
import numpy as np
86
import openai
97
import pytest
108
from scipy.spatial.distance import cosine
@@ -14,10 +12,6 @@
1412

1513
from ....utils import RemoteOpenAIServer
1614

17-
# GritLM embedding implementation is only supported by XFormers backend.
18-
pytestmark = pytest.mark.skipif(not importlib.util.find_spec("xformers"),
19-
reason="GritLM requires XFormers")
20-
2115
MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
2216
MAX_MODEL_LEN = 4000
2317

@@ -26,11 +20,11 @@ def _arr(arr):
2620
"""
2721
Convert a list of integers to an array of integers.
2822
"""
29-
return array("i", arr)
23+
return np.array(arr)
3024

3125

3226
def test_find_array():
33-
from vllm.model_executor.models.gritlm import GritLMPooler
27+
from vllm.model_executor.models.gritlm import GritLMMeanPool
3428

3529
model_config = ModelConfig(
3630
MODEL_NAME,
@@ -41,17 +35,19 @@ def test_find_array():
4135
dtype="bfloat16",
4236
seed=0,
4337
)
44-
pooler = GritLMPooler(model_config=model_config)
38+
pooling = GritLMMeanPool(model_config=model_config)
4539

4640
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
4741

48-
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
49-
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
50-
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
51-
assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1
42+
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
43+
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
44+
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
45+
assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=3) == -1
46+
assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=4) == 3
47+
assert pooling._find_array(arr, _arr([3, 5]), start_idx=0) == -1
5248

5349
with pytest.raises(ValueError):
54-
pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
50+
pooling._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
5551

5652

5753
def run_llm_encode(

vllm/entrypoints/llm.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
4545
PoolingRequestOutput, RequestOutput,
4646
ScoringRequestOutput)
47-
from vllm.pooling_params import PoolingParams
47+
from vllm.pooling_params import PoolingParams, PoolingTask
4848
from vllm.prompt_adapter.request import PromptAdapterRequest
4949
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
5050
RequestOutputKind, SamplingParams)
@@ -964,6 +964,7 @@ def encode(
964964
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
965965
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
966966
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
967+
pooling_task: PoolingTask = "encode",
967968
) -> list[PoolingRequestOutput]:
968969
...
969970

@@ -979,6 +980,7 @@ def encode(
979980
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
980981
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
981982
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
983+
pooling_task: PoolingTask = "encode",
982984
) -> list[PoolingRequestOutput]:
983985
...
984986

@@ -994,6 +996,7 @@ def encode(
994996
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
995997
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
996998
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
999+
pooling_task: PoolingTask = "encode",
9971000
) -> list[PoolingRequestOutput]:
9981001
...
9991002

@@ -1010,6 +1013,7 @@ def encode(
10101013
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
10111014
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
10121015
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1016+
pooling_task: PoolingTask = "encode",
10131017
) -> list[PoolingRequestOutput]:
10141018
...
10151019

@@ -1026,6 +1030,7 @@ def encode(
10261030
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
10271031
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
10281032
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1033+
pooling_task: PoolingTask = "encode",
10291034
) -> list[PoolingRequestOutput]:
10301035
...
10311036

@@ -1040,6 +1045,7 @@ def encode(
10401045
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
10411046
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
10421047
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1048+
pooling_task: PoolingTask = "encode",
10431049
) -> list[PoolingRequestOutput]:
10441050
...
10451051

@@ -1059,6 +1065,7 @@ def encode(
10591065
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
10601066
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
10611067
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1068+
pooling_task: PoolingTask = "encode",
10621069
) -> list[PoolingRequestOutput]:
10631070
"""Apply pooling to the hidden states corresponding to the input
10641071
prompts.
@@ -1080,6 +1087,7 @@ def encode(
10801087
lora_request: LoRA request to use for generation, if any.
10811088
prompt_adapter_request: Prompt Adapter request to use for
10821089
generation, if any.
1090+
pooling_task: Override the pooling task to use.
10831091
10841092
Returns:
10851093
A list of `PoolingRequestOutput` objects containing the
@@ -1116,11 +1124,12 @@ def encode(
11161124
if pooling_params is None:
11171125
# Use default pooling params.
11181126
pooling_params = PoolingParams()
1119-
elif isinstance(pooling_params, PoolingParams):
1120-
pooling_params.verify(model_config)
1127+
1128+
if isinstance(pooling_params, PoolingParams):
1129+
pooling_params.verify(pooling_task, model_config)
11211130
else:
11221131
for pooling_param in pooling_params:
1123-
pooling_param.verify(model_config)
1132+
pooling_param.verify(pooling_task, model_config)
11241133

11251134
tokenization_kwargs = dict[str, Any]()
11261135
_validate_truncation_size(model_config.max_model_len,
@@ -1181,12 +1190,15 @@ def embed(
11811190
raise ValueError("Embedding API is not supported by this model. "
11821191
"Please set `--task embed`.")
11831192

1184-
items = self.encode(prompts,
1185-
truncate_prompt_tokens=truncate_prompt_tokens,
1186-
use_tqdm=use_tqdm,
1187-
pooling_params=pooling_params,
1188-
lora_request=lora_request,
1189-
prompt_adapter_request=prompt_adapter_request)
1193+
items = self.encode(
1194+
prompts,
1195+
truncate_prompt_tokens=truncate_prompt_tokens,
1196+
use_tqdm=use_tqdm,
1197+
pooling_params=pooling_params,
1198+
lora_request=lora_request,
1199+
prompt_adapter_request=prompt_adapter_request,
1200+
pooling_task="embed",
1201+
)
11901202

11911203
return [EmbeddingRequestOutput.from_base(item) for item in items]
11921204

@@ -1228,10 +1240,13 @@ def classify(
12281240
"Classification API is not supported by this model. "
12291241
"Please set `--task classify`.")
12301242

1231-
items = self.encode(prompts,
1232-
use_tqdm=use_tqdm,
1233-
lora_request=lora_request,
1234-
prompt_adapter_request=prompt_adapter_request)
1243+
items = self.encode(
1244+
prompts,
1245+
use_tqdm=use_tqdm,
1246+
lora_request=lora_request,
1247+
prompt_adapter_request=prompt_adapter_request,
1248+
pooling_task="classify",
1249+
)
12351250

12361251
return [ClassificationRequestOutput.from_base(item) for item in items]
12371252

@@ -1251,7 +1266,9 @@ def _embedding_score(
12511266
truncate_prompt_tokens=truncate_prompt_tokens,
12521267
use_tqdm=use_tqdm,
12531268
lora_request=lora_request,
1254-
prompt_adapter_request=prompt_adapter_request)
1269+
prompt_adapter_request=prompt_adapter_request,
1270+
pooling_task="embed",
1271+
)
12551272

12561273
encoded_output_1: list[PoolingRequestOutput] = encoded_output[
12571274
0:len(text_1)]
@@ -1287,7 +1304,7 @@ def _cross_encoding_score(
12871304
if len(data_1) == 1:
12881305
data_1 = data_1 * len(data_2)
12891306

1290-
pooling_params = PoolingParams(use_cross_encoder=True)
1307+
pooling_params = PoolingParams(task="score")
12911308
tokenization_kwargs: dict[str, Any] = {}
12921309
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
12931310
truncate_prompt_tokens, tokenization_kwargs)

vllm/entrypoints/openai/protocol.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,8 +1347,8 @@ class ScoreRequest(OpenAIBaseModel):
13471347

13481348
# --8<-- [end:score-extra-params]
13491349

1350-
def to_pooling_params(self, *, use_cross_encoder: bool = False):
1351-
return PoolingParams(use_cross_encoder=use_cross_encoder)
1350+
def to_pooling_params(self):
1351+
return PoolingParams()
13521352

13531353

13541354
class RerankRequest(OpenAIBaseModel):
@@ -1375,8 +1375,8 @@ class RerankRequest(OpenAIBaseModel):
13751375

13761376
# --8<-- [end:rerank-extra-params]
13771377

1378-
def to_pooling_params(self, *, use_cross_encoder: bool = False):
1379-
return PoolingParams(use_cross_encoder=use_cross_encoder)
1378+
def to_pooling_params(self):
1379+
return PoolingParams()
13801380

13811381

13821382
class RerankDocument(BaseModel):

vllm/entrypoints/openai/serving_classification.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88
from fastapi import Request
9+
from typing_extensions import override
910

1011
from vllm.config import ModelConfig
1112
from vllm.engine.protocol import EngineClient
@@ -21,12 +22,14 @@
2122
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
2223
from vllm.logger import init_logger
2324
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
25+
from vllm.pooling_params import PoolingParams
2426

2527
logger = init_logger(__name__)
2628

2729

2830
class ClassificationMixin(OpenAIServing):
2931

32+
@override
3033
async def _preprocess(
3134
self,
3235
ctx: ServeContext,
@@ -75,6 +78,7 @@ async def _preprocess(
7578
logger.exception("Error in preprocessing prompt inputs")
7679
return self.create_error_response(str(e))
7780

81+
@override
7882
def _build_response(
7983
self,
8084
ctx: ServeContext,
@@ -158,3 +162,31 @@ async def create_classify(
158162
)
159163

160164
return await super().handle(ctx) # type: ignore
165+
166+
@override
167+
def _validate_request(
168+
self,
169+
ctx: ClassificationServeContext,
170+
) -> Optional[ErrorResponse]:
171+
if error := super()._validate_request(ctx):
172+
return error
173+
174+
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
175+
176+
return None
177+
178+
@override
179+
def _create_pooling_params(
180+
self,
181+
ctx: ClassificationServeContext,
182+
) -> Union[PoolingParams, ErrorResponse]:
183+
pooling_params = super()._create_pooling_params(ctx)
184+
if isinstance(pooling_params, ErrorResponse):
185+
return pooling_params
186+
187+
try:
188+
pooling_params.verify("classify", self.model_config)
189+
except ValueError as e:
190+
return self.create_error_response(str(e))
191+
192+
return pooling_params

vllm/entrypoints/openai/serving_embedding.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from vllm.logger import init_logger
2525
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
2626
PoolingRequestOutput)
27+
from vllm.pooling_params import PoolingParams
2728

2829
logger = init_logger(__name__)
2930

@@ -45,6 +46,7 @@ def _get_embedding(
4546

4647
class EmbeddingMixin(OpenAIServing):
4748

49+
@override
4850
async def _preprocess(
4951
self,
5052
ctx: ServeContext,
@@ -97,6 +99,7 @@ async def _preprocess(
9799
logger.exception("Error in preprocessing prompt inputs")
98100
return self.create_error_response(str(e))
99101

102+
@override
100103
def _build_response(
101104
self,
102105
ctx: ServeContext,
@@ -191,11 +194,20 @@ def _validate_request(
191194

192195
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
193196

194-
pooling_params = ctx.request.to_pooling_params()
197+
return None
198+
199+
@override
200+
def _create_pooling_params(
201+
self,
202+
ctx: ServeContext[EmbeddingRequest],
203+
) -> Union[PoolingParams, ErrorResponse]:
204+
pooling_params = super()._create_pooling_params(ctx)
205+
if isinstance(pooling_params, ErrorResponse):
206+
return pooling_params
195207

196208
try:
197-
pooling_params.verify(self.model_config)
209+
pooling_params.verify("embed", self.model_config)
198210
except ValueError as e:
199211
return self.create_error_response(str(e))
200212

201-
return None
213+
return pooling_params

vllm/entrypoints/openai/serving_engine.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,16 @@ def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
305305
" Please, select a smaller truncation size.")
306306
return None
307307

308+
def _create_pooling_params(
309+
self,
310+
ctx: ServeContext,
311+
) -> Union[PoolingParams, ErrorResponse]:
312+
if not hasattr(ctx.request, "to_pooling_params"):
313+
return self.create_error_response(
314+
"Request type does not support pooling parameters")
315+
316+
return ctx.request.to_pooling_params()
317+
308318
async def _prepare_generators(
309319
self,
310320
ctx: ServeContext,
@@ -318,11 +328,9 @@ async def _prepare_generators(
318328
trace_headers = (None if ctx.raw_request is None else await
319329
self._get_trace_headers(ctx.raw_request.headers))
320330

321-
if not hasattr(ctx.request, "to_pooling_params"):
322-
return self.create_error_response(
323-
"Request type does not support pooling parameters")
324-
325-
pooling_params = ctx.request.to_pooling_params()
331+
pooling_params = self._create_pooling_params(ctx)
332+
if isinstance(pooling_params, ErrorResponse):
333+
return pooling_params
326334

327335
if ctx.engine_prompts is None:
328336
return self.create_error_response(

vllm/entrypoints/openai/serving_pooling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ async def create_pooling(
142142
try:
143143
pooling_params = request.to_pooling_params()
144144

145+
try:
146+
pooling_params.verify("encode", self.model_config)
147+
except ValueError as e:
148+
return self.create_error_response(str(e))
149+
145150
for i, engine_prompt in enumerate(engine_prompts):
146151
request_id_item = f"{request_id}-{i}"
147152

0 commit comments

Comments
 (0)