Skip to content

Commit 0642536

Browse files
authored
minor fixes to make classifier heads more usable (#10)
Signed-off-by: Rohin Garg <rohin@character.ai>
1 parent dbd9ca6 commit 0642536

File tree

3 files changed

+48
-3
lines changed

3 files changed

+48
-3
lines changed

vllm/entrypoints/openai/protocol.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,6 +1292,47 @@ class PoolingResponse(OpenAIBaseModel):
12921292
usage: UsageInfo
12931293

12941294

1295+
class ClassificationRequest(OpenAIBaseModel):
1296+
model: Optional[str] = None
1297+
input: Union[list[str], str]
1298+
truncate_prompt_tokens: Optional[int] = None
1299+
user: Optional[str] = None
1300+
1301+
# --8<-- [start:classification-pooling-params]
1302+
additional_data: Optional[Any] = None
1303+
# --8<-- [end:classification-pooling-params]
1304+
1305+
# --8<-- [start:classification-extra-params]
1306+
priority: int = Field(
1307+
default=0,
1308+
description=(
1309+
"The priority of the request (lower means earlier handling; "
1310+
"default: 0). Any priority other than 0 will raise an error "
1311+
"if the served model does not use priority scheduling."),
1312+
)
1313+
1314+
# --8<-- [end:classification-extra-params]
1315+
1316+
def to_pooling_params(self):
1317+
return PoolingParams(additional_data=self.additional_data)
1318+
1319+
1320+
class ClassificationData(OpenAIBaseModel):
1321+
index: int
1322+
label: Optional[str]
1323+
probs: list[float]
1324+
num_classes: int
1325+
1326+
1327+
class ClassificationResponse(OpenAIBaseModel):
1328+
id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
1329+
object: str = "list"
1330+
created: int = Field(default_factory=lambda: int(time.time()))
1331+
model: str
1332+
data: list[ClassificationData]
1333+
usage: UsageInfo
1334+
1335+
12951336
class ScoreResponseData(OpenAIBaseModel):
12961337
index: int
12971338
object: str = "score"

vllm/model_executor/model_loader/loader.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,9 +465,13 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
465465
if model_config.quantization is None and loaded_weights is not None:
466466
weights_not_loaded = weights_to_load - loaded_weights
467467
if weights_not_loaded:
468-
raise ValueError(
468+
logger.error(
469469
"Following weights were not initialized from "
470-
f"checkpoint: {weights_not_loaded}")
470+
"checkpoint: %s", weights_not_loaded)
471+
472+
# raise ValueError(
473+
# "Following weights were not initialized from "
474+
# f"checkpoint: {weights_not_loaded}")
471475

472476
_process_weights_after_loading(model, model_config, target_device)
473477

vllm/model_executor/models/transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def create_attention_instances(self) -> dict[int, Attention]:
267267
self.config.global_attention_layers, list):
268268
global_attention_layers = self.config.global_attention_layers
269269
else:
270-
global_attention_layers = None
270+
global_attention_layers = []
271271

272272
for i in range(start, end):
273273
sliding_window = None

0 commit comments

Comments
 (0)