Skip to content

Commit bfa3a19

Browse files
clean up
Signed-off-by: Congcong Chen <congcongchen@microsoft.com>
1 parent 531f610 commit bfa3a19

File tree

1 file changed

+6
-42
lines changed

1 file changed

+6
-42
lines changed

vllm/attention/backends/flash_attn.py

Lines changed: 6 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def get_kv_cache_shape(
7676
if block_size % 16 != 0:
7777
raise ValueError("Block size must be a multiple of 16.")
7878
return (2, num_blocks, block_size, num_kv_heads, head_size)
79-
# return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size)
8079

8180
@staticmethod
8281
def swap_blocks(
@@ -186,9 +185,6 @@ class FlashAttentionMetadata(AttentionMetadata):
186185
cross_slot_mapping: Optional[torch.Tensor] = None
187186
cross_block_tables: Optional[torch.Tensor] = None
188187

189-
# Cross-layer shared attention block tables
190-
cross_layer_shared_block_tables: Optional[torch.Tensor] = None
191-
192188
@property
193189
def is_all_encoder_attn_metadata_set(self):
194190
'''
@@ -233,9 +229,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
233229
self.context_lens_tensor[:self.num_prefills])
234230
block_tables = (None if self.block_tables is None else
235231
self.block_tables[:self.num_prefills])
236-
cross_layer_shared_block_tables = (None if self.cross_layer_shared_block_tables is None else
237-
self.cross_layer_shared_block_tables[:self.num_prefills])
238-
232+
239233
self._cached_prefill_metadata = FlashAttentionMetadata(
240234
num_prefills=self.num_prefills,
241235
num_prefill_tokens=self.num_prefill_tokens,
@@ -254,7 +248,6 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
254248
seq_start_loc=seq_start_loc,
255249
context_lens_tensor=context_lens_tensor,
256250
block_tables=block_tables,
257-
cross_layer_shared_block_tables=cross_layer_shared_block_tables,
258251
use_cuda_graph=False,
259252
# Begin encoder & cross attn fields below...
260253
encoder_seq_lens=self.encoder_seq_lens,
@@ -282,8 +275,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
282275
self.seq_lens_tensor[self.num_prefills:])
283276
block_tables = (None if self.block_tables is None else
284277
self.block_tables[self.num_prefills:])
285-
cross_layer_shared_block_tables = (None if self.cross_layer_shared_block_tables is None else
286-
self.cross_layer_shared_block_tables[self.num_prefills:])
278+
287279
self._cached_decode_metadata = FlashAttentionMetadata(
288280
num_prefills=0,
289281
num_prefill_tokens=0,
@@ -307,7 +299,6 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
307299
if self.seq_start_loc is not None else None,
308300
context_lens_tensor=None,
309301
block_tables=block_tables,
310-
cross_layer_shared_block_tables=cross_layer_shared_block_tables,
311302
use_cuda_graph=self.use_cuda_graph,
312303
# Begin encoder & cross attn fields below...
313304
encoder_seq_lens=self.encoder_seq_lens,
@@ -406,7 +397,6 @@ def prepare(self):
406397
self.prefill_seq_lens: List[int] = []
407398
self.context_lens: List[int] = []
408399
self.block_tables: List[List[int]] = []
409-
self.cross_layer_shared_block_tables: List[List[int]] = []
410400
self.curr_seq_lens: List[int] = []
411401
self.multimodal_placeholder_maps: Dict[
412402
str,
@@ -467,17 +457,6 @@ def _add_seq_group(
467457
-curr_sliding_window_block:]
468458
self.block_tables.append(block_table)
469459

470-
cross_layer_shared_block_table = []
471-
if prefix_cache_hit:
472-
cross_layer_shared_block_table = block_tables[seq_id]
473-
elif block_tables is not None:
474-
if curr_sliding_window_block == 0:
475-
cross_layer_shared_block_table = block_tables[seq_id]
476-
else:
477-
cross_layer_shared_block_table = block_tables[seq_id][
478-
-curr_sliding_window_block:]
479-
self.cross_layer_shared_block_tables.append(cross_layer_shared_block_table)
480-
481460
# Compute slot mapping.
482461
is_profile_run = is_block_tables_empty(block_tables)
483462
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
@@ -489,16 +468,13 @@ def _add_seq_group(
489468

490469
def _get_graph_runner_block_tables(
491470
self, num_seqs: int,
492-
block_tables: List[List[int]],
493-
graph_block_tables) -> torch.Tensor:
471+
block_tables: List[List[int]]) -> torch.Tensor:
494472
# The shape of graph_block_tables is
495473
# [max batch size, max context len // block size].
496-
# max_batch_size, max_blocks = self.runner.graph_block_tables.shape
497-
max_batch_size, max_blocks = graph_block_tables.shape
474+
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
498475
assert max_batch_size >= num_seqs
499476

500-
# graph_block_tables = self.runner.graph_block_tables[:num_seqs]
501-
graph_block_tables = graph_block_tables[:num_seqs]
477+
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
502478
for i, block_table in enumerate(block_tables):
503479
if block_table:
504480
num_blocks = len(block_table)
@@ -553,27 +529,16 @@ def build(self, seq_lens: List[int], query_lens: List[int],
553529
if use_captured_graph:
554530
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
555531
self.block_tables.extend([] * cuda_graph_pad_size)
556-
557-
self.cross_layer_shared_block_tables.extend([] * cuda_graph_pad_size)
558-
559532
num_decode_tokens = batch_size - self.num_prefill_tokens
560533
block_tables = self._get_graph_runner_block_tables(
561-
num_seqs, self.block_tables, self.runner.graph_block_tables)
562-
cross_layer_shared_block_tables = self._get_graph_runner_block_tables(
563-
num_seqs, self.cross_layer_shared_block_tables, self.runner.cross_layer_shared_graph_block_tables)
534+
num_seqs, self.block_tables)
564535
else:
565536
block_tables = make_tensor_with_pad(
566537
self.block_tables,
567538
pad=0,
568539
dtype=torch.int,
569540
device=device,
570541
)
571-
cross_layer_shared_block_tables = make_tensor_with_pad(
572-
self.cross_layer_shared_block_tables,
573-
pad=0,
574-
dtype=torch.int,
575-
device=device,
576-
)
577542
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
578543

579544
assert device is not None
@@ -611,7 +576,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
611576
seq_start_loc=seq_start_loc_tensor,
612577
context_lens_tensor=context_lens_tensor,
613578
block_tables=block_tables,
614-
cross_layer_shared_block_tables=cross_layer_shared_block_tables,
615579
use_cuda_graph=use_captured_graph,
616580
)
617581

0 commit comments

Comments
 (0)