Skip to content

Commit ca9e6f5

Browse files
review comments
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 948a555 commit ca9e6f5

File tree

2 files changed

+31
-83
lines changed

2 files changed

+31
-83
lines changed

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 25 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88

99
from vllm.attention.backends.abstract import AttentionBackend
1010
from vllm.config import VllmConfig
11-
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
12-
CommonAttentionMetadata)
11+
from vllm.v1.attention.backends.utils import (
12+
AttentionMetadataBuilder, CommonAttentionMetadata,
13+
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
1314
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
1415

1516
if TYPE_CHECKING:
@@ -96,65 +97,9 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
9697

9798
def reorder_batch(self, input_batch: "InputBatch",
9899
scheduler_output: "SchedulerOutput") -> bool:
99-
# NOTE (Chen): Copied from MLACommonMetadataBuilder and
100-
# FlashInferMetadataBuilder. Should be refactored later to avoid code
101-
# duplication of these 3 functions.
102-
# We now want to reorder the batch so that the "decode" requests are and
103-
# the front and the "prefill" requests are at the using the least amount
104-
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
105-
# where attention is likely memory-bound and "prefill" to mean requests
106-
# where attention is likely compute-bound, TODO(lucas): figure out a
107-
# better naming here)
108-
decodes = []
109-
prefills = []
110-
num_decode_tokens = 0
111-
num_prefill_tokens = 0
112-
113-
for i, req_id in enumerate(input_batch.req_ids):
114-
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
115-
# for now treat 1 scheduled token as "decode" even if its not,
116-
# we should update this to something like < 8 in the future but
117-
# currently the decode run only supports num_tokens = 1
118-
if num_tokens == 1:
119-
decodes.append(i)
120-
num_decode_tokens += num_tokens
121-
else:
122-
prefills.append(i)
123-
num_prefill_tokens += num_tokens
124-
125-
# We hope that this is fairly minimal since decodes
126-
# should be around for a number of iterations so hopefully they are
127-
# relatively stationary (and new request are generally appended to the
128-
# persistent batch so already should be at the back)
129-
# To achieve this we loop over the decodes in descending order and
130-
# the prefills in ascending order. We swap decodes from the "back"
131-
# i.e. past where the last decode should be in the reodorered with
132-
# prefills from the front of the batch.
133-
# `decodes` and `prefills` are already in ascending order just based on
134-
# the above loop
135-
num_decodes = len(decodes)
136-
num_prefills = len(prefills)
137-
modified_batch = False
138-
139-
for i in range(1, min(num_decodes, num_prefills) + 1):
140-
# If the decode is at the "back" of the batch, i, we can swap it
141-
# with the prefill closest to the front of the batch
142-
decode_idx = decodes[num_decodes - i]
143-
if decode_idx < num_decodes:
144-
break
145-
146-
input_batch.swap_states(prefills[i - 1], decode_idx)
147-
modified_batch = True
148-
149-
# Save for next `build` call
150-
# TODO(lucas): this is a bit of a hack, we should probably have a
151-
# better way of doing this
152-
self._num_decodes = num_decodes
153-
self._num_prefills = num_prefills
154-
self._num_decode_tokens = num_decode_tokens
155-
self._num_prefill_tokens = num_prefill_tokens
156-
157-
return modified_batch
100+
return reorder_batch_to_split_decodes_and_prefills(input_batch,
101+
scheduler_output,
102+
decode_threshold=1)
158103

159104
def build(self,
160105
common_prefix_len: int,
@@ -173,26 +118,29 @@ def build(self,
173118

174119
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
175120

121+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
122+
split_decodes_and_prefills(common_attn_metadata,
123+
decode_threshold=1))
124+
176125
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
177-
if self._num_prefills > 0:
126+
if num_prefills > 0:
178127
#[batch,]
179128
has_initial_states_cpu = (
180129
common_attn_metadata.
181-
num_computed_tokens_cpu[num_reqs - self._num_prefills:num_reqs]
182-
> 0)
130+
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
183131
prep_initial_states = torch.any(has_initial_states_cpu).item()
184132
has_initial_states = has_initial_states_cpu.to(
185133
query_start_loc.device)
186134

187135
query_start_loc_p = common_attn_metadata.query_start_loc[
188-
-self._num_prefills - 1:] - self._num_decode_tokens
189-
190-
seq_idx = torch.repeat_interleave(
191-
torch.arange(self._num_prefills,
192-
dtype=torch.int32,
193-
device=query_start_loc_p.device),
194-
query_start_loc_p.diff(),
195-
output_size=self._num_prefill_tokens)
136+
-num_prefills - 1:] - num_decode_tokens
137+
138+
seq_idx = torch.repeat_interleave(torch.arange(
139+
num_prefills,
140+
dtype=torch.int32,
141+
device=query_start_loc_p.device),
142+
query_start_loc_p.diff(),
143+
output_size=num_prefill_tokens)
196144
seq_idx.unsqueeze_(0)
197145

198146
# We compute metadata for chunked prefill once at the top level
@@ -202,13 +150,13 @@ def build(self,
202150
chunk_indices, chunk_offsets = (
203151
_query_start_loc_to_chunk_indices_offsets(
204152
query_start_loc_p, self.chunk_size,
205-
self._num_prefill_tokens))
153+
num_prefill_tokens))
206154

207155
attn_metadata = Mamba2AttentionMetadata(
208-
num_prefills=self._num_prefills,
209-
num_prefill_tokens=self._num_prefill_tokens,
210-
num_decodes=self._num_decodes,
211-
num_decode_tokens=self._num_decode_tokens,
156+
num_prefills=num_prefills,
157+
num_prefill_tokens=num_prefill_tokens,
158+
num_decodes=num_decodes,
159+
num_decode_tokens=num_decode_tokens,
212160
query_start_loc=query_start_loc,
213161
seq_lens=seq_lens,
214162
has_initial_states=has_initial_states,

vllm/v1/attention/backends/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -438,12 +438,12 @@ def reorder_batch_to_split_decodes_and_prefills(
438438
Returns:
439439
True if the batch was modified, False otherwise.
440440
"""
441-
# We now want to reorder the batch so that the "decode" requests are and
442-
# the front and the "prefill" requests are at the using the least amount
443-
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
444-
# where attention is likely memory-bound and "prefill" to mean requests
445-
# where attention is likely compute-bound, TODO(lucas): figure out a
446-
# better naming here)
441+
# We now want to reorder the batch so that the "decode" requests are at
442+
# the front and the "prefill" requests are at the back using the least
443+
# amount of swaps possible. (NOTE for now we loosely use "decode" to mean
444+
# requests where attention is likely memory-bound and "prefill" to mean
445+
# requests where attention is likely compute-bound, TODO(lucas): figure out
446+
# a better naming here)
447447
decodes = []
448448
prefills = []
449449
num_decode_tokens = 0

0 commit comments

Comments
 (0)