Skip to content

Commit e1d13fc

Browse files
authored
[0.7.3] patch from_seq_group to clear finished seq in seq_id_to_seq_group (#691)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? Fix CPU memory leak issue as is stated in vllm-project/vllm#16472 <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? No <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? With CI <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: Shuqiao Li <celestialli@outlook.com>
1 parent 15314cc commit e1d13fc

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed

vllm_ascend/patch_outputs.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import Dict, Optional
19+
20+
from vllm.outputs import CompletionOutput, RequestOutput
21+
from vllm.sampling_params import RequestOutputKind
22+
from vllm.sequence import SequenceGroup, SequenceGroupBase, SequenceStatus
23+
24+
25+
@classmethod # type: ignore
26+
def from_seq_group(
27+
cls, seq_group: SequenceGroup, use_cache: bool,
28+
seq_id_to_seq_group: Dict[str, SequenceGroupBase]
29+
) -> Optional["RequestOutput"]:
30+
finished = seq_group.is_finished()
31+
32+
if seq_group.request_id in seq_id_to_seq_group:
33+
group: SequenceGroupBase = seq_id_to_seq_group[seq_group.request_id]
34+
assembled_seq_group = group.maybe_assemble_group(seq_group)
35+
if finished:
36+
group.finish_seq(seq_group)
37+
if assembled_seq_group is None:
38+
return None
39+
40+
# clear finished seq in seq_id_to_seq_group
41+
if len(group.to_be_finished) == 0:
42+
for sub_request_id in list(group.seq_id_to_index.keys()):
43+
if sub_request_id in seq_id_to_seq_group:
44+
del seq_id_to_seq_group[sub_request_id]
45+
return cls.from_seq_group(assembled_seq_group, use_cache,
46+
seq_id_to_seq_group)
47+
48+
sampling_params = seq_group.sampling_params
49+
if sampling_params is None:
50+
raise ValueError(
51+
"Sampling parameters are missing for a CompletionRequest.")
52+
53+
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
54+
not finished):
55+
return None
56+
57+
# Init cache (if needed)
58+
if use_cache and seq_group.cached_request_output is None:
59+
seq_group.cached_request_output = RequestOutput( # type: ignore
60+
request_id="",
61+
prompt=None,
62+
prompt_token_ids=[],
63+
prompt_logprobs=None,
64+
outputs=[],
65+
finished=False)
66+
67+
top_n_seqs = seq_group.get_seqs()
68+
69+
# Create the outputs.
70+
# NOTE: We need omit logprobs here explicitly because the sequence
71+
# always has the logprobs of the sampled tokens even if the
72+
# logprobs are not requested.
73+
include_logprobs = sampling_params.logprobs is not None
74+
text_buffer_length = sampling_params.output_text_buffer_length
75+
delta = sampling_params.output_kind == RequestOutputKind.DELTA
76+
77+
outputs = []
78+
include_prompt = True
79+
# num_cached_tokens should be the same for all the sequences
80+
num_cached_tokens = None
81+
for i, seq in enumerate(top_n_seqs):
82+
output_text = seq.get_output_text_to_return(text_buffer_length, delta)
83+
84+
output_token_ids = seq.get_output_token_ids_to_return(delta)
85+
num_output_tokens = 1 if isinstance(output_token_ids,
86+
int) else len(output_token_ids)
87+
num_cached_tokens = seq.data.get_num_cached_tokens() # noqa
88+
89+
output_logprobs = seq.output_logprobs if include_logprobs else None
90+
91+
if delta:
92+
# Slice logprobs delta if applicable
93+
if output_logprobs:
94+
output_logprobs = output_logprobs[-num_output_tokens:]
95+
# Don't include prompt if this is after the first output
96+
# containing decode token ids
97+
if include_prompt and seq.get_output_len() > num_output_tokens:
98+
include_prompt = False
99+
100+
if use_cache:
101+
# Get cached output object
102+
cached_outputs = seq_group.cached_request_output.outputs # type: ignore
103+
if i >= len(cached_outputs):
104+
cached_outputs.append(
105+
CompletionOutput(index=i,
106+
text="",
107+
token_ids=[],
108+
cumulative_logprob=None,
109+
logprobs=None,
110+
finish_reason=None,
111+
stop_reason=None))
112+
output = cached_outputs[i]
113+
114+
# Init cached output object
115+
assert output.index == i
116+
output.text = output_text
117+
118+
if isinstance(output_token_ids, int):
119+
output.token_ids.clear()
120+
output.token_ids.append(output_token_ids)
121+
else:
122+
output.token_ids = output_token_ids
123+
124+
output.cumulative_logprob = seq.get_cumulative_logprob() \
125+
if include_logprobs else None
126+
output.logprobs = output_logprobs
127+
output.finish_reason = SequenceStatus.get_finished_reason(
128+
seq.status)
129+
output.stop_reason = seq.stop_reason
130+
131+
else:
132+
output = CompletionOutput(
133+
top_n_seqs.index(seq), output_text, [output_token_ids]
134+
if isinstance(output_token_ids, int) else output_token_ids,
135+
seq.get_cumulative_logprob() if include_logprobs else None,
136+
output_logprobs,
137+
SequenceStatus.get_finished_reason(seq.status),
138+
seq.stop_reason)
139+
140+
outputs.append(output)
141+
142+
return None
143+
144+
145+
# Add code to clear finished seq in seq_id_to_seq_group
146+
RequestOutput.from_seq_group = from_seq_group

vllm_ascend/platform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def pre_register_and_update(cls,
6464
parser: Optional[FlexibleArgumentParser] = None
6565
) -> None:
6666
import vllm_ascend.patch_config # noqa: F401
67+
import vllm_ascend.patch_outputs # noqa: F401
6768
from vllm_ascend.quantization.quant_config import \
6869
AscendQuantConfig # noqa: F401
6970

0 commit comments

Comments
 (0)