@@ -76,7 +76,6 @@ def get_kv_cache_shape(
76
76
if block_size % 16 != 0 :
77
77
raise ValueError ("Block size must be a multiple of 16." )
78
78
return (2 , num_blocks , block_size , num_kv_heads , head_size )
79
- # return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size)
80
79
81
80
@staticmethod
82
81
def swap_blocks (
@@ -186,9 +185,6 @@ class FlashAttentionMetadata(AttentionMetadata):
186
185
cross_slot_mapping : Optional [torch .Tensor ] = None
187
186
cross_block_tables : Optional [torch .Tensor ] = None
188
187
189
- # Cross-layer shared attention block tables
190
- cross_layer_shared_block_tables : Optional [torch .Tensor ] = None
191
-
192
188
@property
193
189
def is_all_encoder_attn_metadata_set (self ):
194
190
'''
@@ -233,9 +229,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
233
229
self .context_lens_tensor [:self .num_prefills ])
234
230
block_tables = (None if self .block_tables is None else
235
231
self .block_tables [:self .num_prefills ])
236
- cross_layer_shared_block_tables = (None if self .cross_layer_shared_block_tables is None else
237
- self .cross_layer_shared_block_tables [:self .num_prefills ])
238
-
232
+
239
233
self ._cached_prefill_metadata = FlashAttentionMetadata (
240
234
num_prefills = self .num_prefills ,
241
235
num_prefill_tokens = self .num_prefill_tokens ,
@@ -254,7 +248,6 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
254
248
seq_start_loc = seq_start_loc ,
255
249
context_lens_tensor = context_lens_tensor ,
256
250
block_tables = block_tables ,
257
- cross_layer_shared_block_tables = cross_layer_shared_block_tables ,
258
251
use_cuda_graph = False ,
259
252
# Begin encoder & cross attn fields below...
260
253
encoder_seq_lens = self .encoder_seq_lens ,
@@ -282,8 +275,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
282
275
self .seq_lens_tensor [self .num_prefills :])
283
276
block_tables = (None if self .block_tables is None else
284
277
self .block_tables [self .num_prefills :])
285
- cross_layer_shared_block_tables = (None if self .cross_layer_shared_block_tables is None else
286
- self .cross_layer_shared_block_tables [self .num_prefills :])
278
+
287
279
self ._cached_decode_metadata = FlashAttentionMetadata (
288
280
num_prefills = 0 ,
289
281
num_prefill_tokens = 0 ,
@@ -307,7 +299,6 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
307
299
if self .seq_start_loc is not None else None ,
308
300
context_lens_tensor = None ,
309
301
block_tables = block_tables ,
310
- cross_layer_shared_block_tables = cross_layer_shared_block_tables ,
311
302
use_cuda_graph = self .use_cuda_graph ,
312
303
# Begin encoder & cross attn fields below...
313
304
encoder_seq_lens = self .encoder_seq_lens ,
@@ -406,7 +397,6 @@ def prepare(self):
406
397
self .prefill_seq_lens : List [int ] = []
407
398
self .context_lens : List [int ] = []
408
399
self .block_tables : List [List [int ]] = []
409
- self .cross_layer_shared_block_tables : List [List [int ]] = []
410
400
self .curr_seq_lens : List [int ] = []
411
401
self .multimodal_placeholder_maps : Dict [
412
402
str ,
@@ -467,17 +457,6 @@ def _add_seq_group(
467
457
- curr_sliding_window_block :]
468
458
self .block_tables .append (block_table )
469
459
470
- cross_layer_shared_block_table = []
471
- if prefix_cache_hit :
472
- cross_layer_shared_block_table = block_tables [seq_id ]
473
- elif block_tables is not None :
474
- if curr_sliding_window_block == 0 :
475
- cross_layer_shared_block_table = block_tables [seq_id ]
476
- else :
477
- cross_layer_shared_block_table = block_tables [seq_id ][
478
- - curr_sliding_window_block :]
479
- self .cross_layer_shared_block_tables .append (cross_layer_shared_block_table )
480
-
481
460
# Compute slot mapping.
482
461
is_profile_run = is_block_tables_empty (block_tables )
483
462
start_idx = compute_slot_mapping_start_idx (is_prompt , query_len ,
@@ -489,16 +468,13 @@ def _add_seq_group(
489
468
490
469
def _get_graph_runner_block_tables (
491
470
self , num_seqs : int ,
492
- block_tables : List [List [int ]],
493
- graph_block_tables ) -> torch .Tensor :
471
+ block_tables : List [List [int ]]) -> torch .Tensor :
494
472
# The shape of graph_block_tables is
495
473
# [max batch size, max context len // block size].
496
- # max_batch_size, max_blocks = self.runner.graph_block_tables.shape
497
- max_batch_size , max_blocks = graph_block_tables .shape
474
+ max_batch_size , max_blocks = self .runner .graph_block_tables .shape
498
475
assert max_batch_size >= num_seqs
499
476
500
- # graph_block_tables = self.runner.graph_block_tables[:num_seqs]
501
- graph_block_tables = graph_block_tables [:num_seqs ]
477
+ graph_block_tables = self .runner .graph_block_tables [:num_seqs ]
502
478
for i , block_table in enumerate (block_tables ):
503
479
if block_table :
504
480
num_blocks = len (block_table )
@@ -553,27 +529,16 @@ def build(self, seq_lens: List[int], query_lens: List[int],
553
529
if use_captured_graph :
554
530
self .slot_mapping .extend ([PAD_SLOT_ID ] * cuda_graph_pad_size )
555
531
self .block_tables .extend ([] * cuda_graph_pad_size )
556
-
557
- self .cross_layer_shared_block_tables .extend ([] * cuda_graph_pad_size )
558
-
559
532
num_decode_tokens = batch_size - self .num_prefill_tokens
560
533
block_tables = self ._get_graph_runner_block_tables (
561
- num_seqs , self .block_tables , self .runner .graph_block_tables )
562
- cross_layer_shared_block_tables = self ._get_graph_runner_block_tables (
563
- num_seqs , self .cross_layer_shared_block_tables , self .runner .cross_layer_shared_graph_block_tables )
534
+ num_seqs , self .block_tables )
564
535
else :
565
536
block_tables = make_tensor_with_pad (
566
537
self .block_tables ,
567
538
pad = 0 ,
568
539
dtype = torch .int ,
569
540
device = device ,
570
541
)
571
- cross_layer_shared_block_tables = make_tensor_with_pad (
572
- self .cross_layer_shared_block_tables ,
573
- pad = 0 ,
574
- dtype = torch .int ,
575
- device = device ,
576
- )
577
542
assert max_query_len > 0 , ("query_lens: {}" .format (query_lens ))
578
543
579
544
assert device is not None
@@ -611,7 +576,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
611
576
seq_start_loc = seq_start_loc_tensor ,
612
577
context_lens_tensor = context_lens_tensor ,
613
578
block_tables = block_tables ,
614
- cross_layer_shared_block_tables = cross_layer_shared_block_tables ,
615
579
use_cuda_graph = use_captured_graph ,
616
580
)
617
581
0 commit comments