|
24 | 24 | import weakref
|
25 | 25 | from contextlib import contextmanager, nullcontext
|
26 | 26 | from dataclasses import dataclass
|
27 |
| -from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast, get_args |
| 27 | +from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast |
28 | 28 |
|
29 | 29 | import numpy as np
|
30 | 30 | import numpy.typing as npt
|
|
93 | 93 | from vllm.model_executor.models.interfaces import has_step_pooler
|
94 | 94 | from vllm.v1.utils import bind_kv_cache
|
95 | 95 | else:
|
96 |
| - from vllm.pooling_params import PoolingTask |
97 | 96 | from vllm.v1.worker.utils import bind_kv_cache
|
98 | 97 |
|
99 | 98 | if TYPE_CHECKING:
|
@@ -408,13 +407,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
408 | 407 | generator = None
|
409 | 408 |
|
410 | 409 | if not vllm_version_is("0.9.2") and pooling_params:
|
411 |
| - assert pooling_params.task is not None, ( |
| 410 | + assert (task := pooling_params.task) is not None, ( |
412 | 411 | "You did not set `task` in the API")
|
413 | 412 | model = cast(VllmModelForPooling, self.model)
|
414 |
| - to_update = (model.pooler.get_pooling_updates( |
415 |
| - pooling_params.task)) |
416 |
| - assert to_update is not None, ( |
417 |
| - f"{pooling_params.task=} is not supported by the model") |
| 413 | + to_update = model.pooler.get_pooling_updates(task) |
418 | 414 | to_update.apply(pooling_params)
|
419 | 415 |
|
420 | 416 | self.requests[req_id] = CachedRequestState(
|
@@ -1772,7 +1768,6 @@ def _dummy_pooler_run(
|
1772 | 1768 | dummy_pooling_params = PoolingParams(task=dummy_task)
|
1773 | 1769 |
|
1774 | 1770 | to_update = model.pooler.get_pooling_updates(dummy_task)
|
1775 |
| - assert to_update is not None |
1776 | 1771 | to_update.apply(dummy_pooling_params)
|
1777 | 1772 |
|
1778 | 1773 | dummy_metadata = PoolingMetadata(
|
@@ -2434,7 +2429,4 @@ def get_supported_pooling_tasks(self):
|
2434 | 2429 | if not is_pooling_model(model):
|
2435 | 2430 | return []
|
2436 | 2431 |
|
2437 |
| - return [ |
2438 |
| - task for task in get_args(PoolingTask) |
2439 |
| - if model.pooler.get_pooling_updates(task) |
2440 |
| - ] |
| 2432 | + return list(model.pooler.get_supported_tasks()) |
0 commit comments