Skip to content

Commit 33e1ea4

Browse files
authored
[CI] Fix broken CI (#1915)
### What this PR does / why we need it? Fix [#21227](vllm-project/vllm#21227) to make ci happy - vLLM version: v0.9.2 - vLLM main: vllm-project/vllm@6b46c4b --------- Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent 7265dc0 commit 33e1ea4

File tree

1 file changed

+4
-12
lines changed

1 file changed

+4
-12
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import weakref
2525
from contextlib import contextmanager, nullcontext
2626
from dataclasses import dataclass
27-
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast, get_args
27+
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
2828

2929
import numpy as np
3030
import numpy.typing as npt
@@ -93,7 +93,6 @@
9393
from vllm.model_executor.models.interfaces import has_step_pooler
9494
from vllm.v1.utils import bind_kv_cache
9595
else:
96-
from vllm.pooling_params import PoolingTask
9796
from vllm.v1.worker.utils import bind_kv_cache
9897

9998
if TYPE_CHECKING:
@@ -408,13 +407,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
408407
generator = None
409408

410409
if not vllm_version_is("0.9.2") and pooling_params:
411-
assert pooling_params.task is not None, (
410+
assert (task := pooling_params.task) is not None, (
412411
"You did not set `task` in the API")
413412
model = cast(VllmModelForPooling, self.model)
414-
to_update = (model.pooler.get_pooling_updates(
415-
pooling_params.task))
416-
assert to_update is not None, (
417-
f"{pooling_params.task=} is not supported by the model")
413+
to_update = model.pooler.get_pooling_updates(task)
418414
to_update.apply(pooling_params)
419415

420416
self.requests[req_id] = CachedRequestState(
@@ -1772,7 +1768,6 @@ def _dummy_pooler_run(
17721768
dummy_pooling_params = PoolingParams(task=dummy_task)
17731769

17741770
to_update = model.pooler.get_pooling_updates(dummy_task)
1775-
assert to_update is not None
17761771
to_update.apply(dummy_pooling_params)
17771772

17781773
dummy_metadata = PoolingMetadata(
@@ -2434,7 +2429,4 @@ def get_supported_pooling_tasks(self):
24342429
if not is_pooling_model(model):
24352430
return []
24362431

2437-
return [
2438-
task for task in get_args(PoolingTask)
2439-
if model.pooler.get_pooling_updates(task)
2440-
]
2432+
return list(model.pooler.get_supported_tasks())

0 commit comments

Comments
 (0)