8
8
9
9
from vllm .attention .backends .abstract import AttentionBackend
10
10
from vllm .config import VllmConfig
11
- from vllm .v1 .attention .backends .utils import (AttentionMetadataBuilder ,
12
- CommonAttentionMetadata )
11
+ from vllm .v1 .attention .backends .utils import (
12
+ AttentionMetadataBuilder , CommonAttentionMetadata ,
13
+ reorder_batch_to_split_decodes_and_prefills , split_decodes_and_prefills )
13
14
from vllm .v1 .kv_cache_interface import AttentionSpec , MambaSpec
14
15
15
16
if TYPE_CHECKING :
@@ -96,65 +97,9 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
96
97
97
98
def reorder_batch (self , input_batch : "InputBatch" ,
98
99
scheduler_output : "SchedulerOutput" ) -> bool :
99
- # NOTE (Chen): Copied from MLACommonMetadataBuilder and
100
- # FlashInferMetadataBuilder. Should be refactored later to avoid code
101
- # duplication of these 3 functions.
102
- # We now want to reorder the batch so that the "decode" requests are and
103
- # the front and the "prefill" requests are at the using the least amount
104
- # swaps possible. (NOTE for now we loosely use "decode" to mean requests
105
- # where attention is likely memory-bound and "prefill" to mean requests
106
- # where attention is likely compute-bound, TODO(lucas): figure out a
107
- # better naming here)
108
- decodes = []
109
- prefills = []
110
- num_decode_tokens = 0
111
- num_prefill_tokens = 0
112
-
113
- for i , req_id in enumerate (input_batch .req_ids ):
114
- num_tokens = scheduler_output .num_scheduled_tokens [req_id ]
115
- # for now treat 1 scheduled token as "decode" even if its not,
116
- # we should update this to something like < 8 in the future but
117
- # currently the decode run only supports num_tokens = 1
118
- if num_tokens == 1 :
119
- decodes .append (i )
120
- num_decode_tokens += num_tokens
121
- else :
122
- prefills .append (i )
123
- num_prefill_tokens += num_tokens
124
-
125
- # We hope that this is fairly minimal since decodes
126
- # should be around for a number of iterations so hopefully they are
127
- # relatively stationary (and new request are generally appended to the
128
- # persistent batch so already should be at the back)
129
- # To achieve this we loop over the decodes in descending order and
130
- # the prefills in ascending order. We swap decodes from the "back"
131
- # i.e. past where the last decode should be in the reodorered with
132
- # prefills from the front of the batch.
133
- # `decodes` and `prefills` are already in ascending order just based on
134
- # the above loop
135
- num_decodes = len (decodes )
136
- num_prefills = len (prefills )
137
- modified_batch = False
138
-
139
- for i in range (1 , min (num_decodes , num_prefills ) + 1 ):
140
- # If the decode is at the "back" of the batch, i, we can swap it
141
- # with the prefill closest to the front of the batch
142
- decode_idx = decodes [num_decodes - i ]
143
- if decode_idx < num_decodes :
144
- break
145
-
146
- input_batch .swap_states (prefills [i - 1 ], decode_idx )
147
- modified_batch = True
148
-
149
- # Save for next `build` call
150
- # TODO(lucas): this is a bit of a hack, we should probably have a
151
- # better way of doing this
152
- self ._num_decodes = num_decodes
153
- self ._num_prefills = num_prefills
154
- self ._num_decode_tokens = num_decode_tokens
155
- self ._num_prefill_tokens = num_prefill_tokens
156
-
157
- return modified_batch
100
+ return reorder_batch_to_split_decodes_and_prefills (input_batch ,
101
+ scheduler_output ,
102
+ decode_threshold = 1 )
158
103
159
104
def build (self ,
160
105
common_prefix_len : int ,
@@ -173,26 +118,29 @@ def build(self,
173
118
174
119
state_indices_tensor = common_attn_metadata .block_table_tensor [:, 0 ]
175
120
121
+ num_decodes , num_prefills , num_decode_tokens , num_prefill_tokens = (
122
+ split_decodes_and_prefills (common_attn_metadata ,
123
+ decode_threshold = 1 ))
124
+
176
125
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
177
- if self . _num_prefills > 0 :
126
+ if num_prefills > 0 :
178
127
#[batch,]
179
128
has_initial_states_cpu = (
180
129
common_attn_metadata .
181
- num_computed_tokens_cpu [num_reqs - self ._num_prefills :num_reqs ]
182
- > 0 )
130
+ num_computed_tokens_cpu [num_reqs - num_prefills :num_reqs ] > 0 )
183
131
prep_initial_states = torch .any (has_initial_states_cpu ).item ()
184
132
has_initial_states = has_initial_states_cpu .to (
185
133
query_start_loc .device )
186
134
187
135
query_start_loc_p = common_attn_metadata .query_start_loc [
188
- - self . _num_prefills - 1 :] - self . _num_decode_tokens
189
-
190
- seq_idx = torch .repeat_interleave (
191
- torch . arange ( self . _num_prefills ,
192
- dtype = torch .int32 ,
193
- device = query_start_loc_p .device ),
194
- query_start_loc_p .diff (),
195
- output_size = self . _num_prefill_tokens )
136
+ - num_prefills - 1 :] - num_decode_tokens
137
+
138
+ seq_idx = torch .repeat_interleave (torch . arange (
139
+ num_prefills ,
140
+ dtype = torch .int32 ,
141
+ device = query_start_loc_p .device ),
142
+ query_start_loc_p .diff (),
143
+ output_size = num_prefill_tokens )
196
144
seq_idx .unsqueeze_ (0 )
197
145
198
146
# We compute metadata for chunked prefill once at the top level
@@ -202,13 +150,13 @@ def build(self,
202
150
chunk_indices , chunk_offsets = (
203
151
_query_start_loc_to_chunk_indices_offsets (
204
152
query_start_loc_p , self .chunk_size ,
205
- self . _num_prefill_tokens ))
153
+ num_prefill_tokens ))
206
154
207
155
attn_metadata = Mamba2AttentionMetadata (
208
- num_prefills = self . _num_prefills ,
209
- num_prefill_tokens = self . _num_prefill_tokens ,
210
- num_decodes = self . _num_decodes ,
211
- num_decode_tokens = self . _num_decode_tokens ,
156
+ num_prefills = num_prefills ,
157
+ num_prefill_tokens = num_prefill_tokens ,
158
+ num_decodes = num_decodes ,
159
+ num_decode_tokens = num_decode_tokens ,
212
160
query_start_loc = query_start_loc ,
213
161
seq_lens = seq_lens ,
214
162
has_initial_states = has_initial_states ,
0 commit comments