Skip to content

Commit 2e3520e

Browse files
authored
[Bugfix] Fix output tensor shape in vanilla_chunked_prefill and update import paths for model_loader (#773)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> Fix output tensor shape in vanilla_chunked_prefill function. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> None. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Run offline inference on DeepSeek models. --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent ec27af3 commit 2e3520e

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

tests/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,19 @@
4242
init_distributed_environment)
4343
from vllm.engine.arg_utils import AsyncEngineArgs
4444
from vllm.entrypoints.openai.cli_args import make_arg_parser
45-
from vllm.model_executor.model_loader.loader import get_model_loader
4645
from vllm.platforms import current_platform
4746
from vllm.transformers_utils.tokenizer import get_tokenizer
4847
from vllm.utils import FlexibleArgumentParser, GB_bytes, get_open_port
4948

49+
from vllm_ascend.utils import vllm_version_is
50+
5051
from .model_utils import TextTextLogprobs
5152

53+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
54+
from vllm.model_executor.model_loader.loader import get_model_loader # type: ignore[import] # isort: skip
55+
else:
56+
from vllm.model_executor.model_loader import get_model_loader
57+
5258
VLLM_PATH = Path(__file__).parent.parent
5359
"""Path to root of the vLLM repository."""
5460

vllm_ascend/ops/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def vanilla_chunked_prefill(
131131

132132
attn_output = (attn_output[q_mask].view([-1, num_query_heads,
133133
head_dim]).to(output.dtype))
134+
output = output.view_as(attn_output)
134135
output.copy_(attn_output)
135136
return attn_output
136137

vllm_ascend/worker/model_runner.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@
6464
_init_attn_metadata_from_tensor_dict,
6565
_init_sampling_metadata_from_tensor_dict)
6666

67+
from vllm_ascend.utils import vllm_version_is
68+
6769
if TYPE_CHECKING:
6870
from vllm.attention.backends.abstract import AttentionBackend
6971

@@ -1007,7 +1009,10 @@ def save_sharded_state(
10071009
pattern: Optional[str] = None,
10081010
max_size: Optional[int] = None,
10091011
) -> None:
1010-
from vllm.model_executor.model_loader.loader import ShardedStateLoader
1012+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
1013+
from vllm.model_executor.model_loader.loader import ShardedStateLoader # type: ignore[import] # isort: skip # noqa
1014+
else:
1015+
from vllm.model_executor.model_loader import ShardedStateLoader
10111016
ShardedStateLoader.save_model(
10121017
self.model,
10131018
path,
@@ -1019,7 +1024,12 @@ def save_tensorized_model(
10191024
self,
10201025
tensorizer_config: TensorizerConfig,
10211026
) -> None:
1022-
from vllm.model_executor.model_loader.loader import TensorizerLoader
1027+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
1028+
from vllm.model_executor.model_loader.loader import \
1029+
TensorizerLoader # type: ignore # noqa
1030+
else:
1031+
from vllm.model_executor.model_loader import \
1032+
TensorizerLoader # type: ignore # noqa
10231033
TensorizerLoader.save_model(
10241034
self.model,
10251035
tensorizer_config=tensorizer_config,

0 commit comments

Comments
 (0)