File tree Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -59,11 +59,6 @@ class CommonAttentionMetadata:
59
59
block_table_tensor : torch .Tensor
60
60
slot_mapping : torch .Tensor
61
61
62
- def __post_init__ (self ):
63
- # Fill unused with -1. Needed for reshape_and_cache in full cuda graph
64
- # mode.
65
- self .slot_mapping [self .num_actual_tokens :].fill_ (- 1 )
66
-
67
62
68
63
M = TypeVar ("M" )
69
64
Original file line number Diff line number Diff line change @@ -684,7 +684,7 @@ def _prepare_inputs(
684
684
self .seq_lens [:num_reqs ].copy_ (self .seq_lens_cpu [:num_reqs ],
685
685
non_blocking = True )
686
686
687
- # Fill unused with -1. Needed for reshape_and_cache
687
+ # Fill unused with 0 for full cuda graph mode.
688
688
self .seq_lens [num_reqs :].fill_ (0 )
689
689
# Note: pad query_start_loc to be non-decreasing, as kernels
690
690
# like FlashAttention requires that
@@ -704,6 +704,11 @@ def _prepare_inputs(
704
704
blk_table = self .input_batch .block_table [kv_cache_group_id ]
705
705
blk_table_tensor = blk_table .get_device_tensor ()[:num_reqs ]
706
706
slot_mapping = blk_table .slot_mapping [:total_num_scheduled_tokens ]
707
+
708
+ # Fill unused with -1. Needed for reshape_and_cache in full cuda
709
+ # graph mode.
710
+ blk_table .slot_mapping [total_num_scheduled_tokens :].fill_ (- 1 )
711
+
707
712
common_attn_metadata = CommonAttentionMetadata (
708
713
query_start_loc = self .query_start_loc [:num_reqs + 1 ],
709
714
query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ],
You can’t perform that action at this time.
0 commit comments