|
19 | 19 | from vllm.logger import init_logger
|
20 | 20 | from vllm.platforms import current_platform
|
21 | 21 | from vllm.utils import cdiv
|
22 |
| -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, |
23 |
| - CommonAttentionMetadata, |
24 |
| - get_kv_cache_layout) |
| 22 | +from vllm.v1.attention.backends.utils import ( |
| 23 | + AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, |
| 24 | + make_local_attention_virtual_batches) |
25 | 25 | from vllm.v1.kv_cache_interface import AttentionSpec
|
26 | 26 | from vllm.v1.worker.block_table import BlockTable
|
27 | 27 |
|
@@ -126,172 +126,6 @@ class LocalAttentionMetadata:
|
126 | 126 | local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
127 | 127 |
|
128 | 128 |
|
129 |
| -# |
130 |
| -# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into |
131 |
| -# local attention blocks, where each block is passed to the attention kernel |
132 |
| -# as an independent local ("virtual") batch item. |
133 |
| -# |
134 |
| -# For example, if are performing a chunked prefill a batch of 3 sequences: |
135 |
| -# q_seqlens = [4, 10, 5] |
136 |
| -# kv_seqlens = [6, 17, 9] |
137 |
| -# Then normally for regular attention we would compute with an attention mask |
138 |
| -# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like: |
139 |
| -# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6) |
140 |
| -# k_toks > 0 1 2 3 4 5 |
141 |
| -# q_toks v _____________ |
142 |
| -# 0 | 1 1 1 |
143 |
| -# 1 | 1 1 1 1 |
144 |
| -# 2 | 1 1 1 1 1 |
145 |
| -# 3 | 1 1 1 1 1 1 |
146 |
| -# |
147 |
| -# for local attention (with attn_chunk_size = 4) we would compute with an |
148 |
| -# attention mask like: |
149 |
| -# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4) |
150 |
| -# k_toks > 0 1 2 3 4 5 |
151 |
| -# q_toks v _____________ |
152 |
| -# 0 | 1 1 1 |
153 |
| -# 1 | 1 1 1 1 |
154 |
| -# 2 | 1 |
155 |
| -# 3 | 1 1 |
156 |
| -# |
157 |
| -# We can simulate this mask using standard flash-attention by breaking the |
158 |
| -# sequences into local ("virtual") batches, where each local batch item is a |
159 |
| -# local attention block, so in this case batch idx 0 would be broken up into: |
160 |
| -# |
161 |
| -# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0) |
162 |
| -# k_toks > 0 1 2 3 |
163 |
| -# q_toks v _____________ |
164 |
| -# 0 | 1 1 1 |
165 |
| -# 1 | 1 1 1 1 |
166 |
| -# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0) |
167 |
| -# k_toks > 4 5 |
168 |
| -# q_toks v _____________ |
169 |
| -# 2 | 1 |
170 |
| -# 3 | 1 1 |
171 |
| -# |
172 |
| -# e.g. if we have: |
173 |
| -# attn_chunk_size = 4 |
174 |
| -# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5]) |
175 |
| -# Then this function would return: |
176 |
| -# __b0__ ______b1______ __b2__ < orig batch indices |
177 |
| -# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1] |
178 |
| -# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24] |
179 |
| -# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1] |
180 |
| -# block_table_local : shape[local_virtual_batches, pages_per_local_batch] |
181 |
| -def make_local_attention_virtual_batches( |
182 |
| - attn_chunk_size: int, |
183 |
| - query_start_loc_np: np.ndarray, |
184 |
| - seq_lens_np: np.ndarray, |
185 |
| - block_table: torch.Tensor, |
186 |
| - block_size: int = 0, |
187 |
| -) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: |
188 |
| - q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] |
189 |
| - actual_batch_size = seq_lens_np.shape[0] |
190 |
| - |
191 |
| - # Handle if we are starting in the middle of a local attention block, |
192 |
| - # we assume q_seqlens > 0 (for all elements), for each batch idx we compute |
193 |
| - # the number of tokens that are not in the first local attention block and |
194 |
| - # then we can simply use a cdiv for the rest. |
195 |
| - # For example if we have: |
196 |
| - # attn_chunk_size = 4 |
197 |
| - # q_seqlens = [4, 10, 5] |
198 |
| - # k_seqlens = [6, 17, 9] |
199 |
| - # Then we would get: |
200 |
| - # new_tokens_in_first_block = [2, 1, 4] |
201 |
| - # local_blocks = [2, 4, 2] |
202 |
| - q_tokens_in_first_block = np.minimum( |
203 |
| - attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), |
204 |
| - q_seqlens).astype(np.int32) |
205 |
| - tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) |
206 |
| - local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, |
207 |
| - attn_chunk_size) |
208 |
| - |
209 |
| - # Once we know the number of local blocks we can compute the request spans |
210 |
| - # for each batch idx, we can figure out the number of "virtual" requests we |
211 |
| - # have to make, |
212 |
| - # For the above example we would get: |
213 |
| - # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1] |
214 |
| - # |
215 |
| - # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1]) |
216 |
| - # (TODO: max a utility to share this code with _prepare_inputs) |
217 |
| - # arange step 1. [2, 4, 2] -> [2, 6, 8] |
218 |
| - cu_num_blocks = np.cumsum(local_blocks) |
219 |
| - virtual_batches = cu_num_blocks[-1] |
220 |
| - # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6] |
221 |
| - block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks) |
222 |
| - # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1] |
223 |
| - arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets |
224 |
| - # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0]) |
225 |
| - rarange = np.repeat(local_blocks, local_blocks) - arange - 1 |
226 |
| - # Then we can compute the seqlens_q_local, handling the fact that the |
227 |
| - # first and last blocks could be partial |
228 |
| - seqlens_q_local = \ |
229 |
| - np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) |
230 |
| - # set the first block since this may be a partial block |
231 |
| - seqlens_q_local[arange == 0] = q_tokens_in_first_block |
232 |
| - # set the remaining blocks |
233 |
| - seqlens_q_local[arange > 0] = np.minimum( |
234 |
| - seqlens_q_local - attn_chunk_size * (arange - 1), |
235 |
| - attn_chunk_size)[arange > 0] |
236 |
| - |
237 |
| - # convert from q_seqlens to cu_seqlens_q |
238 |
| - cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\ |
239 |
| - .astype(np.int32) |
240 |
| - |
241 |
| - # compute the seqlens_k_local, |
242 |
| - # basically a full local attention block for all but the last block in each |
243 |
| - # batch |
244 |
| - # For our example this will be: |
245 |
| - # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] |
246 |
| - seqlens_k_local = np.full(cu_num_blocks[-1], |
247 |
| - attn_chunk_size, |
248 |
| - dtype=np.int32) |
249 |
| - seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block |
250 |
| - |
251 |
| - k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \ |
252 |
| - (rarange * attn_chunk_size + \ |
253 |
| - np.repeat(tokens_in_last_block, local_blocks)) |
254 |
| - # For the example the local attention blocks start at: |
255 |
| - # _b0_ _____b1_____ _b2_ |
256 |
| - # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] |
257 |
| - block_starts = k_seqstarts_absolute // block_size |
258 |
| - assert attn_chunk_size % block_size == 0, \ |
259 |
| - f"attn_chunk_size {attn_chunk_size} is not " \ |
260 |
| - f"divisible by block_size {block_size}" |
261 |
| - pages_per_local_batch = attn_chunk_size // block_size |
262 |
| - |
263 |
| - # Create a block_table for the local attention blocks |
264 |
| - # For out example if we have a block-table like (assuming block_size=2): |
265 |
| - # block_table = [ |
266 |
| - # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 |
267 |
| - # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 |
268 |
| - # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2 |
269 |
| - # ] |
270 |
| - # Then for the local batches we would want a block-table like |
271 |
| - # block_table_local = [ |
272 |
| - # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0]) |
273 |
| - # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4]) |
274 |
| - # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4]) |
275 |
| - # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8]) |
276 |
| - # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12]) |
277 |
| - # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16]) |
278 |
| - # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) |
279 |
| - # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) |
280 |
| - # ] |
281 |
| - block_indices= np.broadcast_to( |
282 |
| - np.arange(pages_per_local_batch, dtype=np.int32), |
283 |
| - (virtual_batches, pages_per_local_batch)) \ |
284 |
| - + np.expand_dims(block_starts, axis=1) |
285 |
| - block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1) |
286 |
| - batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), |
287 |
| - local_blocks * pages_per_local_batch) |
288 |
| - block_table_local = block_table[batch_indices, block_indices]\ |
289 |
| - .view(virtual_batches, -1) |
290 |
| - |
291 |
| - return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \ |
292 |
| - block_table_local |
293 |
| - |
294 |
| - |
295 | 129 | def _get_sliding_window_configs(
|
296 | 130 | vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
|
297 | 131 | """Get the set of all sliding window configs used in the model."""
|
|
0 commit comments